"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "365e8461ac9045da3210b79522302d7706d943ad"
Commit 8441cbf3 authored by Peng-Jen Chen's avatar Peng-Jen Chen Committed by Facebook Github Bot
Browse files

Manually port pull request 385

Summary:
Manually port fairinternal fairseq-py pull request #385 [1] to fbcode.

Resolve the merge conflict of removing fp16_trainer per offline discussion with Myle. Also updated codes to make generate.py works.

[1] https://github.com/fairinternal/fairseq-py/pull/385/commits/18fa6e154781cf0c4b1596429dba7e753a545069

Reviewed By: liezl200

Differential Revision: D10052908

fbshipit-source-id: c3c378d78dc1e9ac087c815f359e78c0048ff2f5
parent 0a628401
...@@ -12,6 +12,7 @@ from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemo ...@@ -12,6 +12,7 @@ from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemo
from .append_eos_dataset import AppendEosDataset from .append_eos_dataset import AppendEosDataset
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .token_block_dataset import TokenBlockDataset from .token_block_dataset import TokenBlockDataset
from .iterators import ( from .iterators import (
...@@ -35,6 +36,7 @@ __all__ = [ ...@@ -35,6 +36,7 @@ __all__ = [
'IndexedRawTextDataset', 'IndexedRawTextDataset',
'LanguagePairDataset', 'LanguagePairDataset',
'MonolingualDataset', 'MonolingualDataset',
'RoundRobinZipDatasets',
'ShardedIterator', 'ShardedIterator',
'TokenBlockDataset', 'TokenBlockDataset',
] ]
...@@ -87,6 +87,13 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False): ...@@ -87,6 +87,13 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
def check_size(idx): def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int): if isinstance(max_positions, float) or isinstance(max_positions, int):
return size_fn(idx) <= max_positions return size_fn(idx) <= max_positions
elif isinstance(max_positions, dict):
idx_size = size_fn(idx)
assert isinstance(idx_size, dict)
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
return all(
idx_size[key] <= max_positions[key] for key in intersect_keys
)
else: else:
return all(a is None or b is None or a <= b return all(a is None or b is None or a <= b
for a, b in zip(size_fn(idx), max_positions)) for a, b in zip(size_fn(idx), max_positions))
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import OrderedDict
import numpy as np
from . import FairseqDataset
class RoundRobinZipDatasets(FairseqDataset):
"""Zip multiple FairseqDatasets together, repeating shorter datasets in a
round-robin fashion to match the length of the longest one.
Args:
datasets: a dictionary of FairseqDatasets
eval_key: an optional key used at evaluation time that causes this
instance to pass-through batches from `datasets[eval_key]`.
"""
def __init__(self, datasets, eval_key=None):
super().__init__()
assert isinstance(datasets, OrderedDict)
self.datasets = datasets
self.eval_key = eval_key
self.longest_dataset = None
self.longest_dataset_key = None
for key, dataset in datasets.items():
assert isinstance(dataset, FairseqDataset)
if self.longest_dataset is None or len(dataset) > len(self.longest_dataset):
self.longest_dataset = dataset
self.longest_dataset_key = key
self._ordered_indices = OrderedDict([
(key, dataset.ordered_indices())
for key, dataset in datasets.items()
])
def _map_index(self, key, index):
return self._ordered_indices[key][index % len(self.datasets[key])]
def __getitem__(self, index):
if self.eval_key is None:
return OrderedDict([
(key, dataset[self._map_index(key, index)])
for key, dataset in self.datasets.items()
])
else:
# at evaluation time it's useful to pass-through batches from a single key
return self.datasets[self.eval_key][self._map_index(self.eval_key, index)]
def __len__(self):
return len(self.longest_dataset)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
if self.eval_key is None:
return OrderedDict([
(key, dataset.collater([sample[key] for sample in samples]))
for key, dataset in self.datasets.items()
])
else:
# at evaluation time it's useful to pass-through batches from a single key
return self.datasets[self.eval_key].collater(samples)
def get_dummy_batch(self, max_tokens, max_positions):
if self.eval_key is None:
# TODO should max_tokens be used independently for each batch like this?
return OrderedDict([
(key, dataset.get_dummy_batch(max_tokens, max_positions[key]))
for key, dataset in self.datasets.items()
])
else:
# at evaluation time it's useful to return a single batch directly
return self.datasets[self.eval_key].get_dummy_batch(max_tokens, max_positions[self.eval_key])
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
# TODO make it configurable whether to use max() or sum() here
return max(
dataset.num_tokens(self._map_index(key, index))
for key, dataset in self.datasets.items()
)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return {
key: dataset.size(self._map_index(key, index))
for key, dataset in self.datasets.items()
}
def ordered_indices(self):
"""Ordered indices for batching."""
return np.arange(len(self))
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
return all(
dataset.valid_size(self._map_index(key, index), max_positions[key])
for key, dataset in self.datasets.items()
)
...@@ -12,7 +12,12 @@ import os ...@@ -12,7 +12,12 @@ import os
from .fairseq_decoder import FairseqDecoder # noqa: F401 from .fairseq_decoder import FairseqDecoder # noqa: F401
from .fairseq_encoder import FairseqEncoder # noqa: F401 from .fairseq_encoder import FairseqEncoder # noqa: F401
from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401 from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401 from .fairseq_model import (
BaseFairseqModel,
FairseqModel, # noqa: F401
FairseqMultiModel, # noqa: F401
FairseqLanguageModel, # noqa: F401
)
from .composite_encoder import CompositeEncoder # noqa: F401 from .composite_encoder import CompositeEncoder # noqa: F401
from .distributed_fairseq_model import DistributedFairseqModel # noqa: F401 from .distributed_fairseq_model import DistributedFairseqModel # noqa: F401
......
...@@ -168,6 +168,48 @@ class FairseqModel(BaseFairseqModel): ...@@ -168,6 +168,48 @@ class FairseqModel(BaseFairseqModel):
return (self.encoder.max_positions(), self.decoder.max_positions()) return (self.encoder.max_positions(), self.decoder.max_positions())
class FairseqMultiModel(BaseFairseqModel):
"""Base class for combining multiple encoder-decoder models."""
def __init__(self, encoders, decoders):
super().__init__()
assert encoders.keys() == decoders.keys()
self.keys = list(encoders.keys())
for key in self.keys:
assert isinstance(encoders[key], FairseqEncoder)
assert isinstance(decoders[key], FairseqDecoder)
self.models = nn.ModuleDict({
key: FairseqModel(encoders[key], decoders[key])
for key in self.keys
})
def forward(self, src_tokens, src_lengths, prev_output_tokens):
decoder_outs = {}
for key in self.keys:
encoder_out = self.models[key].encoder(src_tokens, src_lengths)
decoder_outs[key] = self.models[key].decoder(prev_output_tokens, encoder_out)
return decoder_outs
def max_positions(self):
"""Maximum length supported by the model."""
return {
key: (self.models[key].encoder.max_positions(), self.models[key].decoder.max_positions())
for key in self.keys
}
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return min(model.decoder.max_positions() for model in self.models.values())
@property
def encoder(self):
return self.models[self.keys[0]].encoder
@property
def decoder(self):
return self.models[self.keys[0]].decoder
class FairseqLanguageModel(BaseFairseqModel): class FairseqLanguageModel(BaseFairseqModel):
"""Base class for decoder-only models. """Base class for decoder-only models.
......
...@@ -95,14 +95,14 @@ class LSTMModel(FairseqModel): ...@@ -95,14 +95,14 @@ class LSTMModel(FairseqModel):
if args.share_all_embeddings: if args.share_all_embeddings:
# double check all parameters combinations are valid # double check all parameters combinations are valid
if task.source_dictionary != task.target_dictionary: if task.source_dictionary != task.target_dictionary:
raise RuntimeError('--share-all-embeddings requires a joint dictionary') raise ValueError('--share-all-embeddings requires a joint dictionary')
if args.decoder_embed_path and ( if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path): args.decoder_embed_path != args.encoder_embed_path):
raise RuntimeError( raise ValueError(
'--share-all-embed not compatible with --decoder-embed-path' '--share-all-embed not compatible with --decoder-embed-path'
) )
if args.encoder_embed_dim != args.decoder_embed_dim: if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError( raise ValueError(
'--share-all-embeddings requires --encoder-embed-dim to ' '--share-all-embeddings requires --encoder-embed-dim to '
'match --decoder-embed-dim' 'match --decoder-embed-dim'
) )
...@@ -120,7 +120,7 @@ class LSTMModel(FairseqModel): ...@@ -120,7 +120,7 @@ class LSTMModel(FairseqModel):
# one last double check of parameter combinations # one last double check of parameter combinations
if args.share_decoder_input_output_embed and ( if args.share_decoder_input_output_embed and (
args.decoder_embed_dim != args.decoder_out_embed_dim): args.decoder_embed_dim != args.decoder_out_embed_dim):
raise RuntimeError( raise ValueError(
'--share-decoder-input-output-embeddings requires ' '--share-decoder-input-output-embeddings requires '
'--decoder-embed-dim to match --decoder-out-embed-dim' '--decoder-embed-dim to match --decoder-out-embed-dim'
) )
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import OrderedDict
from fairseq import utils
from fairseq.tasks.multilingual_translation import MultilingualTranslationTask
from . import FairseqMultiModel, register_model, register_model_architecture
from .transformer import (
base_architecture,
Embedding,
TransformerModel,
TransformerEncoder,
TransformerDecoder,
)
@register_model('multilingual_transformer')
class MultilingualTransformerModel(FairseqMultiModel):
"""Train Transformer models for multiple language pairs simultaneously.
Requires `--task multilingual_translation`.
We inherit all arguments from TransformerModel and assume that all language
pairs use a single Transformer architecture. In addition, we provide several
options that are specific to the multilingual setting.
Args:
--share-encoder-embeddings: share encoder embeddings across all source languages
--share-decoder-embeddings: share decoder embeddings across all target languages
--share-encoders: share all encoder params (incl. embeddings) across all source languages
--share-decoders: share all decoder params (incl. embeddings) across all target languages
"""
def __init__(self, encoders, decoders):
super().__init__(encoders, decoders)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
TransformerModel.add_args(parser)
parser.add_argument('--share-encoder-embeddings', action='store_true',
help='share encoder embeddings across languages')
parser.add_argument('--share-decoder-embeddings', action='store_true',
help='share decoder embeddings across languages')
parser.add_argument('--share-encoders', action='store_true',
help='share encoders across languages')
parser.add_argument('--share-decoders', action='store_true',
help='share decoders across languages')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
assert isinstance(task, MultilingualTranslationTask)
# make sure all arguments are present in older models
base_multilingual_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = 1024
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = 1024
src_langs = [lang_pair.split('-')[0] for lang_pair in args.lang_pairs]
tgt_langs = [lang_pair.split('-')[1] for lang_pair in args.lang_pairs]
if args.share_encoders:
args.share_encoder_embeddings = True
if args.share_decoders:
args.share_decoder_embeddings = True
def build_embedding(dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
# build shared embeddings (if applicable)
shared_encoder_embed_tokens, shared_decoder_embed_tokens = None, None
if args.share_all_embeddings:
shared_dict = task.dicts[task.langs[0]]
if any(dict != shared_dict for dict in task.dicts.values()):
raise ValueError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path')
shared_encoder_embed_tokens = build_embedding(
shared_dict, args.encoder_embed_dim, args.encoder_embed_path
)
shared_decoder_embed_tokens = shared_encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
if args.share_encoder_embeddings:
shared_dict = task.dicts[src_langs[0]]
if any(task.dicts[src_lang] != shared_dict for src_lang in src_langs):
raise ValueError('--share-encoder-embeddings requires a joined source dictionary')
shared_encoder_embed_tokens = build_embedding(
shared_dict, args.encoder_embed_dim, args.encoder_embed_path
)
if args.share_decoder_embeddings:
shared_dict = task.dicts[tgt_langs[0]]
if any(task.dicts[tgt_lang] != shared_dict for tgt_lang in src_langs):
raise ValueError('--share-decoder-embeddings requires a joined target dictionary')
shared_decoder_embed_tokens = build_embedding(
shared_dict, args.decoder_embed_dim, args.decoder_embed_path
)
# encoders/decoders for each language
lang_encoders, lang_decoders = {}, {}
def get_encoder(lang):
if lang not in lang_encoders:
if shared_encoder_embed_tokens is not None:
encoder_embed_tokens = shared_encoder_embed_tokens
else:
encoder_embed_tokens = build_embedding(
task.dicts[lang], args.encoder_embed_dim, args.encoder_embed_path
)
lang_encoders[lang] = TransformerEncoder(args, task.dicts[lang], encoder_embed_tokens)
return lang_encoders[lang]
def get_decoder(lang):
if lang not in lang_decoders:
if shared_decoder_embed_tokens is not None:
decoder_embed_tokens = shared_decoder_embed_tokens
else:
decoder_embed_tokens = build_embedding(
task.dicts[lang], args.decoder_embed_dim, args.decoder_embed_path
)
lang_decoders[lang] = TransformerDecoder(args, task.dicts[lang], decoder_embed_tokens)
return lang_decoders[lang]
# shared encoders/decoders (if applicable)
shared_encoder, shared_decoder = None, None
if args.share_encoders:
shared_encoder = get_encoder(src_langs[0])
if args.share_decoders:
shared_decoder = get_decoder(tgt_langs[0])
encoders, decoders = OrderedDict(), OrderedDict()
for lang_pair, src, tgt in zip(args.lang_pairs, src_langs, tgt_langs):
encoders[lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(src)
decoders[lang_pair] = shared_decoder if shared_decoder is not None else get_decoder(tgt)
return MultilingualTransformerModel(encoders, decoders)
@register_model_architecture('multilingual_transformer', 'multilingual_transformer')
def base_multilingual_architecture(args):
base_architecture(args)
args.share_encoder_embeddings = getattr(args, 'share_encoder_embeddings', False)
args.share_decoder_embeddings = getattr(args, 'share_decoder_embeddings', False)
args.share_encoders = getattr(args, 'share_encoders', False)
args.share_decoders = getattr(args, 'share_decoders', False)
@register_model_architecture('multilingual_transformer', 'multilingual_transformer_iwslt_de_en')
def multilingual_transformer_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 1024)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 1024)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
base_multilingual_architecture(args)
...@@ -120,13 +120,13 @@ class TransformerModel(FairseqModel): ...@@ -120,13 +120,13 @@ class TransformerModel(FairseqModel):
if args.share_all_embeddings: if args.share_all_embeddings:
if src_dict != tgt_dict: if src_dict != tgt_dict:
raise RuntimeError('--share-all-embeddings requires a joined dictionary') raise ValueError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim: if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError( raise ValueError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim') '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
if args.decoder_embed_path and ( if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path): args.decoder_embed_path != args.encoder_embed_path):
raise RuntimeError('--share-all-embeddings not compatible with --decoder-embed-path') raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path')
encoder_embed_tokens = build_embedding( encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path src_dict, args.encoder_embed_dim, args.encoder_embed_path
) )
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from fairseq.data import data_utils, FairseqDataset, iterators from fairseq.data import data_utils, FairseqDataset, iterators
import torch
class FairseqTask(object): class FairseqTask(object):
...@@ -32,7 +33,7 @@ class FairseqTask(object): ...@@ -32,7 +33,7 @@ class FairseqTask(object):
""" """
return cls(args) return cls(args)
def load_dataset(self, split, combine=False): def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split. """Load a given dataset split.
Args: Args:
...@@ -143,18 +144,51 @@ class FairseqTask(object): ...@@ -143,18 +144,51 @@ class FairseqTask(object):
from fairseq import criterions from fairseq import criterions
return criterions.build_criterion(args, self) return criterions.build_criterion(args, self)
def get_loss(self, model, criterion, sample): def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
""" """
Return the loss as computed by *criterion* for the given *model* and Do forward and backward, and return the loss as computed by *criterion*
*sample*. for the given *model* and *sample*.
Args: Args:
model (~fairseq.models.BaseFairseqModel): the model
criterion (~fairseq.criterions.FairseqCriterion): the criterion
sample (dict): the mini-batch. The format is defined by the sample (dict): the mini-batch. The format is defined by the
:class:`~fairseq.data.FairseqDataset`. :class:`~fairseq.data.FairseqDataset`.
model (~fairseq.models.BaseFairseqModel): the model
criterion (~fairseq.criterions.FairseqCriterion): the criterion
optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
ignore_grad (bool): multiply loss by 0 if this is set to True
Returns:
tuple:
- the loss
- the sample size, which is used as the denominator for the
gradient
- logging outputs to display while training
""" """
return criterion(model, sample) model.train()
loss, sample_size, logging_output = criterion(model, sample)
if ignore_grad:
loss *= 0
optimizer.backward(loss)
return loss, sample_size, logging_output
def valid_step(self, sample, model, criterion):
model.eval()
with torch.no_grad():
loss, sample_size, logging_output = criterion(model, sample)
return loss, sample_size, logging_output
def init_logging_output(self, sample):
return {
'ntokens': sample['ntokens'] if sample is not None else 0,
'nsentences': sample['target'].size(0) if sample is not None else 0,
}
def grad_denom(self, sample_sizes, criterion):
return criterion.__class__.grad_denom(sample_sizes)
def aggregate_logging_outputs(self, logging_outputs, criterion):
return criterion.__class__.aggregate_logging_outputs(logging_outputs)
def max_positions(self): def max_positions(self):
"""Return the max input length allowed by the task.""" """Return the max input length allowed by the task."""
......
...@@ -119,8 +119,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -119,8 +119,7 @@ class LanguageModelingTask(FairseqTask):
return model return model
def load_dataset(self, split, combine=False, **kwargs):
def load_dataset(self, split, combine=False):
"""Load a given dataset split. """Load a given dataset split.
Args: Args:
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import OrderedDict
import os
import torch
from fairseq import options
from fairseq.data import (
Dictionary, LanguagePairDataset, IndexedInMemoryDataset,
IndexedRawTextDataset, RoundRobinZipDatasets,
)
from fairseq.models import FairseqMultiModel
from . import FairseqTask, register_task
@register_task('multilingual_translation')
class MultilingualTranslationTask(FairseqTask):
"""A task for training multiple translation models simultaneously.
We iterate round-robin over batches from multiple language pairs, ordered
according to the `--lang-pairs` argument.
The training loop is roughly:
for i in range(len(epoch)):
for lang_pair in args.lang_pairs:
batch = next_batch_for_lang_pair(lang_pair)
loss = criterion(model_for_lang_pair(lang_pair), batch)
loss.backward()
optimizer.step()
In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset
(e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that
implements the `FairseqMultiModel` interface.
During inference it is required to specify a single `--source-lang` and
`--target-lang`, instead of `--lang-pairs`.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', help='path to data directory')
parser.add_argument('--lang-pairs', default=None, metavar='PAIRS',
help='comma-separated list of language pairs (in training order): en-de,en-fr,de-fr')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language (only needed for inference)')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language (only needed for inference)')
parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left (default: True)')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
help='pad the target on the left (default: False)')
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
def __init__(self, args, dicts, training):
super().__init__(args)
self.dicts = dicts
self.langs = list(dicts.keys())
self.training = training
@classmethod
def setup_task(cls, args, **kwargs):
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
if args.source_lang is not None or args.target_lang is not None:
if args.lang_pairs is not None:
raise ValueError(
'--source-lang/--target-lang implies generation, which is '
'incompatible with --lang-pairs'
)
training = False
args.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
else:
training = True
args.lang_pairs = args.lang_pairs.split(',')
args.source_lang, args.target_lang = args.lang_pairs[0].split('-')
langs = list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')})
# load dictionaries
dicts = OrderedDict()
for lang in langs:
dicts[lang] = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(lang)))
if len(dicts) > 0:
assert dicts[lang].pad() == dicts[langs[0]].pad()
assert dicts[lang].eos() == dicts[langs[0]].eos()
assert dicts[lang].unk() == dicts[langs[0]].unk()
print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang])))
return cls(args, dicts, training)
def load_dataset(self, split, **kwargs):
"""Load a dataset split."""
def split_exists(split, src, tgt, lang):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename):
return True
return False
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedInMemoryDataset.exists(path):
return IndexedInMemoryDataset(path, fix_lua_indexing=True)
return None
def sort_lang_pair(lang_pair):
return '-'.join(sorted(lang_pair.split('-')))
src_datasets, tgt_datasets = {}, {}
for lang_pair in set(map(sort_lang_pair, self.args.lang_pairs)):
src, tgt = lang_pair.split('-')
if split_exists(split, src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(split, tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
else:
continue
src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src])
tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt])
print('| {} {} {} examples'.format(self.args.data, split, len(src_datasets[lang_pair])))
if len(src_datasets) == 0:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
def language_pair_dataset(lang_pair):
src, tgt = lang_pair.split('-')
if lang_pair in src_datasets:
src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
else:
lang_pair = sort_lang_pair(lang_pair)
tgt_dataset, src_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
return LanguagePairDataset(
src_dataset, src_dataset.sizes, self.dicts[src],
tgt_dataset, tgt_dataset.sizes, self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
)
self.datasets[split] = RoundRobinZipDatasets(
OrderedDict([
(lang_pair, language_pair_dataset(lang_pair))
for lang_pair in self.args.lang_pairs
]),
eval_key=None if self.training else self.args.lang_pairs[0],
)
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
if not isinstance(model, FairseqMultiModel):
raise ValueError('MultilingualTranslationTask requires a FairseqMultiModel architecture')
return model
def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
model.train()
agg_loss, agg_sample_size, agg_logging_output = 0., 0., {}
for lang_pair in self.args.lang_pairs:
if sample[lang_pair] is None or len(sample[lang_pair]) == 0:
continue
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
if ignore_grad:
loss *= 0
optimizer.backward(loss)
agg_loss += loss.detach().item()
# TODO make summing of the sample sizes configurable
agg_sample_size += sample_size
agg_logging_output[lang_pair] = logging_output
return agg_loss, agg_sample_size, agg_logging_output
def valid_step(self, sample, model, criterion):
model.eval()
with torch.no_grad():
agg_loss, agg_sample_size, agg_logging_output = 0., 0., {}
for lang_pair in self.args.lang_pairs:
if sample[lang_pair] is None or len(sample[lang_pair]) == 0:
continue
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
agg_loss += loss.data.item()
# TODO make summing of the sample sizes configurable
agg_sample_size += sample_size
agg_logging_output[lang_pair] = logging_output
return agg_loss, agg_sample_size, agg_logging_output
def init_logging_output(self, sample):
return {
'ntokens': sum(
sample_lang.get('ntokens', 0)
for sample_lang in sample.values()
) if sample is not None else 0,
'nsentences': sum(
sample_lang['target'].size(0) if 'target' in sample_lang else 0
for sample_lang in sample.values()
) if sample is not None else 0,
}
def grad_denom(self, sample_sizes, criterion):
return criterion.__class__.grad_denom(sample_sizes)
def aggregate_logging_outputs(self, logging_outputs, criterion):
# aggregate logging outputs for each language pair
agg_logging_outputs = {
lang_pair: criterion.__class__.aggregate_logging_outputs([
logging_output.get(lang_pair, {}) for logging_output in logging_outputs
])
for lang_pair in self.args.lang_pairs
}
def sum_over_languages(key):
return sum(logging_output[key] for logging_output in agg_logging_outputs.values())
# flatten logging outputs
flat_logging_output = {
'{}:{}'.format(lang_pair, k): v
for lang_pair, agg_logging_output in agg_logging_outputs.items()
for k, v in agg_logging_output.items()
}
flat_logging_output['loss'] = sum_over_languages('loss')
flat_logging_output['nll_loss'] = sum_over_languages('nll_loss')
flat_logging_output['sample_size'] = sum_over_languages('sample_size')
flat_logging_output['nsentences'] = sum_over_languages('nsentences')
flat_logging_output['ntokens'] = sum_over_languages('ntokens')
return flat_logging_output
@property
def source_dictionary(self):
return self.dicts[self.args.source_lang]
@property
def target_dictionary(self):
return self.dicts[self.args.target_lang]
...@@ -93,7 +93,7 @@ class TranslationTask(FairseqTask): ...@@ -93,7 +93,7 @@ class TranslationTask(FairseqTask):
return cls(args, src_dict, tgt_dict) return cls(args, src_dict, tgt_dict)
def load_dataset(self, split, combine=False): def load_dataset(self, split, combine=False, **kwargs):
"""Load a given dataset split. """Load a given dataset split.
Args: Args:
...@@ -144,9 +144,6 @@ class TranslationTask(FairseqTask): ...@@ -144,9 +144,6 @@ class TranslationTask(FairseqTask):
if not combine: if not combine:
break break
assert len(src_datasets) == len(tgt_datasets) assert len(src_datasets) == len(tgt_datasets)
if len(src_datasets) == 1: if len(src_datasets) == 1:
......
...@@ -9,8 +9,7 @@ ...@@ -9,8 +9,7 @@
Train a network across multiple GPUs. Train a network across multiple GPUs.
""" """
from collections import defaultdict, OrderedDict from collections import OrderedDict
import contextlib
from itertools import chain from itertools import chain
import torch import torch
...@@ -171,13 +170,11 @@ class Trainer(object): ...@@ -171,13 +170,11 @@ class Trainer(object):
ignore_grad = False ignore_grad = False
try: try:
# forward # forward and backward
loss, sample_size, logging_output = self.task.get_loss( loss, sample_size, logging_output = self.task.train_step(
self.model, self.criterion, sample, sample, self.model, self.criterion, self.optimizer,
ignore_grad
) )
if ignore_grad:
loss *= 0
if self.args.distributed_world_size > 1: if self.args.distributed_world_size > 1:
# only all-reduce gradients in the last backwards pass # only all-reduce gradients in the last backwards pass
if i < len(samples) - 1: if i < len(samples) - 1:
...@@ -185,9 +182,6 @@ class Trainer(object): ...@@ -185,9 +182,6 @@ class Trainer(object):
else: else:
self.model.need_reduction = True self.model.need_reduction = True
# backward
self.optimizer.backward(loss)
if not ignore_grad: if not ignore_grad:
logging_outputs.append(logging_output) logging_outputs.append(logging_output)
sample_sizes.append(sample_size) sample_sizes.append(sample_size)
...@@ -217,14 +211,16 @@ class Trainer(object): ...@@ -217,14 +211,16 @@ class Trainer(object):
return None return None
# aggregate logging outputs and sample sizes # aggregate logging outputs and sample sizes
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) sample_size = self.task.grad_denom(sample_sizes, self.criterion)
sample_size = self.criterion.__class__.grad_denom(sample_sizes) logging_output = self.task.aggregate_logging_outputs(
logging_outputs, self.criterion
)
if not all(k in logging_output for k in ['ntokens', 'nsentences']): if not all(k in logging_output for k in ['ntokens', 'nsentences']):
raise Exception(( raise Exception((
'Please update the {}.aggregate_logging_outputs() method to ' 'Please update the {}.aggregate_logging_outputs() method to '
'return ntokens and nsentences' 'return ntokens and nsentences'
).format(self.criterion.__class__.__name__)) ).format(self.task.__class__.__name__))
try: try:
# normalize grads by sample size # normalize grads by sample size
...@@ -281,8 +277,8 @@ class Trainer(object): ...@@ -281,8 +277,8 @@ class Trainer(object):
ignore_results = False ignore_results = False
try: try:
_loss, sample_size, logging_output = self.task.get_loss( _loss, sample_size, logging_output = self.task.valid_step(
self.model, self.criterion, sample, sample, self.model, self.criterion
) )
except RuntimeError as e: except RuntimeError as e:
if 'out of memory' in str(e) and not raise_oom: if 'out of memory' in str(e) and not raise_oom:
...@@ -310,8 +306,12 @@ class Trainer(object): ...@@ -310,8 +306,12 @@ class Trainer(object):
sample_size = [sample_size] sample_size = [sample_size]
# aggregate logging outputs and sample sizes # aggregate logging outputs and sample sizes
logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_output) logging_output = self.task.aggregate_logging_outputs(
sample_size = self.criterion.__class__.grad_denom(sample_size) logging_output, self.criterion
)
sample_size = self.task.grad_denom(
sample_size, self.criterion
)
# update meters for validation # update meters for validation
ntokens = logging_output.get('ntokens', 0) ntokens = logging_output.get('ntokens', 0)
......
...@@ -121,7 +121,6 @@ def train(args, trainer, task, epoch_itr): ...@@ -121,7 +121,6 @@ def train(args, trainer, task, epoch_itr):
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(',')[0] first_valid = args.valid_subset.split(',')[0]
max_update = args.max_update or math.inf max_update = args.max_update or math.inf
num_batches = len(epoch_itr)
for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch): for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
log_output = trainer.train_step(samples) log_output = trainer.train_step(samples)
if log_output is None: if log_output is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment