Unverified Commit 66415206 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

fairseq-py goes distributed (#106)

This PR includes breaking API changes to modularize fairseq-py and adds support for distributed training across multiple nodes.

Changes:
- c7033ef: add support for distributed training! See updated README for usage.
- e016299: modularize fairseq-py, adding support for register_model, register_criterion, register_optimizer, etc.
- 154e440: update LSTM implementation to use PackedSequence objects in the encoder, better following best practices and improving perf
- 90c2973 and 1da6265: improve unit test coverage
parent 7e86e30c
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
from torch.autograd import Variable
from fairseq import data, dictionary
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqModel,
)
def dummy_dictionary(vocab_size, prefix='token_'):
d = dictionary.Dictionary()
for i in range(vocab_size):
token = prefix + str(i)
d.add_symbol(token)
d.finalize()
return d
def dummy_dataloader(
samples,
padding_idx=1,
eos_idx=2,
batch_size=None,
):
if batch_size is None:
batch_size = len(samples)
# add any missing data to samples
for i, sample in enumerate(samples):
if 'id' not in sample:
sample['id'] = i
# create dataloader
dataset = TestDataset(samples)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
collate_fn=(
lambda samples: data.LanguagePairDataset.collate(
samples,
padding_idx,
eos_idx,
)
),
)
return iter(dataloader)
class TestDataset(torch.utils.data.Dataset):
def __init__(self, data):
super().__init__()
self.data = data
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
class TestModel(FairseqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@classmethod
def build_model(cls, args, src_dict, dst_dict):
encoder = TestEncoder(args, src_dict)
decoder = TestIncrementalDecoder(args, dst_dict)
return cls(encoder, decoder)
class TestEncoder(FairseqEncoder):
def __init__(self, args, dictionary):
super().__init__(dictionary)
self.args = args
def forward(self, src_tokens, src_lengths):
return src_tokens
class TestIncrementalDecoder(FairseqIncrementalDecoder):
def __init__(self, args, dictionary):
super().__init__(dictionary)
assert hasattr(args, 'beam_probs')
args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100)
self.args = args
def forward(self, prev_output_tokens, encoder_out):
if self._is_incremental_eval:
prev_output_tokens = prev_output_tokens[:, -1:]
return self._forward(prev_output_tokens, encoder_out)
def _forward(self, prev_output_tokens, encoder_out):
bbsz = prev_output_tokens.size(0)
vocab = len(self.dictionary)
src_len = encoder_out.size(1)
tgt_len = prev_output_tokens.size(1)
# determine number of steps
if self._is_incremental_eval:
# cache step number
step = self.get_incremental_state('step')
if step is None:
step = 0
self.set_incremental_state('step', step + 1)
steps = [step]
else:
steps = list(range(tgt_len))
# define output in terms of raw probs
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
for i, step in enumerate(steps):
# args.beam_probs gives the probability for every vocab element,
# starting with eos, then unknown, and then the rest of the vocab
if step < len(self.args.beam_probs):
probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step]
else:
probs[:, i, self.dictionary.eos()] = 1.0
# random attention
attn = torch.rand(bbsz, src_len, tgt_len)
return Variable(probs), Variable(attn)
def get_normalized_probs(self, net_output, log_probs):
# the decoder returns probabilities directly
if log_probs:
return net_output.log()
else:
return net_output
def max_positions(self):
return self.args.max_decoder_positions
#!/usr/bin/env python3
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import collections
import os
import torch
import math
from fairseq import data, options, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.multiprocessing_trainer import MultiprocessingTrainer
def main():
parser = options.get_parser('Trainer')
dataset_args = options.add_dataset_args(parser)
dataset_args.add_argument('--max-tokens', default=6000, type=int, metavar='N',
help='maximum number of tokens in a batch')
dataset_args.add_argument('--max-sentences', type=int, metavar='N',
help='maximum number of sentences in a batch')
dataset_args.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'],
help='data subset to use for training (train, valid, test)')
dataset_args.add_argument('--valid-subset', default='valid', metavar='SPLIT',
help='comma separated list of data subsets '
' to use for validation (train, valid, valid1,test, test1)')
dataset_args.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch')
options.add_optimization_args(parser)
options.add_checkpoint_args(parser)
options.add_model_args(parser)
args = utils.parse_args_and_arch(parser)
if args.no_progress_bar and args.log_format is None:
args.log_format = 'simple'
if args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
torch.manual_seed(args.seed)
# Load dataset
splits = ['train', 'valid']
if data.has_binary_files(args.data, splits):
dataset = data.load_dataset(args.data, splits, args.source_lang, args.target_lang)
else:
dataset = data.load_raw_text_dataset(args.data, splits, args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
args.num_gpus = torch.cuda.device_count()
print(args)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in splits:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
print('| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'.format(
args.num_gpus, args.max_tokens, args.max_sentences))
from fairseq import options
# Build model and criterion
model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {}'.format(sum(p.data.numel() for p in model.parameters())))
from distributed_train import main as distributed_main
from multiprocessing_train import main as multiprocessing_main
from singleprocess_train import main as singleprocess_main
# The max number of positions can be different for train and valid
# e.g., RNNs may support more positions at test time than seen in training
max_positions_train = (
min(args.max_source_positions, model.max_encoder_positions()),
min(args.max_target_positions, model.max_decoder_positions())
)
max_positions_valid = (model.max_encoder_positions(), model.max_decoder_positions())
# Start multiprocessing
trainer = MultiprocessingTrainer(args, model, criterion)
# Load the latest checkpoint if one is available
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
epoch = extra_state['epoch']
batch_offset = extra_state['batch_offset']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
if batch_offset == 0:
epoch += 1
def main(args):
if args.distributed_port > 0 \
or args.distributed_init_method is not None:
distributed_main(args)
elif torch.cuda.device_count() > 1:
multiprocessing_main(args)
else:
epoch, batch_offset = 1, 0
# Train until the learning rate gets too small
val_loss = None
max_epoch = args.max_epoch or math.inf
lr = trainer.get_lr()
train_meter = StopwatchMeter()
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch
train(args, epoch, batch_offset, trainer, dataset, max_positions_train)
# evaluate on validate set
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset)
if k == 0:
if not args.no_save:
# save checkpoint
save_checkpoint(trainer, args, epoch, 0, val_loss)
# only use first validation loss to update the learning schedule
lr = trainer.lr_step(val_loss, epoch)
epoch += 1
batch_offset = 0
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
# Stop multiprocessing
trainer.stop()
def get_perplexity(loss):
try:
return round(math.pow(2, loss), 2)
except OverflowError:
return float('inf')
def train(args, epoch, batch_offset, trainer, dataset, max_positions):
"""Train the model for one epoch."""
seed = args.seed + epoch
torch.manual_seed(seed)
trainer.set_seed(seed)
itr = dataset.train_dataloader(
args.train_subset, num_workers=args.workers,
max_tokens=args.max_tokens, max_sentences=args.max_sentences,
max_positions=max_positions, seed=seed, epoch=epoch,
sample_without_replacement=args.sample_without_replacement,
sort_by_source_size=(epoch <= args.curriculum))
loss_meter = AverageMeter()
nll_loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch
wpb_meter = AverageMeter() # words per batch
wps_meter = TimeMeter() # words per second
clip_meter = AverageMeter() # % of updates clipped
extra_meters = collections.defaultdict(lambda: AverageMeter())
lr = trainer.get_lr()
with utils.build_progress_bar(args, itr, epoch) as t:
for i, sample in data.skip_group_enumerator(t, args.num_gpus, batch_offset):
loss_dict = trainer.train_step(sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample)
if 'nll_loss' in loss_dict:
nll_loss = loss_dict['nll_loss']
nll_loss_meter.update(nll_loss, ntokens)
nsentences = sum(s['net_input']['src_tokens'].size(0) for s in sample)
loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
bsz_meter.update(nsentences)
wpb_meter.update(ntokens)
wps_meter.update(ntokens)
clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)
extra_postfix = []
for k, v in loss_dict.items():
extra_meters[k].update(v)
extra_postfix.append((k, extra_meters[k].avg))
t.log(collections.OrderedDict([
('loss', loss_meter),
('wps', round(wps_meter.avg)),
('wpb', round(wpb_meter.avg)),
('bsz', round(bsz_meter.avg)),
('lr', lr),
('clip', '{:.0%}'.format(clip_meter.avg)),
] + extra_postfix))
if i == 0:
# ignore the first mini-batch in words-per-second calculation
wps_meter.reset()
if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, i + 1)
t.print(collections.OrderedDict([
('train loss', round(loss_meter.avg, 2)),
('train ppl', get_perplexity(nll_loss_meter.avg
if nll_loss_meter.count > 0
else loss_meter.avg)),
('s/checkpoint', round(wps_meter.elapsed_time)),
('words/s', round(wps_meter.avg)),
('words/batch', round(wpb_meter.avg)),
('bsz', round(bsz_meter.avg)),
('lr', lr),
('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
] + [
(k, meter.avg)
for k, meter in extra_meters.items()
]))
def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
extra_state = {
'epoch': epoch,
'batch_offset': batch_offset,
'val_loss': val_loss,
}
if batch_offset == 0:
if not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
trainer.save_checkpoint(best_filename, extra_state)
elif not args.no_epoch_checkpoints:
epoch_filename = os.path.join(
args.save_dir, 'checkpoint{}_{}.pt'.format(epoch, batch_offset))
trainer.save_checkpoint(epoch_filename, extra_state)
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
trainer.save_checkpoint(last_filename, extra_state)
def validate(args, epoch, trainer, dataset, max_positions, subset):
"""Evaluate the model on the validation set and return the average loss."""
itr = dataset.eval_dataloader(
subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid,
max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
descending=True, # largest batch first to warm the caching allocator
)
loss_meter = AverageMeter()
nll_loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter())
prefix = 'valid on \'{}\' subset'.format(subset)
with utils.build_progress_bar(args, itr, epoch, prefix) as t:
for _, sample in data.skip_group_enumerator(t, args.num_gpus):
loss_dict = trainer.valid_step(sample)
ntokens = sum(s['ntokens'] for s in sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
if 'nll_loss' in loss_dict:
nll_loss = loss_dict['nll_loss']
nll_loss_meter.update(nll_loss, ntokens)
loss_meter.update(loss, ntokens)
extra_postfix = []
for k, v in loss_dict.items():
extra_meters[k].update(v)
extra_postfix.append((k, extra_meters[k].avg))
t.log(collections.OrderedDict([
('valid loss', round(loss_meter.avg, 2)),
] + extra_postfix))
t.print(collections.OrderedDict([
('valid loss', round(loss_meter.avg, 2)),
('valid ppl', get_perplexity(nll_loss_meter.avg
if nll_loss_meter.count > 0
else loss_meter.avg)),
] + [
(k, meter.avg)
for k, meter in extra_meters.items()
]))
# update and return the learning rate
return loss_meter.avg
singleprocess_main(args)
if __name__ == '__main__':
main()
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment