Commit 9e8a8c05 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn.functional as F
from torch import nn
class AdaptiveSoftmax(nn.Module):
"""
This is an implementation of the efficient softmax approximation for
graphical processing units (GPU), described in the paper "Efficient softmax
approximation for GPUs" (http://arxiv.org/abs/1609.04309).
"""
def __init__(self, vocab_size, input_dim, cutoff, dropout):
super().__init__()
if vocab_size > cutoff[-1]:
cutoff = cutoff + [vocab_size]
else:
assert vocab_size == cutoff[
-1], 'cannot specify cutoff smaller than vocab size'
output_dim = cutoff[0] + len(cutoff) - 1
self.vocab_size = vocab_size
self.cutoff = cutoff
self.dropout = dropout
self.lsm = nn.LogSoftmax(dim=1)
self.head = nn.Linear(input_dim, output_dim, bias=False)
self.tail = nn.ModuleList()
for i in range(len(cutoff) - 1):
self.tail.append(
nn.Sequential(
nn.Linear(input_dim, input_dim // 4 ** i, bias=False),
nn.Dropout(dropout),
nn.Linear(input_dim // 4 ** i, cutoff[i + 1] - cutoff[i], bias=False)
)
)
def init_weights(m):
if hasattr(m, 'weight'):
nn.init.xavier_uniform_(m.weight)
self.apply(init_weights)
def adapt_target(self, target):
"""
In order to be efficient, the AdaptiveSoftMax does not compute the
scores for all the word of the vocabulary for all the examples. It is
thus necessary to call the method adapt_target of the AdaptiveSoftMax
layer inside each forward pass.
"""
target = target.view(-1)
new_target = [target.clone()]
target_idxs = []
for i in range(len(self.cutoff) - 1):
mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
new_target[0][mask] = self.cutoff[0] + i - 1
if mask.any():
target_idxs.append(mask.nonzero().squeeze(1))
new_target.append(target[mask].add(-self.cutoff[i]))
else:
target_idxs.append(None)
new_target.append(None)
return new_target, target_idxs
def forward(self, input, target):
"""
Args:
input: (b x t x d)
target: (b x t)
Returns:
2 lists: output for each cutoff section and new targets by cut off
"""
input = input.contiguous().view(-1, input.size(-1))
input = F.dropout(input, p=self.dropout, training=self.training)
new_target, target_idxs = self.adapt_target(target)
output = [self.head(input)]
for i in range(len(target_idxs)):
if target_idxs[i] is not None:
output.append(self.tail[i](input.index_select(0, target_idxs[i])))
else:
output.append(None)
return output, new_target
def get_log_prob(self, input, target):
"""
Computes the log probabilities for all the words of the vocabulary,
given a 2D tensor of hidden vectors.
"""
bsz, length, dim = input.size()
input = input.contiguous().view(-1, dim)
if target is not None:
_, target_idxs = self.adapt_target(target)
else:
target_idxs = None
head_y = self.head(input)
log_probs = head_y.new_zeros(input.size(0), self.vocab_size)
head_sz = self.cutoff[0] + len(self.tail)
log_probs[:, :head_sz] = self.lsm(head_y)
tail_priors = log_probs[:, self.cutoff[0] - 1: head_sz - 1].clone()
for i in range(len(self.tail)):
start = self.cutoff[i]
end = self.cutoff[i + 1]
if target_idxs is None:
tail_out = log_probs[:, start:end]
tail_out.copy_(self.tail[i](input))
log_probs[:, start:end] = self.lsm(tail_out).add_(tail_priors[:, i, None])
elif target_idxs[i] is not None:
idxs = target_idxs[i]
tail_out = log_probs[idxs, start:end]
tail_out.copy_(self.tail[i](input[idxs]))
log_probs[idxs, start:end] = self.lsm(tail_out).add_(tail_priors[idxs, i, None])
log_probs = log_probs.view(bsz, length, -1)
return log_probs
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn as nn
class BeamableMM(nn.Module):
"""This module provides an optimized MM for beam decoding with attention.
It leverage the fact that the source-side of the input is replicated beam
times and the target-side of the input is of width one. This layer speeds up
inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
"""
def __init__(self, beam_size=None):
super(BeamableMM, self).__init__()
self.beam_size = beam_size
def forward(self, input1, input2):
if (
not self.training and # test mode
self.beam_size is not None and # beam size is set
input1.dim() == 3 and # only support batched input
input1.size(1) == 1 # single time step update
):
bsz, beam = input1.size(0), self.beam_size
# bsz x 1 x nhu --> bsz/beam x beam x nhu
input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1)
# bsz x sz2 x nhu --> bsz/beam x sz2 x nhu
input2 = input2.unfold(0, beam, beam)[:, :, :, 0]
# use non batched operation if bsz = beam
if input1.size(0) == 1:
output = torch.mm(input1[0, :, :], input2[0, :, :])
else:
output = input1.bmm(input2)
return output.view(bsz, 1, -1)
else:
return input1.bmm(input2)
def set_beam_size(self, beam_size):
self.beam_size = beam_size
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from torch.nn.modules.utils import _single
class ConvTBC(torch.nn.Module):
"""1D convolution over an input of shape (time x batch x channel)
The implementation uses gemm to perform the convolution. This implementation
is faster than cuDNN for small kernel sizes.
"""
def __init__(self, in_channels, out_channels, kernel_size, padding=0):
super(ConvTBC, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _single(kernel_size)
self.padding = _single(padding)
self.weight = torch.nn.Parameter(torch.Tensor(
self.kernel_size[0], in_channels, out_channels))
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
def forward(self, input):
return input.contiguous().conv_tbc(self.weight, self.bias, self.padding[0])
def __repr__(self):
s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
', padding={padding}')
if self.bias is None:
s += ', bias=False'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules.scalar_bias import scalar_bias
class SingleHeadAttention(nn.Module):
"""
Single-head attention that supports Gating and Downsampling
"""
def __init__(
self, out_channels, embed_dim, head_dim, head_index, dropout=0.,
bias=True, project_input=True, gated=False, downsample=False,
num_heads=1,
):
super().__init__()
self.embed_dim = embed_dim
self.dropout = dropout
self.head_index = head_index
self.head_dim = head_dim
self.project_input = project_input
self.gated = gated
self.downsample = downsample
self.num_heads = num_heads
self.projection = None
k_layers = []
v_layers = []
if self.downsample:
k_layers.append(Downsample(self.head_index))
v_layers.append(Downsample(self.head_index))
out_proj_size = self.head_dim
else:
out_proj_size = self.head_dim * self.num_heads
if self.gated:
k_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias))
self.in_proj_q = GatedLinear(self.embed_dim, out_proj_size, bias=bias)
v_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias))
else:
k_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias))
self.in_proj_q = Linear(self.embed_dim, out_proj_size, bias=bias)
v_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias))
self.in_proj_k = nn.Sequential(*k_layers)
self.in_proj_v = nn.Sequential(*v_layers)
if self.downsample:
self.out_proj = Linear(out_proj_size, self.head_dim, bias=bias)
else:
self.out_proj = Linear(out_proj_size, out_channels, bias=bias)
self.scaling = self.head_dim**-0.5
def forward(
self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None, use_scalar_bias=False,
):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
src_len, bsz, out_channels = key.size()
tgt_len = query.size(0)
assert list(query.size()) == [tgt_len, bsz, out_channels]
assert key.size() == value.size()
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.downsample:
size = bsz
else:
size = bsz * self.num_heads
k = key
v = value
q = query
if self.project_input:
q = self.in_proj_q(q)
k = self.in_proj_k(k)
v = self.in_proj_v(v)
src_len = k.size()[0]
q *= self.scaling
if not self.downsample:
q = q.view(tgt_len, size, self.head_dim)
k = k.view(src_len, size, self.head_dim)
v = v.view(src_len, size, self.head_dim)
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
if mask_future_timesteps:
assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention'
attn_weights *= torch.tril(
attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(),
diagonal=-1,
)[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0)
attn_weights += torch.triu(
attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(),
diagonal=0
)[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0)
tgt_size = tgt_len
if use_scalar_bias:
attn_weights = scalar_bias(attn_weights, 2)
v = scalar_bias(v, 1)
tgt_size += 1
if key_padding_mask is not None:
# don't attend to padding symbols
if key_padding_mask.max() > 0:
if self.downsample:
attn_weights = attn_weights.view(bsz, 1, tgt_len, src_len)
else:
attn_weights = attn_weights.view(size, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
-math.inf,
)
attn_weights = attn_weights.view(size, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn = torch.bmm(attn_weights, v)
if self.downsample:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.head_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
attn = self.out_proj(attn)
return attn, attn_weights
class DownsampledMultiHeadAttention(nn.ModuleList):
"""
Multi-headed attention with Gating and Downsampling
"""
def __init__(
self, out_channels, embed_dim, num_heads, dropout=0., bias=True,
project_input=True, gated=False, downsample=False,
):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.downsample = downsample
self.gated = gated
self.project_input = project_input
assert self.head_dim * num_heads == embed_dim
if self.downsample:
attention_heads = []
for index in range(self.num_heads):
attention_heads.append(
SingleHeadAttention(
out_channels, self.embed_dim, self.head_dim, index,
self.dropout, bias, self.project_input, self.gated,
self.downsample, self.num_heads,
)
)
super().__init__(modules=attention_heads)
self.out_proj = Linear(embed_dim, out_channels, bias=bias)
else:
# either we have a list of attention heads, or just one attention head
# if not being downsampled, we can do the heads with one linear layer instead of separate ones
super().__init__()
self.attention_module = SingleHeadAttention(
out_channels, self.embed_dim, self.head_dim, 1, self.dropout,
bias, self.project_input, self.gated, self.downsample, self.num_heads,
)
def forward(
self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None, use_scalar_bias=False,
):
src_len, bsz, embed_dim = key.size()
tgt_len = query.size(0)
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
tgt_size = tgt_len
if use_scalar_bias:
tgt_size += 1
attn = []
attn_weights = []
if self.downsample:
for attention_head_number in range(self.num_heads):
# call the forward of each attention head
_attn, _attn_weight = self[attention_head_number](
query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias,
)
attn.append(_attn)
attn_weights.append(_attn_weight)
full_attn = torch.cat(attn, dim=2)
full_attn = self.out_proj(full_attn)
return full_attn, attn_weights[0].clone()
else:
_attn, _attn_weight = self.attention_module(
query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias,
)
attn.append(_attn)
attn_weights.append(_attn_weight)
full_attn = torch.cat(attn, dim=2)
full_attn_weights = torch.cat(attn_weights)
full_attn_weights = full_attn_weights.view(bsz, self.num_heads, tgt_size, src_len)
full_attn_weights = full_attn_weights.sum(dim=1) / self.num_heads
return full_attn, full_attn_weights
class Downsample(nn.Module):
"""
Selects every nth element, where n is the index
"""
def __init__(self, index):
super().__init__()
self.index = index
def forward(self, x):
return x[::self.index+1]
def Linear(in_features, out_features, dropout=0., bias=True):
"""Weight-normalized Linear layer (input: B x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
m.bias.data.zero_()
return nn.utils.weight_norm(m)
def GatedLinear(in_features, out_features, dropout=0., bias=True):
"""Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units"""
return nn.Sequential(
Linear(in_features, out_features*4, dropout, bias),
nn.GLU(),
Linear(out_features*2, out_features*2, dropout, bias),
nn.GLU(),
Linear(out_features, out_features, dropout, bias)
)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn as nn
from fairseq import utils
class LearnedPositionalEmbedding(nn.Embedding):
"""This module learns positional embeddings up to a fixed maximum size.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.left_pad = left_pad
def forward(self, input, incremental_state=None):
"""Input is expected to be of size [bsz x seqlen]."""
if incremental_state is not None:
# positions is the same for every token when decoding a single step
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
else:
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
return super().forward(positions)
def max_positions(self):
"""Maximum number of supported positions."""
return self.num_embeddings - self.padding_idx - 1
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn.functional as F
from fairseq import utils
from .conv_tbc import ConvTBC
class LinearizedConvolution(ConvTBC):
"""An optimized version of nn.Conv1d.
At training time, this module uses ConvTBC, which is an optimized version
of Conv1d. At inference time, it optimizes incremental generation (i.e.,
one time step at a time) by replacing the convolutions with linear layers.
Note that the input order changes from training to inference.
"""
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self._linearized_weight = None
self.register_backward_hook(self._clear_linearized_weight)
def forward(self, input, incremental_state=None):
"""
Input:
Time x Batch x Channel during training
Batch x Time x Channel during inference
Args:
incremental_state: Used to buffer signal; if not None, then input is
expected to contain a single frame. If the input order changes
between time steps, call reorder_incremental_state.
"""
if incremental_state is None:
output = super().forward(input)
if self.kernel_size[0] > 1 and self.padding[0] > 0:
# remove future timesteps added by padding
output = output[:-self.padding[0], :, :]
return output
# reshape weight
weight = self._get_linearized_weight()
kw = self.kernel_size[0]
bsz = input.size(0) # input: bsz x len x dim
if kw > 1:
input = input.data
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is None:
input_buffer = input.new(bsz, kw, input.size(2)).zero_()
self._set_input_buffer(incremental_state, input_buffer)
else:
# shift buffer
input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
# append next input
input_buffer[:, -1, :] = input[:, -1, :]
input = input_buffer
with torch.no_grad():
output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1)
def reorder_incremental_state(self, incremental_state, new_order):
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
input_buffer = input_buffer.index_select(0, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(self, incremental_state, 'input_buffer')
def _set_input_buffer(self, incremental_state, new_buffer):
return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
def _get_linearized_weight(self):
if self._linearized_weight is None:
kw = self.kernel_size[0]
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
assert weight.size() == (self.out_channels, kw, self.in_channels)
self._linearized_weight = weight.view(self.out_channels, -1)
return self._linearized_weight
def _clear_linearized_weight(self, *args):
self._linearized_weight = None
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
class ScalarBias(torch.autograd.Function):
"""
Adds a vector of scalars, used in self-attention mechanism to allow
the model to optionally attend to this vector instead of the past
"""
@staticmethod
def forward(ctx, input, dim, bias_init):
size = list(input.size())
size[dim] += 1
output = input.new(*size).fill_(bias_init)
output.narrow(dim, 1, size[dim] - 1).copy_(input)
ctx.dim = dim
return output
@staticmethod
def backward(ctx, grad):
return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None
def scalar_bias(input, dim, bias_init=0):
return ScalarBias.apply(input, dim, bias_init)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn as nn
from fairseq import utils
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
def __init__(self, embedding_dim, padding_idx, left_pad, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.left_pad = left_pad
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
)
self.register_buffer('_float_tensor', torch.FloatTensor(1))
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input, incremental_state=None):
"""Input is expected to be of size [bsz x seqlen]."""
# recompute/expand embeddings if needed
bsz, seq_len = input.size()
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
self.padding_idx,
)
#self.weights = self.weights.type_as(self._float_tensor)
self.weights = self.weights.to(self._float_tensor, non_blocking=True)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1)
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
def max_positions(self):
"""Maximum number of supported positions."""
return int(1e5) # an arbitrary large number
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import multiprocessing
import os
import pdb
import sys
class MultiprocessingPdb(pdb.Pdb):
"""A Pdb wrapper that works in a multiprocessing environment.
Usage: `from fairseq import pdb; pdb.set_trace()`
"""
_stdin_fd = sys.stdin.fileno()
_stdin = None
_stdin_lock = multiprocessing.Lock()
def __init__(self):
pdb.Pdb.__init__(self, nosigint=True)
def _cmdloop(self):
stdin_bak = sys.stdin
with self._stdin_lock:
try:
if not self._stdin:
self._stdin = os.fdopen(self._stdin_fd)
sys.stdin = self._stdin
self.cmdloop()
finally:
sys.stdin = stdin_bak
pdb = MultiprocessingPdb()
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import importlib
import os
from .fairseq_optimizer import FairseqOptimizer
OPTIMIZER_REGISTRY = {}
OPTIMIZER_CLASS_NAMES = set()
def build_optimizer(args, params):
params = filter(lambda p: p.requires_grad, params)
return OPTIMIZER_REGISTRY[args.optimizer](args, params)
def register_optimizer(name):
"""Decorator to register a new optimizer."""
def register_optimizer_cls(cls):
if name in OPTIMIZER_REGISTRY:
raise ValueError('Cannot register duplicate optimizer ({})'.format(name))
if not issubclass(cls, FairseqOptimizer):
raise ValueError('Optimizer ({}: {}) must extend FairseqOptimizer'.format(name, cls.__name__))
if cls.__name__ in OPTIMIZER_CLASS_NAMES:
# We use the optimizer class name as a unique identifier in
# checkpoints, so all optimizer must have unique class names.
raise ValueError('Cannot register optimizer with duplicate class name ({})'.format(cls.__name__))
OPTIMIZER_REGISTRY[name] = cls
OPTIMIZER_CLASS_NAMES.add(cls.__name__)
return cls
return register_optimizer_cls
# automatically import any Python files in the optim/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.optim.' + module)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.optim
from . import FairseqOptimizer, register_optimizer
@register_optimizer('adagrad')
class Adagrad(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config)
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
'lr': self.args.lr[0],
'weight_decay': self.args.weight_decay,
}
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.optim
from . import FairseqOptimizer, register_optimizer
from apex.contrib.optimizers.fused_adam import FusedAdam
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
from apex.contrib.optimizers.distributed_fused_adam_v2 import DistributedFusedAdamV2
from apex.contrib.optimizers.distributed_fused_adam_v3 import DistributedFusedAdamV3
@register_optimizer('adam')
class FairseqAdam(FairseqOptimizer):
def __init__(self, args, params):
super().__init__(args, params)
if self.args.distributed_weight_update == 2:
dwu_args = self.distributed_weight_update_config
print("DistributedFusedAdam",dwu_args)
self._optimizer = DistributedFusedAdam(params, **dwu_args, **self.optimizer_config)
elif self.args.distributed_weight_update == 3:
dwu_args = self.distributed_weight_update_config
print("DistributedFusedAdamV2",dwu_args)
self._optimizer = DistributedFusedAdamV2(params, **dwu_args, **self.optimizer_config)
elif self.args.distributed_weight_update == 4:
dwu_args = self.distributed_weight_update_config
print("DistributedFusedAdamV3",dwu_args)
self._optimizer = DistributedFusedAdamV3(params, **dwu_args, **self.optimizer_config)
else:
assert (self.args.distributed_weight_update == 0), "Vanilla optimizer not supported anymore"
self._optimizer = FusedAdam(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B',
help='betas for Adam optimizer')
parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
help='epsilon for Adam optimizer')
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
'lr': self.args.lr[0],
'betas': eval(self.args.adam_betas),
'eps': self.args.adam_eps,
'weight_decay': self.args.weight_decay,
}
@property
def distributed_weight_update_config(self):
"""
Return a kwarg dictionary that provides arguments for the distributed
weight update feature.
"""
return {
'dwu_group_size': self.args.dwu_group_size,
'dwu_num_blocks': self.args.dwu_num_blocks,
'dwu_num_chunks': self.args.dwu_num_chunks,
'dwu_num_rs_pg': self.args.dwu_num_rs_pg,
'dwu_num_ar_pg': self.args.dwu_num_ar_pg,
'dwu_num_ag_pg': self.args.dwu_num_ag_pg,
'overlap_reductions': self.args.dwu_overlap_reductions,
'full_pipeline': self.args.dwu_full_pipeline,
'compute_L2_grad_norm': self.args.dwu_compute_L2_grad_norm,
'flat_mt': self.args.dwu_flat_mt,
'e5m2_allgather': self.args.dwu_e5m2_allgather,
'do_not_flatten_model': self.args.dwu_do_not_flatten_model,
}
class Adam(torch.optim.Optimizer):
"""Implements Adam algorithm.
This implementation is modified from torch.optim.Adam based on:
`Fixed Weight Decay Regularization in Adam`
(see https://arxiv.org/abs/1711.05101)
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(Adam, self).__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'] * group['lr'], p.data)
p.data.addcdiv_(-step_size, exp_avg, denom)
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment