Unverified Commit cbaf59d4 authored by Sergey Edunov's avatar Sergey Edunov Committed by GitHub
Browse files

Merge pull request #116 from facebookresearch/oss-merge-internal

Oss merge internal
parents 56f9ec3c b03b53b4
......@@ -7,7 +7,6 @@
import math
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from fairseq import utils
......@@ -15,41 +14,6 @@ from fairseq import utils
from . import FairseqCriterion, register_criterion
class LabelSmoothedNLLLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, input, target, eps, padding_idx, weights, reduce=True):
grad_input = input.new(input.size()).zero_()
target = target.view(target.size(0), 1)
grad_input = grad_input.scatter_(grad_input.dim() - 1, target, eps - 1)
norm = grad_input.size(-1)
if weights is not None:
if isinstance(grad_input, Variable) and not isinstance(weights, Variable):
weights = Variable(weights, requires_grad=False)
norm = weights.sum()
grad_input.mul(weights.view(1, weights.size(0)).expand_as(grad_input))
if padding_idx is not None:
norm -= 1 if weights is None else weights[padding_idx]
grad_input.select(grad_input.dim() - 1, padding_idx).fill_(0)
grad_input = grad_input.add(-eps / norm)
ctx.grad_input = grad_input
if reduce:
return input.new([grad_input.view(-1).dot(input.view(-1))])
else:
return grad_input * input
@staticmethod
def backward(ctx, grad):
grad_input = ctx.grad_input
if not isinstance(grad_input, torch.autograd.Variable):
grad_input = utils.volatile_variable(grad_input)
return grad_input * grad, None, None, None, None, None
@register_criterion('label_smoothed_cross_entropy')
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
......@@ -73,10 +37,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
"""
net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = sample['target'].view(-1)
loss = LabelSmoothedNLLLoss.apply(lprobs, target, self.eps, self.padding_idx, None, reduce)
nll_loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, reduce=reduce)
target = sample['target'].unsqueeze(-1)
non_pad_mask = target.ne(self.padding_idx)
nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask]
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask]
if reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = self.eps / lprobs.size(-1)
loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
......
......@@ -57,17 +57,17 @@ def load_dataset(path, load_splits, src=None, dst=None):
dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
# Load dataset from binary files
def all_splits_exist(src, dst):
def all_splits_exist(src, dst, lang):
for split in load_splits:
filename = '{0}.{1}-{2}.{1}.idx'.format(split, src, dst)
filename = '{0}.{1}-{2}.{3}.idx'.format(split, src, dst, lang)
if not os.path.exists(os.path.join(path, filename)):
return False
return True
# infer langcode
if all_splits_exist(src, dst):
if all_splits_exist(src, dst, src):
langcode = '{}-{}'.format(src, dst)
elif all_splits_exist(dst, src):
elif all_splits_exist(dst, src, src):
langcode = '{}-{}'.format(dst, src)
else:
raise Exception('Dataset cannot be loaded from path: ' + path)
......@@ -84,9 +84,13 @@ def load_dataset(path, load_splits, src=None, dst=None):
if not IndexedInMemoryDataset.exists(src_path):
break
target_dataset = None
if IndexedInMemoryDataset.exists(dst_path):
target_dataset = IndexedInMemoryDataset(dst_path)
dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path),
IndexedInMemoryDataset(dst_path),
target_dataset,
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
......@@ -194,21 +198,20 @@ class LanguagePairDataset(torch.utils.data.Dataset):
def __getitem__(self, i):
# subtract 1 for 0-based indexing
source = self.src[i].long() - 1
target = self.dst[i].long() - 1
return {
'id': i,
'source': source,
'target': target,
}
res = { 'id': i, 'source': source }
if self.dst:
res['target'] = self.dst[i].long() - 1
return res
def __len__(self):
return len(self.src)
def collater(self, samples):
return LanguagePairDataset.collate(samples, self.pad_idx, self.eos_idx)
return LanguagePairDataset.collate(samples, self.pad_idx, self.eos_idx, self.dst is not None)
@staticmethod
def collate(samples, pad_idx, eos_idx):
def collate(samples, pad_idx, eos_idx, has_target=True):
if len(samples) == 0:
return {}
......@@ -220,26 +223,31 @@ class LanguagePairDataset(torch.utils.data.Dataset):
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE)
target = merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True,
)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order)
prev_output_tokens = None
target = None
ntokens = None
if has_target:
target = merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': id,
'ntokens': sum(len(s['target']) for s in samples),
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
......@@ -301,21 +309,23 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
sample_len = 0
ignored = []
for idx in map(int, indices):
if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions):
src_size = src.sizes[idx]
dst_size = dst.sizes[idx] if dst else src_size
if not _valid_size(src_size, dst_size, max_positions):
if ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception((
"Sample #{} has size (src={}, dst={}) but max size is {}."
" Skip this example with --skip-invalid-size-inputs-valid-test"
).format(idx, src.sizes[idx], dst.sizes[idx], max_positions))
).format(idx, src_size, dst_size, max_positions))
sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
sample_len = max(sample_len, src_size, dst_size)
num_tokens = (len(batch) + 1) * sample_len
if yield_batch(idx, num_tokens):
yield batch
batch = []
sample_len = max(src.sizes[idx], dst.sizes[idx])
sample_len = max(src_size, dst_size)
batch.append(idx)
......@@ -332,7 +342,7 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
descending=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset)
assert isinstance(src, IndexedDataset) and (dst is None or isinstance(dst, IndexedDataset))
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
......
......@@ -77,11 +77,11 @@ class SequenceGenerator(object):
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
)
if timer is not None:
timer.stop(s['ntokens'])
timer.stop(sum([len(h[0]['tokens']) for h in hypos]))
for i, id in enumerate(s['id'].data):
src = input['src_tokens'].data[i, :]
# remove padding from ref
ref = utils.strip_pad(s['target'].data[i, :], self.pad)
ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
......
......@@ -46,7 +46,7 @@ class SequenceScorer(object):
'alignment': alignment,
'positional_scores': pos_scores_i,
}]
# return results in the same format as SequenceGenenerator
# return results in the same format as SequenceGenerator
yield id, src, ref, hypos
def score(self, sample):
......
......@@ -84,6 +84,7 @@ def main(args):
# Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
num_sentences = 0
has_target = True
with progress_bar.build_progress_bar(args, itr) as t:
if args.score_reference:
translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
......@@ -94,18 +95,22 @@ def main(args):
wps_meter = TimeMeter()
for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth
target_tokens = target_tokens.int().cpu()
has_target = target_tokens is not None
target_tokens = target_tokens.int().cpu() if has_target else None
# Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None:
src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id)
target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id)
else:
src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
target_str = dataset.dst_dict.string(target_tokens, args.remove_bpe, escape_unk=True)
target_str = dataset.dst_dict.string(target_tokens,
args.remove_bpe,
escape_unk=True) if has_target else ''
if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str))
print('T-{}\t{}'.format(sample_id, target_str))
if has_target:
print('T-{}\t{}'.format(sample_id, target_str))
# Process top predictions
for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
......@@ -133,7 +138,7 @@ def main(args):
))
# Score only the top hypothesis
if i == 0:
if has_target and i == 0:
if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(
......@@ -146,7 +151,8 @@ def main(args):
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
if has_target:
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
if __name__ == '__main__':
......
......@@ -21,9 +21,9 @@ def get_parser():
description='Data pre-processing: Create dictionary and store data in binary format')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='target language')
parser.add_argument('--trainpref', metavar='FP', default='train', help='target language')
parser.add_argument('--validpref', metavar='FP', default='valid', help='comma separated, valid language prefixes')
parser.add_argument('--testpref', metavar='FP', default='test', help='comma separated, test language prefixes')
parser.add_argument('--trainpref', metavar='FP', default=None, help='target language')
parser.add_argument('--validpref', metavar='FP', default=None, help='comma separated, valid language prefixes')
parser.add_argument('--testpref', metavar='FP', default=None, help='comma separated, test language prefixes')
parser.add_argument('--destdir', metavar='DIR', default='data-bin', help='destination dir')
parser.add_argument('--thresholdtgt', metavar='N', default=0, type=int,
help='map words appearing less than threshold times to unknown')
......@@ -37,12 +37,14 @@ def get_parser():
parser.add_argument('--output-format', metavar='FORMAT', default='binary', choices=['binary', 'raw'],
help='output format (optional)')
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
parser.add_argument('--only-source', action='store_true', help='Only process the source language')
return parser
def main(args):
print(args)
os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source
if args.joined_dictionary:
assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
......@@ -60,16 +62,20 @@ def main(args):
if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict)
else:
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.source_lang))
if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict)
else:
tgt_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.target_lang))
if target:
if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict)
else:
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.target_lang))
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)),
threshold=args.thresholdsrc, nwords=args.nwordssrc)
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)),
threshold=args.thresholdtgt, nwords=args.nwordstgt)
if target:
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)),
threshold=args.thresholdtgt, nwords=args.nwordstgt)
def make_binary_dataset(input_prefix, output_prefix, lang):
dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(lang)))
......@@ -100,19 +106,26 @@ def main(args):
output_text_file = os.path.join(args.destdir, '{}.{}'.format(output_prefix, lang))
shutil.copyfile('{}.{}'.format(input_prefix, lang), output_text_file)
make_dataset(args.trainpref, 'train', args.source_lang, args.output_format)
make_dataset(args.trainpref, 'train', args.target_lang, args.output_format)
for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
make_dataset(validpref, outprefix, args.source_lang, args.output_format)
make_dataset(validpref, outprefix, args.target_lang, args.output_format)
for k, testpref in enumerate(args.testpref.split(',')):
outprefix = 'test{}'.format(k) if k > 0 else 'test'
make_dataset(testpref, outprefix, args.source_lang, args.output_format)
make_dataset(testpref, outprefix, args.target_lang, args.output_format)
def make_all(args, make_dataset, lang):
if args.trainpref:
make_dataset(args.trainpref, 'train', lang, args.output_format)
if args.validpref:
for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
make_dataset(validpref, outprefix, lang, args.output_format)
if args.testpref:
for k, testpref in enumerate(args.testpref.split(',')):
outprefix = 'test{}'.format(k) if k > 0 else 'test'
make_dataset(testpref, outprefix, lang, args.output_format)
make_all(args, make_dataset, args.source_lang)
if target:
make_all(args, make_dataset, args.target_lang)
print('| Wrote preprocessed data to {}'.format(args.destdir))
if args.alignfile:
assert args.trainpref, "--trainpref must be set if --alignfile is specified"
src_file_name = '{}.{}'.format(args.trainpref, args.source_lang)
tgt_file_name = '{}.{}'.format(args.trainpref, args.target_lang)
src_dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)))
......
......@@ -4,31 +4,98 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
import argparse
import copy
import unittest
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedNLLLoss
from torch.autograd import Variable, gradcheck
import torch
from torch.autograd import Variable
from fairseq import utils
from fairseq.criterions.cross_entropy import CrossEntropyCriterion
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
torch.set_default_tensor_type('torch.DoubleTensor')
import tests.utils as test_utils
class TestLabelSmoothing(unittest.TestCase):
def test_label_smoothing(self):
input = Variable(torch.randn(3, 5), requires_grad=True)
idx = torch.rand(3) * 4
target = Variable(idx.long())
criterion = LabelSmoothedNLLLoss()
self.assertTrue(gradcheck(
lambda x, y: criterion.apply(x, y, 0.1, 2, None), (input, target)
))
weights = torch.ones(5)
weights[2] = 0
self.assertTrue(gradcheck(lambda x, y: criterion.apply(x, y, 0.1, None, weights), (input, target)))
self.assertTrue(gradcheck(lambda x, y: criterion.apply(x, y, 0.1, None, None), (input, target)))
def setUp(self):
# build dictionary
self.d = test_utils.dummy_dictionary(3)
vocab = len(self.d)
self.assertEqual(vocab, 4 + 3) # 4 special + 3 tokens
self.assertEqual(self.d.pad(), 1)
self.assertEqual(self.d.eos(), 2)
self.assertEqual(self.d.unk(), 3)
pad, eos, unk, w1, w2, w3 = 1, 2, 3, 4, 5, 6
# build dataset
self.data = [
# the first batch item has padding
{'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, eos])},
{'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, w1, eos])},
]
self.sample = next(test_utils.dummy_dataloader(self.data))
# build model
self.args = argparse.Namespace()
self.args.sentence_avg = False
self.args.probs = torch.FloatTensor([
# pad eos unk w1 w2 w3
[0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05],
[0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10],
[0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15],
]).unsqueeze(0).expand(2, 3, 7) # add batch dimension
self.model = test_utils.TestModel.build_model(self.args, self.d, self.d)
def test_nll_loss(self):
self.args.label_smoothing = 0.1
nll_crit = CrossEntropyCriterion(self.args, self.d, self.d)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6)
self.assertLess(abs(nll_loss - smooth_logging_output['nll_loss']), 1e-6)
def test_padding(self):
self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
loss, _, logging_output = crit(self.model, self.sample)
def get_one_no_padding(idx):
# create a new sample with just a single batch item so that there's
# no padding
sample1 = next(test_utils.dummy_dataloader([self.data[idx]]))
args1 = copy.copy(self.args)
args1.probs = args1.probs[idx, :, :].unsqueeze(0)
model1 = test_utils.TestModel.build_model(args1, self.d, self.d)
loss1, _, _ = crit(model1, sample1)
return loss1
loss1 = get_one_no_padding(0)
loss2 = get_one_no_padding(1)
self.assertAlmostEqual(loss, loss1 + loss2)
def test_reduction(self):
self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
loss, _, logging_output = crit(self.model, self.sample, reduce=True)
unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False)
self.assertAlmostEqual(loss, unreduced_loss.sum())
def test_zero_eps(self):
self.args.label_smoothing = 0.0
nll_crit = CrossEntropyCriterion(self.args, self.d, self.d)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertAlmostEqual(nll_loss, smooth_loss)
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-6)
if __name__ == '__main__':
......
......@@ -92,7 +92,7 @@ class TestEncoder(FairseqEncoder):
class TestIncrementalDecoder(FairseqIncrementalDecoder):
def __init__(self, args, dictionary):
super().__init__(dictionary)
assert hasattr(args, 'beam_probs')
assert hasattr(args, 'beam_probs') or hasattr(args, 'probs')
args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100)
self.args = args
......@@ -116,14 +116,19 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
steps = list(range(tgt_len))
# define output in terms of raw probs
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
for i, step in enumerate(steps):
# args.beam_probs gives the probability for every vocab element,
# starting with eos, then unknown, and then the rest of the vocab
if step < len(self.args.beam_probs):
probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step]
else:
probs[:, i, self.dictionary.eos()] = 1.0
if hasattr(self.args, 'probs'):
assert self.args.probs.dim() == 3, \
'expected probs to have size bsz*steps*vocab'
probs = self.args.probs.index_select(1, torch.LongTensor(steps))
else:
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
for i, step in enumerate(steps):
# args.beam_probs gives the probability for every vocab element,
# starting with eos, then unknown, and then the rest of the vocab
if step < len(self.args.beam_probs):
probs[:, i, self.dictionary.eos():] = self.args.beam_probs[step]
else:
probs[:, i, self.dictionary.eos()] = 1.0
# random attention
attn = torch.rand(bbsz, src_len, tgt_len)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment