Unverified Commit 388c520b authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

0.4.0 -> 0.5.0

Changelog:
- 97b58b46: add Transformer model from Vaswani et al. (2017)
- b2374e52: faster Transformer inference with improved caching
- 2d27ae08: simulate large mini-batch training with delayed updates (`--update-freq`)
- 7ee1d284: add FP16 training support (`--fp16`)
- 2a84f46b: faster inference by removing completed sentences from the batch
- 663fd806: batched interactive generation
- 4c2ef2de: add language modeling / gated convolutional model from Dauphin et al. (2017)
- b59815bc: add Hierarchical Neural Story Generation model from Fan et al. (2018)
- ff68a9ef: add FairseqTask to modularize task definitions (e.g., translation, language modeling)
parents ec0031df 5383b5db
...@@ -12,7 +12,7 @@ from itertools import zip_longest ...@@ -12,7 +12,7 @@ from itertools import zip_longest
import os import os
import shutil import shutil
from fairseq import dictionary, indexed_dataset from fairseq.data import indexed_dataset, dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line from fairseq.tokenizer import Tokenizer, tokenize_line
...@@ -38,6 +38,8 @@ def get_parser(): ...@@ -38,6 +38,8 @@ def get_parser():
help='output format (optional)') help='output format (optional)')
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary') parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
parser.add_argument('--only-source', action='store_true', help='Only process the source language') parser.add_argument('--only-source', action='store_true', help='Only process the source language')
parser.add_argument('--padding-factor', metavar='N', default=8, type=int,
help='Pad dictionary size to be multiple of N')
return parser return parser
...@@ -46,65 +48,91 @@ def main(args): ...@@ -46,65 +48,91 @@ def main(args):
os.makedirs(args.destdir, exist_ok=True) os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source target = not args.only_source
def build_dictionary(filenames):
d = dictionary.Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenize_line)
return d
def train_path(lang):
return '{}{}'.format(args.trainpref, ('.' + lang) if lang else '')
def file_name(prefix, lang):
fname = prefix
if lang is not None:
fname += f'.{lang}'
return fname
def dest_path(prefix, lang):
return os.path.join(args.destdir, file_name(prefix, lang))
def dict_path(lang):
return dest_path('dict', lang) + '.txt'
def dataset_dest_path(output_prefix, lang, extension):
base = f'{args.destdir}/{output_prefix}'
lang_part = f'.{args.source_lang}-{args.target_lang}.{lang}' if lang is not None else ''
return f'{base}{lang_part}.{extension}'
if args.joined_dictionary: if args.joined_dictionary:
assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary' assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary' assert not args.tgtdict, 'cannot combine --tgtdict and --joined-dictionary'
src_dict = dictionary.Dictionary() src_dict = build_dictionary(set([
for lang in [args.source_lang, args.target_lang]: train_path(lang)
Tokenizer.add_file_to_dictionary( for lang in [args.source_lang, args.target_lang]
filename='{}.{}'.format(args.trainpref, lang), ]))
dict=src_dict,
tokenize=tokenize_line,
)
src_dict.finalize()
tgt_dict = src_dict tgt_dict = src_dict
else: else:
if args.srcdict: if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict) src_dict = dictionary.Dictionary.load(args.srcdict)
else: else:
assert args.trainpref, "--trainpref must be set if --srcdict is not specified" assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.source_lang)) src_dict = build_dictionary([train_path(args.source_lang)])
if target: if target:
if args.tgtdict: if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict) tgt_dict = dictionary.Dictionary.load(args.tgtdict)
else: else:
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified" assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.target_lang)) tgt_dict = build_dictionary([train_path(args.target_lang)])
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)), src_dict.finalize(
threshold=args.thresholdsrc, nwords=args.nwordssrc) threshold=args.thresholdsrc,
nwords=args.nwordssrc,
padding_factor=args.padding_factor,
)
src_dict.save(dict_path(args.source_lang))
if target: if target:
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)), if not args.joined_dictionary:
threshold=args.thresholdtgt, nwords=args.nwordstgt) tgt_dict.finalize(
threshold=args.thresholdtgt,
nwords=args.nwordstgt,
padding_factor=args.padding_factor,
)
tgt_dict.save(dict_path(args.target_lang))
def make_binary_dataset(input_prefix, output_prefix, lang): def make_binary_dataset(input_prefix, output_prefix, lang):
dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(lang))) dict = dictionary.Dictionary.load(dict_path(lang))
print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1)) print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1))
ds = indexed_dataset.IndexedDatasetBuilder( ds = indexed_dataset.IndexedDatasetBuilder(dataset_dest_path(output_prefix, lang, 'bin'))
'{}/{}.{}-{}.{}.bin'.format(args.destdir, output_prefix, args.source_lang,
args.target_lang, lang)
)
def consumer(tensor): def consumer(tensor):
ds.add_item(tensor) ds.add_item(tensor)
input_file = '{}.{}'.format(input_prefix, lang) input_file = '{}{}'.format(input_prefix, ('.' + lang) if lang is not None else '')
res = Tokenizer.binarize(input_file, dict, consumer) res = Tokenizer.binarize(input_file, dict, consumer)
print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format( print('| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}'.format(
lang, input_file, res['nseq'], res['ntok'], lang, input_file, res['nseq'], res['ntok'],
100 * res['nunk'] / res['ntok'], dict.unk_word)) 100 * res['nunk'] / res['ntok'], dict.unk_word))
ds.finalize('{}/{}.{}-{}.{}.idx'.format( ds.finalize(dataset_dest_path(output_prefix, lang, 'idx'))
args.destdir, output_prefix,
args.source_lang, args.target_lang, lang))
def make_dataset(input_prefix, output_prefix, lang, output_format='binary'): def make_dataset(input_prefix, output_prefix, lang, output_format='binary'):
if output_format == 'binary': if output_format == 'binary':
make_binary_dataset(input_prefix, output_prefix, lang) make_binary_dataset(input_prefix, output_prefix, lang)
elif output_format == 'raw': elif output_format == 'raw':
# Copy original text file to destination folder # Copy original text file to destination folder
output_text_file = os.path.join(args.destdir, '{}.{}'.format(output_prefix, lang)) output_text_file = dest_path(output_prefix, lang)
shutil.copyfile('{}.{}'.format(input_prefix, lang), output_text_file) shutil.copyfile(file_name(input_prefix, lang), output_text_file)
def make_all(args, make_dataset, lang): def make_all(args, make_dataset, lang):
if args.trainpref: if args.trainpref:
...@@ -126,10 +154,10 @@ def main(args): ...@@ -126,10 +154,10 @@ def main(args):
if args.alignfile: if args.alignfile:
assert args.trainpref, "--trainpref must be set if --alignfile is specified" assert args.trainpref, "--trainpref must be set if --alignfile is specified"
src_file_name = '{}.{}'.format(args.trainpref, args.source_lang) src_file_name = train_path(args.source_lang)
tgt_file_name = '{}.{}'.format(args.trainpref, args.target_lang) tgt_file_name = train_path(args.target_lang)
src_dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang))) src_dict = dictionary.Dictionary.load(dict_path(args.source_lang))
tgt_dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang))) tgt_dict = dictionary.Dictionary.load(dict_path(args.target_lang))
freq_map = {} freq_map = {}
with open(args.alignfile, 'r') as align_file: with open(args.alignfile, 'r') as align_file:
with open(src_file_name, 'r') as src_file: with open(src_file_name, 'r') as src_file:
......
...@@ -11,7 +11,8 @@ import argparse ...@@ -11,7 +11,8 @@ import argparse
import os import os
import sys import sys
from fairseq import bleu, dictionary, tokenizer from fairseq import bleu, tokenizer
from fairseq.data import dictionary
def main(): def main():
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
import argparse import argparse
import collections import collections
import torch import torch
import os
import re
def average_checkpoints(inputs): def average_checkpoints(inputs):
...@@ -44,7 +46,10 @@ def average_checkpoints(inputs): ...@@ -44,7 +46,10 @@ def average_checkpoints(inputs):
for k in params_keys: for k in params_keys:
if k not in params_dict: if k not in params_dict:
params_dict[k] = [] params_dict[k] = []
params_dict[k].append(model_params[k]) p = model_params[k]
if isinstance(p, torch.HalfTensor):
p = p.float()
params_dict[k].append(p)
averaged_params = collections.OrderedDict() averaged_params = collections.OrderedDict()
# v should be a list of torch Tensor. # v should be a list of torch Tensor.
...@@ -57,10 +62,29 @@ def average_checkpoints(inputs): ...@@ -57,10 +62,29 @@ def average_checkpoints(inputs):
return new_state return new_state
def last_n_checkpoints(paths, n, update_based):
assert len(paths) == 1
path = paths[0]
if update_based:
pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
else:
pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
files = os.listdir(path)
entries = []
for f in files:
m = pt_regexp.fullmatch(f)
if m is not None:
entries.append((int(m.group(1)), m.group(0)))
if len(entries) < n:
raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Tool to average the params of input checkpoints to ' description='Tool to average the params of input checkpoints to '
'produce a new checkpoint', 'produce a new checkpoint',
) )
parser.add_argument( parser.add_argument(
...@@ -74,11 +98,36 @@ def main(): ...@@ -74,11 +98,36 @@ def main():
required=True, required=True,
metavar='FILE', metavar='FILE',
help='Write the new checkpoint containing the averaged weights to this ' help='Write the new checkpoint containing the averaged weights to this '
'path.', 'path.',
)
num_group = parser.add_mutually_exclusive_group()
num_group.add_argument(
'--num-epoch-checkpoints',
type=int,
help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
'and average last this many of them.',
)
num_group.add_argument(
'--num-update-checkpoints',
type=int,
help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
'and average last this many of them.',
) )
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
num = None
is_update_based = False
if args.num_update_checkpoints is not None:
num = args.num_update_checkpoints
is_update_based = True
elif args.num_epoch_checkpoints is not None:
num = args.num_epoch_checkpoints
if num is not None:
args.inputs = last_n_checkpoints(args.inputs, num, is_update_based)
print('averaging checkpoints: ', args.inputs)
new_state = average_checkpoints(args.inputs) new_state = average_checkpoints(args.inputs)
torch.save(new_state, args.output) torch.save(new_state, args.output)
print('Finished writing averaged checkpoint to {}.'.format(args.output)) print('Finished writing averaged checkpoint to {}.'.format(args.output))
......
...@@ -35,7 +35,7 @@ bleu = Extension( ...@@ -35,7 +35,7 @@ bleu = Extension(
setup( setup(
name='fairseq', name='fairseq',
version='0.4.0', version='0.5.0',
description='Facebook AI Research Sequence-to-Sequence Toolkit', description='Facebook AI Research Sequence-to-Sequence Toolkit',
long_description=readme, long_description=readme,
license=license, license=license,
......
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import collections
import itertools
import os
import math
import torch
from fairseq import criterions, data, models, options, progress_bar
from fairseq.meters import AverageMeter, StopwatchMeter
from fairseq.trainer import Trainer
def main(args):
print(args)
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
torch.cuda.set_device(args.device_id)
torch.manual_seed(args.seed)
# Load dataset
splits = ['train', 'valid']
if data.has_binary_files(args.data, splits):
dataset = data.load_dataset(
args.data, splits, args.source_lang, args.target_lang)
else:
dataset = data.load_raw_text_dataset(
args.data, splits, args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in splits:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
# Build model and criterion
model = models.build_model(args, dataset.src_dict, dataset.dst_dict)
criterion = criterions.build_criterion(args, dataset.src_dict, dataset.dst_dict)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {}'.format(sum(p.data.numel() for p in model.parameters())))
# Build trainer
trainer = Trainer(args, model, criterion)
print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens,
args.max_sentences,
))
# Load the latest checkpoint if one is available
os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
epoch = extra_state['epoch']
batch_offset = extra_state['batch_offset']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
if batch_offset == 0:
trainer.lr_step(epoch)
epoch += 1
else:
epoch, batch_offset = 1, 0
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr()
train_meter = StopwatchMeter()
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch
train(args, trainer, dataset, epoch, batch_offset)
# evaluate on validate set
if epoch % args.validate_interval == 0:
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, trainer, dataset, subset, epoch)
if k == 0:
# only use first validation loss to update the learning schedule
lr = trainer.lr_step(epoch, val_loss)
# save checkpoint
if not args.no_save:
save_checkpoint(trainer, args, epoch, 0, val_loss)
else:
lr = trainer.lr_step(epoch)
epoch += 1
batch_offset = 0
if trainer.get_num_updates() >= max_update:
break
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
def train(args, trainer, dataset, epoch, batch_offset):
"""Train the model for one epoch."""
# Set seed based on args.seed and the epoch number so that we get
# reproducible results when resuming from checkpoints
seed = args.seed + epoch
torch.manual_seed(seed)
# The max number of positions can be different for train and valid
# e.g., RNNs may support more positions at test time than seen in training
max_positions_train = (
min(args.max_source_positions, trainer.get_model().max_encoder_positions()),
min(args.max_target_positions, trainer.get_model().max_decoder_positions())
)
# Initialize dataloader, starting at batch_offset
itr = dataset.train_dataloader(
args.train_subset,
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=max_positions_train,
seed=seed,
epoch=epoch,
sample_without_replacement=args.sample_without_replacement,
sort_by_source_size=(epoch <= args.curriculum),
shard_id=args.distributed_rank,
num_shards=args.distributed_world_size,
)
progress = progress_bar.build_progress_bar(args, itr, epoch, no_progress_bar='simple')
itr = itertools.islice(progress, batch_offset, None)
# reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
extra_meters = collections.defaultdict(lambda: AverageMeter())
max_update = args.max_update or math.inf
for i, sample in enumerate(itr, start=batch_offset):
log_output = trainer.train_step(sample)
# log mid-epoch stats
stats = get_training_stats(trainer)
for k, v in log_output.items():
if k in ['loss', 'nll_loss']:
continue # these are already logged above
if 'loss' in k:
extra_meters[k].update(v, log_output['sample_size'])
else:
extra_meters[k].update(v)
stats[k] = extra_meters[k].avg
progress.log(stats)
# save mid-epoch checkpoints
if i == batch_offset:
# ignore the first mini-batch in words-per-second calculation
trainer.get_meter('wps').reset()
# save mid-epoch checkpoints
num_updates = trainer.get_num_updates()
if args.save_interval > 0 and num_updates > 0 and num_updates % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, i + 1)
if num_updates >= max_update:
break
# log end-of-epoch stats
stats = get_training_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
def get_training_stats(trainer):
stats = collections.OrderedDict()
stats['loss'] = '{:.3f}'.format(trainer.get_meter('train_loss').avg)
if trainer.get_meter('train_nll_loss').count > 0:
nll_loss = trainer.get_meter('train_nll_loss').avg
stats['nll_loss'] = '{:.3f}'.format(nll_loss)
else:
nll_loss = trainer.get_meter('train_loss').avg
stats['ppl'] = get_perplexity(nll_loss)
stats['wps'] = round(trainer.get_meter('wps').avg)
stats['ups'] = '{:.1f}'.format(trainer.get_meter('ups').avg)
stats['wpb'] = round(trainer.get_meter('wpb').avg)
stats['bsz'] = round(trainer.get_meter('bsz').avg)
stats['num_updates'] = trainer.get_num_updates()
stats['lr'] = trainer.get_lr()
stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg)
stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg)
stats['oom'] = trainer.get_meter('oom').avg
return stats
def validate(args, trainer, dataset, subset, epoch):
"""Evaluate the model on the validation set and return the average loss."""
# Initialize dataloader
max_positions_valid = (
trainer.get_model().max_encoder_positions(),
trainer.get_model().max_decoder_positions(),
)
itr = dataset.eval_dataloader(
subset,
max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid,
max_positions=max_positions_valid,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
descending=True, # largest batch first to warm the caching allocator
shard_id=args.distributed_rank,
num_shards=args.distributed_world_size,
)
progress = progress_bar.build_progress_bar(
args, itr, epoch,
prefix='valid on \'{}\' subset'.format(subset),
no_progress_bar='simple'
)
# reset validation loss meters
for k in ['valid_loss', 'valid_nll_loss']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
extra_meters = collections.defaultdict(lambda: AverageMeter())
for sample in progress:
log_output = trainer.valid_step(sample)
# log mid-validation stats
stats = get_valid_stats(trainer)
for k, v in log_output.items():
if k in ['loss', 'nll_loss']:
continue
extra_meters[k].update(v)
stats[k] = extra_meters[k].avg
progress.log(stats)
# log validation stats
stats = get_valid_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
return stats['valid_loss']
def get_valid_stats(trainer):
stats = collections.OrderedDict()
stats['valid_loss'] = trainer.get_meter('valid_loss').avg
if trainer.get_meter('valid_nll_loss').count > 0:
nll_loss = trainer.get_meter('valid_nll_loss').avg
stats['valid_nll_loss'] = nll_loss
else:
nll_loss = trainer.get_meter('valid_loss').avg
stats['valid_ppl'] = get_perplexity(nll_loss)
return stats
def get_perplexity(loss):
try:
return '{:.2f}'.format(math.pow(2, loss))
except OverflowError:
return float('inf')
def save_checkpoint(trainer, args, epoch, batch_offset, val_loss=None):
extra_state = {
'epoch': epoch,
'batch_offset': batch_offset,
'val_loss': val_loss,
}
if batch_offset == 0:
if not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
trainer.save_checkpoint(best_filename, extra_state)
elif not args.no_epoch_checkpoints:
epoch_filename = os.path.join(
args.save_dir, 'checkpoint{}_{}.pt'.format(epoch, batch_offset))
trainer.save_checkpoint(epoch_filename, extra_state)
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
trainer.save_checkpoint(last_filename, extra_state)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import contextlib
from io import StringIO from io import StringIO
import os import os
import random import random
...@@ -20,102 +21,241 @@ import preprocess ...@@ -20,102 +21,241 @@ import preprocess
import train import train
import generate import generate
import interactive import interactive
import eval_lm
class TestBinaries(unittest.TestCase): class TestTranslation(unittest.TestCase):
def test_binaries(self): def test_fconv(self):
# comment this out to debug the unittest if it's failing with contextlib.redirect_stdout(StringIO()):
self.mock_stdout() with tempfile.TemporaryDirectory('test_fconv') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en')
generate_main(data_dir)
with tempfile.TemporaryDirectory() as data_dir: def test_fp16(self):
self.create_dummy_data(data_dir) with contextlib.redirect_stdout(StringIO()):
self.preprocess_data(data_dir) with tempfile.TemporaryDirectory('test_fp16') as data_dir:
self.train_model(data_dir) create_dummy_data(data_dir)
self.generate(data_dir) preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--fp16'])
generate_main(data_dir)
self.unmock_stdout() def test_update_freq(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_update_freq') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--update-freq', '3'])
generate_main(data_dir)
def create_dummy_data(self, data_dir, num_examples=1000, maxlen=20): def test_lstm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_lstm') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'lstm_wiseman_iwslt_de_en', [
'--encoder-layers', '2',
'--decoder-layers', '2',
])
generate_main(data_dir)
def _create_dummy_data(filename): def test_lstm_bidirectional(self):
data = torch.rand(num_examples * maxlen) with contextlib.redirect_stdout(StringIO()):
data = 97 + torch.floor(26 * data).int() with tempfile.TemporaryDirectory('test_lstm_bidirectional') as data_dir:
with open(os.path.join(data_dir, filename), 'w') as h: create_dummy_data(data_dir)
offset = 0 preprocess_translation_data(data_dir)
for _ in range(num_examples): train_translation_model(data_dir, 'lstm', [
ex_len = random.randint(1, maxlen) '--encoder-layers', '2',
ex_str = ' '.join(map(chr, data[offset:offset+ex_len])) '--encoder-bidirectional',
print(ex_str, file=h) '--encoder-hidden-size', '256',
offset += ex_len '--decoder-layers', '2',
])
generate_main(data_dir)
_create_dummy_data('train.in') def test_transformer(self):
_create_dummy_data('train.out') with contextlib.redirect_stdout(StringIO()):
_create_dummy_data('valid.in') with tempfile.TemporaryDirectory('test_transformer') as data_dir:
_create_dummy_data('valid.out') create_dummy_data(data_dir)
_create_dummy_data('test.in') preprocess_translation_data(data_dir)
_create_dummy_data('test.out') train_translation_model(data_dir, 'transformer_iwslt_de_en')
generate_main(data_dir)
def preprocess_data(self, data_dir):
preprocess_parser = preprocess.get_parser() class TestStories(unittest.TestCase):
preprocess_args = preprocess_parser.parse_args([
def test_fconv_self_att_wp(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv_self_att_wp') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
config = [
'--encoder-layers', '[(512, 3)] * 2',
'--decoder-layers', '[(512, 3)] * 2',
'--decoder-attention', 'True',
'--encoder-attention', 'False',
'--gated-attention', 'True',
'--self-attention', 'True',
'--project-input', 'True',
]
train_translation_model(data_dir, 'fconv_self_att_wp', config)
generate_main(data_dir)
# fusion model
os.rename(os.path.join(data_dir, 'checkpoint_last.pt'), os.path.join(data_dir, 'pretrained.pt'))
config.extend([
'--pretrained', 'True',
'--pretrained-checkpoint', os.path.join(data_dir, 'pretrained.pt'),
'--save-dir', os.path.join(data_dir, 'fusion_model'),
])
train_translation_model(data_dir, 'fconv_self_att_wp', config)
class TestLanguageModeling(unittest.TestCase):
def test_fconv_lm(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv_lm') as data_dir:
create_dummy_data(data_dir)
preprocess_lm_data(data_dir)
train_language_model(data_dir, 'fconv_lm')
eval_lm_main(data_dir)
def create_dummy_data(data_dir, num_examples=1000, maxlen=20):
def _create_dummy_data(filename):
data = torch.rand(num_examples * maxlen)
data = 97 + torch.floor(26 * data).int()
with open(os.path.join(data_dir, filename), 'w') as h:
offset = 0
for _ in range(num_examples):
ex_len = random.randint(1, maxlen)
ex_str = ' '.join(map(chr, data[offset:offset+ex_len]))
print(ex_str, file=h)
offset += ex_len
_create_dummy_data('train.in')
_create_dummy_data('train.out')
_create_dummy_data('valid.in')
_create_dummy_data('valid.out')
_create_dummy_data('test.in')
_create_dummy_data('test.out')
def preprocess_translation_data(data_dir):
preprocess_parser = preprocess.get_parser()
preprocess_args = preprocess_parser.parse_args([
'--source-lang', 'in',
'--target-lang', 'out',
'--trainpref', os.path.join(data_dir, 'train'),
'--validpref', os.path.join(data_dir, 'valid'),
'--testpref', os.path.join(data_dir, 'test'),
'--thresholdtgt', '0',
'--thresholdsrc', '0',
'--destdir', data_dir,
])
preprocess.main(preprocess_args)
def train_translation_model(data_dir, arch, extra_flags=None):
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
'--task', 'translation',
data_dir,
'--save-dir', data_dir,
'--arch', arch,
'--optimizer', 'nag',
'--lr', '0.05',
'--max-tokens', '500',
'--max-epoch', '1',
'--no-progress-bar',
'--distributed-world-size', '1',
'--source-lang', 'in', '--source-lang', 'in',
'--target-lang', 'out', '--target-lang', 'out',
'--trainpref', os.path.join(data_dir, 'train'), ] + (extra_flags or []),
'--validpref', os.path.join(data_dir, 'valid'), )
'--testpref', os.path.join(data_dir, 'test'), train.main(train_args)
'--thresholdtgt', '0',
'--thresholdsrc', '0',
'--destdir', data_dir, def generate_main(data_dir):
]) generate_parser = options.get_generation_parser()
preprocess.main(preprocess_args) generate_args = options.parse_args_and_arch(
generate_parser,
def train_model(self, data_dir): [
train_parser = options.get_training_parser()
train_args = options.parse_args_and_arch(
train_parser,
[
data_dir,
'--arch', 'fconv_iwslt_de_en',
'--optimizer', 'nag',
'--lr', '0.05',
'--max-tokens', '500',
'--save-dir', data_dir,
'--max-epoch', '1',
'--no-progress-bar',
'--distributed-world-size', '1',
],
)
train.main(train_args)
def generate(self, data_dir):
generate_parser = options.get_generation_parser()
generate_args = generate_parser.parse_args([
data_dir, data_dir,
'--path', os.path.join(data_dir, 'checkpoint_best.pt'), '--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--beam', '5', '--beam', '3',
'--batch-size', '32', '--batch-size', '64',
'--max-len-b', '5',
'--gen-subset', 'valid', '--gen-subset', 'valid',
'--no-progress-bar', '--no-progress-bar',
]) ],
)
# evaluate model in batch mode
generate.main(generate_args)
# evaluate model interactively
generate_args.buffer_size = 0
generate_args.max_sentences = None
orig_stdin = sys.stdin
sys.stdin = StringIO('h e l l o\n')
interactive.main(generate_args)
sys.stdin = orig_stdin
# evaluate model in batch mode
generate.main(generate_args)
# evaluate model interactively def preprocess_lm_data(data_dir):
generate_args.max_sentences = None preprocess_parser = preprocess.get_parser()
orig_stdin = sys.stdin preprocess_args = preprocess_parser.parse_args([
sys.stdin = StringIO('h e l l o\n') '--only-source',
interactive.main(generate_args) '--trainpref', os.path.join(data_dir, 'train.out'),
sys.stdin = orig_stdin '--validpref', os.path.join(data_dir, 'valid.out'),
'--testpref', os.path.join(data_dir, 'test.out'),
'--destdir', data_dir,
])
preprocess.main(preprocess_args)
def mock_stdout(self):
self._orig_stdout = sys.stdout
sys.stdout = StringIO()
def unmock_stdout(self): def train_language_model(data_dir, arch):
if hasattr(self, '_orig_stdout'): train_parser = options.get_training_parser()
sys.stdout = self._orig_stdout train_args = options.parse_args_and_arch(
train_parser,
[
'--task', 'language_modeling',
data_dir,
'--arch', arch,
'--optimizer', 'nag',
'--lr', '1.0',
'--criterion', 'adaptive_loss',
'--adaptive-softmax-cutoff', '5,10,15',
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
'--decoder-embed-dim', '280',
'--max-tokens', '500',
'--tokens-per-sample', '500',
'--save-dir', data_dir,
'--max-epoch', '1',
'--no-progress-bar',
'--distributed-world-size', '1',
],
)
train.main(train_args)
def eval_lm_main(data_dir):
eval_lm_parser = options.get_eval_lm_parser()
eval_lm_args = options.parse_args_and_arch(
eval_lm_parser,
[
data_dir,
'--path', os.path.join(data_dir, 'checkpoint_last.pt'),
'--no-progress-bar',
],
)
eval_lm.main(eval_lm_args)
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import unittest
from fairseq.data import data_utils
class TestDataUtils(unittest.TestCase):
def test_counting_iterator(self):
x = list(range(10))
itr = data_utils.CountingIterator(x)
self.assertTrue(itr.has_next())
self.assertEqual(next(itr), 0)
self.assertEqual(next(itr), 1)
itr.skip(3)
self.assertEqual(next(itr), 5)
itr.skip(3)
self.assertEqual(next(itr), 9)
self.assertFalse(itr.has_next())
if __name__ == '__main__':
unittest.main()
...@@ -46,12 +46,13 @@ class TestLabelSmoothing(unittest.TestCase): ...@@ -46,12 +46,13 @@ class TestLabelSmoothing(unittest.TestCase):
[0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10], [0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10],
[0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15], [0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15],
]).unsqueeze(0).expand(2, 3, 7) # add batch dimension ]).unsqueeze(0).expand(2, 3, 7) # add batch dimension
self.model = test_utils.TestModel.build_model(self.args, self.d, self.d) self.task = test_utils.TestTranslationTask.setup_task(self.args, self.d, self.d)
self.model = self.task.build_model(self.args)
def test_nll_loss(self): def test_nll_loss(self):
self.args.label_smoothing = 0.1 self.args.label_smoothing = 0.1
nll_crit = CrossEntropyCriterion(self.args, self.d, self.d) nll_crit = CrossEntropyCriterion(self.args, self.task)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d) smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample) nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample) smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6) self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6)
...@@ -59,7 +60,7 @@ class TestLabelSmoothing(unittest.TestCase): ...@@ -59,7 +60,7 @@ class TestLabelSmoothing(unittest.TestCase):
def test_padding(self): def test_padding(self):
self.args.label_smoothing = 0.1 self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d) crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
loss, _, logging_output = crit(self.model, self.sample) loss, _, logging_output = crit(self.model, self.sample)
def get_one_no_padding(idx): def get_one_no_padding(idx):
...@@ -68,7 +69,7 @@ class TestLabelSmoothing(unittest.TestCase): ...@@ -68,7 +69,7 @@ class TestLabelSmoothing(unittest.TestCase):
sample1 = next(test_utils.dummy_dataloader([self.data[idx]])) sample1 = next(test_utils.dummy_dataloader([self.data[idx]]))
args1 = copy.copy(self.args) args1 = copy.copy(self.args)
args1.probs = args1.probs[idx, :, :].unsqueeze(0) args1.probs = args1.probs[idx, :, :].unsqueeze(0)
model1 = test_utils.TestModel.build_model(args1, self.d, self.d) model1 = self.task.build_model(args1)
loss1, _, _ = crit(model1, sample1) loss1, _, _ = crit(model1, sample1)
return loss1 return loss1
...@@ -78,15 +79,15 @@ class TestLabelSmoothing(unittest.TestCase): ...@@ -78,15 +79,15 @@ class TestLabelSmoothing(unittest.TestCase):
def test_reduction(self): def test_reduction(self):
self.args.label_smoothing = 0.1 self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d) crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
loss, _, logging_output = crit(self.model, self.sample, reduce=True) loss, _, logging_output = crit(self.model, self.sample, reduce=True)
unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False) unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False)
self.assertAlmostEqual(loss, unreduced_loss.sum()) self.assertAlmostEqual(loss, unreduced_loss.sum())
def test_zero_eps(self): def test_zero_eps(self):
self.args.label_smoothing = 0.0 self.args.label_smoothing = 0.0
nll_crit = CrossEntropyCriterion(self.args, self.d, self.d) nll_crit = CrossEntropyCriterion(self.args, self.task)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d) smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.task)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample) nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample) smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertAlmostEqual(nll_loss, smooth_loss) self.assertAlmostEqual(nll_loss, smooth_loss)
......
...@@ -80,10 +80,12 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -80,10 +80,12 @@ class TestSequenceGenerator(unittest.TestCase):
]), ]),
] ]
self.model = test_utils.TestModel.build_model(args, d, d) task = test_utils.TestTranslationTask.setup_task(args, d, d)
self.model = task.build_model(args)
self.tgt_dict = task.target_dictionary
def test_with_normalization(self): def test_with_normalization(self):
generator = SequenceGenerator([self.model]) generator = SequenceGenerator([self.model], self.tgt_dict)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
...@@ -102,7 +104,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -102,7 +104,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_without_normalization(self): def test_without_normalization(self):
# Sentence 1: unchanged from the normalized case # Sentence 1: unchanged from the normalized case
# Sentence 2: beams swap order # Sentence 2: beams swap order
generator = SequenceGenerator([self.model], normalize_scores=False) generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
...@@ -120,7 +122,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -120,7 +122,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_short_hypos(self): def test_with_lenpen_favoring_short_hypos(self):
lenpen = 0.6 lenpen = 0.6
generator = SequenceGenerator([self.model], len_penalty=lenpen) generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
...@@ -138,7 +140,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -138,7 +140,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_with_lenpen_favoring_long_hypos(self): def test_with_lenpen_favoring_long_hypos(self):
lenpen = 5.0 lenpen = 5.0
generator = SequenceGenerator([self.model], len_penalty=lenpen) generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
...@@ -155,7 +157,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -155,7 +157,7 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen) self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.6], lenpen=lenpen)
def test_maxlen(self): def test_maxlen(self):
generator = SequenceGenerator([self.model], maxlen=2) generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
...@@ -172,7 +174,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -172,7 +174,7 @@ class TestSequenceGenerator(unittest.TestCase):
self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01]) self.assertHypoScore(hypos[1][1], [0.3, 0.9, 0.01])
def test_no_stop_early(self): def test_no_stop_early(self):
generator = SequenceGenerator([self.model], stop_early=False) generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False)
hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2) hypos = generator.generate(self.src_tokens, self.src_lengths, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.eos, self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
......
...@@ -83,8 +83,9 @@ class TestSequenceScorer(unittest.TestCase): ...@@ -83,8 +83,9 @@ class TestSequenceScorer(unittest.TestCase):
[0.3, 0.7], # sentence 3 [0.3, 0.7], # sentence 3
] ]
model = test_utils.TestModel.build_model(args, d, d) task = test_utils.TestTranslationTask.setup_task(args, d, d)
scorer = SequenceScorer([model]) model = task.build_model(args)
scorer = SequenceScorer([model], task.target_dictionary)
for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr): for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr):
self.assertHypoTokens(hypos[0], data[id]['target']) self.assertHypoTokens(hypos[0], data[id]['target'])
self.assertHypoScore(hypos[0], expected_scores[id]) self.assertHypoScore(hypos[0], expected_scores[id])
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import contextlib
from io import StringIO
import unittest
from unittest.mock import MagicMock, patch
import torch
from fairseq import data
import train
def mock_trainer(epoch, num_updates, iterations_in_epoch):
trainer = MagicMock()
trainer.load_checkpoint.return_value = {
'train_iterator': {
'epoch': epoch,
'iterations_in_epoch': iterations_in_epoch,
'shuffle': False,
},
}
trainer.get_num_updates.return_value = num_updates
return trainer
def mock_dict():
d = MagicMock()
d.pad.return_value = 1
d.eos.return_value = 2
d.unk.return_value = 3
return d
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
tokens = torch.LongTensor(list(range(epoch_size)))
tokens_ds = data.TokenBlockDataset(tokens, [len(tokens)], 1, include_targets=False)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
epoch_itr = data.EpochBatchIterator(
dataset=data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False),
max_tokens=1,
)
return trainer, epoch_itr
class TestLoadCheckpoint(unittest.TestCase):
def setUp(self):
self.patches = {
'os.makedirs': MagicMock(),
'os.path.join': MagicMock(),
'os.path.isfile': MagicMock(return_value=True),
}
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
[p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
train.load_checkpoint(MagicMock(), trainer, epoch_itr)
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 2)
self.assertEqual(epoch_itr.iterations_in_epoch, 50)
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 50)
self.assertEqual(epoch_itr.iterations_in_epoch, 51)
def test_load_full_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 300, 150)
train.load_checkpoint(MagicMock(), trainer, epoch_itr)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 3)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
def test_load_no_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(0, 150, 0, 0)
self.patches['os.path.isfile'].return_value = False
train.load_checkpoint(MagicMock(), trainer, epoch_itr)
itr = epoch_itr.next_epoch_itr(shuffle=False)
self.assertEqual(epoch_itr.epoch, 1)
self.assertEqual(epoch_itr.iterations_in_epoch, 0)
self.assertEqual(next(itr)['net_input']['src_tokens'][0].item(), 0)
def tearDown(self):
patch.stopall()
if __name__ == '__main__':
unittest.main()
...@@ -77,16 +77,6 @@ class TestUtils(unittest.TestCase): ...@@ -77,16 +77,6 @@ class TestUtils(unittest.TestCase):
utils.make_positions(right_pad_input, pad, left_pad=False), utils.make_positions(right_pad_input, pad, left_pad=False),
) )
def test_make_variable(self):
t = [{'k': torch.rand(5, 5)}]
v = utils.make_variable(t)[0]['k']
self.assertTrue(isinstance(v, Variable))
self.assertFalse(v.data.is_cuda)
v = utils.make_variable(t, cuda=True)[0]['k']
self.assertEqual(v.data.is_cuda, torch.cuda.is_available())
def assertAlmostEqual(self, t1, t2): def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch") self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4) self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4)
......
...@@ -8,20 +8,23 @@ ...@@ -8,20 +8,23 @@
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
from fairseq import data, dictionary, utils from fairseq import utils
from fairseq.data import Dictionary
from fairseq.data.language_pair_dataset import collate
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqIncrementalDecoder, FairseqIncrementalDecoder,
FairseqModel, FairseqModel,
) )
from fairseq.tasks import FairseqTask
def dummy_dictionary(vocab_size, prefix='token_'): def dummy_dictionary(vocab_size, prefix='token_'):
d = dictionary.Dictionary() d = Dictionary()
for i in range(vocab_size): for i in range(vocab_size):
token = prefix + str(i) token = prefix + str(i)
d.add_symbol(token) d.add_symbol(token)
d.finalize() d.finalize(padding_factor=1) # don't add extra padding symbols
return d return d
...@@ -44,13 +47,7 @@ def dummy_dataloader( ...@@ -44,13 +47,7 @@ def dummy_dataloader(
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
collate_fn=( collate_fn=(lambda samples: collate(samples, padding_idx, eos_idx)),
lambda samples: data.LanguagePairDataset.collate(
samples,
padding_idx,
eos_idx,
)
),
) )
return iter(dataloader) return iter(dataloader)
...@@ -68,14 +65,38 @@ class TestDataset(torch.utils.data.Dataset): ...@@ -68,14 +65,38 @@ class TestDataset(torch.utils.data.Dataset):
return len(self.data) return len(self.data)
class TestTranslationTask(FairseqTask):
def __init__(self, args, src_dict, tgt_dict, model):
super().__init__(args)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.model = model
@classmethod
def setup_task(cls, args, src_dict=None, tgt_dict=None, model=None):
return cls(args, src_dict, tgt_dict, model)
def build_model(self, args):
return TestModel.build_model(args, self)
@property
def source_dictionary(self):
return self.src_dict
@property
def target_dictionary(self):
return self.tgt_dict
class TestModel(FairseqModel): class TestModel(FairseqModel):
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
@classmethod @classmethod
def build_model(cls, args, src_dict, dst_dict): def build_model(cls, args, task):
encoder = TestEncoder(args, src_dict) encoder = TestEncoder(args, task.source_dictionary)
decoder = TestIncrementalDecoder(args, dst_dict) decoder = TestIncrementalDecoder(args, task.target_dictionary)
return cls(encoder, decoder) return cls(encoder, decoder)
...@@ -134,7 +155,7 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder): ...@@ -134,7 +155,7 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
return Variable(probs), Variable(attn) return Variable(probs), Variable(attn)
def get_normalized_probs(self, net_output, log_probs): def get_normalized_probs(self, net_output, log_probs, _):
# the decoder returns probabilities directly # the decoder returns probabilities directly
probs = net_output[0] probs = net_output[0]
if log_probs: if log_probs:
......
...@@ -6,24 +6,348 @@ ...@@ -6,24 +6,348 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from fairseq import options import collections
import itertools
import os
import math
import torch
from distributed_train import main as distributed_main from fairseq import data, distributed_utils, options, progress_bar, tasks, utils
from multiprocessing_train import main as multiprocessing_main from fairseq.fp16_trainer import FP16Trainer
from singleprocess_train import main as singleprocess_main from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter
def main(args): def main(args):
if args.distributed_port > 0 \ if args.max_tokens is None:
or args.distributed_init_method is not None: args.max_tokens = 6000
distributed_main(args) print(args)
elif args.distributed_world_size > 1:
multiprocessing_main(args) if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
torch.cuda.set_device(args.device_id)
torch.manual_seed(args.seed)
# Setup task, e.g., translation, language modeling, etc.
task = tasks.setup_task(args)
# Load dataset splits
load_dataset_splits(args, task, ['train', 'valid'])
# Build model and criterion
model = task.build_model(args)
criterion = task.build_criterion(args)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {}'.format(sum(p.numel() for p in model.parameters())))
# Build trainer
if args.fp16:
trainer = FP16Trainer(args, task, model, criterion)
else: else:
singleprocess_main(args) if torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16')
trainer = Trainer(args, task, model, criterion)
print('| training on {} GPUs'.format(args.distributed_world_size))
print('| max tokens per GPU = {} and max sentences per GPU = {}'.format(
args.max_tokens,
args.max_sentences,
))
# Initialize dataloader
max_positions = trainer.get_model().max_positions()
epoch_itr = data.EpochBatchIterator(
dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid,
max_positions=max_positions,
ignore_invalid_inputs=True,
required_batch_size_multiple=8,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
)
# Load the latest checkpoint if one is available
load_checkpoint(args, trainer, epoch_itr)
# Send a dummy batch to warm the caching allocator
dummy_batch = task.dataset('train').get_dummy_batch(args.max_tokens, max_positions)
trainer.dummy_train_step(dummy_batch)
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr()
train_meter = StopwatchMeter()
train_meter.start()
valid_losses = [None]
valid_subsets = args.valid_subset.split(',')
while lr > args.min_lr and epoch_itr.epoch <= max_epoch and trainer.get_num_updates() < max_update:
# train for one epoch
train(args, trainer, task, epoch_itr)
if epoch_itr.epoch % args.validate_interval == 0:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
# save checkpoint
if epoch_itr.epoch % args.save_interval == 0:
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch."""
# Initialize data iterator
itr = epoch_itr.next_epoch_itr()
progress = progress_bar.build_progress_bar(args, itr, epoch_itr.epoch, no_progress_bar='simple')
# update parameters every N batches
if epoch_itr.epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch_itr.epoch - 1]
else:
update_freq = args.update_freq[-1]
extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(',')[0]
max_update = args.max_update or math.inf
num_batches = len(epoch_itr)
for i, sample in enumerate(progress, start=epoch_itr.iterations_in_epoch):
if i < num_batches - 1 and (i + 1) % update_freq > 0:
# buffer updates according to --update-freq
trainer.train_step(sample, update_params=False)
continue
else:
log_output = trainer.train_step(sample, update_params=True)
# log mid-epoch stats
stats = get_training_stats(trainer)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
continue # these are already logged above
if 'loss' in k:
extra_meters[k].update(v, log_output['sample_size'])
else:
extra_meters[k].update(v)
stats[k] = extra_meters[k].avg
progress.log(stats)
# ignore the first mini-batch in words-per-second calculation
if i == 0:
trainer.get_meter('wps').reset()
num_updates = trainer.get_num_updates()
if args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0:
valid_losses = validate(args, trainer, task, epoch_itr, [first_valid])
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
if num_updates >= max_update:
break
# log end-of-epoch stats
stats = get_training_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
# reset training meters
for k in ['train_loss', 'train_nll_loss', 'wps', 'ups', 'wpb', 'bsz', 'clip']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
def get_training_stats(trainer):
stats = collections.OrderedDict()
stats['loss'] = '{:.3f}'.format(trainer.get_meter('train_loss').avg)
if trainer.get_meter('train_nll_loss').count > 0:
nll_loss = trainer.get_meter('train_nll_loss').avg
stats['nll_loss'] = '{:.3f}'.format(nll_loss)
else:
nll_loss = trainer.get_meter('train_loss').avg
stats['ppl'] = get_perplexity(nll_loss)
stats['wps'] = round(trainer.get_meter('wps').avg)
stats['ups'] = '{:.1f}'.format(trainer.get_meter('ups').avg)
stats['wpb'] = round(trainer.get_meter('wpb').avg)
stats['bsz'] = round(trainer.get_meter('bsz').avg)
stats['num_updates'] = trainer.get_num_updates()
stats['lr'] = trainer.get_lr()
stats['gnorm'] = '{:.3f}'.format(trainer.get_meter('gnorm').avg)
stats['clip'] = '{:.0%}'.format(trainer.get_meter('clip').avg)
stats['oom'] = trainer.get_meter('oom').avg
if trainer.get_meter('loss_scale') is not None:
stats['loss_scale'] = '{:.3f}'.format(trainer.get_meter('loss_scale').avg)
stats['wall'] = round(trainer.get_meter('wall').elapsed_time)
return stats
def validate(args, trainer, task, epoch_itr, subsets):
"""Evaluate the model on the validation set(s) and return the losses."""
valid_losses = []
for subset in subsets:
# Initialize data iterator
itr = data.EpochBatchIterator(
dataset=task.dataset(subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid,
max_positions=trainer.get_model().max_positions(),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
).next_epoch_itr(shuffle=False)
progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch,
prefix='valid on \'{}\' subset'.format(subset),
no_progress_bar='simple'
)
# reset validation loss meters
for k in ['valid_loss', 'valid_nll_loss']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
extra_meters = collections.defaultdict(lambda: AverageMeter())
for sample in progress:
log_output = trainer.valid_step(sample)
for k, v in log_output.items():
if k in ['loss', 'nll_loss', 'sample_size']:
continue
extra_meters[k].update(v)
# log validation stats
stats = get_valid_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
valid_losses.append(stats['valid_loss'])
return valid_losses
def get_valid_stats(trainer):
stats = collections.OrderedDict()
stats['valid_loss'] = trainer.get_meter('valid_loss').avg
if trainer.get_meter('valid_nll_loss').count > 0:
nll_loss = trainer.get_meter('valid_nll_loss').avg
stats['valid_nll_loss'] = nll_loss
else:
nll_loss = trainer.get_meter('valid_loss').avg
stats['valid_ppl'] = get_perplexity(nll_loss)
stats['num_updates'] = trainer.get_num_updates()
if hasattr(save_checkpoint, 'best'):
stats['best'] = min(save_checkpoint.best, stats['valid_loss'])
return stats
def get_perplexity(loss):
try:
return '{:.2f}'.format(math.pow(2, loss))
except OverflowError:
return float('inf')
def save_checkpoint(args, trainer, epoch_itr, val_loss):
if args.no_save or not distributed_utils.is_master(args):
return
epoch = epoch_itr.epoch
end_of_epoch = epoch_itr.end_of_epoch()
updates = trainer.get_num_updates()
checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0
)
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0
)
checkpoint_conds['checkpoint_best.pt'] = (
val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
)
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
prev_best = getattr(save_checkpoint, 'best', val_loss)
if val_loss is not None:
save_checkpoint.best = min(val_loss, prev_best)
extra_state = {
'best': save_checkpoint.best,
'train_iterator': epoch_itr.state_dict(),
'val_loss': val_loss,
}
checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond]
if len(checkpoints) > 0:
for fn in checkpoints:
if os.path.exists(fn):
os.remove(fn)
if not end_of_epoch and args.keep_interval_updates > 0:
for cp in checkpoints:
trainer.save_checkpoint(cp, extra_state)
else:
trainer.save_checkpoint(checkpoints[0], extra_state)
for fn in checkpoints[1:]:
os.symlink(os.path.basename(checkpoints[0]), fn)
if not end_of_epoch and args.keep_interval_updates > 0:
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
for old_chk in checkpoints[args.keep_interval_updates:]:
os.remove(old_chk)
def load_checkpoint(args, trainer, epoch_itr):
"""Load a checkpoint and replay dataloader to match."""
os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
# replay train iterator to match checkpoint
epoch_itr.load_state_dict(extra_state['train_iterator'])
print('| loaded checkpoint {} (epoch {} @ {} updates)'.format(
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()))
trainer.lr_step(epoch_itr.epoch)
trainer.lr_step_update(trainer.get_num_updates())
if 'best' in extra_state:
save_checkpoint.best = extra_state['best']
def load_dataset_splits(args, task, splits):
for split in splits:
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
try:
task.load_dataset(split_k)
print('| {} {} {} examples'.format(args.data, split_k, len(task.dataset(split_k))))
except FileNotFoundError as e:
if k > 0:
break
raise e
if __name__ == '__main__': if __name__ == '__main__':
parser = options.get_training_parser() parser = options.get_training_parser()
args = options.parse_args_and_arch(parser) args = options.parse_args_and_arch(parser)
main(args)
if args.distributed_port > 0 or args.distributed_init_method is not None:
from distributed_train import main as distributed_main
distributed_main(args)
elif args.distributed_world_size > 1:
from multiprocessing_train import main as multiprocessing_main
multiprocessing_main(args)
else:
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment