Unverified Commit d3795d6c authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge internal changes (#136)

Changes:
- 7d19e36: Add `--sampling` flag to generate.py to sample instead of doing beam search
- c777340: Add `scripts/average_checkpoints.py` to average multiple checkpoints into a combined model
- 3ea882c: Add `--max-update` option to train.py to stop training after a given number of updates
- small bugfixes for distributed training, LSTM, inverse square root LR scheduler
parent 48836525
...@@ -60,7 +60,7 @@ class Trainer(object): ...@@ -60,7 +60,7 @@ class Trainer(object):
def save_checkpoint(self, filename, extra_state): def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file.""" """Save all training state in a checkpoint file."""
if self.args.distributed_rank == 0: # only save one checkpoint if distributed_utils.is_master(self.args): # only save one checkpoint
utils.save_state(filename, self.args, self.model, self.criterion, self.optimizer, utils.save_state(filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state) self.lr_scheduler, self._num_updates, self._optim_history, extra_state)
......
...@@ -15,8 +15,6 @@ import traceback ...@@ -15,8 +15,6 @@ import traceback
from torch.autograd import Variable from torch.autograd import Variable
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from fairseq import tokenizer
def torch_persistent_save(*args, **kwargs): def torch_persistent_save(*args, **kwargs):
for i in range(3): for i in range(3):
...@@ -116,11 +114,16 @@ def _upgrade_state_dict(state): ...@@ -116,11 +114,16 @@ def _upgrade_state_dict(state):
return state return state
def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_dir=None): def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
data_dir=None, model_arg_overrides=None):
"""Load an ensemble of models for inference. """Load an ensemble of models for inference.
The source and target dictionaries can be given explicitly, or loaded from The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory. the `data_dir` directory.
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model
training
""" """
from fairseq import data, models from fairseq import data, models
...@@ -133,7 +136,8 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di ...@@ -133,7 +136,8 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di
torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
) )
args = states[0]['args'] args = states[0]['args']
args = _upgrade_args(args) if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
if src_dict is None or dst_dict is None: if src_dict is None or dst_dict is None:
assert data_dir is not None assert data_dir is not None
...@@ -148,12 +152,10 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di ...@@ -148,12 +152,10 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di
return ensemble, args return ensemble, args
def _upgrade_args(args): def _override_model_args(args, model_arg_overrides):
if not hasattr(args, 'max_source_positions'): # Uses model_arg_overrides {'arg_name': arg} to override model args
args.max_source_positions = args.max_positions for arg_name, arg_val in model_arg_overrides.items():
args.max_target_positions = args.max_positions setattr(args, arg_name, arg_val)
if not hasattr(args, 'share_input_output_embed'):
args.share_input_output_embed = False
return args return args
...@@ -247,6 +249,7 @@ def load_align_dict(replace_unk): ...@@ -247,6 +249,7 @@ def load_align_dict(replace_unk):
def replace_unk(hypo_str, src_str, alignment, align_dict, unk): def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
from fairseq import tokenizer
# Tokens are strings here # Tokens are strings here
hypo_tokens = tokenizer.tokenize_line(hypo_str) hypo_tokens = tokenizer.tokenize_line(hypo_str)
# TODO: Very rare cases where the replacement is '<eos>' should be handled gracefully # TODO: Very rare cases where the replacement is '<eos>' should be handled gracefully
...@@ -260,6 +263,7 @@ def replace_unk(hypo_str, src_str, alignment, align_dict, unk): ...@@ -260,6 +263,7 @@ def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dict, remove_bpe): def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dict, remove_bpe):
from fairseq import tokenizer
hypo_str = dst_dict.string(hypo_tokens, remove_bpe) hypo_str = dst_dict.string(hypo_tokens, remove_bpe)
if align_dict is not None: if align_dict is not None:
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, dst_dict.unk_string()) hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, dst_dict.unk_string())
...@@ -270,6 +274,27 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dic ...@@ -270,6 +274,27 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dic
return hypo_tokens, hypo_str, alignment return hypo_tokens, hypo_str, alignment
def make_positions(tensor, padding_idx, left_pad):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
max_pos = padding_idx + 1 + tensor.size(1)
if not hasattr(make_positions, 'range_buf'):
make_positions.range_buf = tensor.new()
make_positions.range_buf = make_positions.range_buf.type_as(tensor)
if make_positions.range_buf.numel() < max_pos:
torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf)
mask = tensor.ne(padding_idx)
positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return tensor.clone().masked_scatter_(mask, positions[mask])
def strip_pad(tensor, pad): def strip_pad(tensor, pad):
return tensor[tensor.ne(pad)] return tensor[tensor.ne(pad)]
...@@ -303,6 +328,7 @@ def convert_padding_direction( ...@@ -303,6 +328,7 @@ def convert_padding_direction(
index = torch.remainder(range + num_pads, max_len) index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index) return src_tokens.gather(1, index)
def item(tensor): def item(tensor):
if hasattr(tensor, 'item'): if hasattr(tensor, 'item'):
return tensor.item() return tensor.item()
......
...@@ -15,7 +15,10 @@ from fairseq.sequence_scorer import SequenceScorer ...@@ -15,7 +15,10 @@ from fairseq.sequence_scorer import SequenceScorer
def main(args): def main(args):
assert args.path is not None, '--path required for generation!'
print(args) print(args)
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
...@@ -77,7 +80,7 @@ def main(args): ...@@ -77,7 +80,7 @@ def main(args):
translator = SequenceGenerator( translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop), models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen, normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen) unk_penalty=args.unkpen, sampling=args.sampling)
if use_cuda: if use_cuda:
translator.cuda() translator.cuda()
......
...@@ -16,6 +16,8 @@ from fairseq.sequence_generator import SequenceGenerator ...@@ -16,6 +16,8 @@ from fairseq.sequence_generator import SequenceGenerator
def main(args): def main(args):
print(args) print(args)
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
......
#!/usr/bin/env python3
import argparse
import collections
import torch
def average_checkpoints(inputs):
"""Loads checkpoints from inputs and returns a model with averaged weights.
Args:
inputs: An iterable of string paths of checkpoints to load from.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = collections.OrderedDict()
params_keys = None
new_state = None
for f in inputs:
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
),
)
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
model_params = state['model']
model_params_keys = list(model_params.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
'For checkpoint {}, expected list of params: {}, '
'but found: {}'.format(f, params_keys, model_params_keys)
)
for k in params_keys:
if k not in params_dict:
params_dict[k] = []
params_dict[k].append(model_params[k])
averaged_params = collections.OrderedDict()
# v should be a list of torch Tensor.
for k, v in params_dict.items():
summed_v = None
for x in v:
summed_v = summed_v + x if summed_v is not None else x
averaged_params[k] = summed_v / len(v)
new_state['model'] = averaged_params
return new_state
def main():
parser = argparse.ArgumentParser(
description='Tool to average the params of input checkpoints to '
'produce a new checkpoint',
)
parser.add_argument(
'--inputs',
required=True,
nargs='+',
help='Input checkpoint file paths.',
)
parser.add_argument(
'--output',
required=True,
metavar='FILE',
help='Write the new checkpoint containing the averaged weights to this '
'path.',
)
args = parser.parse_args()
print(args)
new_state = average_checkpoints(args.inputs)
torch.save(new_state, args.output)
print('Finished writing averaged checkpoint to {}.'.format(args.output))
if __name__ == '__main__':
main()
...@@ -71,6 +71,7 @@ def main(args): ...@@ -71,6 +71,7 @@ def main(args):
# Train until the learning rate gets too small # Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr() lr = trainer.get_lr()
train_meter = StopwatchMeter() train_meter = StopwatchMeter()
train_meter.start() train_meter.start()
...@@ -79,6 +80,7 @@ def main(args): ...@@ -79,6 +80,7 @@ def main(args):
train(args, trainer, dataset, epoch, batch_offset) train(args, trainer, dataset, epoch, batch_offset)
# evaluate on validate set # evaluate on validate set
if epoch % args.validate_interval == 0:
for k, subset in enumerate(args.valid_subset.split(',')): for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, trainer, dataset, subset, epoch) val_loss = validate(args, trainer, dataset, subset, epoch)
if k == 0: if k == 0:
...@@ -88,9 +90,14 @@ def main(args): ...@@ -88,9 +90,14 @@ def main(args):
# save checkpoint # save checkpoint
if not args.no_save: if not args.no_save:
save_checkpoint(trainer, args, epoch, 0, val_loss) save_checkpoint(trainer, args, epoch, 0, val_loss)
else:
lr = trainer.lr_step(epoch)
epoch += 1 epoch += 1
batch_offset = 0 batch_offset = 0
if trainer.get_num_updates() >= max_update:
break
train_meter.stop() train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum)) print('| done training in {:.1f} seconds'.format(train_meter.sum))
...@@ -134,6 +141,7 @@ def train(args, trainer, dataset, epoch, batch_offset): ...@@ -134,6 +141,7 @@ def train(args, trainer, dataset, epoch, batch_offset):
meter.reset() meter.reset()
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
max_update = args.max_update or math.inf
for i, sample in enumerate(itr, start=batch_offset): for i, sample in enumerate(itr, start=batch_offset):
log_output = trainer.train_step(sample) log_output = trainer.train_step(sample)
...@@ -142,6 +150,9 @@ def train(args, trainer, dataset, epoch, batch_offset): ...@@ -142,6 +150,9 @@ def train(args, trainer, dataset, epoch, batch_offset):
for k, v in log_output.items(): for k, v in log_output.items():
if k in ['loss', 'nll_loss']: if k in ['loss', 'nll_loss']:
continue # these are already logged above continue # these are already logged above
if 'loss' in k:
extra_meters[k].update(v, log_output['sample_size'])
else:
extra_meters[k].update(v) extra_meters[k].update(v)
stats[k] = extra_meters[k].avg stats[k] = extra_meters[k].avg
progress.log(stats) progress.log(stats)
...@@ -150,9 +161,15 @@ def train(args, trainer, dataset, epoch, batch_offset): ...@@ -150,9 +161,15 @@ def train(args, trainer, dataset, epoch, batch_offset):
if i == batch_offset: if i == batch_offset:
# ignore the first mini-batch in words-per-second calculation # ignore the first mini-batch in words-per-second calculation
trainer.get_meter('wps').reset() trainer.get_meter('wps').reset()
if args.save_interval > 0 and trainer.get_num_updates() % args.save_interval == 0:
# save mid-epoch checkpoints
num_updates = trainer.get_num_updates()
if args.save_interval > 0 and num_updates > 0 and num_updates % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, i + 1) save_checkpoint(trainer, args, epoch, i + 1)
if num_updates >= max_update:
break
# log end-of-epoch stats # log end-of-epoch stats
stats = get_training_stats(trainer) stats = get_training_stats(trainer)
for k, meter in extra_meters.items(): for k, meter in extra_meters.items():
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import collections
import os
import tempfile
import unittest
import numpy as np
import torch
from scripts.average_checkpoints import average_checkpoints
class TestAverageCheckpoints(unittest.TestCase):
def test_average_checkpoints(self):
params_0 = collections.OrderedDict(
[
('a', torch.DoubleTensor([100.0])),
('b', torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])),
('c', torch.IntTensor([7, 8, 9])),
]
)
params_1 = collections.OrderedDict(
[
('a', torch.DoubleTensor([1.0])),
('b', torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])),
('c', torch.IntTensor([2, 2, 2])),
]
)
params_avg = collections.OrderedDict(
[
('a', torch.DoubleTensor([50.5])),
('b', torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])),
# We expect truncation for integer division
('c', torch.IntTensor([4, 5, 5])),
]
)
fd_0, path_0 = tempfile.mkstemp()
fd_1, path_1 = tempfile.mkstemp()
torch.save(collections.OrderedDict([('model', params_0)]), path_0)
torch.save(collections.OrderedDict([('model', params_1)]), path_1)
output = average_checkpoints([path_0, path_1])['model']
os.close(fd_0)
os.remove(path_0)
os.close(fd_1)
os.remove(path_1)
for (k_expected, v_expected), (k_out, v_out) in zip(
params_avg.items(), output.items()):
self.assertEqual(
k_expected, k_out, 'Key mismatch - expected {} but found {}. '
'(Expected list of keys: {} vs actual list of keys: {})'.format(
k_expected, k_out, params_avg.keys(), output.keys()
)
)
np.testing.assert_allclose(
v_expected.numpy(),
v_out.numpy(),
err_msg='Tensor value mismatch for key {}'.format(k_expected)
)
if __name__ == '__main__':
unittest.main()
...@@ -43,7 +43,7 @@ class TestBinaries(unittest.TestCase): ...@@ -43,7 +43,7 @@ class TestBinaries(unittest.TestCase):
data = 97 + torch.floor(26 * data).int() data = 97 + torch.floor(26 * data).int()
with open(os.path.join(data_dir, filename), 'w') as h: with open(os.path.join(data_dir, filename), 'w') as h:
offset = 0 offset = 0
for i in range(num_examples): for _ in range(num_examples):
ex_len = random.randint(1, maxlen) ex_len = random.randint(1, maxlen)
ex_str = ' '.join(map(chr, data[offset:offset+ex_len])) ex_str = ' '.join(map(chr, data[offset:offset+ex_len]))
print(ex_str, file=h) print(ex_str, file=h)
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import torch import torch
import unittest import unittest
......
...@@ -10,9 +10,7 @@ import copy ...@@ -10,9 +10,7 @@ import copy
import unittest import unittest
import torch import torch
from torch.autograd import Variable
from fairseq import utils
from fairseq.criterions.cross_entropy import CrossEntropyCriterion from fairseq.criterions.cross_entropy import CrossEntropyCriterion
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
...@@ -29,7 +27,7 @@ class TestLabelSmoothing(unittest.TestCase): ...@@ -29,7 +27,7 @@ class TestLabelSmoothing(unittest.TestCase):
self.assertEqual(self.d.pad(), 1) self.assertEqual(self.d.pad(), 1)
self.assertEqual(self.d.eos(), 2) self.assertEqual(self.d.eos(), 2)
self.assertEqual(self.d.unk(), 3) self.assertEqual(self.d.unk(), 3)
pad, eos, unk, w1, w2, w3 = 1, 2, 3, 4, 5, 6 pad, eos, unk, w1, w2, w3 = 1, 2, 3, 4, 5, 6 # noqa: F841
# build dataset # build dataset
self.data = [ self.data = [
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import argparse import argparse
import unittest import unittest
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import argparse import argparse
import unittest import unittest
...@@ -86,7 +85,7 @@ class TestSequenceScorer(unittest.TestCase): ...@@ -86,7 +85,7 @@ class TestSequenceScorer(unittest.TestCase):
model = test_utils.TestModel.build_model(args, d, d) model = test_utils.TestModel.build_model(args, d, d)
scorer = SequenceScorer([model]) scorer = SequenceScorer([model])
for id, src, ref, hypos in scorer.score_batched_itr(data_itr): for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr):
self.assertHypoTokens(hypos[0], data[id]['target']) self.assertHypoTokens(hypos[0], data[id]['target'])
self.assertHypoScore(hypos[0], expected_scores[id]) self.assertHypoScore(hypos[0], expected_scores[id])
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import unittest import unittest
...@@ -49,6 +48,38 @@ class TestUtils(unittest.TestCase): ...@@ -49,6 +48,38 @@ class TestUtils(unittest.TestCase):
), ),
) )
def test_make_positions(self):
pad = 1
left_pad_input = torch.LongTensor([
[9, 9, 9, 9, 9],
[1, 9, 9, 9, 9],
[1, 1, 1, 9, 9],
])
left_pad_output = torch.LongTensor([
[2, 3, 4, 5, 6],
[1, 2, 3, 4, 5],
[1, 1, 1, 2, 3],
])
right_pad_input = torch.LongTensor([
[9, 9, 9, 9, 9],
[9, 9, 9, 9, 1],
[9, 9, 1, 1, 1],
])
right_pad_output = torch.LongTensor([
[2, 3, 4, 5, 6],
[2, 3, 4, 5, 1],
[2, 3, 1, 1, 1],
])
self.assertAlmostEqual(
left_pad_output,
utils.make_positions(left_pad_input, pad, left_pad=True),
)
self.assertAlmostEqual(
right_pad_output,
utils.make_positions(right_pad_input, pad, left_pad=False),
)
def test_make_variable(self): def test_make_variable(self):
t = [{'k': torch.rand(5, 5)}] t = [{'k': torch.rand(5, 5)}]
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import torch import torch
from torch.autograd import Variable from torch.autograd import Variable
...@@ -137,10 +136,11 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder): ...@@ -137,10 +136,11 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
def get_normalized_probs(self, net_output, log_probs): def get_normalized_probs(self, net_output, log_probs):
# the decoder returns probabilities directly # the decoder returns probabilities directly
probs = net_output[0]
if log_probs: if log_probs:
return net_output.log() return probs.log()
else: else:
return net_output return probs
def max_positions(self): def max_positions(self):
return self.args.max_decoder_positions return self.args.max_decoder_positions
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment