Unverified Commit e5b3c1f4 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Merge pull request #54: Version 0.1.0 -> 0.2.0

Release notes:
- 5c7f4954: Added simple LSTM model with input feeding and attention
- 6e4b7e22: Refactored model definitions and incremental generation to be cleaner
- 7ae79c12: Split interactive generation out of generate.py and into a new binary: interactive.py
- 19a3865d: Subtle correctness fix in beam search decoder. Previously, for a beam size of k, we might emit a hypotheses
           if the <eos> was among the top 2*k candidates. Now we only emit hypotheses for which the <eos> is among the
           top-k candidates. This may subtly change generation results, and in the case of k=1 we will now produce
           strictly greedy outputs.
- 97d7fcb9: Fixed bug in padding direction, where previously we right-padded the source and left-padded the target. We
           now left-pad the source and right-pad the target. This should not effect existing trained models, but may
           change (usually improves) the quality of new models.
- f442f896: Add support for batching based on the number of sentences (`--max-sentences`) in addition to the number of
           tokens (`--max-tokens`). When batching by the number of sentences, one can optionally normalize the gradients
           by the number of sentences with `--sentence-avg` (the default is to normalize by the number of tokens).
- c6d6256b: Add `--log-format` option and JSON logger
parents ba5d7dcd 13a3c811
...@@ -18,6 +18,8 @@ def get_parser(desc): ...@@ -18,6 +18,8 @@ def get_parser(desc):
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
parser.add_argument('--log-interval', type=int, default=1000, metavar='N', parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
help='log progress every N updates (when progress bar is disabled)') help='log progress every N updates (when progress bar is disabled)')
parser.add_argument('--log-format', default='tqdm', help='log format to use',
choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--seed', default=1, type=int, metavar='N', parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed') help='pseudo random number generator seed')
return parser return parser
...@@ -33,8 +35,10 @@ def add_dataset_args(parser): ...@@ -33,8 +35,10 @@ def add_dataset_args(parser):
help='target language') help='target language')
group.add_argument('-j', '--workers', default=1, type=int, metavar='N', group.add_argument('-j', '--workers', default=1, type=int, metavar='N',
help='number of data loading workers (default: 1)') help='number of data loading workers (default: 1)')
group.add_argument('--max-positions', default=1024, type=int, metavar='N', group.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the sequence') help='max number of tokens in the source sequence')
group.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='Ignore too long or too short lines in valid and test set') help='Ignore too long or too short lines in valid and test set')
return group return group
...@@ -65,8 +69,13 @@ def add_optimization_args(parser): ...@@ -65,8 +69,13 @@ def add_optimization_args(parser):
help='weight decay') help='weight decay')
group.add_argument('--sample-without-replacement', default=0, type=int, metavar='N', group.add_argument('--sample-without-replacement', default=0, type=int, metavar='N',
help='If bigger than 0, use that number of mini-batches for each epoch,' help='If bigger than 0, use that number of mini-batches for each epoch,'
' where each sample is drawn randomly with replacement from the' ' where each sample is drawn randomly without replacement from the'
' dataset') ' dataset')
group.add_argument('--curriculum', default=0, type=int, metavar='N',
help='sort batches by source length for first N epochs')
group.add_argument('--sentence-avg', action='store_true',
help='normalize gradients by the number of sentences in a batch'
' (default is to normalize by number of tokens)')
return group return group
...@@ -110,8 +119,10 @@ def add_generation_args(parser): ...@@ -110,8 +119,10 @@ def add_generation_args(parser):
help='don\'t use BeamableMM in attention layers') help='don\'t use BeamableMM in attention layers')
group.add_argument('--lenpen', default=1, type=float, group.add_argument('--lenpen', default=1, type=float,
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences') help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
group.add_argument('--unk-replace-dict', default='', type=str, group.add_argument('--unkpen', default=0, type=float,
help='performs unk word replacement') help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--quiet', action='store_true', group.add_argument('--quiet', action='store_true',
help='Only print final scores') help='Only print final scores')
...@@ -147,6 +158,16 @@ def add_model_args(parser): ...@@ -147,6 +158,16 @@ def add_model_args(parser):
group.add_argument('--decoder-attention', type=str, metavar='EXPR', group.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]') help='decoder attention [True, ...]')
# Granular dropout settings for models that support them (e.g., LSTM):
group.add_argument('--encoder-dropout-in', type=float, metavar='D',
help='dropout probability for encoder input embedding')
group.add_argument('--encoder-dropout-out', type=float, metavar='D',
help='dropout probability for encoder output')
group.add_argument('--decoder-dropout-in', type=float, metavar='D',
help='dropout probability for decoder input embedding')
group.add_argument('--decoder-dropout-out', type=float, metavar='D',
help='dropout probability for decoder output')
# These arguments have default values independent of the model: # These arguments have default values independent of the model:
group.add_argument('--dropout', default=0.1, type=float, metavar='D', group.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability') help='dropout probability')
......
...@@ -7,35 +7,29 @@ ...@@ -7,35 +7,29 @@
# #
""" """
Progress bar wrapper around tqdm which handles non-TTY outputs. Wrapper around various loggers and progress bars (e.g., tqdm).
""" """
from collections import OrderedDict from collections import OrderedDict
import json
from numbers import Number from numbers import Number
import sys import sys
from tqdm import tqdm from tqdm import tqdm
from fairseq.meters import AverageMeter
class progress_bar(tqdm):
enabled = sys.stderr.isatty()
print_interval = 1000
def __new__(cls, *args, **kwargs): class progress_bar(object):
if cls.enabled: """Abstract class for progress bars."""
return tqdm(*args, **kwargs) def __init__(self, iterable, epoch=None, prefix=None):
else:
return simple_progress_bar(cls.print_interval, *args, **kwargs)
class simple_progress_bar(object):
"""A minimal replacement for tqdm in non-TTY environments."""
def __init__(self, print_interval, iterable, desc=None, *_args, **_kwargs):
super().__init__()
self.print_interval = print_interval
self.iterable = iterable self.iterable = iterable
self.desc = desc self.epoch = epoch
self.prefix = ''
if epoch is not None:
self.prefix += f'| epoch {epoch:03d}'
if prefix is not None:
self.prefix += f' | {prefix}'
def __enter__(self): def __enter__(self):
return self return self
...@@ -44,36 +38,149 @@ class simple_progress_bar(object): ...@@ -44,36 +38,149 @@ class simple_progress_bar(object):
return False return False
def __iter__(self): def __iter__(self):
size = len(self.iterable) raise NotImplementedError
for i, obj in enumerate(self.iterable):
yield obj def log(self, stats):
if i > 0 and i % self.print_interval == 0: """Log intermediate stats according to log_interval."""
desc = '' if self.desc is None else '{}: '.format(self.desc) raise NotImplementedError
msg = '{}{:5d} / {:d} {}\n'.format(desc, i, size, self.postfix)
sys.stdout.write(msg) def print(self, stats):
sys.stdout.flush() """Print end-of-epoch stats."""
raise NotImplementedError
def set_postfix(self, ordered_dict=None, refresh=True, **kwargs): def _str_commas(self, stats):
# Sort in alphabetical order to be more deterministic return ', '.join(key + '=' + stats[key].strip()
postfix = OrderedDict([] if ordered_dict is None else ordered_dict) for key in stats.keys())
for key in sorted(kwargs.keys()):
postfix[key] = kwargs[key] def _str_pipes(self, stats):
return ' | '.join(key + ' ' + stats[key].strip()
for key in stats.keys())
def _format_stats(self, stats):
postfix = OrderedDict(stats)
# Preprocess stats according to datatype # Preprocess stats according to datatype
for key in postfix.keys(): for key in postfix.keys():
# Number: limit the length of the string # Number: limit the length of the string
if isinstance(postfix[key], Number): if isinstance(postfix[key], Number):
postfix[key] = '{0:2.3g}'.format(postfix[key]) postfix[key] = '{:g}'.format(postfix[key])
# Meter: display both current and average value
elif isinstance(postfix[key], AverageMeter):
postfix[key] = '{:.2f} ({:.2f})'.format(
postfix[key].val, postfix[key].avg)
# Else for any other type, try to get the string conversion # Else for any other type, try to get the string conversion
elif not isinstance(postfix[key], str): elif not isinstance(postfix[key], str):
postfix[key] = str(postfix[key]) postfix[key] = str(postfix[key])
# Else if it's a string, don't need to preprocess anything # Else if it's a string, don't need to preprocess anything
# Stitch together to get the final postfix return postfix
self.postfix = ', '.join(key + '=' + postfix[key].strip()
for key in postfix.keys())
class json_progress_bar(progress_bar):
@classmethod """Log output in JSON format."""
def write(cls, s, file=None, end="\n"):
fp = file if file is not None else sys.stdout def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
fp.write(s) super().__init__(iterable, epoch, prefix)
fp.write(end) self.log_interval = log_interval
fp.flush() self.stats = None
def __iter__(self):
size = float(len(self.iterable))
for i, obj in enumerate(self.iterable):
yield obj
if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0:
update = self.epoch + float(i / size) if self.epoch is not None else None
stats = self._format_stats(self.stats, epoch=self.epoch, update=update)
print("sweep_log: " + json.dumps(stats))
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.stats = stats
def print(self, stats):
"""Print end-of-epoch stats."""
stats = self._format_stats(self.stats, epoch=self.epoch)
print("sweep_log: " + json.dumps(stats))
def _format_stats(self, stats, epoch=None, update=None):
postfix = OrderedDict()
if epoch is not None:
postfix['epoch'] = epoch
if update is not None:
postfix['update'] = update
# Preprocess stats according to datatype
for key in stats.keys():
# Meter: display both current and average value
if isinstance(stats[key], AverageMeter):
postfix[key] = stats[key].val
postfix[key + '_avg'] = stats[key].avg
else:
postfix[key] = stats[key]
return postfix
class noop_progress_bar(progress_bar):
"""No logging."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
def __iter__(self):
for obj in self.iterable:
yield obj
def log(self, stats):
"""Log intermediate stats according to log_interval."""
pass
def print(self, stats):
"""Print end-of-epoch stats."""
pass
class simple_progress_bar(progress_bar):
"""A minimal logger for non-TTY environments."""
def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000):
super().__init__(iterable, epoch, prefix)
self.log_interval = log_interval
self.stats = None
def __iter__(self):
size = len(self.iterable)
for i, obj in enumerate(self.iterable):
yield obj
if self.stats is not None and i > 0 and \
self.log_interval is not None and i % self.log_interval == 0:
postfix = self._str_commas(self.stats)
print(f'{self.prefix}: {i:5d} / {size:d} {postfix}')
sys.stdout.flush()
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.stats = self._format_stats(stats)
def print(self, stats):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
print(f'{self.prefix} | {postfix}')
sys.stdout.flush()
class tqdm_progress_bar(progress_bar):
"""Log to tqdm."""
def __init__(self, iterable, epoch=None, prefix=None):
super().__init__(iterable, epoch, prefix)
self.tqdm = tqdm(iterable, self.prefix, leave=False)
def __iter__(self):
return iter(self.tqdm)
def log(self, stats):
"""Log intermediate stats according to log_interval."""
self.tqdm.set_postfix(self._format_stats(stats), refresh=False)
def print(self, stats):
"""Print end-of-epoch stats."""
postfix = self._str_pipes(self._format_stats(stats))
self.tqdm.write(f'{self.tqdm.desc} | {postfix}')
...@@ -13,11 +13,13 @@ import torch.nn.functional as F ...@@ -13,11 +13,13 @@ import torch.nn.functional as F
from torch.autograd import Variable from torch.autograd import Variable
from fairseq import utils from fairseq import utils
from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object): class SequenceGenerator(object):
def __init__(self, models, beam_size=1, minlen=1, maxlen=200, def __init__(self, models, beam_size=1, minlen=1, maxlen=200,
stop_early=True, normalize_scores=True, len_penalty=1): stop_early=True, normalize_scores=True, len_penalty=1,
unk_penalty=0):
"""Generates translations of a given source sentence. """Generates translations of a given source sentence.
Args: Args:
...@@ -30,26 +32,26 @@ class SequenceGenerator(object): ...@@ -30,26 +32,26 @@ class SequenceGenerator(object):
""" """
self.models = models self.models = models
self.pad = models[0].dst_dict.pad() self.pad = models[0].dst_dict.pad()
self.unk = models[0].dst_dict.unk()
self.eos = models[0].dst_dict.eos() self.eos = models[0].dst_dict.eos()
assert all(m.dst_dict.pad() == self.pad for m in self.models[1:]) assert all(m.dst_dict.pad() == self.pad for m in self.models[1:])
assert all(m.dst_dict.unk() == self.unk for m in self.models[1:])
assert all(m.dst_dict.eos() == self.eos for m in self.models[1:]) assert all(m.dst_dict.eos() == self.eos for m in self.models[1:])
self.vocab_size = len(models[0].dst_dict) self.vocab_size = len(models[0].dst_dict)
self.beam_size = beam_size self.beam_size = beam_size
self.minlen = minlen self.minlen = minlen
self.maxlen = min(maxlen, *[m.decoder.max_positions() - self.pad - 2 for m in self.models]) self.maxlen = min(maxlen, *[m.max_decoder_positions() for m in self.models])
self.positions = torch.LongTensor(range(self.pad + 1, self.pad + self.maxlen + 2))
self.decoder_context = models[0].decoder.context_size()
self.stop_early = stop_early self.stop_early = stop_early
self.normalize_scores = normalize_scores self.normalize_scores = normalize_scores
self.len_penalty = len_penalty self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
def cuda(self): def cuda(self):
for model in self.models: for model in self.models:
model.cuda() model.cuda()
self.positions = self.positions.cuda()
return self return self
def generate_batched_itr(self, data_itr, maxlen_a=0.0, maxlen_b=200, def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda_device=None, timer=None): cuda_device=None, timer=None):
"""Iterate over a batched dataset and yield individual translations. """Iterate over a batched dataset and yield individual translations.
...@@ -59,9 +61,8 @@ class SequenceGenerator(object): ...@@ -59,9 +61,8 @@ class SequenceGenerator(object):
cuda_device: GPU on which to do generation. cuda_device: GPU on which to do generation.
timer: StopwatchMeter for timing generations. timer: StopwatchMeter for timing generations.
""" """
if maxlen_b is None:
def lstrip_pad(tensor): maxlen_b = self.maxlen
return tensor[tensor.eq(self.pad).sum():]
for sample in data_itr: for sample in data_itr:
s = utils.prepare_sample(sample, volatile=True, cuda_device=cuda_device) s = utils.prepare_sample(sample, volatile=True, cuda_device=cuda_device)
...@@ -69,25 +70,26 @@ class SequenceGenerator(object): ...@@ -69,25 +70,26 @@ class SequenceGenerator(object):
srclen = input['src_tokens'].size(1) srclen = input['src_tokens'].size(1)
if timer is not None: if timer is not None:
timer.start() timer.start()
hypos = self.generate(input['src_tokens'], input['src_positions'], hypos = self.generate(input['src_tokens'], beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b)) maxlen=int(maxlen_a*srclen + maxlen_b))
if timer is not None: if timer is not None:
timer.stop(s['ntokens']) timer.stop(s['ntokens'])
for i, id in enumerate(s['id']): for i, id in enumerate(s['id']):
src = input['src_tokens'].data[i, :] src = input['src_tokens'].data[i, :]
# remove padding from ref, which appears at the beginning # remove padding from ref
ref = lstrip_pad(s['target'].data[i, :]) ref = utils.rstrip_pad(s['target'].data[i, :], self.pad)
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_positions, beam_size=None, maxlen=None): def generate(self, src_tokens, beam_size=None, maxlen=None):
"""Generate a batch of translations.""" """Generate a batch of translations."""
with ExitStack() as stack: with ExitStack() as stack:
for model in self.models: for model in self.models:
if isinstance(model.decoder, FairseqIncrementalDecoder):
stack.enter_context(model.decoder.incremental_inference()) stack.enter_context(model.decoder.incremental_inference())
return self._generate(src_tokens, src_positions, beam_size, maxlen) return self._generate(src_tokens, beam_size, maxlen)
def _generate(self, src_tokens, src_positions, beam_size=None, maxlen=None): def _generate(self, src_tokens, beam_size=None, maxlen=None):
bsz = src_tokens.size(0) bsz, srclen = src_tokens.size()
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
# the max beam size is the dictionary size - 1, since we never select pad # the max beam size is the dictionary size - 1, since we never select pad
...@@ -97,11 +99,11 @@ class SequenceGenerator(object): ...@@ -97,11 +99,11 @@ class SequenceGenerator(object):
encoder_outs = [] encoder_outs = []
for model in self.models: for model in self.models:
model.eval() model.eval()
model.decoder.start_fresh_sequence(beam_size) # start a fresh sequence if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.set_beam_size(beam_size)
# compute the encoder output and expand to beam size # compute the encoder output for each beam
encoder_out = model.encoder(src_tokens, src_positions) encoder_out = model.encoder(src_tokens.repeat(1, beam_size).view(-1, srclen))
encoder_out = self._expand_encoder_out(encoder_out, beam_size)
encoder_outs.append(encoder_out) encoder_outs.append(encoder_out)
# initialize buffers # initialize buffers
...@@ -215,6 +217,7 @@ class SequenceGenerator(object): ...@@ -215,6 +217,7 @@ class SequenceGenerator(object):
# reorder decoder internal states based on the prev choice of beams # reorder decoder internal states based on the prev choice of beams
if reorder_state is not None: if reorder_state is not None:
for model in self.models: for model in self.models:
if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(reorder_state) model.decoder.reorder_incremental_state(reorder_state)
probs, avg_attn_scores = self._decode(tokens[:, :step+1], encoder_outs) probs, avg_attn_scores = self._decode(tokens[:, :step+1], encoder_outs)
...@@ -226,6 +229,7 @@ class SequenceGenerator(object): ...@@ -226,6 +229,7 @@ class SequenceGenerator(object):
# make probs contain cumulative scores for each hypothesis # make probs contain cumulative scores for each hypothesis
probs.add_(scores.view(-1, 1)) probs.add_(scores.view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad probs[:, self.pad] = -math.inf # never select pad
probs[:, self.unk] -= self.unk_penalty # apply unk penalty
# Record attention scores # Record attention scores
attn[:, :, step+1].copy_(avg_attn_scores) attn[:, :, step+1].copy_(avg_attn_scores)
...@@ -250,10 +254,11 @@ class SequenceGenerator(object): ...@@ -250,10 +254,11 @@ class SequenceGenerator(object):
eos_mask = cand_indices.eq(self.eos) eos_mask = cand_indices.eq(self.eos)
if step >= self.minlen: if step >= self.minlen:
eos_bbsz_idx = buffer('eos_bbsz_idx') eos_bbsz_idx = buffer('eos_bbsz_idx')
cand_bbsz_idx.masked_select(eos_mask, out=eos_bbsz_idx) # only consider eos when it's among the top beam_size indices
cand_bbsz_idx[:, :beam_size].masked_select(eos_mask[:, :beam_size], out=eos_bbsz_idx)
if eos_bbsz_idx.numel() > 0: if eos_bbsz_idx.numel() > 0:
eos_scores = buffer('eos_scores', type_of=scores) eos_scores = buffer('eos_scores', type_of=scores)
cand_scores.masked_select(eos_mask, out=eos_scores) cand_scores[:, :beam_size].masked_select(eos_mask[:, :beam_size], out=eos_scores)
num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx, eos_scores) num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx, eos_scores)
assert num_remaining_sent >= 0 assert num_remaining_sent >= 0
...@@ -314,19 +319,13 @@ class SequenceGenerator(object): ...@@ -314,19 +319,13 @@ class SequenceGenerator(object):
return finalized return finalized
def _decode(self, tokens, encoder_outs): def _decode(self, tokens, encoder_outs):
length = tokens.size(1) # wrap in Variable
# repeat the first length positions to fill batch
positions = self.positions[:length].view(1, length)
# wrap in Variables
tokens = Variable(tokens, volatile=True) tokens = Variable(tokens, volatile=True)
positions = Variable(positions, volatile=True)
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
decoder_out, attn = model.decoder(tokens, positions, encoder_out) decoder_out, attn = model.decoder(tokens, encoder_out)
probs = F.softmax(decoder_out[:, -1, :]).data probs = F.softmax(decoder_out[:, -1, :]).data
attn = attn[:, -1, :].data attn = attn[:, -1, :].data
if avg_probs is None or avg_attn is None: if avg_probs is None or avg_attn is None:
...@@ -340,14 +339,3 @@ class SequenceGenerator(object): ...@@ -340,14 +339,3 @@ class SequenceGenerator(object):
avg_attn.div_(len(self.models)) avg_attn.div_(len(self.models))
return avg_probs, avg_attn return avg_probs, avg_attn
def _expand_encoder_out(self, encoder_out, beam_size):
res = []
for tensor in encoder_out:
res.append(
# repeat beam_size times along second dimension
tensor.repeat(1, beam_size, *[1 for i in range(tensor.dim()-2)]) \
# then collapse into [bsz*beam, ...original dims...]
.view(-1, *tensor.size()[1:])
)
return tuple(res)
...@@ -14,7 +14,7 @@ import traceback ...@@ -14,7 +14,7 @@ import traceback
from torch.autograd import Variable from torch.autograd import Variable
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from fairseq import criterions, models from fairseq import criterions, data, models, progress_bar, tokenizer
def parse_args_and_arch(parser): def parse_args_and_arch(parser):
...@@ -30,11 +30,22 @@ def build_model(args, src_dict, dst_dict): ...@@ -30,11 +30,22 @@ def build_model(args, src_dict, dst_dict):
def build_criterion(args, src_dict, dst_dict): def build_criterion(args, src_dict, dst_dict):
padding_idx = dst_dict.pad()
if args.label_smoothing > 0: if args.label_smoothing > 0:
return criterions.LabelSmoothedCrossEntropyCriterion(args.label_smoothing, padding_idx) return criterions.LabelSmoothedCrossEntropyCriterion(args, dst_dict)
else: else:
return criterions.CrossEntropyCriterion(padding_idx) return criterions.CrossEntropyCriterion(args, dst_dict)
def build_progress_bar(args, iterator, epoch=None, prefix=None):
if args.log_format == 'json':
bar = progress_bar.json_progress_bar(iterator, epoch, prefix, args.log_interval)
elif args.log_format == 'none':
bar = progress_bar.noop_progress_bar(iterator, epoch, prefix)
elif args.log_format == 'tqdm':
bar = progress_bar.tqdm_progress_bar(iterator, epoch, prefix)
else:
bar = progress_bar.simple_progress_bar(iterator, epoch, prefix, args.log_interval)
return bar
def torch_persistent_save(*args, **kwargs): def torch_persistent_save(*args, **kwargs):
...@@ -122,7 +133,12 @@ def _upgrade_state_dict(state): ...@@ -122,7 +133,12 @@ def _upgrade_state_dict(state):
return state return state
def load_ensemble_for_inference(filenames, src_dict, dst_dict): def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_dir=None):
"""Load an ensemble of models for inference.
The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory.
"""
# load model architectures and weights # load model architectures and weights
states = [] states = []
for filename in filenames: for filename in filenames:
...@@ -132,6 +148,11 @@ def load_ensemble_for_inference(filenames, src_dict, dst_dict): ...@@ -132,6 +148,11 @@ def load_ensemble_for_inference(filenames, src_dict, dst_dict):
torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
) )
args = states[0]['args'] args = states[0]['args']
args = _upgrade_args(args)
if src_dict is None or dst_dict is None:
assert data_dir is not None
src_dict, dst_dict = data.load_dictionaries(data_dir, args.source_lang, args.target_lang)
# build ensemble # build ensemble
ensemble = [] ensemble = []
...@@ -139,7 +160,14 @@ def load_ensemble_for_inference(filenames, src_dict, dst_dict): ...@@ -139,7 +160,14 @@ def load_ensemble_for_inference(filenames, src_dict, dst_dict):
model = build_model(args, src_dict, dst_dict) model = build_model(args, src_dict, dst_dict)
model.load_state_dict(state['model']) model.load_state_dict(state['model'])
ensemble.append(model) ensemble.append(model)
return ensemble return ensemble, args
def _upgrade_args(args):
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.max_positions
args.max_target_positions = args.max_positions
return args
def prepare_sample(sample, volatile=False, cuda_device=None): def prepare_sample(sample, volatile=False, cuda_device=None):
...@@ -156,6 +184,58 @@ def prepare_sample(sample, volatile=False, cuda_device=None): ...@@ -156,6 +184,58 @@ def prepare_sample(sample, volatile=False, cuda_device=None):
'target': make_variable(sample['target']), 'target': make_variable(sample['target']),
'net_input': { 'net_input': {
key: make_variable(sample[key]) key: make_variable(sample[key])
for key in ['src_tokens', 'src_positions', 'input_tokens', 'input_positions'] for key in ['src_tokens', 'input_tokens']
}, },
} }
def load_align_dict(replace_unk):
if replace_unk is None:
align_dict = None
elif isinstance(replace_unk, str):
# Load alignment dictionary for unknown word replacement if it was passed as an argument.
align_dict = {}
with open(replace_unk, 'r') as f:
for line in f:
l = line.split()
align_dict[l[0]] = l[1]
else:
# No alignment dictionary provided but we still want to perform unknown word replacement by copying the
# original source word.
align_dict = {}
return align_dict
def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
# Tokens are strings here
hypo_tokens = tokenizer.tokenize_line(hypo_str)
# TODO: Very rare cases where the replacement is '<eos>' should be handled gracefully
src_tokens = tokenizer.tokenize_line(src_str) + ['<eos>']
for i, ht in enumerate(hypo_tokens):
if ht == unk:
src_token = src_tokens[alignment[i]]
# Either take the corresponding value in the aligned dictionary or just copy the original value.
hypo_tokens[i] = align_dict.get(src_token, src_token)
return ' '.join(hypo_tokens)
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dict, remove_bpe):
hypo_str = dst_dict.string(hypo_tokens, remove_bpe)
if align_dict is not None:
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, dst_dict.unk_string())
if align_dict is not None or remove_bpe is not None:
# Convert back to tokens for evaluating with unk replacement or without BPE
# Note that the dictionary can be modified inside the method.
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, dst_dict, add_if_not_exist=True)
return hypo_tokens, hypo_str, alignment
def lstrip_pad(tensor, pad):
return tensor[tensor.eq(pad).sum():]
def rstrip_pad(tensor, pad):
strip = tensor.eq(pad).sum()
if strip > 0:
return tensor[:-strip]
return tensor
...@@ -7,13 +7,10 @@ ...@@ -7,13 +7,10 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
# #
import sys
import torch import torch
from torch.autograd import Variable
from fairseq import bleu, data, options, tokenizer, utils from fairseq import bleu, data, options, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.progress_bar import progress_bar
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
...@@ -22,8 +19,6 @@ def main(): ...@@ -22,8 +19,6 @@ def main():
parser.add_argument('--path', metavar='FILE', required=True, action='append', parser.add_argument('--path', metavar='FILE', required=True, action='append',
help='path(s) to model file(s)') help='path(s) to model file(s)')
dataset_args = options.add_dataset_args(parser) dataset_args = options.add_dataset_args(parser)
dataset_args.add_argument('-i', '--interactive', action='store_true',
help='generate translations in interactive mode')
dataset_args.add_argument('--batch-size', default=32, type=int, metavar='N', dataset_args.add_argument('--batch-size', default=32, type=int, metavar='N',
help='batch size') help='batch size')
dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT', dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT',
...@@ -31,125 +26,99 @@ def main(): ...@@ -31,125 +26,99 @@ def main():
options.add_generation_args(parser) options.add_generation_args(parser)
args = parser.parse_args() args = parser.parse_args()
if args.no_progress_bar:
args.log_format = 'none'
print(args) print(args)
if args.no_progress_bar:
progress_bar.enabled = False
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset # Load dataset
dataset = data.load_with_check(args.data, [args.gen_subset], args.source_lang, args.target_lang) if args.replace_unk is None:
dataset = data.load_dataset(args.data, [args.gen_subset], args.source_lang, args.target_lang)
else:
dataset = data.load_raw_text_dataset(args.data, [args.gen_subset], args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None: if args.source_lang is None or args.target_lang is None:
# record inferred languages in args # record inferred languages in args
args.source_lang, args.target_lang = dataset.src, dataset.dst args.source_lang, args.target_lang = dataset.src, dataset.dst
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path))) print('| loading model(s) from {}'.format(', '.join(args.path)))
models = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict) models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
if not args.interactive:
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset]))) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Max positions is the model property but it is needed in data reader to be able to
# ignore too long sentences
args.max_positions = min(args.max_positions, *(m.decoder.max_positions() for m in models))
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
model.make_generation_fast_(not args.no_beamable_mm) model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
# Initialize generator # Initialize generator
translator = SequenceGenerator( translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop), models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
) unk_penalty=args.unkpen)
if use_cuda: if use_cuda:
translator.cuda() translator.cuda()
# Load alignment dictionary for unknown word replacement # Load alignment dictionary for unknown word replacement
align_dict = {} # (None if no unknown word replacement, empty if no path to align dictionary)
if args.unk_replace_dict != '': align_dict = utils.load_align_dict(args.replace_unk)
assert args.interactive, \
'Unknown word replacement requires access to original source and is only supported in interactive mode'
with open(args.unk_replace_dict, 'r') as f:
for line in f:
l = line.split()
align_dict[l[0]] = l[1]
def replace_unk(hypo_str, align_str, src, unk):
hypo_tokens = hypo_str.split()
src_tokens = tokenizer.tokenize_line(src)
align_idx = [int(i) for i in align_str.split()]
for i, ht in enumerate(hypo_tokens):
if ht == unk:
src_token = src_tokens[align_idx[i]]
if src_token in align_dict:
hypo_tokens[i] = align_dict[src_token]
else:
hypo_tokens[i] = src_token
return ' '.join(hypo_tokens)
def display_hypotheses(id, src, orig, ref, hypos):
if args.quiet:
return
id_str = '' if id is None else '-{}'.format(id)
src_str = dataset.src_dict.string(src, args.remove_bpe)
print('S{}\t{}'.format(id_str, src_str))
if orig is not None:
print('O{}\t{}'.format(id_str, orig.strip()))
if ref is not None:
print('T{}\t{}'.format(id_str, dataset.dst_dict.string(ref, args.remove_bpe, escape_unk=True)))
for hypo in hypos:
hypo_str = dataset.dst_dict.string(hypo['tokens'], args.remove_bpe)
align_str = ' '.join(map(str, hypo['alignment']))
if args.unk_replace_dict != '':
hypo_str = replace_unk(hypo_str, align_str, orig, dataset.dst_dict.unk_string())
print('H{}\t{}\t{}'.format(id_str, hypo['score'], hypo_str))
print('A{}\t{}'.format(id_str, align_str))
if args.interactive:
for line in sys.stdin:
tokens = tokenizer.Tokenizer.tokenize(line, dataset.src_dict, add_if_not_exist=False).long()
start = dataset.src_dict.pad() + 1
positions = torch.arange(start, start + len(tokens)).type_as(tokens)
if use_cuda:
positions = positions.cuda()
tokens = tokens.cuda()
translations = translator.generate(Variable(tokens.view(1, -1)), Variable(positions.view(1, -1)))
hypos = translations[0]
display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)])
else:
def maybe_remove_bpe(tokens, escape_unk=False):
"""Helper for removing BPE symbols from a hypothesis."""
if args.remove_bpe is None:
return tokens
assert (tokens == dataset.dst_dict.pad()).sum() == 0
hypo_minus_bpe = dataset.dst_dict.string(tokens, args.remove_bpe, escape_unk)
return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True)
# Generate and compute BLEU score # Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk()) scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size, max_positions = min(model.max_encoder_positions() for model in models)
max_positions=args.max_positions, itr = dataset.eval_dataloader(
args.gen_subset, max_sentences=args.batch_size, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test) skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
num_sentences = 0 num_sentences = 0
with progress_bar(itr, smoothing=0, leave=False) as t: with utils.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter() wps_meter = TimeMeter()
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
translations = translator.generate_batched_itr( translations = translator.generate_batched_itr(
t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b, t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
cuda_device=0 if use_cuda else None, timer=gen_timer) cuda_device=0 if use_cuda else None, timer=gen_timer)
for id, src, ref, hypos in translations: for sample_id, src_tokens, target_tokens, hypos in translations:
ref = ref.int().cpu() # Process input and ground truth
top_hypo = hypos[0]['tokens'].int().cpu() target_tokens = target_tokens.int().cpu()
scorer.add(maybe_remove_bpe(ref, escape_unk=True), maybe_remove_bpe(top_hypo)) # Either retrieve the original sentences or regenerate them from tokens.
display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)]) if align_dict is not None:
src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id)
wps_meter.update(src.size(0)) target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id)
t.set_postfix(wps='{:5d}'.format(round(wps_meter.avg)), refresh=False) else:
src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
target_str = dataset.dst_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str))
print('T-{}\t{}'.format(sample_id, target_str))
# Process top predictions
for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict,
dst_dict=dataset.dst_dict,
remove_bpe=args.remove_bpe)
if not args.quiet:
print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
print('A-{}\t{}'.format(sample_id, ' '.join(map(str, alignment))))
# Score only the top hypothesis
if i == 0:
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(target_str,
dataset.dst_dict,
add_if_not_exist=True)
scorer.add(target_tokens, hypo_tokens)
wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)})
num_sentences += 1 num_sentences += 1
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format( print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
......
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import sys
import torch
from torch.autograd import Variable
from fairseq import options, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator
def main():
parser = options.get_parser('Generation')
parser.add_argument('--path', metavar='FILE', required=True, action='append',
help='path(s) to model file(s)')
options.add_dataset_args(parser)
options.add_generation_args(parser)
args = parser.parse_args()
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models, model_args = utils.load_ensemble_for_inference(args.path, data_dir=args.data)
src_dict, dst_dict = models[0].src_dict, models[0].dst_dict
print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict)))
print('| [{}] dictionary: {} types'.format(model_args.target_lang, len(dst_dict)))
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
# Initialize generator
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen)
if use_cuda:
translator.cuda()
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
print('| Type the input sentence and press return:')
for src_str in sys.stdin:
src_str = src_str.strip()
src_tokens = tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
if use_cuda:
src_tokens = src_tokens.cuda()
translations = translator.generate(Variable(src_tokens.view(1, -1)))
hypos = translations[0]
print('O\t{}'.format(src_str))
# Process top predictions
for hypo in hypos[:min(len(hypos), args.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo['tokens'].int().cpu(),
src_str=src_str,
alignment=hypo['alignment'].int().cpu(),
align_dict=align_dict,
dst_dict=dst_dict,
remove_bpe=args.remove_bpe)
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
print('A\t{}'.format(' '.join(map(str, alignment))))
if __name__ == '__main__':
main()
...@@ -8,8 +8,9 @@ ...@@ -8,8 +8,9 @@
# #
import argparse import argparse
import os
from itertools import zip_longest from itertools import zip_longest
import os
import shutil
from fairseq import dictionary, indexed_dataset from fairseq import dictionary, indexed_dataset
from fairseq.tokenizer import Tokenizer from fairseq.tokenizer import Tokenizer
...@@ -28,23 +29,33 @@ def main(): ...@@ -28,23 +29,33 @@ def main():
help='map words appearing less than threshold times to unknown') help='map words appearing less than threshold times to unknown')
parser.add_argument('--thresholdsrc', metavar='N', default=0, type=int, parser.add_argument('--thresholdsrc', metavar='N', default=0, type=int,
help='map words appearing less than threshold times to unknown') help='map words appearing less than threshold times to unknown')
parser.add_argument('--tgtdict', metavar='FP', help='reuse given target dictionary')
parser.add_argument('--srcdict', metavar='FP', help='reuse given source dictionary')
parser.add_argument('--nwordstgt', metavar='N', default=-1, type=int, help='number of target words to retain') parser.add_argument('--nwordstgt', metavar='N', default=-1, type=int, help='number of target words to retain')
parser.add_argument('--nwordssrc', metavar='N', default=-1, type=int, help='number of source words to retain') parser.add_argument('--nwordssrc', metavar='N', default=-1, type=int, help='number of source words to retain')
parser.add_argument('--alignfile', metavar='ALIGN', default=None, help='an alignment file (optional)') parser.add_argument('--alignfile', metavar='ALIGN', default=None, help='an alignment file (optional)')
parser.add_argument('--output-format', metavar='FORMAT', default='binary', choices=['binary', 'raw'],
help='output format (optional)')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
os.makedirs(args.destdir, exist_ok=True) os.makedirs(args.destdir, exist_ok=True)
if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict)
else:
src_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.source_lang)) src_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.source_lang))
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)), src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)),
threshold=args.thresholdsrc, nwords=args.nwordssrc) threshold=args.thresholdsrc, nwords=args.nwordssrc)
if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict)
else:
tgt_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.target_lang)) tgt_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.target_lang))
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)), tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)),
threshold=args.thresholdtgt, nwords=args.nwordstgt) threshold=args.thresholdtgt, nwords=args.nwordstgt)
def make_dataset(input_prefix, output_prefix, lang): def make_binary_dataset(input_prefix, output_prefix, lang):
dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(lang))) dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(lang)))
print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1)) print('| [{}] Dictionary: {} types'.format(lang, len(dict) - 1))
...@@ -65,16 +76,24 @@ def main(): ...@@ -65,16 +76,24 @@ def main():
args.destdir, output_prefix, args.destdir, output_prefix,
args.source_lang, args.target_lang, lang)) args.source_lang, args.target_lang, lang))
make_dataset(args.trainpref, 'train', args.source_lang) def make_dataset(input_prefix, output_prefix, lang, output_format='binary'):
make_dataset(args.trainpref, 'train', args.target_lang) if output_format == 'binary':
make_binary_dataset(input_prefix, output_prefix, lang)
elif output_format == 'raw':
# Copy original text file to destination folder
output_text_file = os.path.join(args.destdir, f'{output_prefix}.{lang}')
shutil.copyfile('{}.{}'.format(input_prefix, lang), output_text_file)
make_dataset(args.trainpref, 'train', args.source_lang, args.output_format)
make_dataset(args.trainpref, 'train', args.target_lang, args.output_format)
for k, validpref in enumerate(args.validpref.split(',')): for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid' outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
make_dataset(validpref, outprefix, args.source_lang) make_dataset(validpref, outprefix, args.source_lang, args.output_format)
make_dataset(validpref, outprefix, args.target_lang) make_dataset(validpref, outprefix, args.target_lang, args.output_format)
for k, testpref in enumerate(args.testpref.split(',')): for k, testpref in enumerate(args.testpref.split(',')):
outprefix = 'test{}'.format(k) if k > 0 else 'test' outprefix = 'test{}'.format(k) if k > 0 else 'test'
make_dataset(testpref, outprefix, args.source_lang) make_dataset(testpref, outprefix, args.source_lang, args.output_format)
make_dataset(testpref, outprefix, args.target_lang) make_dataset(testpref, outprefix, args.target_lang, args.output_format)
print('| Wrote preprocessed data to {}'.format(args.destdir)) print('| Wrote preprocessed data to {}'.format(args.destdir))
if args.alignfile: if args.alignfile:
......
...@@ -54,7 +54,7 @@ class build_py_hook(build_py): ...@@ -54,7 +54,7 @@ class build_py_hook(build_py):
setup( setup(
name='fairseq', name='fairseq',
version='0.1.0', version='0.2.0',
description='Facebook AI Research Sequence-to-Sequence Toolkit', description='Facebook AI Research Sequence-to-Sequence Toolkit',
long_description=readme, long_description=readme,
license=license, license=license,
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import unittest import unittest
from fairseq.modules import ConvTBC from fairseq.modules import ConvTBC
import torch.nn as nn import torch.nn as nn
from torch.autograd import Variable, gradcheck from torch.autograd import Variable
class TestConvTBC(unittest.TestCase): class TestConvTBC(unittest.TestCase):
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
import torch import torch
import unittest import unittest
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropy from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedNLLLoss
from torch.autograd import Variable, gradcheck from torch.autograd import Variable, gradcheck
...@@ -21,7 +21,7 @@ class TestLabelSmoothing(unittest.TestCase): ...@@ -21,7 +21,7 @@ class TestLabelSmoothing(unittest.TestCase):
input = Variable(torch.randn(3, 5), requires_grad=True) input = Variable(torch.randn(3, 5), requires_grad=True)
idx = torch.rand(3) * 4 idx = torch.rand(3) * 4
target = Variable(idx.long()) target = Variable(idx.long())
criterion = LabelSmoothedCrossEntropy() criterion = LabelSmoothedNLLLoss()
self.assertTrue(gradcheck( self.assertTrue(gradcheck(
lambda x, y: criterion.apply(x, y, 0.1, 2, None), (input, target) lambda x, y: criterion.apply(x, y, 0.1, 2, None), (input, target)
)) ))
......
...@@ -15,7 +15,6 @@ import math ...@@ -15,7 +15,6 @@ import math
from fairseq import data, options, utils from fairseq import data, options, utils
from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter from fairseq.meters import AverageMeter, StopwatchMeter, TimeMeter
from fairseq.multiprocessing_trainer import MultiprocessingTrainer from fairseq.multiprocessing_trainer import MultiprocessingTrainer
from fairseq.progress_bar import progress_bar
def main(): def main():
...@@ -23,6 +22,8 @@ def main(): ...@@ -23,6 +22,8 @@ def main():
dataset_args = options.add_dataset_args(parser) dataset_args = options.add_dataset_args(parser)
dataset_args.add_argument('--max-tokens', default=6000, type=int, metavar='N', dataset_args.add_argument('--max-tokens', default=6000, type=int, metavar='N',
help='maximum number of tokens in a batch') help='maximum number of tokens in a batch')
dataset_args.add_argument('--max-sentences', type=int, metavar='N',
help='maximum number of sentences in a batch')
dataset_args.add_argument('--train-subset', default='train', metavar='SPLIT', dataset_args.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'], choices=['train', 'valid', 'test'],
help='data subset to use for training (train, valid, test)') help='data subset to use for training (train, valid, test)')
...@@ -34,38 +35,50 @@ def main(): ...@@ -34,38 +35,50 @@ def main():
options.add_model_args(parser) options.add_model_args(parser)
args = utils.parse_args_and_arch(parser) args = utils.parse_args_and_arch(parser)
print(args)
if args.no_progress_bar: if args.no_progress_bar and args.log_format == 'tqdm':
progress_bar.enabled = False args.log_format = 'simple'
progress_bar.print_interval = args.log_interval
if not os.path.exists(args.save_dir): if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir) os.makedirs(args.save_dir)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
# Load dataset # Load dataset
dataset = data.load_with_check(args.data, ['train', 'valid'], args.source_lang, args.target_lang) splits = ['train', 'valid']
if data.has_binary_files(args.data, splits):
dataset = data.load_dataset(args.data, splits, args.source_lang, args.target_lang)
else:
dataset = data.load_raw_text_dataset(args.data, splits, args.source_lang, args.target_lang)
if args.source_lang is None or args.target_lang is None: if args.source_lang is None or args.target_lang is None:
# record inferred languages in args, so that it's saved in checkpoints # record inferred languages in args, so that it's saved in checkpoints
args.source_lang, args.target_lang = dataset.src, dataset.dst args.source_lang, args.target_lang = dataset.src, dataset.dst
print(args)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict))) print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict))) print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
for split in ['train', 'valid']: for split in splits:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split]))) print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported') raise NotImplementedError('Training on CPU is not supported')
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
print('| using {} GPUs (with max tokens per GPU = {})'.format(num_gpus, args.max_tokens)) print('| using {} GPUs (with max tokens per GPU = {} and max sentences per GPU = {})'.format(
num_gpus, args.max_tokens, args.max_sentences))
# Build model and criterion # Build model and criterion
model = utils.build_model(args, dataset.src_dict, dataset.dst_dict) model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict) criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
# The max number of positions can be different for train and valid
# e.g., RNNs may support more positions at test time than seen in training
max_positions_train = (args.max_source_positions, args.max_target_positions)
max_positions_valid = (
min(args.max_source_positions, model.max_encoder_positions()),
min(args.max_target_positions, model.max_decoder_positions())
)
# Start multiprocessing # Start multiprocessing
trainer = MultiprocessingTrainer(args, model, criterion) trainer = MultiprocessingTrainer(args, model, criterion)
...@@ -89,11 +102,11 @@ def main(): ...@@ -89,11 +102,11 @@ def main():
train_meter.start() train_meter.start()
while lr > args.min_lr and epoch <= max_epoch: while lr > args.min_lr and epoch <= max_epoch:
# train for one epoch # train for one epoch
train(args, epoch, batch_offset, trainer, dataset, num_gpus) train(args, epoch, batch_offset, trainer, dataset, max_positions_train, num_gpus)
# evaluate on validate set # evaluate on validate set
for k, subset in enumerate(args.valid_subset.split(',')): for k, subset in enumerate(args.valid_subset.split(',')):
val_loss = validate(args, epoch, trainer, dataset, subset, num_gpus) val_loss = validate(args, epoch, trainer, dataset, max_positions_valid, subset, num_gpus)
if k == 0: if k == 0:
if not args.no_save: if not args.no_save:
# save checkpoint # save checkpoint
...@@ -112,19 +125,24 @@ def main(): ...@@ -112,19 +125,24 @@ def main():
def get_perplexity(loss): def get_perplexity(loss):
try: try:
return math.pow(2, loss) return round(math.pow(2, loss), 2)
except OverflowError: except OverflowError:
return float('inf') return float('inf')
def train(args, epoch, batch_offset, trainer, dataset, num_gpus): def train(args, epoch, batch_offset, trainer, dataset, max_positions, num_gpus):
"""Train the model for one epoch.""" """Train the model for one epoch."""
itr = dataset.dataloader(args.train_subset, num_workers=args.workers, seed = args.seed + epoch
max_tokens=args.max_tokens, seed=args.seed, epoch=epoch, torch.manual_seed(seed)
max_positions=args.max_positions, trainer.set_seed(seed)
itr = dataset.train_dataloader(
args.train_subset, num_workers=args.workers,
max_tokens=args.max_tokens, max_sentences=args.max_sentences,
max_positions=max_positions, seed=seed, epoch=epoch,
sample_without_replacement=args.sample_without_replacement, sample_without_replacement=args.sample_without_replacement,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test) sort_by_source_size=(epoch <= args.curriculum))
loss_meter = AverageMeter() loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch bsz_meter = AverageMeter() # sentences per batch
wpb_meter = AverageMeter() # words per batch wpb_meter = AverageMeter() # words per batch
...@@ -132,19 +150,17 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus): ...@@ -132,19 +150,17 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
clip_meter = AverageMeter() # % of updates clipped clip_meter = AverageMeter() # % of updates clipped
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d}'.format(epoch)
trainer.set_seed(args.seed + epoch)
lr = trainer.get_lr() lr = trainer.get_lr()
with progress_bar(itr, desc, leave=False) as t: with utils.build_progress_bar(args, itr, epoch) as t:
for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset): for i, sample in data.skip_group_enumerator(t, num_gpus, batch_offset):
loss_dict = trainer.train_step(sample) loss_dict = trainer.train_step(sample)
loss = loss_dict['loss'] loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample) ntokens = sum(s['ntokens'] for s in sample)
src_size = sum(s['src_tokens'].size(0) for s in sample) nsentences = sum(s['src_tokens'].size(0) for s in sample)
loss_meter.update(loss, ntokens) loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
bsz_meter.update(src_size) bsz_meter.update(nsentences)
wpb_meter.update(ntokens) wpb_meter.update(ntokens)
wps_meter.update(ntokens) wps_meter.update(ntokens)
clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0) clip_meter.update(1 if loss_dict['gnorm'] > args.clip_norm else 0)
...@@ -152,16 +168,16 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus): ...@@ -152,16 +168,16 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
extra_postfix = [] extra_postfix = []
for k, v in loss_dict.items(): for k, v in loss_dict.items():
extra_meters[k].update(v) extra_meters[k].update(v)
extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg))) extra_postfix.append((k, extra_meters[k].avg))
t.set_postfix(collections.OrderedDict([ t.log(collections.OrderedDict([
('loss', '{:.2f} ({:.2f})'.format(loss, loss_meter.avg)), ('loss', loss_meter),
('wps', '{:5d}'.format(round(wps_meter.avg))), ('wps', round(wps_meter.avg)),
('wpb', '{:5d}'.format(round(wpb_meter.avg))), ('wpb', round(wpb_meter.avg)),
('bsz', '{:5d}'.format(round(bsz_meter.avg))), ('bsz', round(bsz_meter.avg)),
('lr', lr), ('lr', lr),
('clip', '{:3.0f}%'.format(clip_meter.avg * 100)), ('clip', '{:.0%}'.format(clip_meter.avg)),
] + extra_postfix), refresh=False) ] + extra_postfix))
if i == 0: if i == 0:
# ignore the first mini-batch in words-per-second calculation # ignore the first mini-batch in words-per-second calculation
...@@ -169,17 +185,19 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus): ...@@ -169,17 +185,19 @@ def train(args, epoch, batch_offset, trainer, dataset, num_gpus):
if args.save_interval > 0 and (i + 1) % args.save_interval == 0: if args.save_interval > 0 and (i + 1) % args.save_interval == 0:
save_checkpoint(trainer, args, epoch, i + 1) save_checkpoint(trainer, args, epoch, i + 1)
fmt = desc + ' | train loss {:2.2f} | train ppl {:3.2f}'.format( t.print(collections.OrderedDict([
loss_meter.avg, get_perplexity(loss_meter.avg)) ('train loss', round(loss_meter.avg, 2)),
fmt += ' | s/checkpoint {:7d} | words/s {:6d} | words/batch {:6d}'.format( ('train ppl', get_perplexity(loss_meter.avg)),
round(wps_meter.elapsed_time), round(wps_meter.avg), round(wpb_meter.avg)) ('s/checkpoint', round(wps_meter.elapsed_time)),
fmt += ' | bsz {:5d} | lr {:0.6f} | clip {:3.0f}%'.format( ('words/s', round(wps_meter.avg)),
round(bsz_meter.avg), lr, clip_meter.avg * 100) ('words/batch', round(wpb_meter.avg)),
fmt += ''.join( ('bsz', round(bsz_meter.avg)),
' | {} {:.4f}'.format(k, meter.avg) ('lr', lr),
('clip', '{:3.0f}%'.format(clip_meter.avg * 100)),
] + [
(k, meter.avg)
for k, meter in extra_meters.items() for k, meter in extra_meters.items()
) ]))
t.write(fmt)
def save_checkpoint(trainer, args, epoch, batch_offset, val_loss): def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
...@@ -204,18 +222,20 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss): ...@@ -204,18 +222,20 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
trainer.save_checkpoint(last_filename, extra_state) trainer.save_checkpoint(last_filename, extra_state)
def validate(args, epoch, trainer, dataset, subset, ngpus): def validate(args, epoch, trainer, dataset, max_positions, subset, ngpus):
"""Evaluate the model on the validation set and return the average loss.""" """Evaluate the model on the validation set and return the average loss."""
itr = dataset.dataloader(subset, batch_size=None, itr = dataset.eval_dataloader(
max_tokens=args.max_tokens, subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences,
max_positions=args.max_positions, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test) skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
descending=True, # largest batch first to warm the caching allocator
)
loss_meter = AverageMeter() loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
desc = '| epoch {:03d} | valid on \'{}\' subset'.format(epoch, subset) prefix = 'valid on \'{}\' subset'.format(subset)
with progress_bar(itr, desc, leave=False) as t: with utils.build_progress_bar(args, itr, epoch, prefix) as t:
for _, sample in data.skip_group_enumerator(t, ngpus): for _, sample in data.skip_group_enumerator(t, ngpus):
loss_dict = trainer.valid_step(sample) loss_dict = trainer.valid_step(sample)
loss = loss_dict['loss'] loss = loss_dict['loss']
...@@ -227,23 +247,22 @@ def validate(args, epoch, trainer, dataset, subset, ngpus): ...@@ -227,23 +247,22 @@ def validate(args, epoch, trainer, dataset, subset, ngpus):
extra_postfix = [] extra_postfix = []
for k, v in loss_dict.items(): for k, v in loss_dict.items():
extra_meters[k].update(v) extra_meters[k].update(v)
extra_postfix.append((k, '{:.4f}'.format(extra_meters[k].avg))) extra_postfix.append((k, extra_meters[k].avg))
t.set_postfix(collections.OrderedDict([ t.log(collections.OrderedDict([
('loss', '{:.2f}'.format(loss_meter.avg)), ('valid loss', round(loss_meter.avg, 2)),
] + extra_postfix), refresh=False) ] + extra_postfix))
val_loss = loss_meter.avg t.print(collections.OrderedDict([
fmt = desc + ' | valid loss {:2.2f} | valid ppl {:3.2f}'.format( ('valid loss', round(loss_meter.avg, 2)),
val_loss, get_perplexity(val_loss)) ('valid ppl', get_perplexity(loss_meter.avg)),
fmt += ''.join( ] + [
' | {} {:.4f}'.format(k, meter.avg) (k, meter.avg)
for k, meter in extra_meters.items() for k, meter in extra_meters.items()
) ]))
t.write(fmt)
# update and return the learning rate # update and return the learning rate
return val_loss return loss_meter.avg
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment