Unverified Commit e5b3c1f4 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge pull request #54: Version 0.1.0 -> 0.2.0

Release notes:
- 5c7f4954: Added simple LSTM model with input feeding and attention
- 6e4b7e22: Refactored model definitions and incremental generation to be cleaner
- 7ae79c12: Split interactive generation out of generate.py and into a new binary: interactive.py
- 19a3865d: Subtle correctness fix in beam search decoder. Previously, for a beam size of k, we might emit a hypotheses
           if the <eos> was among the top 2*k candidates. Now we only emit hypotheses for which the <eos> is among the
           top-k candidates. This may subtly change generation results, and in the case of k=1 we will now produce
           strictly greedy outputs.
- 97d7fcb9: Fixed bug in padding direction, where previously we right-padded the source and left-padded the target. We
           now left-pad the source and right-pad the target. This should not effect existing trained models, but may
           change (usually improves) the quality of new models.
- f442f896: Add support for batching based on the number of sentences (`--max-sentences`) in addition to the number of
           tokens (`--max-tokens`). When batching by the number of sentences, one can optionally normalize the gradients
           by the number of sentences with `--sentence-avg` (the default is to normalize by the number of tokens).
- c6d6256b: Add `--log-format` option and JSON logger
parents ba5d7dcd 13a3c811
......@@ -18,6 +18,8 @@ def get_parser(desc):
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
help='log progress every N updates (when progress bar is disabled)')
parser.add_argument('--log-format', default='tqdm', help='log format to use',
choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
return parser
......@@ -33,8 +35,10 @@ def add_dataset_args(parser):
help='target language')
group.add_argument('-j', '--workers', default=1, type=int, metavar='N',
help='number of data loading workers (default: 1)')
group.add_argument('--max-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the sequence')
group.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence')
group.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='Ignore too long or too short lines in valid and test set')
return group
......@@ -65,8 +69,13 @@ def add_optimization_args(parser):
help='weight decay')
group.add_argument('--sample-without-replacement', default=0, type=int, metavar='N',
help='If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly with replacement from the'
' where each sample is drawn randomly without replacement from the'
' dataset')
group.add_argument('--curriculum', default=0, type=int, metavar='N',
help='sort batches by source length for first N epochs')
group.add_argument('--sentence-avg', action='store_true',
help='normalize gradients by the number of sentences in a batch'
' (default is to normalize by number of tokens)')
return group
......@@ -110,8 +119,10 @@ def add_generation_args(parser):
help='don\'t use BeamableMM in attention layers')
group.add_argument('--lenpen', default=1, type=float,
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--unk-replace-dict', default='', type=str,
help='performs unk word replacement')
group.add_argument('--unkpen', default=0, type=float,
help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--quiet', action='store_true',
help='Only print final scores')
......@@ -147,6 +158,16 @@ def add_model_args(parser):
group.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]')
# Granular dropout settings for models that support them (e.g., LSTM):
group.add_argument('--encoder-dropout-in', type=float, metavar='D',
help='dropout probability for encoder input embedding')
group.add_argument('--encoder-dropout-out', type=float, metavar='D',
help='dropout probability for encoder output')
group.add_argument('--decoder-dropout-in', type=float, metavar='D',
help='dropout probability for decoder input embedding')
group.add_argument('--decoder-dropout-out', type=float, metavar='D',
help='dropout probability for decoder output')
# These arguments have default values independent of the model:
group.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
......
......@@ -7,35 +7,29 @@
#
"""
Progress bar wrapper around tqdm which handles non-TTY outputs.
Wrapper around various loggers and progress bars (e.g., tqdm).
"""
from collections import OrderedDict
import json
from numbers import Number
import sys
from tqdm import tqdm
from fairseq.meters import AverageMeter
class progress_bar(tqdm):
enabled = sys.stderr.isatty()
print_interval = 1000
def __new__(cls, *args, **kwargs):
if cls.enabled:
return tqdm(*args, **kwargs)
else:
return simple_progress_bar(cls.print_interval, *args, **kwargs)
class simple_progress_bar(object):
"""A minimal replacement for tqdm in non-TTY environments."""
def __init__(self, print_interval, iterable, desc=None, *_args, **_kwargs):
super().__init__()
self.print_interval = print_interval
class progress_bar(object):
"""Abstract class for progress bars."""
def __init__(self, iterable, epoch=None, prefix=None):
self.iterable = iterable
self.desc = desc
self.epoch = epoch
self.prefix = ''
if epoch is not None:
self.prefix += f'| epoch {epoch:03d}'
if prefix is not None:
self.prefix += f' | {prefix}'
def __enter__(self):
return self
......@@ -44,36 +38,149 @@ class simple_progress_bar(object):
return False
def __iter__(self):
size = len(self.iterable)
for i, obj in enumerate(self.iterable):
yield obj
if i > 0 and i % self.print_interval == 0:
desc = '' if self.desc is None else '{}: '.format(self.desc)
msg = '{}{:5d} / {:d} {}\n'.format(desc, i, size, self.postfix)
sys.stdout.write(msg)
sys.stdout.flush()
raise NotImplementedError
def log(self, stats):
"""Log intermediate stats according to log_interval."""
raise NotImplementedError
def print(self, stats):
"""Print end-of-epoch stats."""
raise NotImplementedError
def set_postfix(self, ordered_dict=None, refresh=True, **kwargs):
# Sort in alphabetical order to be more deterministic
postfix = OrderedDict([] if ordered_dict is None else ordered_dict)
for key in sorted(kwargs.keys()):
postfix[key] = kwargs[key]
def _str_commas(self, stats):
return ', '.join(key + '=' + stats[key].strip()
for key in stats.keys())
def _str_pipes(self, stats):
return ' | '.join(key + ' ' + stats[key].strip()
for key in stats.keys())
def _format_stats(self, stats):
postfix = OrderedDict(stats)
# Preprocess stats according to datatype
for key in postfix.keys():
# Number: limit the length of the string
if isinstance(postfix[key], Number):
postfix[key] = '{0:2.3g}'.format(postfix[key])
postfix[key] = '{:g}'.format(postfix[key])
# Meter: display both current and average value
elif isinstance(postfix[key], AverageMeter):
postfix[key] = '{:.2f} ({:.2f})'.format(
postfix[key].val, postfix[key].avg)
# Else for any other type, try to get the string conversion
elif not isinstance(postfix[key], str):
postfix[key] = str(postfix[key])
# Else if it's a string, don't need to preprocess anything
# Stitch together to get the final postfix
self.postfix = ', '.join(key + '=' + postfix[key].strip()
for key in postfix.keys())
@classmethod
def write(cls, s, file=None, end="\n"):
fp = file if file is not None else sys.stdout
fp.write(s)
fp.write(end)
fp.flush()
return postfix
class json_progress_bar(progress_bar):
"""Log output in JSON format."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.stats = None
def __iter__(self):
size = float(len(self.iterable))
for i, obj in enumerate(self.iterable):
yield obj
if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0:
update = self.epoch + float(i / size) if self.epoch is not None else None
stats = self._format_stats(self.stats, epoch=self.epoch, update=update)
print("sweep_log: " + json.dumps(stats))
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.stats = stats
def print(self, stats):
"""Print end-of-epoch stats."""
stats = self._format_stats(self.stats, epoch=self.epoch)
print("sweep_log: " + json.dumps(stats))
def _format_stats(self, stats, epoch=None, update=None):
postfix = OrderedDict()
if epoch is not None:
postfix['epoch'] = epoch
if update is not None:
postfix['update'] = update
# Preprocess stats according to datatype
for key in stats.keys():
# Meter: display both current and average value
if isinstance(stats[key], AverageMeter):
postfix[key] = stats[key].val
postfix[key + '_avg'] = stats[key].avg
else:
postfix[key] = stats[key]
return postfix
class noop_progress_bar(progress_bar):
"""No logging."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
def __iter__(self):
for obj in self.iterable:
yield obj
def log(self, stats):
"""Log intermediate stats according to log_interval."""
pass
def print(self, stats):
"""Print end-of-epoch stats."""
pass
class simple_progress_bar(progress_bar):
"""A minimal logger for non-TTY environments."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.stats = None
def __iter__(self):
size = len(self.iterable)
for i, obj in enumerate(self.iterable):
yield obj
if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0:
postfix = self._str_commas(self.stats)
print(f'{self.prefix}: {i:5d} / {size:d} {postfix}')
sys.stdout.flush()
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.stats = self._format_stats(stats)
def print(self, stats):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
print(f'{self.prefix} | {postfix}')
sys.stdout.flush()
class tqdm_progress_bar(progress_bar):
"""Log to tqdm."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
self.tqdm = tqdm(iterable, self.prefix, leave=False)
def __iter__(self):
return iter(self.tqdm)
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
def print(self, stats):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
self.tqdm.write(f'{self.tqdm.desc} | {postfix}')
......@@ -13,11 +13,13 @@ import torch.nn.functional as F
from torch.autograd import Variable
from fairseq import utils
from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object):
def __init__(self, models, beam_size=1, minlen=1, maxlen=200,
stop_early=True, normalize_scores=True, len_penalty=1):
stop_early=True, normalize_scores=True, len_penalty=1,
unk_penalty=0):
"""Generates translations of a given source sentence.
Args:
......@@ -30,26 +32,26 @@ class SequenceGenerator(object):
"""
self.models = models
self.pad = models[0].dst_dict.pad()
self.unk = models[0].dst_dict.unk()
self.eos = models[0].dst_dict.eos()
assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
assert all(m.dst_dict.unk() == self.unk for m in self.models[1:])
assert all(m.dst_dict.eos() == self.eos for m in self.models[1:])
self.vocab_size = len(models[0].dst_dict)
self.beam_size = beam_size
self.minlen = minlen
self.maxlen = min(maxlen, *[m.decoder.max_positions() - self.pad - 2 for m in self.models])
self.positions = torch.LongTensor(range(self.pad + 1, self.pad + self.maxlen + 2))
self.decoder_context = models[0].decoder.context_size()
self.maxlen = min(maxlen, *[m.max_decoder_positions() for m in self.models])
self.stop_early = stop_early
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
def cuda(self):
for model in self.models:
model.cuda()
self.positions = self.positions.cuda()
return self
def generate_batched_itr(self, data_itr, maxlen_a=0.0, maxlen_b=200,
def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda_device=None, timer=None):
"""Iterate over a batched dataset and yield individual translations.
......@@ -59,9 +61,8 @@ class SequenceGenerator(object):
cuda_device: GPU on which to do generation.
timer: StopwatchMeter for timing generations.
"""
def lstrip_pad(tensor):
return tensor[tensor.eq(self.pad).sum():]
if maxlen_b is None:
maxlen_b = self.maxlen
for sample in data_itr:
s = utils.prepare_sample(sample, volatile=True, cuda_device=cuda_device)
......@@ -69,25 +70,26 @@ class SequenceGenerator(object):
srclen = input['src_tokens'].size(1)
if timer is not None:
timer.start()
hypos = self.generate(input['src_tokens'], input['src_positions'],
hypos = self.generate(input['src_tokens'], beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b))
if timer is not None:
timer.stop(s['ntokens'])
for i, id in enumerate(s['id']):
src = input['src_tokens'].data[i, :]
# remove padding from ref, which appears at the beginning
ref = lstrip_pad(s['target'].data[i, :])
# remove padding from ref
ref = utils.rstrip_pad(s['target'].data[i, :], self.pad)
yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_positions, beam_size=None, maxlen=None):
def generate(self, src_tokens, beam_size=None, maxlen=None):
"""Generate a batch of translations."""
with ExitStack() as stack:
for model in self.models:
stack.enter_context(model.decoder.incremental_inference())
return self._generate(src_tokens, src_positions, beam_size, maxlen)
if isinstance(model.decoder, FairseqIncrementalDecoder):
stack.enter_context(model.decoder.incremental_inference())
return self._generate(src_tokens, beam_size, maxlen)
def _generate(self, src_tokens, src_positions, beam_size=None, maxlen=None):
bsz = src_tokens.size(0)
def _generate(self, src_tokens, beam_size=None, maxlen=None):
bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
# the max beam size is the dictionary size - 1, since we never select pad
......@@ -97,11 +99,11 @@ class SequenceGenerator(object):
encoder_outs = []
for model in self.models:
model.eval()
model.decoder.start_fresh_sequence(beam_size) # start a fresh sequence
if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.set_beam_size(beam_size)
# compute the encoder output and expand to beam size
encoder_out = model.encoder(src_tokens, src_positions)
encoder_out = self._expand_encoder_out(encoder_out, beam_size)
# compute the encoder output for each beam
encoder_out = model.encoder(src_tokens.repeat(1, beam_size).view(-1, srclen))
encoder_outs.append(encoder_out)
# initialize buffers
......@@ -215,7 +217,8 @@ class SequenceGenerator(object):
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
for model in self.models:
model.decoder.reorder_incremental_state(reorder_state)
if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(reorder_state)
probs, avg_attn_scores = self._decode(tokens[:, :step+1], encoder_outs)
if step == 0:
......@@ -226,6 +229,7 @@ class SequenceGenerator(object):
# make probs contain cumulative scores for each hypothesis
probs.add_(scores.view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad
probs[:, self.unk] -= self.unk_penalty # apply unk penalty
# Record attention scores
attn[:, :, step+1].copy_(avg_attn_scores)
......@@ -250,10 +254,11 @@ class SequenceGenerator(object):
eos_mask = cand_indices.eq(self.eos)
if step >= self.minlen:
eos_bbsz_idx = buffer('eos_bbsz_idx')
cand_bbsz_idx.masked_select(eos_mask, out=eos_bbsz_idx)
# only consider eos when it's among the top beam_size indices
cand_bbsz_idx[:, :beam_size].masked_select(eos_mask[:, :beam_size], out=eos_bbsz_idx)
if eos_bbsz_idx.numel() > 0:
eos_scores = buffer('eos_scores', type_of=scores)
cand_scores.masked_select(eos_mask, out=eos_scores)
cand_scores[:, :beam_size].masked_select(eos_mask[:, :beam_size], out=eos_scores)
num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx, eos_scores)
assert num_remaining_sent >= 0
......@@ -314,19 +319,13 @@ class SequenceGenerator(object):
return finalized
def _decode(self, tokens, encoder_outs):
length = tokens.size(1)
# repeat the first length positions to fill batch
positions = self.positions[:length].view(1, length)
# wrap in Variables
# wrap in Variable
tokens = Variable(tokens, volatile=True)
positions = Variable(positions, volatile=True)
avg_probs = None
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
decoder_out, attn = model.decoder(tokens, positions, encoder_out)
decoder_out, attn = model.decoder(tokens, encoder_out)
probs = F.softmax(decoder_out[:, -1, :]).data
attn = attn[:, -1, :].data
if avg_probs is None or avg_attn is None:
......@@ -340,14 +339,3 @@ class SequenceGenerator(object):
avg_attn.div_(len(self.models))
return avg_probs, avg_attn
def _expand_encoder_out(self, encoder_out, beam_size):
res = []
for tensor in encoder_out:
res.append(
# repeat beam_size times along second dimension
tensor.repeat(1, beam_size, *[1 for i in range(tensor.dim()-2)]) \
# then collapse into [bsz*beam, ...original dims...]
.view(-1, *tensor.size()[1:])
)
return tuple(res)
......@@ -14,7 +14,7 @@ import traceback
from torch.autograd import Variable
from torch.serialization import default_restore_location
from fairseq import criterions, models
from fairseq import criterions, data, models, progress_bar, tokenizer
def parse_args_and_arch(parser):
......@@ -30,11 +30,22 @@ def build_model(args, src_dict, dst_dict):
def build_criterion(args, src_dict, dst_dict):
padding_idx = dst_dict.pad()
if args.label_smoothing > 0:
return criterions.LabelSmoothedCrossEntropyCriterion(args.label_smoothing, padding_idx)
return criterions.LabelSmoothedCrossEntropyCriterion(args, dst_dict)
else:
return criterions.CrossEntropyCriterion(padding_idx)
return criterions.CrossEntropyCriterion(args, dst_dict)
def build_progress_bar(args, iterator, epoch=None, prefix=None):
if args.log_format == 'json':
bar = progress_bar.json_progress_bar(iterator, epoch, prefix, args.log_interval)
elif args.log_format == 'none':
bar = progress_bar.noop_progress_bar(iterator, epoch, prefix)
elif args.log_format == 'tqdm':
bar = progress_bar.tqdm_progress_bar(iterator, epoch, prefix)
else:
bar = progress_bar.simple_progress_bar(iterator, epoch, prefix, args.log_interval)
return bar
def torch_persistent_save(*args, **kwargs):
......@@ -122,7 +133,12 @@ def _upgrade_state_dict(state):
return state
def load_ensemble_for_inference(filenames, src_dict, dst_dict):
def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_dir=None):
"""Load an ensemble of models for inference.
The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory.
"""
# load model architectures and weights
states = []
for filename in filenames:
......@@ -132,6 +148,11 @@ def load_ensemble_for_inference(filenames, src_dict, dst_dict):
torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
)
args = states[0]['args']
args = _upgrade_args(args)
if src_dict is None or dst_dict is None:
assert data_dir is not None
src_dict, dst_dict = data.load_dictionaries(data_dir, args.source_lang, args.target_lang)
# build ensemble
ensemble = []
......@@ -139,7 +160,14 @@ def load_ensemble_for_inference(filenames, src_dict, dst_dict):
model = build_model(args, src_dict, dst_dict)
model.load_state_dict(state['model'])
ensemble.append(model)
return ensemble
return ensemble, args
def _upgrade_args(args):
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.max_positions
args.max_target_positions = args.max_positions
return args
def prepare_sample(sample, volatile=False, cuda_device=None):
......@@ -156,6 +184,58 @@ def prepare_sample(sample, volatile=False, cuda_device=None):
'target': make_variable(sample['target']),
'net_input': {
key: make_variable(sample[key])
for key in ['src_tokens', 'src_positions', 'input_tokens', 'input_positions']
for key in ['src_tokens', 'input_tokens']
},
}
def load_align_dict(replace_unk):
if replace_unk is None:
align_dict = None
elif isinstance(replace_unk, str):
# Load alignment dictionary for unknown word replacement if it was passed as an argument.
align_dict = {}
with open(replace_unk, 'r') as f:
for line in f:
l = line.split()
align_dict[l[0]] = l[1]
else:
# No alignment dictionary provided but we still want to perform unknown word replacement by copying the
# original source word.
align_dict = {}
return align_dict
def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
# Tokens are strings here
hypo_tokens = tokenizer.tokenize_line(hypo_str)
# TODO: Very rare cases where the replacement is '<eos>' should be handled gracefully
src_tokens = tokenizer.tokenize_line(src_str) + ['<eos>']
for i, ht in enumerate(hypo_tokens):
if ht == unk:
src_token = src_tokens[alignment[i]]
# Either take the corresponding value in the aligned dictionary or just copy the original value.
hypo_tokens[i] = align_dict.get(src_token, src_token)
return ' '.join(hypo_tokens)
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dict, remove_bpe):
hypo_str = dst_dict.string(hypo_tokens, remove_bpe)
if align_dict is not None:
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, dst_dict.unk_string())
if align_dict is not None or remove_bpe is not None:
# Convert back to tokens for evaluating with unk replacement or without BPE
# Note that the dictionary can be modified inside the method.
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, dst_dict, add_if_not_exist=True)
return hypo_tokens, hypo_str, alignment
def lstrip_pad(tensor, pad):
return tensor[tensor.eq(pad).sum():]
def rstrip_pad(tensor, pad):
strip = tensor.eq(pad).sum()
if strip > 0:
return tensor[:-strip]
return tensor
......@@ -7,13 +7,10 @@
# can be found in the PATENTS file in the same directory.
#
import sys
import torch
from torch.autograd import Variable
from fairseq import bleu, data, options, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator
......@@ -22,8 +19,6 @@ def main():
parser.add_argument('--path', metavar='FILE', required=True, action='append',
help='path(s) to model file(s)')
dataset_args = options.add_dataset_args(parser)
dataset_args.add_argument('-i', '--interactive', action='store_true',
help='generate translations in interactive mode')
dataset_args.add_argument('--batch-size', default=32, type=int, metavar='N',
help='batch size')
dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT',
......@@ -31,130 +26,104 @@ def main():
options.add_generation_args(parser)
args = parser.parse_args()
if args.no_progress_bar:
args.log_format = 'none'
print(args)
if args.no_progress_bar:
progress_bar.enabled = False
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset
dataset = data.load_with_check(args.data, [args.gen_subset], args.source_lang, args.target_lang)
if args.replace_unk is None:
dataset = data.load_dataset(args.data, [args.gen_subset], args.source_lang, args.target_lang)
else:
dataset = data.load_raw_text_dataset(args.data, [args.gen_subset], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
if not args.interactive:
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Max positions is the model property but it is needed in data reader to be able to
# ignore too long sentences
args.max_positions = min(args.max_positions, *(m.decoder.max_positions() for m in models))
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(not args.no_beamable_mm)
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
# Initialize generator
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen
)
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen)
if use_cuda:
translator.cuda()
# Load alignment dictionary for unknown word replacement
align_dict = {}
if args.unk_replace_dict != '':
assert args.interactive, \
'Unknown word replacement requires access to original source and is only supported in interactive mode'
with open(args.unk_replace_dict, 'r') as f:
for line in f:
l = line.split()
align_dict[l[0]] = l[1]
def replace_unk(hypo_str, align_str, src, unk):
hypo_tokens = hypo_str.split()
src_tokens = tokenizer.tokenize_line(src)
align_idx = [int(i) for i in align_str.split()]
for i, ht in enumerate(hypo_tokens):
if ht == unk:
src_token = src_tokens[align_idx[i]]
if src_token in align_dict:
hypo_tokens[i] = align_dict[src_token]
else:
hypo_tokens[i] = src_token
return ' '.join(hypo_tokens)
def display_hypotheses(id, src, orig, ref, hypos):
if args.quiet:
return
id_str = '' if id is None else '-{}'.format(id)
src_str = dataset.src_dict.string(src, args.remove_bpe)
print('S{}\t{}'.format(id_str, src_str))
if orig is not None:
print('O{}\t{}'.format(id_str, orig.strip()))
if ref is not None:
print('T{}\t{}'.format(id_str, dataset.dst_dict.string(ref, args.remove_bpe, escape_unk=True)))
for hypo in hypos:
hypo_str = dataset.dst_dict.string(hypo['tokens'], args.remove_bpe)
align_str = ' '.join(map(str, hypo['alignment']))
if args.unk_replace_dict != '':
hypo_str = replace_unk(hypo_str, align_str, orig, dataset.dst_dict.unk_string())
print('H{}\t{}\t{}'.format(id_str, hypo['score'], hypo_str))
print('A{}\t{}'.format(id_str, align_str))
if args.interactive:
for line in sys.stdin:
tokens = tokenizer.Tokenizer.tokenize(line, dataset.src_dict, add_if_not_exist=False).long()
start = dataset.src_dict.pad() + 1
positions = torch.arange(start, start + len(tokens)).type_as(tokens)
if use_cuda:
positions = positions.cuda()
tokens = tokens.cuda()
translations = translator.generate(Variable(tokens.view(1, -1)), Variable(positions.view(1, -1)))
hypos = translations[0]
display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)])
else:
def maybe_remove_bpe(tokens, escape_unk=False):
"""Helper for removing BPE symbols from a hypothesis."""
if args.remove_bpe is None:
return tokens
assert (tokens == dataset.dst_dict.pad()).sum() == 0
hypo_minus_bpe = dataset.dst_dict.string(tokens, args.remove_bpe, escape_unk)
return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True)
# Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size,
max_positions=args.max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
num_sentences = 0
with progress_bar(itr, smoothing=0, leave=False) as t:
wps_meter = TimeMeter()
gen_timer = StopwatchMeter()
translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda_device=0 if use_cuda else None, timer=gen_timer)
for id, src, ref, hypos in translations:
ref = ref.int().cpu()
top_hypo = hypos[0]['tokens'].int().cpu()
scorer.add(maybe_remove_bpe(ref, escape_unk=True), maybe_remove_bpe(top_hypo))
display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)])
wps_meter.update(src.size(0))
t.set_postfix(wps='{:5d}'.format(round(wps_meter.avg)), refresh=False)
num_sentences += 1
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
# Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.eval_dataloader(
args.gen_subset, max_sentences=args.batch_size, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
num_sentences = 0
with utils.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()
gen_timer = StopwatchMeter()
translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda_device=0 if use_cuda else None, timer=gen_timer)
for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth
target_tokens = target_tokens.int().cpu()
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id)
target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id)
else:
src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
target_str = dataset.dst_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str))
print('T-{}\t{}'.format(sample_id, target_str))
# Process top predictions
for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict,
dst_dict=dataset.dst_dict,
remove_bpe=args.remove_bpe)
if not args.quiet:
print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
print('A-{}\t{}'.format(sample_id, ' '.join(map(str, alignment))))
# Score only the top hypothesis
if i == 0:
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(target_str,
dataset.dst_dict,
add_if_not_exist=True)
scorer.add(target_tokens, hypo_tokens)
wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)})
num_sentences += 1
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
if __name__ == '__main__':
......
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import sys
import torch
from torch.autograd import Variable
from fairseq import options, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator
def main():
parser = options.get_parser('Generation')
parser.add_argument('--path', metavar='FILE', required=True, action='append',
help='path(s) to model file(s)')
options.add_dataset_args(parser)
options.add_generation_args(parser)
args = parser.parse_args()
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models, model_args = utils.load_ensemble_for_inference(args.path, data_dir=args.data)
src_dict, dst_dict = models[0].src_dict, models[0].dst_dict
print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(model_args.target_lang, len(dst_dict)))
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
# Initialize generator
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen)
if use_cuda:
translator.cuda()
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
print('| Type the input sentence and press return:')
for src_str in sys.stdin:
src_str = src_str.strip()
src_tokens = tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
if use_cuda:
src_tokens = src_tokens.cuda()
translations = translator.generate(Variable(src_tokens.view(1, -1)))
hypos = translations[0]
print('O\t{}'.format(src_str))
# Process top predictions
for hypo in hypos[:min(len(hypos), args.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict,
dst_dict=dst_dict,
remove_bpe=args.remove_bpe)
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('A\t{}'.format(' '.join(map(str, alignment))))
if __name__ == '__main__':
main()
......@@ -8,8 +8,9 @@
#
import argparse
import os
from itertools import zip_longest
import os
import shutil
from fairseq import dictionary, indexed_dataset
from fairseq.tokenizer import Tokenizer
......@@ -28,23 +29,33 @@ def main():
help='map words appearing less than threshold times to unknown')
parser.add_argument('--thresholdsrc', metavar='N', default=0, type=int,
help='map words appearing less than threshold times to unknown')
parser.add_argument('--tgtdict', metavar='FP', help='reuse given target dictionary')
parser.add_argument('--srcdict', metavar='FP', help='reuse given source dictionary')
parser.add_argument('--nwordstgt', metavar='N', default=-1, type=int, help='number of target words to retain')
parser.add_argument('--nwordssrc', metavar='N', default=-1, type=int, help='number of source words to retain')
parser.add_argument('--alignfile', metavar='ALIGN', default=None, help='an alignment file (optional)')
parser.add_argument('--output-format', metavar='FORMAT', default='binary', choices=['binary', 'raw'],
help='output format (optional)')
args = parser.parse_args()
print(args)
os.makedirs(args.destdir, exist_ok=True)
src_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.source_lang))
if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict)
else:
src_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.source_lang))
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)),
threshold=args.thresholdsrc, nwords=args.nwordssrc)
tgt_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.target_lang))
if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict)
else:
tgt_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.target_lang))
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)),
threshold=args.thresholdtgt, nwords=args.nwordstgt)
def make_dataset(input_prefix, output_prefix, lang):
def make_binary_dataset(input_prefix, output_prefix, lang):
dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(lang)))
print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1))
......@@ -65,16 +76,24 @@ def main():
args.destdir, output_prefix,
args.source_lang, args.target_lang, lang))
make_dataset(args.trainpref, 'train', args.source_lang)
make_dataset(args.trainpref, 'train', args.target_lang)
def make_dataset(input_prefix, output_prefix, lang, output_format='binary'):
if output_format == 'binary':
make_binary_dataset(input_prefix, output_prefix, lang)
elif output_format == 'raw':
# Copy original text file to destination folder
output_text_file = os.path.join(args.destdir, f'{output_prefix}.{lang}')
shutil.copyfile('{}.{}'.format(input_prefix, lang), output_text_file)
make_dataset(args.trainpref, 'train', args.source_lang, args.output_format)
make_dataset(args.trainpref, 'train', args.target_lang, args.output_format)
for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
make_dataset(validpref, outprefix, args.source_lang)
make_dataset(validpref, outprefix, args.target_lang)
make_dataset(validpref, outprefix, args.source_lang, args.output_format)
make_dataset(validpref, outprefix, args.target_lang, args.output_format)
for k, testpref in enumerate(args.testpref.split(',')):
outprefix = 'test{}'.format(k) if k > 0 else 'test'
make_dataset(testpref, outprefix, args.source_lang)
make_dataset(testpref, outprefix, args.target_lang)
make_dataset(testpref, outprefix, args.source_lang, args.output_format)
make_dataset(testpref, outprefix, args.target_lang, args.output_format)
print('| Wrote preprocessed data to {}'.format(args.destdir))
if args.alignfile:
......
......@@ -54,7 +54,7 @@ class build_py_hook(build_py):
setup(
name='fairseq',
version='0.1.0',
version='0.2.0',
description='Facebook AI Research Sequence-to-Sequence Toolkit',
long_description=readme,
license=license,
......
......@@ -10,7 +10,7 @@ import torch
import unittest
from fairseq.modules import ConvTBC
import torch.nn as nn
from torch.autograd import Variable, gradcheck
from torch.autograd import Variable
class TestConvTBC(unittest.TestCase):
......@@ -31,7 +31,7 @@ class TestConvTBC(unittest.TestCase):
output1d = conv1d(input1d)
self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data)
grad_tbc = torch.randn(output_tbc.size())
grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous()
......
......@@ -8,7 +8,7 @@
import torch
import unittest
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropy
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedNLLLoss
from torch.autograd import Variable, gradcheck
......@@ -21,7 +21,7 @@ class TestLabelSmoothing(unittest.TestCase):
input = Variable(torch.randn(3, 5), requires_grad=True)
idx = torch.rand(3) * 4
target = Variable(idx.long())
criterion = LabelSmoothedCrossEntropy()
criterion = LabelSmoothedNLLLoss()
self.assertTrue(gradcheck(
lambda x, y: criterion.apply(x, y, 0.1, 2, None), (input, target)
))
......
......@@ -15,7 +15,6 @@ import math
from fairseq import data, options, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.multiprocessing_trainer import MultiprocessingTrainer
from fairseq.progress_bar import progress_bar
def main():
......@@ -23,6 +22,8 @@ def main():
dataset_args = options.add_dataset_args(parser)
dataset_args.add_argument('--max-tokens', default=6000, type=int, metavar='N',
help='maximum number of tokens in a batch')
dataset_args.add_argument('--max-sentences', type=int, metavar='N',
help='maximum number of sentences in a batch')
dataset_args.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'],
help='data subset to use for training (train, valid, test)')
......@@ -34,38 +35,50 @@ def main():
options.add_model_args(parser)
args = utils.parse_args_and_arch(parser)
print(args)
if args.no_progress_bar:
progress_bar.enabled = False
progress_bar.print_interval = args.log_interval
if args.no_progress_bar and args.log_format == 'tqdm':
args.log_format = 'simple'
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
torch.manual_seed(args.seed)
# Load dataset
dataset = data.load_with_check(args.data, ['train', 'valid'], args.source_lang, args.target_lang)
splits = ['train', 'valid']
if data.has_binary_files(args.data, splits):
dataset = data.load_dataset(args.data, splits, args.source_lang, args.target_lang)
else:
dataset = data.load_raw_text_dataset(args.data, splits, args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None:
# record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst
print(args)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in ['train', 'valid']:
for split in splits:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
num_gpus = torch.cuda.device_count()
print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens))
print('| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'.format(
num_gpus, args.max_tokens, args.max_sentences))
# Build model and criterion
model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
# The max number of positions can be different for train and valid
# e.g., RNNs may support more positions at test time than seen in training
max_positions_train = (args.max_source_positions, args.max_target_positions)
max_positions_valid = (
min(args.max_source_positions, model.max_encoder_positions()),
min(args.max_target_positions, model.max_decoder_positions())
)
# Start multiprocessing
trainer = MultiprocessingTrainer(args, model, criterion)
......@@ -89,11 +102,11 @@ def main():
train_meter.start()
while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch
train(args, epoch, batch_offset, trainer, dataset, num_gpus)
train(args, epoch, batch_offset, trainer, dataset, max_positions_train, num_gpus)
# evaluate on validate set
for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, epoch, trainer, dataset, subset, num_gpus)
val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset, num_gpus)
if k == 0:
if not args.no_save:
# save checkpoint
......@@ -112,19 +125,24 @@ def main():
def get_perplexity(loss):
try:
return math.pow(2, loss)
return round(math.pow(2, loss), 2)
except OverflowError:
return float('inf')
def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
"""Train the model for one epoch."""
itr = dataset.dataloader(args.train_subset, num_workers=args.workers,
max_tokens=args.max_tokens, seed=args.seed, epoch=epoch,
max_positions=args.max_positions,
sample_without_replacement=args.sample_without_replacement,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
seed = args.seed + epoch
torch.manual_seed(seed)
trainer.set_seed(seed)
itr = dataset.train_dataloader(
args.train_subset, num_workers=args.workers,
max_tokens=args.max_tokens, max_sentences=args.max_sentences,
max_positions=max_positions, seed=seed, epoch=epoch,
sample_without_replacement=args.sample_without_replacement,
sort_by_source_size=(epoch <= args.curriculum))
loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch
wpb_meter = AverageMeter() # words per batch
......@@ -132,19 +150,17 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
clip_meter = AverageMeter() # % of updates clipped
extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d}'.format(epoch)
trainer.set_seed(args.seed + epoch)
lr = trainer.get_lr()
with progress_bar(itr, desc, leave=False) as t:
with utils.build_progress_bar(args, itr, epoch) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
loss_dict = trainer.train_step(sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample)
src_size = sum(s['src_tokens'].size(0) for s in sample)
loss_meter.update(loss, ntokens)
bsz_meter.update(src_size)
nsentences = sum(s['src_tokens'].size(0) for s in sample)
loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
bsz_meter.update(nsentences)
wpb_meter.update(ntokens)
wps_meter.update(ntokens)
clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)
......@@ -152,16 +168,16 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
extra_postfix = []
for k, v in loss_dict.items():
extra_meters[k].update(v)
extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))
extra_postfix.append((k, extra_meters[k].avg))
t.set_postfix(collections.OrderedDict([
('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)),
('wps', '{:5d}'.format(round(wps_meter.avg))),
('wpb', '{:5d}'.format(round(wpb_meter.avg))),
('bsz', '{:5d}'.format(round(bsz_meter.avg))),
t.log(collections.OrderedDict([
('loss', loss_meter),
('wps', round(wps_meter.avg)),
('wpb', round(wpb_meter.avg)),
('bsz', round(bsz_meter.avg)),
('lr', lr),
('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
] + extra_postfix), refresh=False)
('clip', '{:.0%}'.format(clip_meter.avg)),
] + extra_postfix))
if i == 0:
# ignore the first mini-batch in words-per-second calculation
......@@ -169,17 +185,19 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, i + 1)
fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format(
loss_meter.avg, get_perplexity(loss_meter.avg))
fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'.format(
round(wps_meter.elapsed_time), round(wps_meter.avg), round(wpb_meter.avg))
fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'.format(
round(bsz_meter.avg), lr, clip_meter.avg * 100)
fmt += ''.join(
' | {} {:.4f}'.format(k, meter.avg)
t.print(collections.OrderedDict([
('train loss', round(loss_meter.avg, 2)),
('train ppl', get_perplexity(loss_meter.avg)),
('s/checkpoint', round(wps_meter.elapsed_time)),
('words/s', round(wps_meter.avg)),
('words/batch', round(wpb_meter.avg)),
('bsz', round(bsz_meter.avg)),
('lr', lr),
('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
] + [
(k, meter.avg)
for k, meter in extra_meters.items()
)
t.write(fmt)
]))
def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
......@@ -204,18 +222,20 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
trainer.save_checkpoint(last_filename, extra_state)
def validate(args, epoch, trainer, dataset, subset, ngpus):
def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss."""
itr = dataset.dataloader(subset, batch_size=None,
max_tokens=args.max_tokens,
max_positions=args.max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
itr = dataset.eval_dataloader(
subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences,
max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
descending=True, # largest batch first to warm the caching allocator
)
loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset)
with progress_bar(itr, desc, leave=False) as t:
prefix = 'valid on \'{}\' subset'.format(subset)
with utils.build_progress_bar(args, itr, epoch, prefix) as t:
for _, sample in data.skip_group_enumerator(t, ngpus):
loss_dict = trainer.valid_step(sample)
loss = loss_dict['loss']
......@@ -227,23 +247,22 @@ def validate(args, epoch, trainer, dataset, subset, ngpus):
extra_postfix = []
for k, v in loss_dict.items():
extra_meters[k].update(v)
extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg)))
extra_postfix.append((k, extra_meters[k].avg))
t.set_postfix(collections.OrderedDict([
('loss', '{:.2f}'.format(loss_meter.avg)),
] + extra_postfix), refresh=False)
t.log(collections.OrderedDict([
('valid loss', round(loss_meter.avg, 2)),
] + extra_postfix))
val_loss = loss_meter.avg
fmt = desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format(
val_loss, get_perplexity(val_loss))
fmt += ''.join(
' | {} {:.4f}'.format(k, meter.avg)
t.print(collections.OrderedDict([
('valid loss', round(loss_meter.avg, 2)),
('valid ppl', get_perplexity(loss_meter.avg)),
] + [
(k, meter.avg)
for k, meter in extra_meters.items()
)
t.write(fmt)
]))
# update and return the learning rate
return val_loss
return loss_meter.avg
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment