Commit c83efd21 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge pull request #33 from facebookresearch/oss-merge-internal

Changes:
Add support for NCCL v2
Add support for additional optimizers
SequenceGenerator returns attention matrix
Misc bugfixes (e.g., fixes #32) and cleanup
parents af86c1ac 104cead1
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
......@@ -11,11 +12,10 @@ import os
import torch
import math
from fairseq import bleu, data, options, utils
from fairseq import data, options, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.multiprocessing_trainer import MultiprocessingTrainer
from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator
def main():
......@@ -52,7 +52,7 @@ def main():
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in dataset.splits:
for split in ['train', 'valid']:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
if not torch.cuda.is_available():
......@@ -61,16 +61,25 @@ def main():
print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens))
# Build model
print('| model {}'.format(args.arch))
model = utils.build_model(args, dataset)
criterion = utils.build_criterion(args, dataset)
# Build model and criterion
model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
# Start multiprocessing
trainer = MultiprocessingTrainer(args, model)
trainer = MultiprocessingTrainer(args, model, criterion)
# Load the latest checkpoint if one is available
epoch, batch_offset = trainer.load_checkpoint(os.path.join(args.save_dir, args.restore_file))
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
extra_state = trainer.load_checkpoint(checkpoint_path)
if extra_state is not None:
epoch = extra_state['epoch']
batch_offset = extra_state['batch_offset']
print('| loaded checkpoint {} (epoch {})'.format(checkpoint_path, epoch))
if batch_offset == 0:
epoch += 1
else:
epoch, batch_offset = 1, 0
# Train until the learning rate gets too small
val_loss = None
......@@ -80,15 +89,15 @@ def main():
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch
train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus)
train(args, epoch, batch_offset, trainer, dataset, num_gpus)
# evaluate on validate set
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, epoch, trainer, criterion, dataset, subset, num_gpus)
val_loss = validate(args, epoch, trainer, dataset, subset, num_gpus)
if k == 0:
if not args.no_save:
# save checkpoint
trainer.save_checkpoint(args, epoch, 0, val_loss)
save_checkpoint(trainer, args, epoch, 0, val_loss)
# only use first validation loss to update the learning schedule
lr = trainer.lr_step(val_loss, epoch)
......@@ -101,7 +110,14 @@ def main():
trainer.stop()
def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
def get_perplexity(loss):
try:
return math.pow(2, loss)
except OverflowError:
return float('inf')
def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
"""Train the model for one epoch."""
itr = dataset.dataloader(args.train_subset, num_workers=args.workers,
......@@ -114,13 +130,16 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
wpb_meter = AverageMeter() # words per batch
wps_meter = TimeMeter() # words per second
clip_meter = AverageMeter() # % of updates clipped
gnorm_meter = AverageMeter() # gradient norm
extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d}'.format(epoch)
trainer.set_seed(args.seed + epoch)
lr = trainer.get_lr()
with progress_bar(itr, desc, leave=False) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
loss, grad_norm = trainer.train_step(sample, criterion)
loss_dict = trainer.train_step(sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample)
src_size = sum(s['src_tokens'].size(0) for s in sample)
......@@ -128,8 +147,12 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
bsz_meter.update(src_size)
wpb_meter.update(ntokens)
wps_meter.update(ntokens)
clip_meter.update(1 if grad_norm > args.clip_norm else 0)
gnorm_meter.update(grad_norm)
clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)
extra_postfix = []
for k, v in loss_dict.items():
extra_meters[k].update(v)
extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))
t.set_postfix(collections.OrderedDict([
('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)),
......@@ -138,28 +161,50 @@ def train(args, epoch, batch_offset, trainer, criterion, dataset, num_gpus):
('bsz', '{:5d}'.format(round(bsz_meter.avg))),
('lr', lr),
('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
('gnorm', '{:.4f}'.format(gnorm_meter.avg)),
]), refresh=False)
] + extra_postfix), refresh=False)
if i == 0:
# ignore the first mini-batch in words-per-second calculation
wps_meter.reset()
if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
trainer.save_checkpoint(args, epoch, i + 1)
fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'
fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'
fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}% | gnorm {:.4f}'
t.write(fmt.format(loss_meter.avg, math.pow(2, loss_meter.avg),
round(wps_meter.elapsed_time),
round(wps_meter.avg),
round(wpb_meter.avg),
round(bsz_meter.avg),
lr, clip_meter.avg * 100,
gnorm_meter.avg))
def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
save_checkpoint(trainer, args, epoch, i + 1)
fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format(
loss_meter.avg, get_perplexity(loss_meter.avg))
fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'.format(
round(wps_meter.elapsed_time), round(wps_meter.avg), round(wpb_meter.avg))
fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'.format(
round(bsz_meter.avg), lr, clip_meter.avg * 100)
fmt += ''.join(
' | {} {:.4f}'.format(k, meter.avg)
for k, meter in extra_meters.items()
)
t.write(fmt)
def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
extra_state = {
'epoch': epoch,
'batch_offset': batch_offset,
'val_loss': val_loss,
}
if batch_offset == 0:
if not args.no_epoch_checkpoints:
epoch_filename = os.path.join(args.save_dir, 'checkpoint{}.pt'.format(epoch))
trainer.save_checkpoint(epoch_filename, extra_state)
assert val_loss is not None
if not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best:
save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
trainer.save_checkpoint(best_filename, extra_state)
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
trainer.save_checkpoint(last_filename, extra_state)
def validate(args, epoch, trainer, dataset, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss."""
itr = dataset.dataloader(subset, batch_size=None,
......@@ -167,18 +212,35 @@ def validate(args, epoch, trainer, criterion, dataset, subset, ngpus):
max_positions=args.max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
with progress_bar(itr, desc, leave=False) as t:
for _, sample in data.skip_group_enumerator(t, ngpus):
loss_dict = trainer.valid_step(sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample)
loss = trainer.valid_step(sample, criterion)
loss_meter.update(loss, ntokens)
t.set_postfix(loss='{:.2f}'.format(loss_meter.avg), refresh=False)
extra_postfix = []
for k, v in loss_dict.items():
extra_meters[k].update(v)
extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))
t.set_postfix(collections.OrderedDict([
('loss', '{:.2f}'.format(loss_meter.avg)),
] + extra_postfix), refresh=False)
val_loss = loss_meter.avg
t.write(desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'
.format(val_loss, math.pow(2, val_loss)))
fmt = desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format(
val_loss, get_perplexity(val_loss))
fmt += ''.join(
' | {} {:.4f}'.format(k, meter.avg)
for k, meter in extra_meters.items()
)
t.write(fmt)
# update and return the learning rate
return val_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment