Commit 7633129b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#283)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/283

Pull Request resolved: https://github.com/pytorch/fairseq/pull/428

Differential Revision: D13564190

Pulled By: myleott

fbshipit-source-id: 3b62282d7069c288f5bdd1dd2c120788cee4abb5
parent 0cb87130
...@@ -22,6 +22,7 @@ class SequenceGenerator(object): ...@@ -22,6 +22,7 @@ class SequenceGenerator(object):
match_source_len=False, no_repeat_ngram_size=0 match_source_len=False, no_repeat_ngram_size=0
): ):
"""Generates translations of a given source sentence. """Generates translations of a given source sentence.
Args: Args:
beam_size (int, optional): beam width (default: 1) beam_size (int, optional): beam width (default: 1)
min/maxlen (int, optional): the length of the generated output will min/maxlen (int, optional): the length of the generated output will
...@@ -90,11 +91,14 @@ class SequenceGenerator(object): ...@@ -90,11 +91,14 @@ class SequenceGenerator(object):
cuda=False, timer=None, prefix_size=0, cuda=False, timer=None, prefix_size=0,
): ):
"""Iterate over a batched dataset and yield individual translations. """Iterate over a batched dataset and yield individual translations.
Args: Args:
maxlen_a/b: generate sequences of maximum length ax + b, maxlen_a/b (int, optional): generate sequences of maximum length
where x is the source sentence length. ``ax + b``, where ``x`` is the source sentence length.
cuda: use GPU for generation cuda (bool, optional): use GPU for generation
timer: StopwatchMeter for timing generations. timer (StopwatchMeter, optional): time generations
prefix_size (int, optional): prefill the generation with the gold
prefix up to this length.
""" """
if maxlen_b is None: if maxlen_b is None:
maxlen_b = self.maxlen maxlen_b = self.maxlen
...@@ -132,12 +136,13 @@ class SequenceGenerator(object): ...@@ -132,12 +136,13 @@ class SequenceGenerator(object):
"""Generate a batch of translations. """Generate a batch of translations.
Args: Args:
encoder_input: dictionary containing the inputs to encoder_input (dict): dictionary containing the inputs to
model.encoder.forward *model.encoder.forward*.
beam_size: int overriding the beam size. defaults to beam_size (int, optional): overriding the beam size
self.beam_size (default: *self.beam_size*).
max_len: maximum length of the generated sequence max_len (int, optional): maximum length of the generated sequence
prefix_tokens: force decoder to begin with these tokens prefix_tokens (LongTensor, optional): force decoder to begin with
these tokens
""" """
with torch.no_grad(): with torch.no_grad():
return self._generate(encoder_input, beam_size, maxlen, prefix_tokens) return self._generate(encoder_input, beam_size, maxlen, prefix_tokens)
......
...@@ -87,4 +87,3 @@ class SequenceScorer(object): ...@@ -87,4 +87,3 @@ class SequenceScorer(object):
index=sample['target'].data.unsqueeze(-1), index=sample['target'].data.unsqueeze(-1),
) )
return avg_probs.squeeze(2), avg_attn return avg_probs.squeeze(2), avg_attn
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import argparse import argparse
......
...@@ -61,29 +61,32 @@ class FairseqTask(object): ...@@ -61,29 +61,32 @@ class FairseqTask(object):
def get_batch_iterator( def get_batch_iterator(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None, self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1, ignore_invalid_inputs=False, required_batch_size_multiple=1,
seed=1, num_shards=1, shard_id=0, seed=1, num_shards=1, shard_id=0, num_workers=0,
): ):
""" """
Get an iterator that yields batches of data from the given dataset. Get an iterator that yields batches of data from the given dataset.
Args: Args:
dataset (~fairseq.data.FairseqDataset): dataset to batch dataset (~fairseq.data.FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch. max_tokens (int, optional): max number of tokens in each batch
Default: ``None`` (default: None).
max_sentences (int, optional): max number of sentences in each max_sentences (int, optional): max number of sentences in each
batch. Default: ``None`` batch (default: None).
max_positions (optional): max sentence length supported by the max_positions (optional): max sentence length supported by the
model. Default: ``None`` model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long. Default: ``False`` sentences that are too long (default: False).
required_batch_size_multiple (int, optional): require batch size to required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1`` be a multiple of N (default: 1).
seed (int, optional): seed for random number generator for seed (int, optional): seed for random number generator for
reproducibility. Default: ``1`` reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N num_shards (int, optional): shard the data iterator into N
shards. Default: ``1`` shards (default: 1).
shard_id (int, optional): which shard of the data iterator to shard_id (int, optional): which shard of the data iterator to
return. Default: ``0`` return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
Returns: Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
...@@ -114,6 +117,7 @@ class FairseqTask(object): ...@@ -114,6 +117,7 @@ class FairseqTask(object):
seed=seed, seed=seed,
num_shards=num_shards, num_shards=num_shards,
shard_id=shard_id, shard_id=shard_id,
num_workers=num_workers,
) )
def build_model(self, args): def build_model(self, args):
......
...@@ -10,9 +10,15 @@ import numpy as np ...@@ -10,9 +10,15 @@ import numpy as np
import os import os
from fairseq.data import ( from fairseq.data import (
ConcatDataset, Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset, ConcatDataset,
MonolingualDataset, TokenBlockDataset, TruncatedDictionary, Dictionary,
IndexedCachedDataset, IndexedDataset) IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
MonolingualDataset,
TokenBlockDataset,
TruncatedDictionary,
)
from . import FairseqTask, register_task from . import FairseqTask, register_task
...@@ -60,6 +66,8 @@ class LanguageModelingTask(FairseqTask): ...@@ -60,6 +66,8 @@ class LanguageModelingTask(FairseqTask):
'If set to "eos", includes only one sentence per sample.') 'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=1024, type=int, parser.add_argument('--tokens-per-sample', default=1024, type=int,
help='max number of tokens per sample for LM dataset') help='max number of tokens per sample for LM dataset')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', default=False, action='store_true', parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset') help='load raw text dataset')
parser.add_argument('--output-dictionary-size', default=-1, type=int, parser.add_argument('--output-dictionary-size', default=-1, type=int,
...@@ -139,7 +147,10 @@ class LanguageModelingTask(FairseqTask): ...@@ -139,7 +147,10 @@ class LanguageModelingTask(FairseqTask):
if self.args.raw_text and IndexedRawTextDataset.exists(path): if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary) ds = IndexedRawTextDataset(path, self.dictionary)
elif not self.args.raw_text and IndexedDataset.exists(path): elif not self.args.raw_text and IndexedDataset.exists(path):
if self.args.lazy_load:
ds = IndexedDataset(path, fix_lua_indexing=True) ds = IndexedDataset(path, fix_lua_indexing=True)
else:
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
else: else:
if k > 0: if k > 0:
break break
...@@ -148,9 +159,11 @@ class LanguageModelingTask(FairseqTask): ...@@ -148,9 +159,11 @@ class LanguageModelingTask(FairseqTask):
loaded_datasets.append( loaded_datasets.append(
TokenBlockDataset( TokenBlockDataset(
ds, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(), ds, ds.sizes, self.args.tokens_per_sample,
pad=self.dictionary.pad(), eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode, include_targets=True, break_mode=self.args.sample_break_mode, include_targets=True,
)) )
)
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
......
...@@ -12,8 +12,12 @@ import torch ...@@ -12,8 +12,12 @@ import torch
from fairseq import options from fairseq import options
from fairseq.data import ( from fairseq.data import (
Dictionary, LanguagePairDataset, IndexedInMemoryDataset, Dictionary,
IndexedRawTextDataset, RoundRobinZipDatasets, IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
RoundRobinZipDatasets,
) )
from fairseq.models import FairseqMultiModel from fairseq.models import FairseqMultiModel
...@@ -55,6 +59,8 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -55,6 +59,8 @@ class MultilingualTranslationTask(FairseqTask):
help='source language (only needed for inference)') help='source language (only needed for inference)')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language (only needed for inference)') help='target language (only needed for inference)')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true', parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset') help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
...@@ -112,15 +118,18 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -112,15 +118,18 @@ class MultilingualTranslationTask(FairseqTask):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename): if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True return True
elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename): elif not self.args.raw_text and IndexedDataset.exists(filename):
return True return True
return False return False
def indexed_dataset(path, dictionary): def indexed_dataset(path, dictionary):
if self.args.raw_text: if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary) return IndexedRawTextDataset(path, dictionary)
elif IndexedInMemoryDataset.exists(path): elif IndexedDataset.exists(path):
return IndexedInMemoryDataset(path, fix_lua_indexing=True) if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None return None
def sort_lang_pair(lang_pair): def sort_lang_pair(lang_pair):
......
...@@ -6,13 +6,17 @@ ...@@ -6,13 +6,17 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import itertools import itertools
import numpy as np
import os import os
from fairseq import options, utils from fairseq import options, utils
from fairseq.data import ( from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, ConcatDataset, ConcatDataset,
IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset data_utils,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
) )
from . import FairseqTask, register_task from . import FairseqTask, register_task
...@@ -49,6 +53,8 @@ class TranslationTask(FairseqTask): ...@@ -49,6 +53,8 @@ class TranslationTask(FairseqTask):
help='source language') help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language') help='target language')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true', parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset') help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
...@@ -132,6 +138,9 @@ class TranslationTask(FairseqTask): ...@@ -132,6 +138,9 @@ class TranslationTask(FairseqTask):
if self.args.raw_text: if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary) return IndexedRawTextDataset(path, dictionary)
elif IndexedDataset.exists(path): elif IndexedDataset.exists(path):
if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True) return IndexedCachedDataset(path, fix_lua_indexing=True)
return None return None
......
...@@ -6,10 +6,11 @@ ...@@ -6,10 +6,11 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import Counter from collections import Counter
import os, re from multiprocessing import Pool
import os
import re
import torch import torch
from multiprocessing import Pool
SPACE_NORMALIZER = re.compile(r"\s+") SPACE_NORMALIZER = re.compile(r"\s+")
...@@ -29,6 +30,7 @@ def safe_readline(f): ...@@ -29,6 +30,7 @@ def safe_readline(f):
pos -= 1 pos -= 1
f.seek(pos) # search where this character begins f.seek(pos) # search where this character begins
class Tokenizer: class Tokenizer:
@staticmethod @staticmethod
...@@ -73,14 +75,17 @@ class Tokenizer: ...@@ -73,14 +75,17 @@ class Tokenizer:
merge_result(Tokenizer.add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word)) merge_result(Tokenizer.add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))
@staticmethod @staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line, def binarize(
append_eos=True, reverse_order=False, filename, dict, consumer, tokenize=tokenize_line, append_eos=True,
offset=0, end=-1): reverse_order=False, offset=0, end=-1,
):
nseq, ntok = 0, 0 nseq, ntok = 0, 0
replaced = Counter() replaced = Counter()
def replaced_consumer(word, idx): def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word: if idx == dict.unk_index and word != dict.unk_word:
replaced.update([word]) replaced.update([word])
with open(filename, 'r') as f: with open(filename, 'r') as f:
f.seek(offset) f.seek(offset)
# next(f) breaks f.tell(), hence readline() must be used # next(f) breaks f.tell(), hence readline() must be used
......
...@@ -30,22 +30,23 @@ class Trainer(object): ...@@ -30,22 +30,23 @@ class Trainer(object):
""" """
def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None): def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None):
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
self.args = args self.args = args
self.task = task self.task = task
# copy model and criterion to current device # copy model and criterion to current device
self.criterion = criterion.cuda() self.criterion = criterion
self._model = model
self.cuda = torch.cuda.is_available() and not args.cpu
if args.fp16: if args.fp16:
self._model = model.half().cuda() self._model = self._model.half()
else: if self.cuda:
self._model = model.cuda() self.criterion = self.criterion.cuda()
self._model = self._model.cuda()
self._dummy_batch = dummy_batch self._dummy_batch = dummy_batch
self._oom_batch = oom_batch self._oom_batch = oom_batch
self._lr_scheduler = None
self._num_updates = 0 self._num_updates = 0
self._optim_history = None self._optim_history = None
self._optimizer = None self._optimizer = None
...@@ -71,7 +72,6 @@ class Trainer(object): ...@@ -71,7 +72,6 @@ class Trainer(object):
self.meters['wall'] = TimeMeter() # wall time in seconds self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
@property @property
def model(self): def model(self):
if self._wrapped_model is None: if self._wrapped_model is None:
...@@ -89,19 +89,26 @@ class Trainer(object): ...@@ -89,19 +89,26 @@ class Trainer(object):
self._build_optimizer() self._build_optimizer()
return self._optimizer return self._optimizer
@property
def lr_scheduler(self):
if self._lr_scheduler is None:
self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
return self._lr_scheduler
def _build_optimizer(self): def _build_optimizer(self):
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
if self.args.fp16: if self.args.fp16:
if torch.cuda.get_device_capability(0)[0] < 7: if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16, ' print('| WARNING: your device does NOT support faster training with --fp16, '
'please switch to FP32 which is likely to be faster') 'please switch to FP32 which is likely to be faster')
params = list(filter(lambda p: p.requires_grad, self.model.parameters())) if self.args.memory_efficient_fp16:
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params)
else:
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
else: else:
if torch.cuda.get_device_capability(0)[0] >= 7: if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16') print('| NOTICE: your device may support faster training with --fp16')
self._optimizer = optim.build_optimizer(self.args, self.model.parameters()) self._optimizer = optim.build_optimizer(self.args, params)
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer)
def save_checkpoint(self, filename, extra_state): def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file.""" """Save all training state in a checkpoint file."""
...@@ -151,6 +158,7 @@ class Trainer(object): ...@@ -151,6 +158,7 @@ class Trainer(object):
# reproducible results when resuming from checkpoints # reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates() seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed) torch.manual_seed(seed)
if self.cuda:
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
self.model.train() self.model.train()
...@@ -296,6 +304,7 @@ class Trainer(object): ...@@ -296,6 +304,7 @@ class Trainer(object):
for p in self.model.parameters(): for p in self.model.parameters():
if p.grad is not None: if p.grad is not None:
del p.grad # free some memory del p.grad # free some memory
if self.cuda:
torch.cuda.empty_cache() torch.cuda.empty_cache()
return self.valid_step(sample, raise_oom=True) return self.valid_step(sample, raise_oom=True)
else: else:
...@@ -377,4 +386,6 @@ class Trainer(object): ...@@ -377,4 +386,6 @@ class Trainer(object):
def _prepare_sample(self, sample): def _prepare_sample(self, sample):
if sample is None or len(sample) == 0: if sample is None or len(sample) == 0:
return None return None
return utils.move_to_cuda(sample) if self.cuda:
sample = utils.move_to_cuda(sample)
return sample
...@@ -378,6 +378,14 @@ def item(tensor): ...@@ -378,6 +378,14 @@ def item(tensor):
return tensor return tensor
def clip_grad_norm_(tensor, max_norm):
grad_norm = item(torch.norm(tensor))
if grad_norm > max_norm > 0:
clip_coef = max_norm / (grad_norm + 1e-6)
tensor.mul_(clip_coef)
return grad_norm
def fill_with_neg_inf(t): def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf.""" """FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t) return t.float().fill_(float('-inf')).type_as(t)
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.fb.rendezvous.zeus # noqa: F401
from fairseq import options
from train import main
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)
...@@ -11,7 +11,7 @@ Translate pre-processed data with a trained model. ...@@ -11,7 +11,7 @@ Translate pre-processed data with a trained model.
import torch import torch
from fairseq import bleu, data, options, progress_bar, tasks, tokenizer, utils from fairseq import bleu, options, progress_bar, tasks, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer from fairseq.sequence_scorer import SequenceScorer
...@@ -41,7 +41,9 @@ def main(args): ...@@ -41,7 +41,9 @@ def main(args):
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(args.path)) print('| loading model(s) from {}'.format(args.path))
models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides)) models, _model_args = utils.load_ensemble_for_inference(
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
)
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
...@@ -69,6 +71,7 @@ def main(args): ...@@ -69,6 +71,7 @@ def main(args):
required_batch_size_multiple=8, required_batch_size_multiple=8,
num_shards=args.num_shards, num_shards=args.num_shards,
shard_id=args.shard_id, shard_id=args.shard_id,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
# Initialize generator # Initialize generator
......
...@@ -75,8 +75,9 @@ def main(args): ...@@ -75,8 +75,9 @@ def main(args):
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(args.path)) print('| loading model(s) from {}'.format(args.path))
model_paths = args.path.split(':') models, _model_args = utils.load_ensemble_for_inference(
models, model_args = utils.load_ensemble_for_inference(model_paths, task, model_arg_overrides=eval(args.model_overrides)) args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
)
# Set dictionaries # Set dictionaries
tgt_dict = task.target_dictionary tgt_dict = task.target_dictionary
......
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import signal
import torch
from fairseq import distributed_utils, options
from train import main as single_process_main
def main(args):
if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
print('| WARNING: when using --update-freq on a single machine, you '
'will get better performance with --ddp-backend=no_c10d')
mp = torch.multiprocessing.get_context('spawn')
# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Train with multiprocessing.
procs = []
base_rank = args.distributed_rank
for i in range(torch.cuda.device_count()):
args.distributed_rank = base_rank + i
args.device_id = i
procs.append(mp.Process(target=run, args=(args, error_queue, ), daemon=True))
procs[i].start()
error_handler.add_child(procs[i].pid)
for p in procs:
p.join()
def run(args, error_queue):
try:
args.distributed_rank = distributed_utils.distributed_init(args)
single_process_main(args)
except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback
error_queue.put((args.distributed_rank, traceback.format_exc()))
class ErrorHandler(object):
"""A class that listens for exceptions in children processes and propagates
the tracebacks to the parent process."""
def __init__(self, error_queue):
import signal
import threading
self.error_queue = error_queue
self.children_pids = []
self.error_thread = threading.Thread(target=self.error_listener, daemon=True)
self.error_thread.start()
signal.signal(signal.SIGUSR1, self.signal_handler)
def add_child(self, pid):
self.children_pids.append(pid)
def error_listener(self):
(rank, original_trace) = self.error_queue.get()
self.error_queue.put((rank, original_trace))
os.kill(os.getpid(), signal.SIGUSR1)
def signal_handler(self, signalnum, stackframe):
for pid in self.children_pids:
os.kill(pid, signal.SIGINT) # kill children processes
(rank, original_trace) = self.error_queue.get()
msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n"
msg += original_trace
raise Exception(msg)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)
...@@ -18,7 +18,7 @@ import shutil ...@@ -18,7 +18,7 @@ import shutil
from fairseq.data import indexed_dataset, dictionary from fairseq.data import indexed_dataset, dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line from fairseq.tokenizer import Tokenizer, tokenize_line
from multiprocessing import Pool, Manager, Process from multiprocessing import Pool
def get_parser(): def get_parser():
......
...@@ -50,6 +50,14 @@ class TestTranslation(unittest.TestCase): ...@@ -50,6 +50,14 @@ class TestTranslation(unittest.TestCase):
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--fp16']) train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--fp16'])
generate_main(data_dir) generate_main(data_dir)
def test_memory_efficient_fp16(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_memory_efficient_fp16') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--memory-efficient-fp16'])
generate_main(data_dir)
def test_update_freq(self): def test_update_freq(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_update_freq') as data_dir: with tempfile.TemporaryDirectory('test_update_freq') as data_dir:
...@@ -68,8 +76,7 @@ class TestTranslation(unittest.TestCase): ...@@ -68,8 +76,7 @@ class TestTranslation(unittest.TestCase):
data_dir, 'fconv_iwslt_de_en', ['--max-target-positions', '5'], data_dir, 'fconv_iwslt_de_en', ['--max-target-positions', '5'],
) )
self.assertTrue( self.assertTrue(
'skip this example with --skip-invalid-size-inputs-valid-test' \ 'skip this example with --skip-invalid-size-inputs-valid-test' in str(context.exception)
in str(context.exception)
) )
train_translation_model( train_translation_model(
data_dir, 'fconv_iwslt_de_en', data_dir, 'fconv_iwslt_de_en',
......
...@@ -12,10 +12,6 @@ import os ...@@ -12,10 +12,6 @@ import os
import tempfile import tempfile
import unittest import unittest
import torch
from fairseq import options
from . import test_binaries from . import test_binaries
...@@ -79,6 +75,12 @@ class TestReproducibility(unittest.TestCase): ...@@ -79,6 +75,12 @@ class TestReproducibility(unittest.TestCase):
'--fp16-init-scale', '4096', '--fp16-init-scale', '4096',
]) ])
def test_reproducibility_memory_efficient_fp16(self):
self._test_reproducibility('test_reproducibility_memory_efficient_fp16', [
'--memory-efficient-fp16',
'--fp16-init-scale', '4096',
])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -39,8 +39,10 @@ def mock_dict(): ...@@ -39,8 +39,10 @@ def mock_dict():
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch): def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
tokens = torch.LongTensor(list(range(epoch_size))) tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
tokens_ds = data.TokenBlockDataset(tokens, sizes=[len(tokens)], block_size=1, pad=0, eos=1, include_targets=False) tokens_ds = data.TokenBlockDataset(
tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False,
)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False) dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
epoch_itr = data.EpochBatchIterator( epoch_itr = data.EpochBatchIterator(
...@@ -64,7 +66,6 @@ class TestLoadCheckpoint(unittest.TestCase): ...@@ -64,7 +66,6 @@ class TestLoadCheckpoint(unittest.TestCase):
self.applied_patches = [patch(p, d) for p, d in self.patches.items()] self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
[p.start() for p in self.applied_patches] [p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self): def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50) trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
......
...@@ -28,8 +28,7 @@ def main(args): ...@@ -28,8 +28,7 @@ def main(args):
args.max_tokens = 6000 args.max_tokens = 6000
print(args) print(args)
if not torch.cuda.is_available(): if torch.cuda.is_available() and not args.cpu:
raise NotImplementedError('Training on CPU is not supported')
torch.cuda.set_device(args.device_id) torch.cuda.set_device(args.device_id)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
...@@ -74,6 +73,7 @@ def main(args): ...@@ -74,6 +73,7 @@ def main(args):
seed=args.seed, seed=args.seed,
num_shards=args.distributed_world_size, num_shards=args.distributed_world_size,
shard_id=args.distributed_rank, shard_id=args.distributed_rank,
num_workers=args.num_workers,
) )
# Load the latest checkpoint if one is available # Load the latest checkpoint if one is available
...@@ -211,6 +211,7 @@ def validate(args, trainer, task, epoch_itr, subsets): ...@@ -211,6 +211,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
seed=args.seed, seed=args.seed,
num_shards=args.distributed_world_size, num_shards=args.distributed_world_size,
shard_id=args.distributed_rank, shard_id=args.distributed_rank,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
progress = progress_bar.build_progress_bar( progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch, args, itr, epoch_itr.epoch,
...@@ -306,6 +307,14 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss): ...@@ -306,6 +307,14 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
# remove old checkpoints; checkpoints are sorted in descending order # remove old checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt') checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
for old_chk in checkpoints[args.keep_interval_updates:]: for old_chk in checkpoints[args.keep_interval_updates:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint\d+\.pt')
for old_chk in checkpoints[args.keep_last_epochs:]:
if os.path.lexists(old_chk):
os.remove(old_chk) os.remove(old_chk)
...@@ -346,23 +355,50 @@ def load_dataset_splits(task, splits): ...@@ -346,23 +355,50 @@ def load_dataset_splits(task, splits):
raise e raise e
def distributed_main(i, args):
import socket
args.device_id = i
if args.distributed_rank is None: # torch.multiprocessing.spawn
args.distributed_rank = i
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
main(args)
if __name__ == '__main__': if __name__ == '__main__':
parser = options.get_training_parser() parser = options.get_training_parser()
args = options.parse_args_and_arch(parser) args = options.parse_args_and_arch(parser)
if args.distributed_port > 0 or args.distributed_init_method is not None: if args.distributed_init_method is None:
from distributed_train import main as distributed_main distributed_utils.infer_init_method(args)
distributed_main(args) if args.distributed_init_method is not None:
# distributed training
distributed_main(args.device_id, args)
args.distributed_rank = distributed_utils.distributed_init(args)
main(args)
elif args.distributed_world_size > 1: elif args.distributed_world_size > 1:
from multiprocessing_train import main as multiprocessing_main # fallback for single node with multiple GPUs
# Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count()
port = random.randint(10000, 20000) port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port) args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_port = port + 1 args.distributed_rank = None # set based on device id
print(
multiprocessing_main(args) '''| NOTE: you may get better performance with:
python -m torch.distributed.launch --nproc_per_node {ngpu} train.py {no_c10d}(...)
'''.format(
ngpu=args.distributed_world_size,
no_c10d=(
'--ddp-backend=no_c10d ' if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d'
else ''
),
)
)
torch.multiprocessing.spawn(
fn=distributed_main,
args=(args, ),
nprocs=args.distributed_world_size,
)
else: else:
# single GPU training
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment