Commit d9f46c54 authored by Sergey Edunov's avatar Sergey Edunov
Browse files

Merge branch 'master' of github.com:facebookresearch/fairseq-py into prepare_wmt

parents 4185d3ed ee36a6f3
...@@ -173,4 +173,8 @@ def add_model_args(parser): ...@@ -173,4 +173,8 @@ def add_model_args(parser):
help='dropout probability') help='dropout probability')
group.add_argument('--label-smoothing', default=0, type=float, metavar='D', group.add_argument('--label-smoothing', default=0, type=float, metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing') help='epsilon for label smoothing, 0 means no label smoothing')
group.add_argument('--share-input-output-embed', action='store_true',
help="Share input and output embeddings, "
"requires --decoder-out-embed-dim and --decoder-embed-dim be equal ")
return group return group
...@@ -19,7 +19,7 @@ from fairseq.models import FairseqIncrementalDecoder ...@@ -19,7 +19,7 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object): class SequenceGenerator(object):
def __init__(self, models, beam_size=1, minlen=1, maxlen=200, def __init__(self, models, beam_size=1, minlen=1, maxlen=200,
stop_early=True, normalize_scores=True, len_penalty=1, stop_early=True, normalize_scores=True, len_penalty=1,
unk_penalty=0): unk_penalty=0, retain_dropout=False):
"""Generates translations of a given source sentence. """Generates translations of a given source sentence.
Args: Args:
...@@ -45,6 +45,7 @@ class SequenceGenerator(object): ...@@ -45,6 +45,7 @@ class SequenceGenerator(object):
self.normalize_scores = normalize_scores self.normalize_scores = normalize_scores
self.len_penalty = len_penalty self.len_penalty = len_penalty
self.unk_penalty = unk_penalty self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout
def cuda(self): def cuda(self):
for model in self.models: for model in self.models:
...@@ -65,19 +66,20 @@ class SequenceGenerator(object): ...@@ -65,19 +66,20 @@ class SequenceGenerator(object):
maxlen_b = self.maxlen maxlen_b = self.maxlen
for sample in data_itr: for sample in data_itr:
s = utils.prepare_sample(sample, volatile=True, cuda_device=cuda_device) s = utils.make_variable(sample, volatile=True, cuda_device=cuda_device)
input = s['net_input'] input = s['net_input']
srclen = input['src_tokens'].size(1) srclen = input['src_tokens'].size(1)
if timer is not None: if timer is not None:
timer.start() timer.start()
hypos = self.generate(input['src_tokens'], beam_size=beam_size, with utils.maybe_no_grad():
maxlen=int(maxlen_a*srclen + maxlen_b)) hypos = self.generate(input['src_tokens'], beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b))
if timer is not None: if timer is not None:
timer.stop(s['ntokens']) timer.stop(s['ntokens'])
for i, id in enumerate(s['id']): for i, id in enumerate(s['id'].data):
src = input['src_tokens'].data[i, :] src = input['src_tokens'].data[i, :]
# remove padding from ref # remove padding from ref
ref = utils.rstrip_pad(s['target'].data[i, :], self.pad) ref = utils.strip_pad(s['target'].data[i, :], self.pad)
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
def generate(self, src_tokens, beam_size=None, maxlen=None): def generate(self, src_tokens, beam_size=None, maxlen=None):
...@@ -98,7 +100,8 @@ class SequenceGenerator(object): ...@@ -98,7 +100,8 @@ class SequenceGenerator(object):
encoder_outs = [] encoder_outs = []
for model in self.models: for model in self.models:
model.eval() if not self.retain_dropout:
model.eval()
if isinstance(model.decoder, FairseqIncrementalDecoder): if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.set_beam_size(beam_size) model.decoder.set_beam_size(beam_size)
...@@ -269,7 +272,7 @@ class SequenceGenerator(object): ...@@ -269,7 +272,7 @@ class SequenceGenerator(object):
# and values < cand_size indicate candidate active hypos. # and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos # After, the min values per row are the top candidate active hypos
active_mask = buffer('active_mask') active_mask = buffer('active_mask')
torch.add((eos_mask*cand_size).type_as(cand_offsets), cand_offsets[:eos_mask.size(1)], torch.add(eos_mask.type_as(cand_offsets)*cand_size, cand_offsets[:eos_mask.size(1)],
out=active_mask) out=active_mask)
# get the top beam_size active hypotheses, which are just the hypos # get the top beam_size active hypotheses, which are just the hypos
...@@ -320,22 +323,27 @@ class SequenceGenerator(object): ...@@ -320,22 +323,27 @@ class SequenceGenerator(object):
def _decode(self, tokens, encoder_outs): def _decode(self, tokens, encoder_outs):
# wrap in Variable # wrap in Variable
tokens = Variable(tokens, volatile=True) tokens = utils.volatile_variable(tokens)
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
decoder_out, attn = model.decoder(tokens, encoder_out) with utils.maybe_no_grad():
probs = F.softmax(decoder_out[:, -1, :]).data decoder_out, attn = model.decoder(tokens, encoder_out)
attn = attn[:, -1, :].data probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data
if avg_probs is None or avg_attn is None: if avg_probs is None:
avg_probs = probs avg_probs = probs
avg_attn = attn
else: else:
avg_probs.add_(probs) avg_probs.add_(probs)
avg_attn.add_(attn) if attn is not None:
attn = attn[:, -1, :].data
if avg_attn is None:
avg_attn = attn
else:
avg_attn.add_(attn)
avg_probs.div_(len(self.models)) avg_probs.div_(len(self.models))
avg_probs.log_() avg_probs.log_()
avg_attn.div_(len(self.models)) if avg_attn is not None:
avg_attn.div_(len(self.models))
return avg_probs, avg_attn return avg_probs, avg_attn
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
# #
import contextlib
import logging import logging
import os import os
import torch import torch
...@@ -15,10 +16,11 @@ import sys ...@@ -15,10 +16,11 @@ import sys
from torch.autograd import Variable from torch.autograd import Variable
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from fairseq import criterions, data, models, progress_bar, tokenizer from fairseq import criterions, progress_bar, tokenizer
def parse_args_and_arch(parser): def parse_args_and_arch(parser):
from fairseq import models
args = parser.parse_args() args = parser.parse_args()
args.model = models.arch_model_map[args.arch] args.model = models.arch_model_map[args.arch]
args = getattr(models, args.model).parse_arch(args) args = getattr(models, args.model).parse_arch(args)
...@@ -26,6 +28,7 @@ def parse_args_and_arch(parser): ...@@ -26,6 +28,7 @@ def parse_args_and_arch(parser):
def build_model(args, src_dict, dst_dict): def build_model(args, src_dict, dst_dict):
from fairseq import models
assert hasattr(models, args.model), 'Missing model type' assert hasattr(models, args.model), 'Missing model type'
return getattr(models, args.model).build_model(args, src_dict, dst_dict) return getattr(models, args.model).build_model(args, src_dict, dst_dict)
...@@ -143,6 +146,8 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di ...@@ -143,6 +146,8 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None, data_di
The source and target dictionaries can be given explicitly, or loaded from The source and target dictionaries can be given explicitly, or loaded from
the `data_dir` directory. the `data_dir` directory.
""" """
from fairseq import data
# load model architectures and weights # load model architectures and weights
states = [] states = []
for filename in filenames: for filename in filenames:
...@@ -172,26 +177,48 @@ def _upgrade_args(args): ...@@ -172,26 +177,48 @@ def _upgrade_args(args):
if not hasattr(args, 'max_source_positions'): if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.max_positions args.max_source_positions = args.max_positions
args.max_target_positions = args.max_positions args.max_target_positions = args.max_positions
if not hasattr(args, 'share_input_output_embed'):
args.share_input_output_embed = False
return args return args
def prepare_sample(sample, volatile=False, cuda_device=None): def maybe_no_grad(condition=True):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
# no-op context manager
return contextlib.ExitStack()
def volatile_variable(*args, **kwargs):
if hasattr(torch, 'no_grad'):
# volatile has been deprecated, use the no_grad context manager instead
return Variable(*args, **kwargs)
else:
return Variable(*args, **kwargs, volatile=True)
def make_variable(sample, volatile=False, cuda_device=None):
"""Wrap input tensors in Variable class.""" """Wrap input tensors in Variable class."""
def make_variable(tensor): def _make_variable(maybe_tensor):
if cuda_device is not None and torch.cuda.is_available(): if torch.is_tensor(maybe_tensor):
tensor = tensor.cuda(async=True, device=cuda_device) if cuda_device is not None and torch.cuda.is_available():
return Variable(tensor, volatile=volatile) maybe_tensor = maybe_tensor.cuda(async=True, device=cuda_device)
if volatile:
return { return volatile_variable(maybe_tensor)
'id': sample['id'], else:
'ntokens': sample['ntokens'], return Variable(maybe_tensor)
'target': make_variable(sample['target']), elif isinstance(maybe_tensor, dict):
'net_input': { return {
key: make_variable(sample[key]) key: _make_variable(value)
for key in ['src_tokens', 'input_tokens'] for key, value in maybe_tensor.items()
}, }
} elif isinstance(maybe_tensor, list):
return [_make_variable(x) for x in maybe_tensor]
else:
return maybe_tensor
return _make_variable(sample)
def load_align_dict(replace_unk): def load_align_dict(replace_unk):
...@@ -236,11 +263,19 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dic ...@@ -236,11 +263,19 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, dst_dic
def lstrip_pad(tensor, pad): def lstrip_pad(tensor, pad):
return tensor[tensor.eq(pad).sum():] return tensor[tensor.eq(pad).long().sum():]
def rstrip_pad(tensor, pad): def rstrip_pad(tensor, pad):
strip = tensor.eq(pad).sum() strip = tensor.eq(pad).long().sum()
if strip > 0: if strip > 0:
return tensor[:-strip] return tensor[:-strip]
return tensor return tensor
def strip_pad(tensor, pad):
if tensor[0] == pad:
tensor = lstrip_pad(tensor, pad)
if tensor[-1] == pad:
tensor = rstrip_pad(tensor, pad)
return tensor
...@@ -23,6 +23,10 @@ def main(): ...@@ -23,6 +23,10 @@ def main():
help='batch size') help='batch size')
dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT', dataset_args.add_argument('--gen-subset', default='test', metavar='SPLIT',
help='data subset to generate (train, valid, test)') help='data subset to generate (train, valid, test)')
dataset_args.add_argument('--num-shards', default=1, type=int, metavar='N',
help='shard generation over N shards')
dataset_args.add_argument('--shard-id', default=0, type=int, metavar='ID',
help='id of the shard to generate (id < num_shards)')
options.add_generation_args(parser) options.add_generation_args(parser)
args = parser.parse_args() args = parser.parse_args()
...@@ -31,6 +35,8 @@ def main(): ...@@ -31,6 +35,8 @@ def main():
print(args) print(args)
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
if hasattr(torch, 'set_grad_enabled'):
torch.set_grad_enabled(False)
# Load dataset # Load dataset
if args.replace_unk is None: if args.replace_unk is None:
...@@ -72,6 +78,10 @@ def main(): ...@@ -72,6 +78,10 @@ def main():
itr = dataset.eval_dataloader( itr = dataset.eval_dataloader(
args.gen_subset, max_sentences=args.batch_size, max_positions=max_positions, args.gen_subset, max_sentences=args.batch_size, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test) skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
if args.num_shards > 1:
if args.shard_id < 0 or args.shard_id >= args.num_shards:
raise ValueError('--shard-id must be between 0 and num_shards')
itr = data.sharded_iterator(itr, args.num_shards, args.shard_id)
num_sentences = 0 num_sentences = 0
with utils.build_progress_bar(args, itr) as t: with utils.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter() wps_meter = TimeMeter()
......
...@@ -54,7 +54,7 @@ class build_py_hook(build_py): ...@@ -54,7 +54,7 @@ class build_py_hook(build_py):
setup( setup(
name='fairseq', name='fairseq',
version='0.2.0', version='0.3.0',
description='Facebook AI Research Sequence-to-Sequence Toolkit', description='Facebook AI Research Sequence-to-Sequence Toolkit',
long_description=readme, long_description=readme,
license=license, license=license,
......
...@@ -30,6 +30,8 @@ def main(): ...@@ -30,6 +30,8 @@ def main():
dataset_args.add_argument('--valid-subset', default='valid', metavar='SPLIT', dataset_args.add_argument('--valid-subset', default='valid', metavar='SPLIT',
help='comma separated list of data subsets ' help='comma separated list of data subsets '
' to use for validation (train, valid, valid1,test, test1)') ' to use for validation (train, valid, valid1,test, test1)')
dataset_args.add_argument('--max-sentences-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch')
options.add_optimization_args(parser) options.add_optimization_args(parser)
options.add_checkpoint_args(parser) options.add_checkpoint_args(parser)
options.add_model_args(parser) options.add_model_args(parser)
...@@ -39,6 +41,9 @@ def main(): ...@@ -39,6 +41,9 @@ def main():
if args.no_progress_bar and args.log_format is None: if args.no_progress_bar and args.log_format is None:
args.log_format = 'simple' args.log_format = 'simple'
if args.max_sentences_valid is None:
args.max_sentences_valid = args.max_sentences
if not os.path.exists(args.save_dir): if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir) os.makedirs(args.save_dir)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
...@@ -70,14 +75,15 @@ def main(): ...@@ -70,14 +75,15 @@ def main():
model = utils.build_model(args, dataset.src_dict, dataset.dst_dict) model = utils.build_model(args, dataset.src_dict, dataset.dst_dict)
criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict) criterion = utils.build_criterion(args, dataset.src_dict, dataset.dst_dict)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__)) print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {}'.format(sum(p.data.numel() for p in model.parameters())))
# The max number of positions can be different for train and valid # The max number of positions can be different for train and valid
# e.g., RNNs may support more positions at test time than seen in training # e.g., RNNs may support more positions at test time than seen in training
max_positions_train = (args.max_source_positions, args.max_target_positions) max_positions_train = (
max_positions_valid = (
min(args.max_source_positions, model.max_encoder_positions()), min(args.max_source_positions, model.max_encoder_positions()),
min(args.max_target_positions, model.max_decoder_positions()) min(args.max_target_positions, model.max_decoder_positions())
) )
max_positions_valid = (model.max_encoder_positions(), model.max_decoder_positions())
# Start multiprocessing # Start multiprocessing
trainer = MultiprocessingTrainer(args, model, criterion) trainer = MultiprocessingTrainer(args, model, criterion)
...@@ -144,6 +150,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions): ...@@ -144,6 +150,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
sample_without_replacement=args.sample_without_replacement, sample_without_replacement=args.sample_without_replacement,
sort_by_source_size=(epoch <= args.curriculum)) sort_by_source_size=(epoch <= args.curriculum))
loss_meter = AverageMeter() loss_meter = AverageMeter()
nll_loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch bsz_meter = AverageMeter() # sentences per batch
wpb_meter = AverageMeter() # words per batch wpb_meter = AverageMeter() # words per batch
wps_meter = TimeMeter() # words per second wps_meter = TimeMeter() # words per second
...@@ -158,7 +165,12 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions): ...@@ -158,7 +165,12 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
del loss_dict['loss'] # don't include in extra_meters or extra_postfix del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample) ntokens = sum(s['ntokens'] for s in sample)
nsentences = sum(s['src_tokens'].size(0) for s in sample)
if 'nll_loss' in loss_dict:
nll_loss = loss_dict['nll_loss']
nll_loss_meter.update(nll_loss, ntokens)
nsentences = sum(s['net_input']['src_tokens'].size(0) for s in sample)
loss_meter.update(loss, nsentences if args.sentence_avg else ntokens) loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
bsz_meter.update(nsentences) bsz_meter.update(nsentences)
wpb_meter.update(ntokens) wpb_meter.update(ntokens)
...@@ -187,7 +199,9 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions): ...@@ -187,7 +199,9 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
t.print(collections.OrderedDict([ t.print(collections.OrderedDict([
('train loss', round(loss_meter.avg, 2)), ('train loss', round(loss_meter.avg, 2)),
('train ppl', get_perplexity(loss_meter.avg)), ('train ppl', get_perplexity(nll_loss_meter.avg
if nll_loss_meter.count > 0
else loss_meter.avg)),
('s/checkpoint', round(wps_meter.elapsed_time)), ('s/checkpoint', round(wps_meter.elapsed_time)),
('words/s', round(wps_meter.avg)), ('words/s', round(wps_meter.avg)),
('words/batch', round(wpb_meter.avg)), ('words/batch', round(wpb_meter.avg)),
...@@ -217,6 +231,10 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss): ...@@ -217,6 +231,10 @@ def save_checkpoint(trainer, args, epoch, batch_offset, val_loss):
save_checkpoint.best = val_loss save_checkpoint.best = val_loss
best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt') best_filename = os.path.join(args.save_dir, 'checkpoint_best.pt')
trainer.save_checkpoint(best_filename, extra_state) trainer.save_checkpoint(best_filename, extra_state)
elif not args.no_epoch_checkpoints:
epoch_filename = os.path.join(
args.save_dir, 'checkpoint{}_{}.pt'.format(epoch, batch_offset))
trainer.save_checkpoint(epoch_filename, extra_state)
last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt') last_filename = os.path.join(args.save_dir, 'checkpoint_last.pt')
trainer.save_checkpoint(last_filename, extra_state) trainer.save_checkpoint(last_filename, extra_state)
...@@ -226,22 +244,27 @@ def validate(args, epoch, trainer, dataset, max_positions, subset): ...@@ -226,22 +244,27 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
"""Evaluate the model on the validation set and return the average loss.""" """Evaluate the model on the validation set and return the average loss."""
itr = dataset.eval_dataloader( itr = dataset.eval_dataloader(
subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, subset, max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid,
max_positions=max_positions, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test, skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test,
descending=True, # largest batch first to warm the caching allocator descending=True, # largest batch first to warm the caching allocator
) )
loss_meter = AverageMeter() loss_meter = AverageMeter()
nll_loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter()) extra_meters = collections.defaultdict(lambda: AverageMeter())
prefix = 'valid on \'{}\' subset'.format(subset) prefix = 'valid on \'{}\' subset'.format(subset)
with utils.build_progress_bar(args, itr, epoch, prefix) as t: with utils.build_progress_bar(args, itr, epoch, prefix) as t:
for _, sample in data.skip_group_enumerator(t, args.num_gpus): for _, sample in data.skip_group_enumerator(t, args.num_gpus):
loss_dict = trainer.valid_step(sample) loss_dict = trainer.valid_step(sample)
ntokens = sum(s['ntokens'] for s in sample)
loss = loss_dict['loss'] loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix del loss_dict['loss'] # don't include in extra_meters or extra_postfix
ntokens = sum(s['ntokens'] for s in sample) if 'nll_loss' in loss_dict:
nll_loss = loss_dict['nll_loss']
nll_loss_meter.update(nll_loss, ntokens)
loss_meter.update(loss, ntokens) loss_meter.update(loss, ntokens)
extra_postfix = [] extra_postfix = []
...@@ -255,7 +278,9 @@ def validate(args, epoch, trainer, dataset, max_positions, subset): ...@@ -255,7 +278,9 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
t.print(collections.OrderedDict([ t.print(collections.OrderedDict([
('valid loss', round(loss_meter.avg, 2)), ('valid loss', round(loss_meter.avg, 2)),
('valid ppl', get_perplexity(loss_meter.avg)), ('valid ppl', get_perplexity(nll_loss_meter.avg
if nll_loss_meter.count > 0
else loss_meter.avg)),
] + [ ] + [
(k, meter.avg) (k, meter.avg)
for k, meter in extra_meters.items() for k, meter in extra_meters.items()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment