"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "37a5f1b3b69ed284086fb31fb1b49668cba6c365"
Commit d9f46c54 authored by Sergey Edunov's avatar Sergey Edunov
Browse files

Merge branch 'master' of github.com:facebookresearch/fairseq-py into prepare_wmt

parents 4185d3ed ee36a6f3
...@@ -24,8 +24,8 @@ If you use the code in your paper, then please cite it as: ...@@ -24,8 +24,8 @@ If you use the code in your paper, then please cite it as:
* Python version 3.6 * Python version 3.6
* A [PyTorch installation](http://pytorch.org/) * A [PyTorch installation](http://pytorch.org/)
Currently fairseq-py requires installing PyTorch from source. Currently fairseq-py requires PyTorch version >= 0.3.0.
Please follow the instructions here: https://github.com/pytorch/pytorch#from-source. Please follow the instructions here: https://github.com/pytorch/pytorch#installation.
If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size` as command line If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size` as command line
options to `nvidia-docker run`. options to `nvidia-docker run`.
......
...@@ -126,5 +126,5 @@ void TemporalConvolutionTBC_backward( ...@@ -126,5 +126,5 @@ void TemporalConvolutionTBC_backward(
} }
auto tmp = dOutput.sum(0, false); auto tmp = dOutput.sum(0, false);
dBias.assign_(tmp.sum(0)); dBias.copy_(tmp.sum(0));
} }
...@@ -17,7 +17,7 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -17,7 +17,7 @@ class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, args, dst_dict): def __init__(self, args, dst_dict):
super().__init__(args, dst_dict) super().__init__(args, dst_dict)
def forward(self, model, sample): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
Returns a tuple with three elements: Returns a tuple with three elements:
...@@ -26,12 +26,14 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -26,12 +26,14 @@ class CrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training 3) logging outputs to display while training
""" """
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
input = net_output.view(-1, net_output.size(-1)) lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = F.cross_entropy(input, target, size_average=False, ignore_index=self.padding_idx) loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
'loss': loss.data[0], 'loss': loss.data[0] if reduce else loss.data,
'ntokens': sample['ntokens'],
'sample_size': sample_size, 'sample_size': sample_size,
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
...@@ -39,7 +41,12 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -39,7 +41,12 @@ class CrossEntropyCriterion(FairseqCriterion):
@staticmethod @staticmethod
def aggregate_logging_outputs(logging_outputs): def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return { agg_output = {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2), 'loss': loss_sum / sample_size / math.log(2),
} }
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
return agg_output
...@@ -16,7 +16,7 @@ class FairseqCriterion(_Loss): ...@@ -16,7 +16,7 @@ class FairseqCriterion(_Loss):
self.args = args self.args = args
self.padding_idx = dst_dict.pad() self.padding_idx = dst_dict.pad()
def forward(self, model, sample): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
Returns a tuple with three elements: Returns a tuple with three elements:
......
...@@ -11,13 +11,15 @@ import torch ...@@ -11,13 +11,15 @@ import torch
from torch.autograd.variable import Variable from torch.autograd.variable import Variable
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils
from .fairseq_criterion import FairseqCriterion from .fairseq_criterion import FairseqCriterion
class LabelSmoothedNLLLoss(torch.autograd.Function): class LabelSmoothedNLLLoss(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, target, eps, padding_idx, weights): def forward(ctx, input, target, eps, padding_idx, weights, reduce=True):
grad_input = input.new(input.size()).zero_() grad_input = input.new(input.size()).zero_()
target = target.view(target.size(0), 1) target = target.view(target.size(0), 1)
grad_input = grad_input.scatter_(grad_input.dim() - 1, target, eps - 1) grad_input = grad_input.scatter_(grad_input.dim() - 1, target, eps - 1)
...@@ -34,11 +36,14 @@ class LabelSmoothedNLLLoss(torch.autograd.Function): ...@@ -34,11 +36,14 @@ class LabelSmoothedNLLLoss(torch.autograd.Function):
grad_input = grad_input.add(-eps / norm) grad_input = grad_input.add(-eps / norm)
ctx.grad_input = grad_input ctx.grad_input = grad_input
return input.new([grad_input.view(-1).dot(input.view(-1))]) if reduce:
return input.new([grad_input.view(-1).dot(input.view(-1))])
else:
return grad_input * input
@staticmethod @staticmethod
def backward(ctx, grad): def backward(ctx, grad):
return Variable(ctx.grad_input, volatile=True) * grad, None, None, None, None return utils.volatile_variable(ctx.grad_input) * grad, None, None, None, None, None
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...@@ -48,7 +53,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -48,7 +53,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
self.eps = args.label_smoothing self.eps = args.label_smoothing
self.weights = weights self.weights = weights
def forward(self, model, sample): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
Returns a tuple with three elements: Returns a tuple with three elements:
...@@ -57,12 +62,15 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -57,12 +62,15 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training 3) logging outputs to display while training
""" """
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
input = F.log_softmax(net_output.view(-1, net_output.size(-1))) lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = LabelSmoothedNLLLoss.apply(input, target, self.eps, self.padding_idx, self.weights) loss = LabelSmoothedNLLLoss.apply(lprobs, target, self.eps, self.padding_idx, self.weights, reduce)
nll_loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
'loss': loss.data[0], 'loss': loss.data[0] if reduce else loss.data,
'nll_loss': nll_loss.data[0] if reduce else loss.data,
'ntokens': sample['ntokens'],
'sample_size': sample_size, 'sample_size': sample_size,
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
...@@ -70,7 +78,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -70,7 +78,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
@staticmethod @staticmethod
def aggregate_logging_outputs(logging_outputs): def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return { return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2), 'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2),
} }
...@@ -175,6 +175,23 @@ def skip_group_enumerator(it, ngpus, offset=0): ...@@ -175,6 +175,23 @@ def skip_group_enumerator(it, ngpus, offset=0):
yield (idx, res) yield (idx, res)
class sharded_iterator(object):
def __init__(self, itr, num_shards, shard_id):
assert shard_id >= 0 and shard_id < num_shards
self.itr = itr
self.num_shards = num_shards
self.shard_id = shard_id
def __len__(self):
return len(self.itr)
def __iter__(self):
for i, v in enumerate(self.itr):
if i % self.num_shards == self.shard_id:
yield v
class LanguagePairDataset(object): class LanguagePairDataset(object):
# padding constants # padding constants
...@@ -212,13 +229,15 @@ class LanguagePairDataset(object): ...@@ -212,13 +229,15 @@ class LanguagePairDataset(object):
return { return {
'id': torch.LongTensor([s['id'].item() for s in samples]), 'id': torch.LongTensor([s['id'].item() for s in samples]),
'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE),
# we create a shifted version of targets for feeding the previous
# output token(s) into the next decoder step
'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True),
'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET),
'ntokens': sum(len(s['target']) for s in samples), 'ntokens': sum(len(s['target']) for s in samples),
'net_input': {
'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE),
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True),
},
'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET),
} }
@staticmethod @staticmethod
...@@ -278,9 +297,10 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions, ...@@ -278,9 +297,10 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
if ignore_invalid_inputs: if ignore_invalid_inputs:
ignored.append(idx) ignored.append(idx)
continue continue
raise Exception( raise Exception((
"Unable to handle input id {} of size {} / {}.".format( "Sample #{} has size (src={}, dst={}) but max size is {}."
idx, src.sizes[idx], dst.sizes[idx])) " Skip this example with --skip-invalid-size-inputs-valid-test"
).format(idx, src.sizes[idx], dst.sizes[idx], max_positions))
sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx]) sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
num_tokens = (len(batch) + 1) * sample_len num_tokens = (len(batch) + 1) * sample_len
......
...@@ -113,6 +113,8 @@ class Dictionary(object): ...@@ -113,6 +113,8 @@ class Dictionary(object):
try: try:
with open(f, 'r', encoding='utf-8') as fd: with open(f, 'r', encoding='utf-8') as fd:
return Dictionary.load(fd) return Dictionary.load(fd)
except FileNotFoundError as fnfe:
raise fnfe
except: except:
raise Exception("Incorrect encoding detected in {}, please " raise Exception("Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f)) "rebuild the dataset".format(f))
......
...@@ -7,13 +7,24 @@ ...@@ -7,13 +7,24 @@
# #
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
class FairseqDecoder(nn.Module): class FairseqDecoder(nn.Module):
"""Base class for decoders.""" """Base class for decoders."""
def __init__(self): def __init__(self, dictionary):
super().__init__() super().__init__()
self.dictionary = dictionary
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
vocab = net_output.size(-1)
net_output1 = net_output.view(-1, vocab)
if log_probs:
return F.log_softmax(net_output1, dim=1).view_as(net_output)
else:
return F.softmax(net_output1, dim=1).view_as(net_output)
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the decoder.""" """Maximum input length supported by the decoder."""
......
...@@ -12,8 +12,9 @@ import torch.nn as nn ...@@ -12,8 +12,9 @@ import torch.nn as nn
class FairseqEncoder(nn.Module): class FairseqEncoder(nn.Module):
"""Base class for encoders.""" """Base class for encoders."""
def __init__(self): def __init__(self, dictionary):
super().__init__() super().__init__()
self.dictionary = dictionary
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
......
...@@ -12,8 +12,8 @@ from . import FairseqDecoder ...@@ -12,8 +12,8 @@ from . import FairseqDecoder
class FairseqIncrementalDecoder(FairseqDecoder): class FairseqIncrementalDecoder(FairseqDecoder):
"""Base class for incremental decoders.""" """Base class for incremental decoders."""
def __init__(self): def __init__(self, dictionary):
super().__init__() super().__init__(dictionary)
self._is_incremental_eval = False self._is_incremental_eval = False
self._incremental_state = {} self._incremental_state = {}
...@@ -37,7 +37,7 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -37,7 +37,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
with model.decoder.incremental_inference(): with model.decoder.incremental_inference():
for step in range(maxlen): for step in range(maxlen):
out, _ = model.decoder(tokens[:, :step], encoder_out) out, _ = model.decoder(tokens[:, :step], encoder_out)
probs = torch.nn.functional.log_softmax(out[:, -1, :]) probs = model.get_normalized_probs(out[:, -1, :], log_probs=False)
``` ```
""" """
class IncrementalInference(object): class IncrementalInference(object):
...@@ -86,6 +86,7 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -86,6 +86,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
beam_size is required if using BeamableMM. beam_size is required if using BeamableMM.
""" """
if self._is_incremental_eval: if self._is_incremental_eval:
del self._incremental_state
self._incremental_state = {} self._incremental_state = {}
def apply_clear_incremental_state(module): def apply_clear_incremental_state(module):
...@@ -110,7 +111,9 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -110,7 +111,9 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def set_beam_size(self, beam_size): def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children.""" """Sets the beam size in the decoder and all children."""
def apply_set_beam_size(module): if getattr(self, '_beam_size', -1) != beam_size:
if module != self and hasattr(module, 'set_beam_size'): def apply_set_beam_size(module):
module.set_beam_size(beam_size) if module != self and hasattr(module, 'set_beam_size'):
self.apply(apply_set_beam_size) module.set_beam_size(beam_size)
self.apply(apply_set_beam_size)
self._beam_size = beam_size
...@@ -35,6 +35,10 @@ class FairseqModel(nn.Module): ...@@ -35,6 +35,10 @@ class FairseqModel(nn.Module):
decoder_out, _ = self.decoder(input_tokens, encoder_out) decoder_out, _ = self.decoder(input_tokens, encoder_out)
return decoder_out.view(-1, decoder_out.size(-1)) return decoder_out.view(-1, decoder_out.size(-1))
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs)
def max_encoder_positions(self): def max_encoder_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return self.encoder.max_positions() return self.encoder.max_positions()
...@@ -62,6 +66,11 @@ class FairseqModel(nn.Module): ...@@ -62,6 +66,11 @@ class FairseqModel(nn.Module):
return return
self.apply(apply_remove_weight_norm) self.apply(apply_remove_weight_norm)
def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'):
module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_)
def train(mode): def train(mode):
if mode: if mode:
raise RuntimeError('cannot train after make_generation_fast') raise RuntimeError('cannot train after make_generation_fast')
...@@ -69,8 +78,3 @@ class FairseqModel(nn.Module): ...@@ -69,8 +78,3 @@ class FairseqModel(nn.Module):
# this model should no longer be used for training # this model should no longer be used for training
self.eval() self.eval()
self.train = train self.train = train
def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'):
module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_)
...@@ -13,26 +13,11 @@ import torch.nn as nn ...@@ -13,26 +13,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.data import LanguagePairDataset from fairseq.data import LanguagePairDataset
from fairseq.modules import BeamableMM, GradMultiply, LinearizedConvolution from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel
def make_positions(tokens, padding_idx, left_pad, offset=0):
seqlen = tokens.size(1)
if not hasattr(make_positions, 'range'):
make_positions.range = tokens.new()
if make_positions.range.numel() < offset + seqlen:
# offset positions by the padding index
torch.arange(padding_idx + 1, padding_idx + 1 + offset + seqlen,
out=make_positions.range)
mask = tokens.ne(padding_idx)
positions = make_positions.range[offset:offset+seqlen].expand_as(tokens)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return tokens.clone().masked_scatter_(mask, positions[mask])
class FConvModel(FairseqModel): class FConvModel(FairseqModel):
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
...@@ -43,15 +28,15 @@ class FConvEncoder(FairseqEncoder): ...@@ -43,15 +28,15 @@ class FConvEncoder(FairseqEncoder):
"""Convolutional encoder""" """Convolutional encoder"""
def __init__(self, dictionary, embed_dim=512, max_positions=1024, def __init__(self, dictionary, embed_dim=512, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1): convolutions=((512, 3),) * 20, dropout=0.1):
super().__init__() super().__init__(dictionary)
self.dictionary = dictionary
self.dropout = dropout self.dropout = dropout
self.num_attention_layers = None self.num_attention_layers = None
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = Embedding(max_positions, embed_dim, padding_idx) self.embed_positions = PositionalEmbedding(max_positions, embed_dim, padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE)
in_channels = convolutions[0][0] in_channels = convolutions[0][0]
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
...@@ -68,11 +53,8 @@ class FConvEncoder(FairseqEncoder): ...@@ -68,11 +53,8 @@ class FConvEncoder(FairseqEncoder):
self.fc2 = Linear(in_channels, embed_dim) self.fc2 = Linear(in_channels, embed_dim)
def forward(self, src_tokens): def forward(self, src_tokens):
positions = Variable(make_positions(src_tokens.data, self.dictionary.pad(),
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE))
# embed tokens and positions # embed tokens and positions
x = self.embed_tokens(src_tokens) + self.embed_positions(positions) x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
input_embedding = x input_embedding = x
...@@ -87,7 +69,7 @@ class FConvEncoder(FairseqEncoder): ...@@ -87,7 +69,7 @@ class FConvEncoder(FairseqEncoder):
residual = x if proj is None else proj(x) residual = x if proj is None else proj(x)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = conv(x) x = conv(x)
x = F.glu(x, dim=-1) x = F.glu(x, dim=2)
x = (x + residual) * math.sqrt(0.5) x = (x + residual) * math.sqrt(0.5)
# T x B x C -> B x T x C # T x B x C -> B x T x C
...@@ -106,7 +88,7 @@ class FConvEncoder(FairseqEncoder): ...@@ -106,7 +88,7 @@ class FConvEncoder(FairseqEncoder):
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return self.embed_positions.num_embeddings - self.dictionary.pad() - 1 return self.embed_positions.max_positions()
class AttentionLayer(nn.Module): class AttentionLayer(nn.Module):
...@@ -128,7 +110,7 @@ class AttentionLayer(nn.Module): ...@@ -128,7 +110,7 @@ class AttentionLayer(nn.Module):
# softmax over last dim # softmax over last dim
sz = x.size() sz = x.size()
x = F.softmax(x.view(sz[0] * sz[1], sz[2])) x = F.softmax(x.view(sz[0] * sz[1], sz[2]), dim=1)
x = x.view(sz) x = x.view(sz)
attn_scores = x attn_scores = x
...@@ -145,28 +127,32 @@ class AttentionLayer(nn.Module): ...@@ -145,28 +127,32 @@ class AttentionLayer(nn.Module):
def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs): def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs):
"""Replace torch.bmm with BeamableMM.""" """Replace torch.bmm with BeamableMM."""
if beamable_mm_beam_size is not None: if beamable_mm_beam_size is not None:
self.bmm = BeamableMM(beamable_mm_beam_size) del self.bmm
self.add_module('bmm', BeamableMM(beamable_mm_beam_size))
class FConvDecoder(FairseqIncrementalDecoder): class FConvDecoder(FairseqIncrementalDecoder):
"""Convolutional decoder""" """Convolutional decoder"""
def __init__(self, dictionary, embed_dim=512, out_embed_dim=256, def __init__(self, dictionary, embed_dim=512, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1): attention=True, dropout=0.1, share_embed=False):
super().__init__() super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([2])) self.register_buffer('version', torch.Tensor([2]))
self.dictionary = dictionary
self.dropout = dropout self.dropout = dropout
in_channels = convolutions[0][0] in_channels = convolutions[0][0]
if isinstance(attention, bool): if isinstance(attention, bool):
# expand True into [True, True, ...] and do the same with False # expand True into [True, True, ...] and do the same with False
attention = [attention] * len(convolutions) attention = [attention] * len(convolutions)
if not isinstance(attention, list) or len(attention) != len(convolutions):
raise ValueError('Attention is expected to be a list of booleans of '
'length equal to the number of layers.')
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = Embedding(max_positions, embed_dim, padding_idx) self.embed_positions = PositionalEmbedding(max_positions, embed_dim, padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_TARGET)
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList() self.projections = nn.ModuleList()
...@@ -183,35 +169,28 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -183,35 +169,28 @@ class FConvDecoder(FairseqIncrementalDecoder):
if attention[i] else None) if attention[i] else None)
in_channels = out_channels in_channels = out_channels
self.fc2 = Linear(in_channels, out_embed_dim) self.fc2 = Linear(in_channels, out_embed_dim)
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) if share_embed:
assert out_embed_dim == embed_dim, \
def forward(self, input_tokens, encoder_out): "Shared embed weights implies same dimensions " \
if self._is_incremental_eval: " out_embed_dim={} vs embed_dim={}".format(out_embed_dim, embed_dim)
return self.incremental_forward(input_tokens, encoder_out) self.fc3 = nn.Linear(out_embed_dim, num_embeddings)
self.fc3.weight = self.embed_tokens.weight
else: else:
return self.batch_forward(input_tokens, encoder_out) self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def batch_forward(self, input_tokens, encoder_out):
"""Forward pass for decoding multiple time steps in batch mode."""
positions = Variable(make_positions(input_tokens.data, self.dictionary.pad(),
left_pad=LanguagePairDataset.LEFT_PAD_TARGET))
return self._forward(input_tokens, positions, encoder_out)
def incremental_forward(self, input_tokens, encoder_out):
"""Forward pass for one time step."""
# positions is the same for every token when decoding a single step
positions = Variable(input_tokens.data.new(1, 1).fill_(
self.dictionary.pad() + input_tokens.size(1)))
# keep only the last token for incremental forward pass def forward(self, input_tokens, encoder_out):
return self._forward(input_tokens[:, -1:], positions, encoder_out)
def _forward(self, input_tokens, positions, encoder_out):
# split and transpose encoder outputs # split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out) encoder_a, encoder_b = self._split_encoder_out(encoder_out)
# embed positions
positions = self.embed_positions(input_tokens)
if self._is_incremental_eval:
# keep only the last token for incremental forward pass
input_tokens = input_tokens[:, -1:]
# embed tokens and positions # embed tokens and positions
x = self.embed_tokens(input_tokens) + self.embed_positions(positions) x = self.embed_tokens(input_tokens) + positions
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
target_embedding = x target_embedding = x
...@@ -230,7 +209,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -230,7 +209,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = conv(x) x = conv(x)
x = conv.remove_future_timesteps(x) x = conv.remove_future_timesteps(x)
x = F.glu(x) x = F.glu(x, dim=2)
# attention # attention
if attention is not None: if attention is not None:
...@@ -264,7 +243,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -264,7 +243,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.num_embeddings - self.dictionary.pad() - 1 return self.embed_positions.max_positions()
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2: if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2:
...@@ -304,6 +283,12 @@ def Embedding(num_embeddings, embedding_dim, padding_idx): ...@@ -304,6 +283,12 @@ def Embedding(num_embeddings, embedding_dim, padding_idx):
return m return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad):
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
m.weight.data.normal_(0, 0.1)
return m
def Linear(in_features, out_features, dropout=0): def Linear(in_features, out_features, dropout=0):
"""Weight-normalized Linear layer (input: N x T x C)""" """Weight-normalized Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features) m = nn.Linear(in_features, out_features)
...@@ -394,6 +379,7 @@ def parse_arch(args): ...@@ -394,6 +379,7 @@ def parse_arch(args):
args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 20') args.decoder_layers = getattr(args, 'decoder_layers', '[(512, 3)] * 20')
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256) args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
args.decoder_attention = getattr(args, 'decoder_attention', 'True') args.decoder_attention = getattr(args, 'decoder_attention', 'True')
args.share_input_output_embed = getattr(args, 'share_input_output_embed', False)
return args return args
...@@ -413,5 +399,6 @@ def build_model(args, src_dict, dst_dict): ...@@ -413,5 +399,6 @@ def build_model(args, src_dict, dst_dict):
attention=eval(args.decoder_attention), attention=eval(args.decoder_attention),
dropout=args.dropout, dropout=args.dropout,
max_positions=args.max_target_positions, max_positions=args.max_target_positions,
share_embed=args.share_input_output_embed
) )
return FConvModel(encoder, decoder) return FConvModel(encoder, decoder)
...@@ -23,8 +23,7 @@ class LSTMEncoder(FairseqEncoder): ...@@ -23,8 +23,7 @@ class LSTMEncoder(FairseqEncoder):
"""LSTM encoder.""" """LSTM encoder."""
def __init__(self, dictionary, embed_dim=512, num_layers=1, dropout_in=0.1, def __init__(self, dictionary, embed_dim=512, num_layers=1, dropout_in=0.1,
dropout_out=0.1): dropout_out=0.1):
super().__init__() super().__init__(dictionary)
self.dictionary = dictionary
self.dropout_in = dropout_in self.dropout_in = dropout_in
self.dropout_out = dropout_out self.dropout_out = dropout_out
...@@ -94,7 +93,7 @@ class AttentionLayer(nn.Module): ...@@ -94,7 +93,7 @@ class AttentionLayer(nn.Module):
# compute attention # compute attention
attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2) attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2)
attn_scores = F.softmax(attn_scores.t()).t() # srclen x bsz attn_scores = F.softmax(attn_scores.t(), dim=1).t() # srclen x bsz
# sum weighted sources # sum weighted sources
x = (attn_scores.unsqueeze(2) * source_hids).sum(dim=0) x = (attn_scores.unsqueeze(2) * source_hids).sum(dim=0)
...@@ -108,8 +107,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -108,8 +107,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
def __init__(self, dictionary, encoder_embed_dim=512, embed_dim=512, def __init__(self, dictionary, encoder_embed_dim=512, embed_dim=512,
out_embed_dim=512, num_layers=1, dropout_in=0.1, out_embed_dim=512, num_layers=1, dropout_in=0.1,
dropout_out=0.1, attention=True): dropout_out=0.1, attention=True):
super().__init__() super().__init__(dictionary)
self.dictionary = dictionary
self.dropout_in = dropout_in self.dropout_in = dropout_in
self.dropout_out = dropout_out self.dropout_out = dropout_out
......
...@@ -9,11 +9,13 @@ ...@@ -9,11 +9,13 @@
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .grad_multiply import GradMultiply from .grad_multiply import GradMultiply
from .learned_positional_embedding import LearnedPositionalEmbedding
from .linearized_convolution import LinearizedConvolution from .linearized_convolution import LinearizedConvolution
__all__ = [ __all__ = [
'BeamableMM', 'BeamableMM',
'ConvTBC', 'ConvTBC',
'GradMultiply', 'GradMultiply',
'LearnedPositionalEmbedding',
'LinearizedConvolution', 'LinearizedConvolution',
] ]
...@@ -7,9 +7,11 @@ ...@@ -7,9 +7,11 @@
# #
import torch import torch
from torch.autograd import Variable, Function from torch.autograd import Function
from torch.nn.modules.utils import _single from torch.nn.modules.utils import _single
from fairseq import utils
try: try:
from fairseq import temporal_convolution_tbc from fairseq import temporal_convolution_tbc
except ImportError as e: except ImportError as e:
...@@ -93,9 +95,9 @@ class ConvTBCFunction(Function): ...@@ -93,9 +95,9 @@ class ConvTBCFunction(Function):
input, input,
weight) weight)
grad_input = Variable(grad_input, volatile=True) grad_input = utils.volatile_variable(grad_input)
grad_weight = Variable(grad_weight, volatile=True) grad_weight = utils.volatile_variable(grad_weight)
grad_bias = Variable(grad_bias, volatile=True) grad_bias = utils.volatile_variable(grad_bias)
return grad_input, grad_weight, grad_bias, None return grad_input, grad_weight, grad_bias, None
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
class LearnedPositionalEmbedding(nn.Embedding):
"""This module learns positional embeddings up to a fixed maximum size.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.left_pad = left_pad
self._is_incremental_eval = False
def incremental_eval(self, mode=True):
self._is_incremental_eval = mode
def forward(self, input):
"""Input is expected to be of size [bsz x seqlen]."""
if self._is_incremental_eval:
# positions is the same for every token when decoding a single step
positions = Variable(
input.data.new(1, 1).fill_(self.padding_idx + input.size(1)))
else:
positions = Variable(self.make_positions(input.data))
return super().forward(positions)
def max_positions(self):
"""Maximum number of supported positions."""
return self.num_embeddings - self.padding_idx - 1
def make_positions(self, input):
"""Replace non-padding symbols with their position numbers."""
if not hasattr(self, 'range_buf'):
self.range_buf = input.new()
seqlen = input.size(1)
if self.range_buf.numel() < seqlen:
# offset positions by the padding index
torch.arange(self.padding_idx + 1, self.padding_idx + 1 + seqlen,
out=self.range_buf)
mask = input.ne(self.padding_idx)
positions = self.range_buf[:seqlen].expand_as(input)
if self.left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return input.clone().masked_scatter_(mask, positions[mask])
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
...@@ -65,8 +68,9 @@ class LinearizedConvolution(ConvTBC): ...@@ -65,8 +68,9 @@ class LinearizedConvolution(ConvTBC):
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone() self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
# append next input # append next input
self.input_buffer[:, -1, :] = input[:, -1, :] self.input_buffer[:, -1, :] = input[:, -1, :]
input = torch.autograd.Variable(self.input_buffer, volatile=True) input = utils.volatile_variable(self.input_buffer)
output = F.linear(input.view(bsz, -1), weight, self.bias) with utils.maybe_no_grad():
output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1) return output.view(bsz, 1, -1)
def clear_incremental_state(self): def clear_incremental_state(self):
......
...@@ -15,9 +15,10 @@ import math ...@@ -15,9 +15,10 @@ import math
import torch import torch
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from fairseq import meters, nccl, utils from fairseq import nccl, utils
from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future from fairseq.multiprocessing_event_loop import MultiprocessingEventLoop, Future
from fairseq.nag import NAG from fairseq.optim.nag import NAG
from fairseq.optim.adam import Adam
class MultiprocessingTrainer(MultiprocessingEventLoop): class MultiprocessingTrainer(MultiprocessingEventLoop):
...@@ -95,7 +96,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -95,7 +96,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
'betas': eval(self.args.adam_betas), 'betas': eval(self.args.adam_betas),
'weight_decay': self.args.weight_decay, 'weight_decay': self.args.weight_decay,
} }
return torch.optim.Adam(self.model.parameters(), **self._override_optim_state) return Adam(self.model.parameters(), **self._override_optim_state)
elif self.args.optimizer == 'nag': elif self.args.optimizer == 'nag':
self._override_optim_state = { self._override_optim_state = {
'lr': self.args.lr[0], 'lr': self.args.lr[0],
...@@ -116,6 +117,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -116,6 +117,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def _build_lr_scheduler(self): def _build_lr_scheduler(self):
if len(self.args.lr) > 1 or self.args.force_anneal > 0: if len(self.args.lr) > 1 or self.args.force_anneal > 0:
lrs = self.args.lr lrs = self.args.lr
def anneal(e): def anneal(e):
if e < self.args.force_anneal: if e < self.args.force_anneal:
# use fixed LR schedule # use fixed LR schedule
...@@ -123,6 +125,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -123,6 +125,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
else: else:
next_lr = lrs[-1] * self.args.lrshrink ** (e + 1 - self.args.force_anneal) next_lr = lrs[-1] * self.args.lrshrink ** (e + 1 - self.args.force_anneal)
return next_lr / lrs[0] # correct for scaling from LambdaLR return next_lr / lrs[0] # correct for scaling from LambdaLR
lr_scheduler = LambdaLR(self.optimizer, anneal) lr_scheduler = LambdaLR(self.optimizer, anneal)
lr_scheduler.best = None lr_scheduler.best = None
else: else:
...@@ -225,20 +228,21 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -225,20 +228,21 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.model.train() self.model.train()
self.optimizer.zero_grad() self.optimizer.zero_grad()
sample_size, logging_output, oom = 0, {}, False with utils.maybe_no_grad(eval):
if self._sample is not None: sample_size, logging_output, oom = 0, {}, False
try: if self._sample is not None:
# calculate loss and sample size try:
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample) # calculate loss and sample size
except RuntimeError as e: self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
if not eval and 'out of memory' in str(e): except RuntimeError as e:
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id)) if not eval and 'out of memory' in str(e):
oom = True print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
self.loss = None oom = True
if hasattr(torch.cuda, 'empty_cache'): self.loss = None
torch.cuda.empty_cache() if hasattr(torch.cuda, 'empty_cache'):
else: torch.cuda.empty_cache()
raise e else:
raise e
return sample_size, logging_output, oom return sample_size, logging_output, oom
...@@ -262,7 +266,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -262,7 +266,10 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._all_reduce_and_rescale_grads(grad_denom) self._all_reduce_and_rescale_grads(grad_denom)
# clip grads # clip grads
grad_norm = torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm) if self.args.clip_norm > 0:
grad_norm = torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm)
else:
grad_norm = math.sqrt(sum([p.grad.data.norm()**2 for p in self.model.parameters()]))
# take an optimization step # take an optimization step
self.optimizer.step() self.optimizer.step()
...@@ -378,4 +385,4 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -378,4 +385,4 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self._max_bsz_seen = sample['target'].size(0) self._max_bsz_seen = sample['target'].size(0)
torch.cuda.empty_cache() torch.cuda.empty_cache()
self._sample = utils.prepare_sample(sample, volatile=volatile, cuda_device=device_id) self._sample = utils.make_variable(sample, volatile=volatile, cuda_device=device_id)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch
from torch.optim.optimizer import Optimizer
class Adam(Optimizer):
"""Implements Adam algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(Adam, self).__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'], p.data)
p.data.addcdiv_(-step_size, exp_avg, denom)
return loss
...@@ -11,7 +11,7 @@ from torch.optim.optimizer import Optimizer, required ...@@ -11,7 +11,7 @@ from torch.optim.optimizer import Optimizer, required
class NAG(Optimizer): class NAG(Optimizer):
def __init__(self, params, lr=required, momentum=0, weight_decay=0): def __init__(self, params, lr=required, momentum=0, weight_decay=0):
defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay)
super(NAG, self).__init__(params, defaults) super(NAG, self).__init__(params, defaults)
def step(self, closure=None): def step(self, closure=None):
...@@ -29,6 +29,8 @@ class NAG(Optimizer): ...@@ -29,6 +29,8 @@ class NAG(Optimizer):
weight_decay = group['weight_decay'] weight_decay = group['weight_decay']
momentum = group['momentum'] momentum = group['momentum']
lr = group['lr'] lr = group['lr']
lr_old = group.get('lr_old', lr)
lr_correct = lr / lr_old
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
...@@ -43,9 +45,11 @@ class NAG(Optimizer): ...@@ -43,9 +45,11 @@ class NAG(Optimizer):
if weight_decay != 0: if weight_decay != 0:
p.data.mul_(1 - weight_decay) p.data.mul_(1 - weight_decay)
p.data.add_(momentum * momentum, buf) p.data.add_(momentum * momentum * lr_correct, buf)
p.data.add_(-(1 + momentum) * lr, d_p) p.data.add_(-(1 + momentum) * lr, d_p)
buf.mul_(momentum).add_(-lr, d_p) buf.mul_(momentum * lr_correct).add_(-lr, d_p)
group['lr_old'] = lr
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment