Commit c83efd21 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge pull request #33 from facebookresearch/oss-merge-internal

Changes:
Add support for NCCL v2
Add support for additional optimizers
SequenceGenerator returns attention matrix
Misc bugfixes (e.g., fixes #32) and cleanup
parents af86c1ac 104cead1
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
# #
...@@ -11,11 +12,10 @@ import os ...@@ -11,11 +12,10 @@ import os
import torch import torch
import math import math
from fairseq import bleu, data, options, utils from fairseq import data, options, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.multiprocessing_trainer import MultiprocessingTrainer from fairseq.multiprocessing_trainer import MultiprocessingTrainer
from fairseq.progress_bar import progress_bar from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator
def main(): def main():
...@@ -52,7 +52,7 @@ def main(): ...@@ -52,7 +52,7 @@ def main():
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in dataset.splits: for split in ['train', 'valid']:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
if not torch.cuda.is_available(): if not torch.cuda.is_available():
...@@ -61,16 +61,25 @@ def main(): ...@@ -61,16 +61,25 @@ def main():
print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens)) print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens))
# Build model # Build model and criterion
print('| model {}'.format(args.arch)) model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
model = utils.build_model(args, dataset) criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
criterion = utils.build_criterion(args, dataset) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
# Start multiprocessing # Start multiprocessing
trainer = MultiprocessingTrainer(args, model) trainer = MultiprocessingTrainer(args, model, criterion)
# Load the latest checkpoint if one is available # Load the latest checkpoint if one is available
epoch, batch_offset = trainer.load_checkpoint(os.path.join(args.save_dir, args.restore_file)) checkpoint_path = os.path.join(args.save_dir, args.restore_file)
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
epoch = extra_state['epoch']
batch_offset = extra_state['batch_offset']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
if batch_offset == 0:
epoch += 1
else:
epoch, batch_offset = 1, 0
# Train until the learning rate gets too small # Train until the learning rate gets too small
val_loss = None val_loss = None
...@@ -80,15 +89,15 @@ def main(): ...@@ -80,15 +89,15 @@ def main():
train_meter.start() train_meter.start()
while lr > args.min_lr and epoch <= max_epoch: while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch # train for one epoch
train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus) train(args, epoch, batch_offset, trainer, dataset, num_gpus)
# evaluate on validate set # evaluate on validate set
for k, subset in enumerate(args.valid_subset.split(',')): for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, epoch, trainer, criterion, dataset, subset, num_gpus) val_loss = validate(args, epoch, trainer, dataset, subset, num_gpus)
if k == 0: if k == 0:
if not args.no_save: if not args.no_save:
# save checkpoint # save checkpoint
trainer.save_checkpoint(args, epoch, 0, val_loss) save_checkpoint(trainer, args, epoch, 0, val_loss)
# only use first validation loss to update the learning schedule # only use first validation loss to update the learning schedule
lr = trainer.lr_step(val_loss, epoch) lr = trainer.lr_step(val_loss, epoch)
...@@ -101,7 +110,14 @@ def main(): ...@@ -101,7 +110,14 @@ def main():
trainer.stop() trainer.stop()
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus): def get_perplexity(loss):
try:
return math.pow(2, loss)
except OverflowError:
return float('inf')
def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
"""Train the model for one epoch.""" """Train the model for one epoch."""
itr = dataset.dataloader(args.train_subset, num_workers=args.workers, itr = dataset.dataloader(args.train_subset, num_workers=args.workers,
...@@ -114,13 +130,16 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus): ...@@ -114,13 +130,16 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
wpb_meter = AverageMeter() # words per batch wpb_meter = AverageMeter() # words per batch
wps_meter = TimeMeter() # words per second wps_meter = TimeMeter() # words per second
clip_meter = AverageMeter() # % of updates clipped clip_meter = AverageMeter() # % of updates clipped
gnorm_meter = AverageMeter() # gradient norm extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d}'.format(epoch) desc = '| epoch {:03d}'.format(epoch)
trainer.set_seed(args.seed + epoch)
lr = trainer.get_lr() lr = trainer.get_lr()
with progress_bar(itr, desc, leave=False) as t: with progress_bar(itr, desc, leave=False) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset): for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
loss, grad_norm = trainer.train_step(sample, criterion) loss_dict = trainer.train_step(sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample) ntokens = sum(s['ntokens'] for s in sample)
src_size = sum(s['src_tokens'].size(0) for s in sample) src_size = sum(s['src_tokens'].size(0) for s in sample)
...@@ -128,8 +147,12 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus): ...@@ -128,8 +147,12 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
bsz_meter.update(src_size) bsz_meter.update(src_size)
wpb_meter.update(ntokens) wpb_meter.update(ntokens)
wps_meter.update(ntokens) wps_meter.update(ntokens)
clip_meter.update(1 if grad_norm > args.clip_norm else 0) clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)
gnorm_meter.update(grad_norm)
extra_postfix = []
for k, v in loss_dict.items():
extra_meters[k].update(v)
extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))
t.set_postfix(collections.OrderedDict([ t.set_postfix(collections.OrderedDict([
('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)), ('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)),
...@@ -138,28 +161,50 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus): ...@@ -138,28 +161,50 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
('bsz', '{:5d}'.format(round(bsz_meter.avg))), ('bsz', '{:5d}'.format(round(bsz_meter.avg))),
('lr', lr), ('lr', lr),
('clip', '{:3.0f}%'.format(clip_meter.avg * 100)), ('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
('gnorm', '{:.4f}'.format(gnorm_meter.avg)), ] + extra_postfix), refresh=False)
]), refresh=False)
if i == 0: if i == 0:
# ignore the first mini-batch in words-per-second calculation # ignore the first mini-batch in words-per-second calculation
wps_meter.reset() wps_meter.reset()
if args.save_interval > 0 and (i + 1) % args.save_interval == 0: if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
trainer.save_checkpoint(args, epoch, i + 1) save_checkpoint(trainer, args, epoch, i + 1)
fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}' fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format(
fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}' loss_meter.avg, get_perplexity(loss_meter.avg))
fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}' fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'.format(
t.write(fmt.format(loss_meter.avg, math.pow(2, loss_meter.avg), round(wps_meter.elapsed_time), round(wps_meter.avg), round(wpb_meter.avg))
round(wps_meter.elapsed_time), fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'.format(
round(wps_meter.avg), round(bsz_meter.avg), lr, clip_meter.avg * 100)
round(wpb_meter.avg), fmt += ''.join(
round(bsz_meter.avg), ' | {} {:.4f}'.format(k, meter.avg)
lr, clip_meter.avg * 100, for k, meter in extra_meters.items()
gnorm_meter.avg)) )
t.write(fmt)
def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
extra_state = {
'epoch': epoch,
'batch_offset': batch_offset,
'val_loss': val_loss,
}
if batch_offset == 0:
if not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
trainer.save_checkpoint(best_filename, extra_state)
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
trainer.save_checkpoint(last_filename, extra_state)
def validate(args, epoch, trainer, dataset, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss.""" """Evaluate the model on the validation set and return the average loss."""
itr = dataset.dataloader(subset, batch_size=None, itr = dataset.dataloader(subset, batch_size=None,
...@@ -167,18 +212,35 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus): ...@@ -167,18 +212,35 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
max_positions=args.max_positions, max_positions=args.max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test) skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
loss_meter = AverageMeter() loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset) desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
with progress_bar(itr, desc, leave=False) as t: with progress_bar(itr, desc, leave=False) as t:
for _, sample in data.skip_group_enumerator(t, ngpus): for _, sample in data.skip_group_enumerator(t, ngpus):
loss_dict = trainer.valid_step(sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample) ntokens = sum(s['ntokens'] for s in sample)
loss = trainer.valid_step(sample, criterion)
loss_meter.update(loss, ntokens) loss_meter.update(loss, ntokens)
t.set_postfix(loss='{:.2f}'.format(loss_meter.avg), refresh=False)
extra_postfix = []
for k, v in loss_dict.items():
extra_meters[k].update(v)
extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))
t.set_postfix(collections.OrderedDict([
('loss', '{:.2f}'.format(loss_meter.avg)),
] + extra_postfix), refresh=False)
val_loss = loss_meter.avg val_loss = loss_meter.avg
t.write(desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}' fmt = desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format(
.format(val_loss, math.pow(2, val_loss))) val_loss, get_perplexity(val_loss))
fmt += ''.join(
' | {} {:.4f}'.format(k, meter.avg)
for k, meter in extra_meters.items()
)
t.write(fmt)
# update and return the learning rate # update and return the learning rate
return val_loss return val_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment