"third_party/HugeCTR/gpu_cache/src/static_table.hip" did not exist on "69a532c1aba4bac714cce6746b610ba2b20b835c"
Unverified Commit d3795d6c authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge internal changes (#136)

Changes:
- 7d19e36: Add `--sampling` flag to generate.py to sample instead of doing beam search
- c777340: Add `scripts/average_checkpoints.py` to average multiple checkpoints into a combined model
- 3ea882c: Add `--max-update` option to train.py to stop training after a given number of updates
- small bugfixes for distributed training, LSTM, inverse square root LR scheduler
parent 48836525
......@@ -60,7 +60,7 @@ class Trainer(object):
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
if self.args.distributed_rank == 0: # only save one checkpoint
if distributed_utils.is_master(self.args): # only save one checkpoint
utils.save_state(filename, self.args, self.model, self.criterion, self.optimizer,
self.lr_scheduler, self._num_updates, self._optim_history, extra_state)
......
......@@ -15,8 +15,6 @@ import traceback
from torch.autograd import Variable
from torch.serialization import default_restore_location
from fairseq import tokenizer
def torch_persistent_save(*args, **kwargs):
for i in range(3):
......@@ -116,11 +114,16 @@ def _upgrade_state_dict(state):
return state
def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_dir=None):
def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
data_dir=None, model_arg_overrides=None):
"""Load an ensemble of models for inference.
The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory.
model_arg_overrides allows you to pass a dictionary model_arg_overrides --
{'arg_name': arg} -- to override model args that were used during model
training
"""
from fairseq import data, models
......@@ -133,7 +136,8 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di
torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
)
args = states[0]['args']
args = _upgrade_args(args)
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
if src_dict is None or dst_dict is None:
assert data_dir is not None
......@@ -148,12 +152,10 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di
return ensemble, args
def _upgrade_args(args):
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.max_positions
args.max_target_positions = args.max_positions
if not hasattr(args, 'share_input_output_embed'):
args.share_input_output_embed = False
def _override_model_args(args, model_arg_overrides):
# Uses model_arg_overrides {'arg_name': arg} to override model args
for arg_name, arg_val in model_arg_overrides.items():
setattr(args, arg_name, arg_val)
return args
......@@ -247,6 +249,7 @@ def load_align_dict(replace_unk):
def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
from fairseq import tokenizer
# Tokens are strings here
hypo_tokens = tokenizer.tokenize_line(hypo_str)
# TODO: Very rare cases where the replacement is '<eos>' should be handled gracefully
......@@ -260,6 +263,7 @@ def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dict, remove_bpe):
from fairseq import tokenizer
hypo_str = dst_dict.string(hypo_tokens, remove_bpe)
if align_dict is not None:
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, dst_dict.unk_string())
......@@ -270,6 +274,27 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dic
return hypo_tokens, hypo_str, alignment
def make_positions(tensor, padding_idx, left_pad):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
max_pos = padding_idx + 1 + tensor.size(1)
if not hasattr(make_positions, 'range_buf'):
make_positions.range_buf = tensor.new()
make_positions.range_buf = make_positions.range_buf.type_as(tensor)
if make_positions.range_buf.numel() < max_pos:
torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf)
mask = tensor.ne(padding_idx)
positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return tensor.clone().masked_scatter_(mask, positions[mask])
def strip_pad(tensor, pad):
return tensor[tensor.ne(pad)]
......@@ -303,6 +328,7 @@ def convert_padding_direction(
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)
def item(tensor):
if hasattr(tensor, 'item'):
return tensor.item()
......
......@@ -15,7 +15,10 @@ from fairseq.sequence_scorer import SequenceScorer
def main(args):
assert args.path is not None, '--path required for generation!'
print(args)
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
use_cuda = torch.cuda.is_available() and not args.cpu
......@@ -77,7 +80,7 @@ def main(args):
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen)
unk_penalty=args.unkpen, sampling=args.sampling)
if use_cuda:
translator.cuda()
......
......@@ -16,6 +16,8 @@ from fairseq.sequence_generator import SequenceGenerator
def main(args):
print(args)
assert not args.sampling or args.nbest == args.beam, \
'--sampling requires --nbest to be equal to --beam'
use_cuda = torch.cuda.is_available() and not args.cpu
......
#!/usr/bin/env python3
import argparse
import collections
import torch
def average_checkpoints(inputs):
"""Loads checkpoints from inputs and returns a model with averaged weights.
Args:
inputs: An iterable of string paths of checkpoints to load from.
Returns:
A dict of string keys mapping to various values. The 'model' key
from the returned dict should correspond to an OrderedDict mapping
string parameter names to torch Tensors.
"""
params_dict = collections.OrderedDict()
params_keys = None
new_state = None
for f in inputs:
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
),
)
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
model_params = state['model']
model_params_keys = list(model_params.keys())
if params_keys is None:
params_keys = model_params_keys
elif params_keys != model_params_keys:
raise KeyError(
'For checkpoint {}, expected list of params: {}, '
'but found: {}'.format(f, params_keys, model_params_keys)
)
for k in params_keys:
if k not in params_dict:
params_dict[k] = []
params_dict[k].append(model_params[k])
averaged_params = collections.OrderedDict()
# v should be a list of torch Tensor.
for k, v in params_dict.items():
summed_v = None
for x in v:
summed_v = summed_v + x if summed_v is not None else x
averaged_params[k] = summed_v / len(v)
new_state['model'] = averaged_params
return new_state
def main():
parser = argparse.ArgumentParser(
description='Tool to average the params of input checkpoints to '
'produce a new checkpoint',
)
parser.add_argument(
'--inputs',
required=True,
nargs='+',
help='Input checkpoint file paths.',
)
parser.add_argument(
'--output',
required=True,
metavar='FILE',
help='Write the new checkpoint containing the averaged weights to this '
'path.',
)
args = parser.parse_args()
print(args)
new_state = average_checkpoints(args.inputs)
torch.save(new_state, args.output)
print('Finished writing averaged checkpoint to {}.'.format(args.output))
if __name__ == '__main__':
main()
......@@ -71,6 +71,7 @@ def main(args):
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
max_update = args.max_update or math.inf
lr = trainer.get_lr()
train_meter = StopwatchMeter()
train_meter.start()
......@@ -79,18 +80,24 @@ def main(args):
train(args, trainer, dataset, epoch, batch_offset)
# evaluate on validate set
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, trainer, dataset, subset, epoch)
if k == 0:
# only use first validation loss to update the learning schedule
lr = trainer.lr_step(epoch, val_loss)
# save checkpoint
if not args.no_save:
save_checkpoint(trainer, args, epoch, 0, val_loss)
if epoch % args.validate_interval == 0:
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, trainer, dataset, subset, epoch)
if k == 0:
# only use first validation loss to update the learning schedule
lr = trainer.lr_step(epoch, val_loss)
# save checkpoint
if not args.no_save:
save_checkpoint(trainer, args, epoch, 0, val_loss)
else:
lr = trainer.lr_step(epoch)
epoch += 1
batch_offset = 0
if trainer.get_num_updates() >= max_update:
break
train_meter.stop()
print('| done training in {:.1f} seconds'.format(train_meter.sum))
......@@ -134,6 +141,7 @@ def train(args, trainer, dataset, epoch, batch_offset):
meter.reset()
extra_meters = collections.defaultdict(lambda: AverageMeter())
max_update = args.max_update or math.inf
for i, sample in enumerate(itr, start=batch_offset):
log_output = trainer.train_step(sample)
......@@ -142,7 +150,10 @@ def train(args, trainer, dataset, epoch, batch_offset):
for k, v in log_output.items():
if k in ['loss', 'nll_loss']:
continue # these are already logged above
extra_meters[k].update(v)
if 'loss' in k:
extra_meters[k].update(v, log_output['sample_size'])
else:
extra_meters[k].update(v)
stats[k] = extra_meters[k].avg
progress.log(stats)
......@@ -150,9 +161,15 @@ def train(args, trainer, dataset, epoch, batch_offset):
if i == batch_offset:
# ignore the first mini-batch in words-per-second calculation
trainer.get_meter('wps').reset()
if args.save_interval > 0 and trainer.get_num_updates() % args.save_interval == 0:
# save mid-epoch checkpoints
num_updates = trainer.get_num_updates()
if args.save_interval > 0 and num_updates > 0 and num_updates % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, i + 1)
if num_updates >= max_update:
break
# log end-of-epoch stats
stats = get_training_stats(trainer)
for k, meter in extra_meters.items():
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import collections
import os
import tempfile
import unittest
import numpy as np
import torch
from scripts.average_checkpoints import average_checkpoints
class TestAverageCheckpoints(unittest.TestCase):
def test_average_checkpoints(self):
params_0 = collections.OrderedDict(
[
('a', torch.DoubleTensor([100.0])),
('b', torch.FloatTensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])),
('c', torch.IntTensor([7, 8, 9])),
]
)
params_1 = collections.OrderedDict(
[
('a', torch.DoubleTensor([1.0])),
('b', torch.FloatTensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])),
('c', torch.IntTensor([2, 2, 2])),
]
)
params_avg = collections.OrderedDict(
[
('a', torch.DoubleTensor([50.5])),
('b', torch.FloatTensor([[1.0, 1.5, 2.0], [2.5, 3.0, 3.5]])),
# We expect truncation for integer division
('c', torch.IntTensor([4, 5, 5])),
]
)
fd_0, path_0 = tempfile.mkstemp()
fd_1, path_1 = tempfile.mkstemp()
torch.save(collections.OrderedDict([('model', params_0)]), path_0)
torch.save(collections.OrderedDict([('model', params_1)]), path_1)
output = average_checkpoints([path_0, path_1])['model']
os.close(fd_0)
os.remove(path_0)
os.close(fd_1)
os.remove(path_1)
for (k_expected, v_expected), (k_out, v_out) in zip(
params_avg.items(), output.items()):
self.assertEqual(
k_expected, k_out, 'Key mismatch - expected {} but found {}. '
'(Expected list of keys: {} vs actual list of keys: {})'.format(
k_expected, k_out, params_avg.keys(), output.keys()
)
)
np.testing.assert_allclose(
v_expected.numpy(),
v_out.numpy(),
err_msg='Tensor value mismatch for key {}'.format(k_expected)
)
if __name__ == '__main__':
unittest.main()
......@@ -43,7 +43,7 @@ class TestBinaries(unittest.TestCase):
data = 97 + torch.floor(26 * data).int()
with open(os.path.join(data_dir, filename), 'w') as h:
offset = 0
for i in range(num_examples):
for _ in range(num_examples):
ex_len = random.randint(1, maxlen)
ex_str = ' '.join(map(chr, data[offset:offset+ex_len]))
print(ex_str, file=h)
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
import unittest
......
......@@ -10,9 +10,7 @@ import copy
import unittest
import torch
from torch.autograd import Variable
from fairseq import utils
from fairseq.criterions.cross_entropy import CrossEntropyCriterion
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
......@@ -29,7 +27,7 @@ class TestLabelSmoothing(unittest.TestCase):
self.assertEqual(self.d.pad(), 1)
self.assertEqual(self.d.eos(), 2)
self.assertEqual(self.d.unk(), 3)
pad, eos, unk, w1, w2, w3 = 1, 2, 3, 4, 5, 6
pad, eos, unk, w1, w2, w3 = 1, 2, 3, 4, 5, 6 # noqa: F841
# build dataset
self.data = [
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import argparse
import unittest
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import argparse
import unittest
......@@ -86,7 +85,7 @@ class TestSequenceScorer(unittest.TestCase):
model = test_utils.TestModel.build_model(args, d, d)
scorer = SequenceScorer([model])
for id, src, ref, hypos in scorer.score_batched_itr(data_itr):
for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr):
self.assertHypoTokens(hypos[0], data[id]['target'])
self.assertHypoScore(hypos[0], expected_scores[id])
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import unittest
......@@ -49,6 +48,38 @@ class TestUtils(unittest.TestCase):
),
)
def test_make_positions(self):
pad = 1
left_pad_input = torch.LongTensor([
[9, 9, 9, 9, 9],
[1, 9, 9, 9, 9],
[1, 1, 1, 9, 9],
])
left_pad_output = torch.LongTensor([
[2, 3, 4, 5, 6],
[1, 2, 3, 4, 5],
[1, 1, 1, 2, 3],
])
right_pad_input = torch.LongTensor([
[9, 9, 9, 9, 9],
[9, 9, 9, 9, 1],
[9, 9, 1, 1, 1],
])
right_pad_output = torch.LongTensor([
[2, 3, 4, 5, 6],
[2, 3, 4, 5, 1],
[2, 3, 1, 1, 1],
])
self.assertAlmostEqual(
left_pad_output,
utils.make_positions(left_pad_input, pad, left_pad=True),
)
self.assertAlmostEqual(
right_pad_output,
utils.make_positions(right_pad_input, pad, left_pad=False),
)
def test_make_variable(self):
t = [{'k': torch.rand(5, 5)}]
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
from torch.autograd import Variable
......@@ -137,10 +136,11 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
def get_normalized_probs(self, net_output, log_probs):
# the decoder returns probabilities directly
probs = net_output[0]
if log_probs:
return net_output.log()
return probs.log()
else:
return net_output
return probs
def max_positions(self):
return self.args.max_decoder_positions
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment