"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6620eda357132bcd034c8b5c239fa4527e150c35"
Commit 8eb232ce authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/352

Differential Revision: D12956930

Pulled By: myleott

fbshipit-source-id: 39334a79544bac570feb04be9103269d7c1563f9
parent 2b13f3c0
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
""" """
Evaluate the perplexity of a trained language model. Evaluate the perplexity of a trained language model.
""" """
...@@ -12,7 +13,7 @@ Evaluate the perplexity of a trained language model. ...@@ -12,7 +13,7 @@ Evaluate the perplexity of a trained language model.
import numpy as np import numpy as np
import torch import torch
from fairseq import data, options, progress_bar, tasks, utils from fairseq import options, progress_bar, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer from fairseq.sequence_scorer import SequenceScorer
...@@ -22,14 +23,25 @@ class WordStat(object): ...@@ -22,14 +23,25 @@ class WordStat(object):
self.word = word self.word = word
self.is_bpe = is_bpe self.is_bpe = is_bpe
self.log_prob = 0 self.log_prob = 0
self.next_word_prob = 0
self.count = 0 self.count = 0
self.missing_next_words = 0
def add(self, log_prob):
def add(self, log_prob, next_word_prob):
""" increments counters for the sum of log probs of current word and next
word (given context ending at current word). Since the next word might be at the end of the example,
or it might be not counted because it is not an ending subword unit,
also keeps track of how many of those we have seen """
if next_word_prob is not None:
self.next_word_prob += next_word_prob
else:
self.missing_next_words += 1
self.log_prob += log_prob self.log_prob += log_prob
self.count += 1 self.count += 1
def __str__(self): def __str__(self):
return '{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob / self.count, self.is_bpe) return '{}\t{}\t{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob, self.is_bpe,
self.next_word_prob, self.count - self.missing_next_words)
def main(parsed_args): def main(parsed_args):
...@@ -62,6 +74,8 @@ def main(parsed_args): ...@@ -62,6 +74,8 @@ def main(parsed_args):
assert len(models) > 0 assert len(models) > 0
print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))
itr = task.get_batch_iterator( itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset), dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens or 36000, max_tokens=args.max_tokens or 36000,
...@@ -112,7 +126,7 @@ def main(parsed_args): ...@@ -112,7 +126,7 @@ def main(parsed_args):
print('| Skipping tokens with inf scores:', print('| Skipping tokens with inf scores:',
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()])) task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()] pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += utils.item(pos_scores.sum()) score_sum += pos_scores.sum().cpu()
count += pos_scores.numel() - skipped_toks count += pos_scores.numel() - skipped_toks
if args.output_word_probs or args.output_word_stats: if args.output_word_probs or args.output_word_stats:
...@@ -127,7 +141,16 @@ def main(parsed_args): ...@@ -127,7 +141,16 @@ def main(parsed_args):
is_bpe = True is_bpe = True
else: else:
word_prob.append((w, pos_scores[i].item())) word_prob.append((w, pos_scores[i].item()))
word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item())
next_prob = None
ind = i + 1
while ind < len(hypo['tokens']):
if pos_scores[ind].item() != 0:
next_prob = pos_scores[ind]
break
ind += 1
word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob)
is_bpe = False is_bpe = False
w = '' w = ''
if args.output_word_probs: if args.output_word_probs:
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
from .multiprocessing_pdb import pdb from .multiprocessing_pdb import pdb
__all__ = ['pdb'] __all__ = ['pdb']
__version__ = '0.6.0'
import fairseq.criterions import fairseq.criterions
import fairseq.models import fairseq.models
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from torch import nn
from fairseq import utils
from . import FairseqCriterion, register_criterion
@register_criterion('composite_loss')
class CompositeLoss(FairseqCriterion):
"""This is a composite loss that, given a list of model outputs and a list of targets,
computes an average of losses for each output-target pair"""
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
parser.add_argument('--underlying-criterion', type=str, metavar='VAL', required=True,
help='underlying criterion to use for the composite loss')
def __init__(self, args, task):
super().__init__(args, task)
saved_criterion = args.criterion
args.criterion = args.underlying_criterion
assert saved_criterion != args.underlying_criterion
self.underlying_criterion = task.build_criterion(args)
args.criterion = saved_criterion
class FakeModel(nn.Module):
def __init__(self, model, net_out, target):
super(CompositeLoss.FakeModel, self).__init__()
self.model = model
self.net_out = net_out
self.target = target
def forward(self, **unused):
return self.net_out
def get_targets(self, *unused):
return self.target
@property
def decoder(self):
return self.model.decoder
def forward(self, model, sample, reduce=True):
net_outputs = model(**sample['net_input'])
targets = sample['target']
bsz = targets[0].size(0)
loss = net_outputs[0][0].new(1 if reduce else bsz).zero_()
sample_size = 0
logging_output = {}
for o, t in zip(net_outputs[0], targets):
m = CompositeLoss.FakeModel(model, (o, net_outputs[1]), t)
l, ss, logging_output = self.underlying_criterion(m, sample, reduce)
loss += l
sample_size += ss
loss.div_(len(targets))
sample_size /= len(targets)
logging_output['loss'] = utils.item(loss.data) if reduce else loss.data
return loss, sample_size, logging_output
def _aggregate_logging_outputs(self, logging_outputs):
return self.underlying_criterion._aggregate_logging_outputs(logging_outputs)
...@@ -35,6 +35,15 @@ class FairseqCriterion(_Loss): ...@@ -35,6 +35,15 @@ class FairseqCriterion(_Loss):
"""Aggregate logging outputs from data parallel training.""" """Aggregate logging outputs from data parallel training."""
raise NotImplementedError raise NotImplementedError
def _aggregate_logging_outputs(self, logging_outputs):
"""An instance method version of :func:`aggregate_logging_outputs`.
This can be overridden if needed, but please be careful not to rely
on shared state when aggregating logging outputs otherwise you may
get incorrect results.
"""
return self.__class__.aggregate_logging_outputs(logging_outputs)
@staticmethod @staticmethod
def grad_denom(sample_sizes): def grad_denom(sample_sizes):
"""Compute the gradient denominator for a set of sample sizes.""" """Compute the gradient denominator for a set of sample sizes."""
......
import bisect import bisect
import numpy as np
from . import FairseqDataset from . import FairseqDataset
class ConcatDataset(FairseqDataset): class ConcatDataset(FairseqDataset):
@staticmethod @staticmethod
def cumsum(sequence): def cumsum(sequence, sample_ratios):
r, s = [], 0 r, s = [], 0
for e in sequence: for e, ratio in zip(sequence, sample_ratios):
l = len(e) l = ratio * len(e)
r.append(l + s) r.append(l + s)
s += l s += l
return r return r
def __init__(self, datasets): def __init__(self, datasets, sample_ratios=1):
super(ConcatDataset, self).__init__() super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable' assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets) self.datasets = list(datasets)
self.cummulative_sizes = self.cumsum(self.datasets) if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets)
self.sample_ratios = sample_ratios
self.cummulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.real_sizes = [len(d) for d in self.datasets]
def __len__(self): def __len__(self):
return self.cummulative_sizes[-1] return self.cummulative_sizes[-1]
...@@ -29,8 +34,13 @@ class ConcatDataset(FairseqDataset): ...@@ -29,8 +34,13 @@ class ConcatDataset(FairseqDataset):
sample_idx = idx sample_idx = idx
else: else:
sample_idx = idx - self.cummulative_sizes[dataset_idx - 1] sample_idx = idx - self.cummulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx]
return self.datasets[dataset_idx][sample_idx] return self.datasets[dataset_idx][sample_idx]
@property
def sizes(self):
return np.concatenate([np.tile(ds.sizes, sr) for ds, sr in zip(self.datasets, self.sample_ratios)])
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return all([d.supports_prefetch for d in self.datasets]) return all([d.supports_prefetch for d in self.datasets])
...@@ -38,5 +48,6 @@ class ConcatDataset(FairseqDataset): ...@@ -38,5 +48,6 @@ class ConcatDataset(FairseqDataset):
def prefetch(self, indices): def prefetch(self, indices):
frm = 0 frm = 0
for to, ds in zip(self.cummulative_sizes, self.datasets): for to, ds in zip(self.cummulative_sizes, self.datasets):
ds.prefetch([i - frm for i in indices if frm <= i < to]) real_size = len(ds)
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to frm = to
...@@ -120,7 +120,7 @@ class IndexedCachedDataset(IndexedDataset): ...@@ -120,7 +120,7 @@ class IndexedCachedDataset(IndexedDataset):
def prefetch(self, indices): def prefetch(self, indices):
if all(i in self.cache_index for i in indices): if all(i in self.cache_index for i in indices):
return return
indices.sort() indices = sorted(set(indices))
total_size = 0 total_size = 0
for i in indices: for i in indices:
total_size += self.data_offsets[i + 1] - self.data_offsets[i] total_size += self.data_offsets[i + 1] - self.data_offsets[i]
......
...@@ -245,3 +245,7 @@ class FairseqLanguageModel(BaseFairseqModel): ...@@ -245,3 +245,7 @@ class FairseqLanguageModel(BaseFairseqModel):
@property @property
def supported_targets(self): def supported_targets(self):
return {'future'} return {'future'}
def remove_head(self):
"""Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed"""
raise NotImplementedError()
...@@ -15,7 +15,7 @@ from fairseq import options ...@@ -15,7 +15,7 @@ from fairseq import options
from fairseq import utils from fairseq import utils
from fairseq.modules import ( from fairseq.modules import (
AdaptiveSoftmax, CharacterTokenEmbedder, LearnedPositionalEmbedding, MultiheadAttention, AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding SinusoidalPositionalEmbedding
) )
...@@ -178,6 +178,8 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -178,6 +178,8 @@ class TransformerLanguageModel(FairseqLanguageModel):
'Must be used with adaptive_loss criterion') 'Must be used with adaptive_loss criterion')
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections') help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--adaptive-softmax-factor', type=float, metavar='N',
help='adaptive input factor')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)') help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true', parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
...@@ -191,6 +193,18 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -191,6 +193,18 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='size of character embeddings') help='size of character embeddings')
parser.add_argument('--char-embedder-highway-layers', type=int, metavar='N', default=2, parser.add_argument('--char-embedder-highway-layers', type=int, metavar='N', default=2,
help='number of highway layers for character token embeddder') help='number of highway layers for character token embeddder')
parser.add_argument('--adaptive-input', default=False, action='store_true',
help='if set, uses adaptive input')
parser.add_argument('--adaptive-input-factor', type=float, metavar='N',
help='adaptive input factor')
parser.add_argument('--adaptive-input-cutoff', metavar='EXPR',
help='comma separated list of adaptive input cutoff points.')
parser.add_argument('--tie-adaptive-weights', action='store_true',
help='if set, ties the weights of adaptive softmax and adaptive input')
parser.add_argument('--tie-adaptive-proj', action='store_true',
help='if set, ties the projection weights of adaptive softmax and adaptive input')
parser.add_argument('--decoder-learned-pos', action='store_true',
help='use learned positional embeddings in the decoder')
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
...@@ -199,6 +213,10 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -199,6 +213,10 @@ class TransformerLanguageModel(FairseqLanguageModel):
# make sure all arguments are present in older models # make sure all arguments are present in older models
base_lm_architecture(args) base_lm_architecture(args)
if hasattr(args, 'no_tie_adaptive_proj') and args.no_tie_adaptive_proj == False:
# backward compatibility
args.tie_adaptive_proj = True
if not hasattr(args, 'max_source_positions'): if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.tokens_per_sample args.max_source_positions = args.tokens_per_sample
if not hasattr(args, 'max_target_positions'): if not hasattr(args, 'max_target_positions'):
...@@ -210,9 +228,20 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -210,9 +228,20 @@ class TransformerLanguageModel(FairseqLanguageModel):
args.decoder_embed_dim, args.decoder_embed_dim,
args.char_embedder_highway_layers, args.char_embedder_highway_layers,
) )
elif args.adaptive_input:
embed_tokens = AdaptiveInput(len(task.dictionary), task.dictionary.pad(), args.decoder_input_dim,
args.adaptive_input_factor, args.decoder_embed_dim,
options.eval_str_list(args.adaptive_input_cutoff, type=int))
else: else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad()) embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
if args.tie_adaptive_weights:
assert args.adaptive_input
assert args.adaptive_input_factor == args.adaptive_softmax_factor
assert args.adaptive_softmax_cutoff == args.adaptive_input_cutoff, '{} != {}'.format(
args.adaptive_softmax_cutoff, args.adaptive_input_cutoff)
assert args.decoder_input_dim == args.decoder_output_dim
decoder = TransformerDecoder(args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False) decoder = TransformerDecoder(args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False)
return TransformerLanguageModel(decoder) return TransformerLanguageModel(decoder)
...@@ -254,7 +283,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -254,7 +283,7 @@ class TransformerEncoder(FairseqEncoder):
self.register_buffer('version', torch.Tensor([2])) self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.encoder_normalize_before self.normalize = args.encoder_normalize_before
if self.normalize: if self.normalize:
self.layer_norm = LayerNorm(embed_dim) self.layer_norm = LayerNorm(embed_dim)
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
""" """
...@@ -366,10 +395,9 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -366,10 +395,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.max_target_positions = args.max_target_positions self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False, self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None
uniform=False) if embed_dim != input_embed_dim else None
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
args.max_target_positions, embed_dim, padding_idx, args.max_target_positions, embed_dim, padding_idx,
...@@ -385,14 +413,18 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -385,14 +413,18 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.adaptive_softmax = None self.adaptive_softmax = None
self.project_out_dim = Linear(embed_dim, output_embed_dim, self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \
bias=False, uniform=False) if embed_dim != output_embed_dim else None if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None
if args.adaptive_softmax_cutoff is not None: if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax( self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary), output_embed_dim, len(dictionary),
output_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int), options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout, dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
) )
elif not self.share_input_output_embed: elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim))
...@@ -400,7 +432,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -400,7 +432,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.register_buffer('version', torch.Tensor([2])) self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.decoder_normalize_before and final_norm self.normalize = args.decoder_normalize_before and final_norm
if self.normalize: if self.normalize:
self.layer_norm = LayerNorm(embed_dim) self.layer_norm = LayerNorm(embed_dim)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
""" """
...@@ -516,7 +548,6 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -516,7 +548,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.normalize = False self.normalize = False
state_dict['{}.version'.format(name)] = torch.Tensor([1]) state_dict['{}.version'.format(name)] = torch.Tensor([1])
return state_dict return state_dict
...@@ -728,12 +759,9 @@ def LayerNorm(embedding_dim): ...@@ -728,12 +759,9 @@ def LayerNorm(embedding_dim):
return m return m
def Linear(in_features, out_features, bias=True, uniform=True): def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias) m = nn.Linear(in_features, out_features, bias)
if uniform: nn.init.xavier_uniform_(m.weight)
nn.init.xavier_uniform_(m.weight)
else:
nn.init.xavier_normal_(m.weight)
if bias: if bias:
nn.init.constant_(m.bias, 0.) nn.init.constant_(m.bias, 0.)
return m return m
...@@ -757,6 +785,7 @@ def base_lm_architecture(args): ...@@ -757,6 +785,7 @@ def base_lm_architecture(args):
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None) args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0) args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0)
args.adaptive_softmax_factor = getattr(args, 'adaptive_softmax_factor', 4)
args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False) args.decoder_learned_pos = getattr(args, 'decoder_learned_pos', False)
args.character_embeddings = getattr(args, 'character_embeddings', False) args.character_embeddings = getattr(args, 'character_embeddings', False)
...@@ -767,6 +796,14 @@ def base_lm_architecture(args): ...@@ -767,6 +796,14 @@ def base_lm_architecture(args):
# The model training is not stable without this # The model training is not stable without this
args.decoder_normalize_before = True args.decoder_normalize_before = True
args.adaptive_input = getattr(args, 'adaptive_input', False)
args.adaptive_input_factor = getattr(args, 'adaptive_input_factor', 4)
args.adaptive_input_cutoff = getattr(args, 'adaptive_input_cutoff', None)
args.tie_adaptive_weights = getattr(args, 'tie_adaptive_weights', False)
args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', False)
@register_model_architecture('transformer_lm', 'transformer_lm_big') @register_model_architecture('transformer_lm', 'transformer_lm_big')
def transformer_lm_big(args): def transformer_lm_big(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from .adaptive_input import AdaptiveInput
from .adaptive_softmax import AdaptiveSoftmax from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
from .character_token_embedder import CharacterTokenEmbedder from .character_token_embedder import CharacterTokenEmbedder
...@@ -19,6 +20,7 @@ from .scalar_bias import ScalarBias ...@@ -19,6 +20,7 @@ from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
__all__ = [ __all__ = [
'AdaptiveInput',
'AdaptiveSoftmax', 'AdaptiveSoftmax',
'BeamableMM', 'BeamableMM',
'CharacterTokenEmbedder', 'CharacterTokenEmbedder',
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn.functional as F
from torch import nn
from typing import List
class AdaptiveInput(nn.Module):
def __init__(
self,
vocab_size: int,
padding_idx: int,
initial_dim: int,
factor: float,
output_dim: int,
cutoff: List[int],
):
super().__init__()
if vocab_size > cutoff[-1]:
cutoff = cutoff + [vocab_size]
else:
assert vocab_size == cutoff[
-1], 'cannot specify cutoff larger than vocab size'
self.cutoff = cutoff
self.embedding_dim = output_dim
self.padding_idx = padding_idx
self.embeddings = nn.ModuleList()
for i in range(len(self.cutoff)):
prev = self.cutoff[i - 1] if i > 0 else 0
size = self.cutoff[i] - prev
dim = int(initial_dim // (factor ** i))
seq = nn.Sequential(
nn.Embedding(size, dim, padding_idx),
nn.Linear(dim, output_dim, bias=False)
)
self.embeddings.append(seq)
def init_weights(m):
if isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0, std=m.weight.shape[1] ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
elif hasattr(m, 'weight'):
nn.init.xavier_uniform_(m.weight)
self.apply(init_weights)
self.register_buffer('_float_tensor', torch.FloatTensor(1))
def weights_for_band(self, band: int):
return self.embeddings[band][0].weight, self.embeddings[band][1].weight
def forward(self, input: torch.Tensor):
result = self._float_tensor.new(input.shape + (self.embedding_dim,))
for i in range(len(self.cutoff)):
mask = input.lt(self.cutoff[i])
if i > 0:
mask.mul_(input.ge(self.cutoff[i - 1]))
chunk_input = input[mask] - self.cutoff[i - 1]
else:
chunk_input = input[mask]
if mask.any():
result[mask] = self.embeddings[i](chunk_input)
return result
...@@ -5,12 +5,50 @@ ...@@ -5,12 +5,50 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import operator
import functools
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
class TiedLinear(nn.Module):
def __init__(self, weight, transpose):
super().__init__()
self.weight = weight
self.transpose = transpose
def forward(self, input):
return F.linear(input, self.weight.t() if self.transpose else self.weight)
class TiedHeadModule(nn.Module):
def __init__(self, weights, input_dim, num_classes):
super().__init__()
tied_emb, _ = weights
self.num_words, emb_dim = tied_emb.size()
self.word_proj = TiedLinear(tied_emb, transpose=False)
if input_dim != emb_dim:
self.word_proj = nn.Sequential(
nn.Linear(input_dim, emb_dim, bias=False),
self.word_proj,
)
self.class_proj = nn.Linear(input_dim, num_classes, bias=False)
self.out_dim = self.num_words + num_classes
self.register_buffer('_float_tensor', torch.FloatTensor(1))
def forward(self, input):
inp_sz = functools.reduce(operator.mul, input.shape[:-1], 1)
out = self._float_tensor.new(inp_sz, self.out_dim)
out[:, :self.num_words] = self.word_proj(input.view(inp_sz, -1))
out[:, self.num_words:] = self.class_proj(input.view(inp_sz, -1))
return out
class AdaptiveSoftmax(nn.Module): class AdaptiveSoftmax(nn.Module):
""" """
This is an implementation of the efficient softmax approximation for This is an implementation of the efficient softmax approximation for
...@@ -18,7 +56,7 @@ class AdaptiveSoftmax(nn.Module): ...@@ -18,7 +56,7 @@ class AdaptiveSoftmax(nn.Module):
approximation for GPUs" (http://arxiv.org/abs/1609.04309). approximation for GPUs" (http://arxiv.org/abs/1609.04309).
""" """
def __init__(self, vocab_size, input_dim, cutoff, dropout): def __init__(self, vocab_size, input_dim, cutoff, dropout, factor=4., adaptive_inputs=None, tie_proj=False):
super().__init__() super().__init__()
if vocab_size > cutoff[-1]: if vocab_size > cutoff[-1]:
...@@ -33,13 +71,19 @@ class AdaptiveSoftmax(nn.Module): ...@@ -33,13 +71,19 @@ class AdaptiveSoftmax(nn.Module):
self.cutoff = cutoff self.cutoff = cutoff
self.dropout = dropout self.dropout = dropout
self.input_dim = input_dim self.input_dim = input_dim
self.factor = factor
self.lsm = nn.LogSoftmax(dim=1) self.lsm = nn.LogSoftmax(dim=1)
self.head = nn.Linear(input_dim, output_dim, bias=False)
self._make_tail(True) if adaptive_inputs is not None:
self.head = TiedHeadModule(adaptive_inputs.weights_for_band(0), input_dim, len(cutoff) - 1)
else:
self.head = nn.Linear(input_dim, output_dim, bias=False)
self._make_tail(True, adaptive_inputs, tie_proj)
def init_weights(m): def init_weights(m):
if hasattr(m, 'weight'): if hasattr(m, 'weight') and not isinstance(m, TiedLinear) and not isinstance(m, TiedHeadModule):
nn.init.xavier_uniform_(m.weight) nn.init.xavier_uniform_(m.weight)
self.apply(init_weights) self.apply(init_weights)
...@@ -48,19 +92,33 @@ class AdaptiveSoftmax(nn.Module): ...@@ -48,19 +92,33 @@ class AdaptiveSoftmax(nn.Module):
# versions prior to 1 had a bug that offset indices on the head by 1 # versions prior to 1 had a bug that offset indices on the head by 1
self.buggy_offset = 0 self.buggy_offset = 0
def _make_tail(self, fix_exponent): def _make_tail(self, fix_exponent, adaptive_inputs=None, tie_proj=False):
extra_denom = 1 if fix_exponent else 0 extra_denom = 1 if fix_exponent else 0
self.tail = nn.ModuleList() self.tail = nn.ModuleList()
for i in range(len(self.cutoff) - 1): for i in range(len(self.cutoff) - 1):
self.tail.append( dim = int(self.input_dim // self.factor ** (i + extra_denom))
nn.Sequential(
nn.Linear(self.input_dim, self.input_dim // 4 ** (i + extra_denom), bias=False), tied_emb, tied_proj = adaptive_inputs.weights_for_band(i + 1) \
nn.Dropout(self.dropout), if adaptive_inputs is not None else (None, None)
nn.Linear(self.input_dim // 4 ** (i + extra_denom), self.cutoff[i + 1] - self.cutoff[i], bias=False)
) if tied_proj is not None:
if tie_proj:
proj = TiedLinear(tied_proj, transpose=True)
else:
proj = nn.Linear(tied_proj.size(0), tied_proj.size(1), bias=False)
else:
proj = nn.Linear(self.input_dim, dim, bias=False)
m = nn.Sequential(
proj,
nn.Dropout(self.dropout),
nn.Linear(dim, self.cutoff[i + 1] - self.cutoff[i], bias=False) \
if tied_emb is None else TiedLinear(tied_emb, transpose=False)
) )
self.tail.append(m)
def upgrade_state_dict_named(self, state_dict, name): def upgrade_state_dict_named(self, state_dict, name):
version_name = name + '.version' version_name = name + '.version'
if version_name not in state_dict: if version_name not in state_dict:
......
...@@ -5,12 +5,10 @@ ...@@ -5,12 +5,10 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn.utils.rnn import pad_sequence
from typing import List, Tuple from typing import List, Tuple
...@@ -47,10 +45,11 @@ class CharacterTokenEmbedder(torch.nn.Module): ...@@ -47,10 +45,11 @@ class CharacterTokenEmbedder(torch.nn.Module):
nn.Conv1d(char_embed_dim, out_c, kernel_size=width) nn.Conv1d(char_embed_dim, out_c, kernel_size=width)
) )
final_dim = sum(f[1] for f in filters) last_dim = sum(f[1] for f in filters)
self.highway = Highway(final_dim, highway_layers) self.highway = Highway(last_dim, highway_layers) if highway_layers > 0 else None
self.projection = nn.Linear(final_dim, word_embed_dim)
self.projection = nn.Linear(last_dim, word_embed_dim)
self.set_vocab(vocab, max_char_len) self.set_vocab(vocab, max_char_len)
self.reset_parameters() self.reset_parameters()
...@@ -84,7 +83,8 @@ class CharacterTokenEmbedder(torch.nn.Module): ...@@ -84,7 +83,8 @@ class CharacterTokenEmbedder(torch.nn.Module):
def reset_parameters(self): def reset_parameters(self):
nn.init.xavier_normal_(self.char_embeddings.weight) nn.init.xavier_normal_(self.char_embeddings.weight)
nn.init.xavier_normal_(self.symbol_embeddings) nn.init.xavier_normal_(self.symbol_embeddings)
nn.init.xavier_normal_(self.projection.weight) nn.init.xavier_uniform_(self.projection.weight)
nn.init.constant_(self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.) nn.init.constant_(self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.)
nn.init.constant_(self.projection.bias, 0.) nn.init.constant_(self.projection.bias, 0.)
...@@ -100,9 +100,8 @@ class CharacterTokenEmbedder(torch.nn.Module): ...@@ -100,9 +100,8 @@ class CharacterTokenEmbedder(torch.nn.Module):
chars[eos] = 0 chars[eos] = 0
unk = None unk = None
else: else:
self.word_to_char = self.word_to_char.type_as(input)
flat_words = input.view(-1) flat_words = input.view(-1)
chars = self.word_to_char[flat_words] chars = self.word_to_char[flat_words.type_as(self.word_to_char)].type_as(input)
pads = flat_words.eq(self.vocab.pad()) pads = flat_words.eq(self.vocab.pad())
eos = flat_words.eq(self.vocab.eos()) eos = flat_words.eq(self.vocab.eos())
unk = flat_words.eq(self.vocab.unk()) unk = flat_words.eq(self.vocab.unk())
...@@ -134,7 +133,10 @@ class CharacterTokenEmbedder(torch.nn.Module): ...@@ -134,7 +133,10 @@ class CharacterTokenEmbedder(torch.nn.Module):
x = F.relu(x) x = F.relu(x)
conv_result.append(x) conv_result.append(x)
conv_result = torch.cat(conv_result, dim=-1) x = torch.cat(conv_result, dim=-1)
conv_result = self.highway(conv_result)
if self.highway is not None:
x = self.highway(x)
x = self.projection(x)
return self.projection(conv_result) return x
...@@ -41,9 +41,20 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -41,9 +41,20 @@ class FP16Optimizer(optim.FairseqOptimizer):
super().__init__(args, params) super().__init__(args, params)
self.fp32_optimizer = fp32_optimizer self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params self.fp32_params = fp32_params
if getattr(args, 'fp16_scale_window', None) is None:
if len(args.update_freq) > 1:
raise ValueError(
'--fp16-scale-window must be given explicitly when using a '
'custom --update-freq schedule'
)
scale_window = 2**14 / args.distributed_world_size / args.update_freq[0]
else:
scale_window = args.fp16_scale_window
self.scaler = DynamicLossScaler( self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale, init_scale=args.fp16_init_scale,
scale_window=(2**14 / args.distributed_world_size), scale_window=scale_window,
) )
@staticmethod @staticmethod
......
...@@ -130,6 +130,8 @@ def get_parser(desc, default_task='translation'): ...@@ -130,6 +130,8 @@ def get_parser(desc, default_task='translation'):
parser.add_argument('--fp16', action='store_true', help='use FP16') parser.add_argument('--fp16', action='store_true', help='use FP16')
parser.add_argument('--fp16-init-scale', default=2**7, type=int, parser.add_argument('--fp16-init-scale', default=2**7, type=int,
help='default FP16 loss scale') help='default FP16 loss scale')
parser.add_argument('--fp16-scale-window', type=int,
help='number of updates before increasing loss scale')
# Task definitions can be found under fairseq/tasks/ # Task definitions can be found under fairseq/tasks/
parser.add_argument( parser.add_argument(
......
...@@ -66,7 +66,7 @@ class SequenceScorer(object): ...@@ -66,7 +66,7 @@ class SequenceScorer(object):
decoder_out = model.forward(**net_input) decoder_out = model.forward(**net_input)
attn = decoder_out[1] attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=False, sample=sample).data probs = model.get_normalized_probs(decoder_out, log_probs=len(self.models) == 1, sample=sample).data
if avg_probs is None: if avg_probs is None:
avg_probs = probs avg_probs = probs
else: else:
...@@ -77,12 +77,14 @@ class SequenceScorer(object): ...@@ -77,12 +77,14 @@ class SequenceScorer(object):
avg_attn = attn avg_attn = attn
else: else:
avg_attn.add_(attn) avg_attn.add_(attn)
avg_probs.div_(len(self.models)) if len(self.models) > 1:
avg_probs.log_() avg_probs.div_(len(self.models))
if avg_attn is not None: avg_probs.log_()
avg_attn.div_(len(self.models)) if avg_attn is not None:
avg_attn.div_(len(self.models))
avg_probs = avg_probs.gather( avg_probs = avg_probs.gather(
dim=2, dim=2,
index=sample['target'].data.unsqueeze(-1), index=sample['target'].data.unsqueeze(-1),
) )
return avg_probs.squeeze(2), avg_attn return avg_probs.squeeze(2), avg_attn
...@@ -188,7 +188,7 @@ class FairseqTask(object): ...@@ -188,7 +188,7 @@ class FairseqTask(object):
return criterion.__class__.grad_denom(sample_sizes) return criterion.__class__.grad_denom(sample_sizes)
def aggregate_logging_outputs(self, logging_outputs, criterion): def aggregate_logging_outputs(self, logging_outputs, criterion):
return criterion.__class__.aggregate_logging_outputs(logging_outputs) return criterion._aggregate_logging_outputs(logging_outputs)
def max_positions(self): def max_positions(self):
"""Return the max input length allowed by the task.""" """Return the max input length allowed by the task."""
......
...@@ -98,11 +98,11 @@ class LanguageModelingTask(FairseqTask): ...@@ -98,11 +98,11 @@ class LanguageModelingTask(FairseqTask):
args.self_target = not args.exclude_self_target args.self_target = not args.exclude_self_target
targets = [] targets = []
if args.self_target: if getattr(args, 'self_target', False):
targets.append('self') targets.append('self')
if args.future_target: if getattr(args, 'future_target', False):
targets.append('future') targets.append('future')
if args.past_target: if getattr(args, 'past_target', False):
targets.append('past') targets.append('past')
if len(targets) == 0: if len(targets) == 0:
# standard language modeling # standard language modeling
...@@ -166,7 +166,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -166,7 +166,7 @@ class LanguageModelingTask(FairseqTask):
self.datasets[split] = MonolingualDataset( self.datasets[split] = MonolingualDataset(
dataset, sizes, self.dictionary, self.output_dictionary, dataset, sizes, self.dictionary, self.output_dictionary,
add_eos_for_other_targets=add_eos_for_other_targets, shuffle=False, add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True,
targets=self.targets, targets=self.targets,
) )
......
...@@ -148,20 +148,15 @@ class TranslationTask(FairseqTask): ...@@ -148,20 +148,15 @@ class TranslationTask(FairseqTask):
if len(src_datasets) == 1: if len(src_datasets) == 1:
src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0] src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
src_sizes = src_dataset.sizes
tgt_sizes = tgt_dataset.sizes
else: else:
if self.args.upsample_primary > 1: sample_ratios = [1] * len(src_datasets)
src_datasets.extend([src_datasets[0]] * (self.args.upsample_primary - 1)) sample_ratios[0] = self.args.upsample_primary
tgt_datasets.extend([tgt_datasets[0]] * (self.args.upsample_primary - 1)) src_dataset = ConcatDataset(src_datasets, sample_ratios)
src_dataset = ConcatDataset(src_datasets) tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios)
tgt_dataset = ConcatDataset(tgt_datasets)
src_sizes = np.concatenate([ds.sizes for ds in src_datasets])
tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets])
self.datasets[split] = LanguagePairDataset( self.datasets[split] = LanguagePairDataset(
src_dataset, src_sizes, self.src_dict, src_dataset, src_dataset.sizes, self.src_dict,
tgt_dataset, tgt_sizes, self.tgt_dict, tgt_dataset, tgt_dataset.sizes, self.tgt_dict,
left_pad_source=self.args.left_pad_source, left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target, left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions, max_source_positions=self.args.max_source_positions,
......
...@@ -211,10 +211,10 @@ class Trainer(object): ...@@ -211,10 +211,10 @@ class Trainer(object):
return None return None
# aggregate logging outputs and sample sizes # aggregate logging outputs and sample sizes
sample_size = self.task.grad_denom(sample_sizes, self.criterion)
logging_output = self.task.aggregate_logging_outputs( logging_output = self.task.aggregate_logging_outputs(
logging_outputs, self.criterion logging_outputs, self.criterion
) )
sample_size = self.task.grad_denom(sample_sizes, self.criterion)
if not all(k in logging_output for k in ['ntokens', 'nsentences']): if not all(k in logging_output for k in ['ntokens', 'nsentences']):
raise Exception(( raise Exception((
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment