Unverified Commit cbaf59d4 authored by Sergey Edunov's avatar Sergey Edunov Committed by GitHub
Browse files

Merge pull request #116 from facebookresearch/oss-merge-internal

Oss merge internal
parents 56f9ec3c b03b53b4
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
import math import math
import torch import torch
from torch.autograd import Variable
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
...@@ -15,41 +14,6 @@ from fairseq import utils ...@@ -15,41 +14,6 @@ from fairseq import utils
from . import FairseqCriterion, register_criterion from . import FairseqCriterion, register_criterion
class LabelSmoothedNLLLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, input, target, eps, padding_idx, weights, reduce=True):
grad_input = input.new(input.size()).zero_()
target = target.view(target.size(0), 1)
grad_input = grad_input.scatter_(grad_input.dim() - 1, target, eps - 1)
norm = grad_input.size(-1)
if weights is not None:
if isinstance(grad_input, Variable) and not isinstance(weights, Variable):
weights = Variable(weights, requires_grad=False)
norm = weights.sum()
grad_input.mul(weights.view(1, weights.size(0)).expand_as(grad_input))
if padding_idx is not None:
norm -= 1 if weights is None else weights[padding_idx]
grad_input.select(grad_input.dim() - 1, padding_idx).fill_(0)
grad_input = grad_input.add(-eps / norm)
ctx.grad_input = grad_input
if reduce:
return input.new([grad_input.view(-1).dot(input.view(-1))])
else:
return grad_input * input
@staticmethod
def backward(ctx, grad):
grad_input = ctx.grad_input
if not isinstance(grad_input, torch.autograd.Variable):
grad_input = utils.volatile_variable(grad_input)
return grad_input * grad, None, None, None, None, None
@register_criterion('label_smoothed_cross_entropy') @register_criterion('label_smoothed_cross_entropy')
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
...@@ -73,10 +37,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -73,10 +37,16 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
""" """
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1)) target = sample['target'].unsqueeze(-1)
target = sample['target'].view(-1) non_pad_mask = target.ne(self.padding_idx)
loss = LabelSmoothedNLLLoss.apply(lprobs, target, self.eps, self.padding_idx, None, reduce) nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask]
nll_loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, reduce=reduce) smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask]
if reduce:
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
eps_i = self.eps / lprobs.size(-1)
loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data, 'loss': utils.item(loss.data) if reduce else loss.data,
......
...@@ -57,17 +57,17 @@ def load_dataset(path, load_splits, src=None, dst=None): ...@@ -57,17 +57,17 @@ def load_dataset(path, load_splits, src=None, dst=None):
dataset = LanguageDatasets(src, dst, src_dict, dst_dict) dataset = LanguageDatasets(src, dst, src_dict, dst_dict)
# Load dataset from binary files # Load dataset from binary files
def all_splits_exist(src, dst): def all_splits_exist(src, dst, lang):
for split in load_splits: for split in load_splits:
filename = '{0}.{1}-{2}.{1}.idx'.format(split, src, dst) filename = '{0}.{1}-{2}.{3}.idx'.format(split, src, dst, lang)
if not os.path.exists(os.path.join(path, filename)): if not os.path.exists(os.path.join(path, filename)):
return False return False
return True return True
# infer langcode # infer langcode
if all_splits_exist(src, dst): if all_splits_exist(src, dst, src):
langcode = '{}-{}'.format(src, dst) langcode = '{}-{}'.format(src, dst)
elif all_splits_exist(dst, src): elif all_splits_exist(dst, src, src):
langcode = '{}-{}'.format(dst, src) langcode = '{}-{}'.format(dst, src)
else: else:
raise Exception('Dataset cannot be loaded from path: ' + path) raise Exception('Dataset cannot be loaded from path: ' + path)
...@@ -84,9 +84,13 @@ def load_dataset(path, load_splits, src=None, dst=None): ...@@ -84,9 +84,13 @@ def load_dataset(path, load_splits, src=None, dst=None):
if not IndexedInMemoryDataset.exists(src_path): if not IndexedInMemoryDataset.exists(src_path):
break break
target_dataset = None
if IndexedInMemoryDataset.exists(dst_path):
target_dataset = IndexedInMemoryDataset(dst_path)
dataset.splits[prefix] = LanguagePairDataset( dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path), IndexedInMemoryDataset(src_path),
IndexedInMemoryDataset(dst_path), target_dataset,
pad_idx=dataset.src_dict.pad(), pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(), eos_idx=dataset.src_dict.eos(),
) )
...@@ -194,21 +198,20 @@ class LanguagePairDataset(torch.utils.data.Dataset): ...@@ -194,21 +198,20 @@ class LanguagePairDataset(torch.utils.data.Dataset):
def __getitem__(self, i): def __getitem__(self, i):
# subtract 1 for 0-based indexing # subtract 1 for 0-based indexing
source = self.src[i].long() - 1 source = self.src[i].long() - 1
target = self.dst[i].long() - 1 res = { 'id': i, 'source': source }
return { if self.dst:
'id': i, res['target'] = self.dst[i].long() - 1
'source': source,
'target': target, return res
}
def __len__(self): def __len__(self):
return len(self.src) return len(self.src)
def collater(self, samples): def collater(self, samples):
return LanguagePairDataset.collate(samples, self.pad_idx, self.eos_idx) return LanguagePairDataset.collate(samples, self.pad_idx, self.eos_idx, self.dst is not None)
@staticmethod @staticmethod
def collate(samples, pad_idx, eos_idx): def collate(samples, pad_idx, eos_idx, has_target=True):
if len(samples) == 0: if len(samples) == 0:
return {} return {}
...@@ -220,6 +223,16 @@ class LanguagePairDataset(torch.utils.data.Dataset): ...@@ -220,6 +223,16 @@ class LanguagePairDataset(torch.utils.data.Dataset):
id = torch.LongTensor([s['id'] for s in samples]) id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE) src_tokens = merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
ntokens = None
if has_target:
target = merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET) target = merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET)
# we create a shifted version of targets for feeding the # we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step # previous output token(s) into the next decoder step
...@@ -228,18 +241,13 @@ class LanguagePairDataset(torch.utils.data.Dataset): ...@@ -228,18 +241,13 @@ class LanguagePairDataset(torch.utils.data.Dataset):
left_pad=LanguagePairDataset.LEFT_PAD_TARGET, left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True, move_eos_to_beginning=True,
) )
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order) prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order) target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
return { return {
'id': id, 'id': id,
'ntokens': sum(len(s['target']) for s in samples), 'ntokens': ntokens,
'net_input': { 'net_input': {
'src_tokens': src_tokens, 'src_tokens': src_tokens,
'src_lengths': src_lengths, 'src_lengths': src_lengths,
...@@ -301,21 +309,23 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions, ...@@ -301,21 +309,23 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
sample_len = 0 sample_len = 0
ignored = [] ignored = []
for idx in map(int, indices): for idx in map(int, indices):
if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions): src_size = src.sizes[idx]
dst_size = dst.sizes[idx] if dst else src_size
if not _valid_size(src_size, dst_size, max_positions):
if ignore_invalid_inputs: if ignore_invalid_inputs:
ignored.append(idx) ignored.append(idx)
continue continue
raise Exception(( raise Exception((
"Sample #{} has size (src={}, dst={}) but max size is {}." "Sample #{} has size (src={}, dst={}) but max size is {}."
" Skip this example with --skip-invalid-size-inputs-valid-test" " Skip this example with --skip-invalid-size-inputs-valid-test"
).format(idx, src.sizes[idx], dst.sizes[idx], max_positions)) ).format(idx, src_size, dst_size, max_positions))
sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx]) sample_len = max(sample_len, src_size, dst_size)
num_tokens = (len(batch) + 1) * sample_len num_tokens = (len(batch) + 1) * sample_len
if yield_batch(idx, num_tokens): if yield_batch(idx, num_tokens):
yield batch yield batch
batch = [] batch = []
sample_len = max(src.sizes[idx], dst.sizes[idx]) sample_len = max(src_size, dst_size)
batch.append(idx) batch.append(idx)
...@@ -332,7 +342,7 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None, ...@@ -332,7 +342,7 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
descending=False): descending=False):
"""Returns batches of indices sorted by size. Sequences with different """Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch.""" source lengths are not allowed in the same batch."""
assert isinstance(src, IndexedDataset) and isinstance(dst, IndexedDataset) assert isinstance(src, IndexedDataset) and (dst is None or isinstance(dst, IndexedDataset))
if max_tokens is None: if max_tokens is None:
max_tokens = float('Inf') max_tokens = float('Inf')
if max_sentences is None: if max_sentences is None:
......
...@@ -77,11 +77,11 @@ class SequenceGenerator(object): ...@@ -77,11 +77,11 @@ class SequenceGenerator(object):
prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None, prefix_tokens=s['target'][:, :prefix_size] if prefix_size > 0 else None,
) )
if timer is not None: if timer is not None:
timer.stop(s['ntokens']) timer.stop(sum([len(h[0]['tokens']) for h in hypos]))
for i, id in enumerate(s['id'].data): for i, id in enumerate(s['id'].data):
src = input['src_tokens'].data[i, :] src = input['src_tokens'].data[i, :]
# remove padding from ref # remove padding from ref
ref = utils.strip_pad(s['target'].data[i, :], self.pad) ref = utils.strip_pad(s['target'].data[i, :], self.pad) if s['target'] is not None else None
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None): def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
......
...@@ -46,7 +46,7 @@ class SequenceScorer(object): ...@@ -46,7 +46,7 @@ class SequenceScorer(object):
'alignment': alignment, 'alignment': alignment,
'positional_scores': pos_scores_i, 'positional_scores': pos_scores_i,
}] }]
# return results in the same format as SequenceGenenerator # return results in the same format as SequenceGenerator
yield id, src, ref, hypos yield id, src, ref, hypos
def score(self, sample): def score(self, sample):
......
...@@ -84,6 +84,7 @@ def main(args): ...@@ -84,6 +84,7 @@ def main(args):
# Generate and compute BLEU score # Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk()) scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
num_sentences = 0 num_sentences = 0
has_target = True
with progress_bar.build_progress_bar(args, itr) as t: with progress_bar.build_progress_bar(args, itr) as t:
if args.score_reference: if args.score_reference:
translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer) translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
...@@ -94,17 +95,21 @@ def main(args): ...@@ -94,17 +95,21 @@ def main(args):
wps_meter = TimeMeter() wps_meter = TimeMeter()
for sample_id, src_tokens, target_tokens, hypos in translations: for sample_id, src_tokens, target_tokens, hypos in translations:
# Process input and ground truth # Process input and ground truth
target_tokens = target_tokens.int().cpu() has_target = target_tokens is not None
target_tokens = target_tokens.int().cpu() if has_target else None
# Either retrieve the original sentences or regenerate them from tokens. # Either retrieve the original sentences or regenerate them from tokens.
if align_dict is not None: if align_dict is not None:
src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id) src_str = dataset.splits[args.gen_subset].src.get_original_text(sample_id)
target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id) target_str = dataset.splits[args.gen_subset].dst.get_original_text(sample_id)
else: else:
src_str = dataset.src_dict.string(src_tokens, args.remove_bpe) src_str = dataset.src_dict.string(src_tokens, args.remove_bpe)
target_str = dataset.dst_dict.string(target_tokens, args.remove_bpe, escape_unk=True) target_str = dataset.dst_dict.string(target_tokens,
args.remove_bpe,
escape_unk=True) if has_target else ''
if not args.quiet: if not args.quiet:
print('S-{}\t{}'.format(sample_id, src_str)) print('S-{}\t{}'.format(sample_id, src_str))
if has_target:
print('T-{}\t{}'.format(sample_id, target_str)) print('T-{}\t{}'.format(sample_id, target_str))
# Process top predictions # Process top predictions
...@@ -133,7 +138,7 @@ def main(args): ...@@ -133,7 +138,7 @@ def main(args):
)) ))
# Score only the top hypothesis # Score only the top hypothesis
if i == 0: if has_target and i == 0:
if align_dict is not None or args.remove_bpe is not None: if align_dict is not None or args.remove_bpe is not None:
# Convert back to tokens for evaluation with unk replacement and/or without BPE # Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize( target_tokens = tokenizer.Tokenizer.tokenize(
...@@ -146,6 +151,7 @@ def main(args): ...@@ -146,6 +151,7 @@ def main(args):
print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format( print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} tokens/s)'.format(
num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg)) num_sentences, gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
if has_target:
print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))
......
...@@ -21,9 +21,9 @@ def get_parser(): ...@@ -21,9 +21,9 @@ def get_parser():
description='Data pre-processing: Create dictionary and store data in binary format') description='Data pre-processing: Create dictionary and store data in binary format')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language') parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='target language') parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', help='target language')
parser.add_argument('--trainpref', metavar='FP', default='train', help='target language') parser.add_argument('--trainpref', metavar='FP', default=None, help='target language')
parser.add_argument('--validpref', metavar='FP', default='valid', help='comma separated, valid language prefixes') parser.add_argument('--validpref', metavar='FP', default=None, help='comma separated, valid language prefixes')
parser.add_argument('--testpref', metavar='FP', default='test', help='comma separated, test language prefixes') parser.add_argument('--testpref', metavar='FP', default=None, help='comma separated, test language prefixes')
parser.add_argument('--destdir', metavar='DIR', default='data-bin', help='destination dir') parser.add_argument('--destdir', metavar='DIR', default='data-bin', help='destination dir')
parser.add_argument('--thresholdtgt', metavar='N', default=0, type=int, parser.add_argument('--thresholdtgt', metavar='N', default=0, type=int,
help='map words appearing less than threshold times to unknown') help='map words appearing less than threshold times to unknown')
...@@ -37,12 +37,14 @@ def get_parser(): ...@@ -37,12 +37,14 @@ def get_parser():
parser.add_argument('--output-format', metavar='FORMAT', default='binary', choices=['binary', 'raw'], parser.add_argument('--output-format', metavar='FORMAT', default='binary', choices=['binary', 'raw'],
help='output format (optional)') help='output format (optional)')
parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary') parser.add_argument('--joined-dictionary', action='store_true', help='Generate joined dictionary')
parser.add_argument('--only-source', action='store_true', help='Only process the source language')
return parser return parser
def main(args): def main(args):
print(args) print(args)
os.makedirs(args.destdir, exist_ok=True) os.makedirs(args.destdir, exist_ok=True)
target = not args.only_source
if args.joined_dictionary: if args.joined_dictionary:
assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary' assert not args.srcdict, 'cannot combine --srcdict and --joined-dictionary'
...@@ -60,14 +62,18 @@ def main(args): ...@@ -60,14 +62,18 @@ def main(args):
if args.srcdict: if args.srcdict:
src_dict = dictionary.Dictionary.load(args.srcdict) src_dict = dictionary.Dictionary.load(args.srcdict)
else: else:
assert args.trainpref, "--trainpref must be set if --srcdict is not specified"
src_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.source_lang)) src_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.source_lang))
if target:
if args.tgtdict: if args.tgtdict:
tgt_dict = dictionary.Dictionary.load(args.tgtdict) tgt_dict = dictionary.Dictionary.load(args.tgtdict)
else: else:
assert args.trainpref, "--trainpref must be set if --tgtdict is not specified"
tgt_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.target_lang)) tgt_dict = Tokenizer.build_dictionary(filename='{}.{}'.format(args.trainpref, args.target_lang))
src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)), src_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)),
threshold=args.thresholdsrc, nwords=args.nwordssrc) threshold=args.thresholdsrc, nwords=args.nwordssrc)
if target:
tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)), tgt_dict.save(os.path.join(args.destdir, 'dict.{}.txt'.format(args.target_lang)),
threshold=args.thresholdtgt, nwords=args.nwordstgt) threshold=args.thresholdtgt, nwords=args.nwordstgt)
...@@ -100,19 +106,26 @@ def main(args): ...@@ -100,19 +106,26 @@ def main(args):
output_text_file = os.path.join(args.destdir, '{}.{}'.format(output_prefix, lang)) output_text_file = os.path.join(args.destdir, '{}.{}'.format(output_prefix, lang))
shutil.copyfile('{}.{}'.format(input_prefix, lang), output_text_file) shutil.copyfile('{}.{}'.format(input_prefix, lang), output_text_file)
make_dataset(args.trainpref, 'train', args.source_lang, args.output_format) def make_all(args, make_dataset, lang):
make_dataset(args.trainpref, 'train', args.target_lang, args.output_format) if args.trainpref:
make_dataset(args.trainpref, 'train', lang, args.output_format)
if args.validpref:
for k, validpref in enumerate(args.validpref.split(',')): for k, validpref in enumerate(args.validpref.split(',')):
outprefix = 'valid{}'.format(k) if k > 0 else 'valid' outprefix = 'valid{}'.format(k) if k > 0 else 'valid'
make_dataset(validpref, outprefix, args.source_lang, args.output_format) make_dataset(validpref, outprefix, lang, args.output_format)
make_dataset(validpref, outprefix, args.target_lang, args.output_format) if args.testpref:
for k, testpref in enumerate(args.testpref.split(',')): for k, testpref in enumerate(args.testpref.split(',')):
outprefix = 'test{}'.format(k) if k > 0 else 'test' outprefix = 'test{}'.format(k) if k > 0 else 'test'
make_dataset(testpref, outprefix, args.source_lang, args.output_format) make_dataset(testpref, outprefix, lang, args.output_format)
make_dataset(testpref, outprefix, args.target_lang, args.output_format)
make_all(args, make_dataset, args.source_lang)
if target:
make_all(args, make_dataset, args.target_lang)
print('| Wrote preprocessed data to {}'.format(args.destdir)) print('| Wrote preprocessed data to {}'.format(args.destdir))
if args.alignfile: if args.alignfile:
assert args.trainpref, "--trainpref must be set if --alignfile is specified"
src_file_name = '{}.{}'.format(args.trainpref, args.source_lang) src_file_name = '{}.{}'.format(args.trainpref, args.source_lang)
tgt_file_name = '{}.{}'.format(args.trainpref, args.target_lang) tgt_file_name = '{}.{}'.format(args.trainpref, args.target_lang)
src_dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang))) src_dict = dictionary.Dictionary.load(os.path.join(args.destdir, 'dict.{}.txt'.format(args.source_lang)))
......
...@@ -4,31 +4,98 @@ ...@@ -4,31 +4,98 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import torch import argparse
import copy
import unittest import unittest
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedNLLLoss
from torch.autograd import Variable, gradcheck
import torch
from torch.autograd import Variable
from fairseq import utils
from fairseq.criterions.cross_entropy import CrossEntropyCriterion
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
torch.set_default_tensor_type('torch.DoubleTensor') import tests.utils as test_utils
class TestLabelSmoothing(unittest.TestCase): class TestLabelSmoothing(unittest.TestCase):
def test_label_smoothing(self): def setUp(self):
input = Variable(torch.randn(3, 5), requires_grad=True) # build dictionary
idx = torch.rand(3) * 4 self.d = test_utils.dummy_dictionary(3)
target = Variable(idx.long()) vocab = len(self.d)
criterion = LabelSmoothedNLLLoss() self.assertEqual(vocab, 4 + 3) # 4 special + 3 tokens
self.assertTrue(gradcheck( self.assertEqual(self.d.pad(), 1)
lambda x, y: criterion.apply(x, y, 0.1, 2, None), (input, target) self.assertEqual(self.d.eos(), 2)
)) self.assertEqual(self.d.unk(), 3)
weights = torch.ones(5) pad, eos, unk, w1, w2, w3 = 1, 2, 3, 4, 5, 6
weights[2] = 0
self.assertTrue(gradcheck(lambda x, y: criterion.apply(x, y, 0.1, None, weights), (input, target))) # build dataset
self.assertTrue(gradcheck(lambda x, y: criterion.apply(x, y, 0.1, None, None), (input, target))) self.data = [
# the first batch item has padding
{'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, eos])},
{'source': torch.LongTensor([w1, eos]), 'target': torch.LongTensor([w1, w1, eos])},
]
self.sample = next(test_utils.dummy_dataloader(self.data))
# build model
self.args = argparse.Namespace()
self.args.sentence_avg = False
self.args.probs = torch.FloatTensor([
# pad eos unk w1 w2 w3
[0.05, 0.05, 0.1, 0.05, 0.3, 0.4, 0.05],
[0.05, 0.10, 0.2, 0.05, 0.2, 0.3, 0.10],
[0.05, 0.15, 0.3, 0.05, 0.1, 0.2, 0.15],
]).unsqueeze(0).expand(2, 3, 7) # add batch dimension
self.model = test_utils.TestModel.build_model(self.args, self.d, self.d)
def test_nll_loss(self):
self.args.label_smoothing = 0.1
nll_crit = CrossEntropyCriterion(self.args, self.d, self.d)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertLess(abs(nll_loss - nll_logging_output['loss']), 1e-6)
self.assertLess(abs(nll_loss - smooth_logging_output['nll_loss']), 1e-6)
def test_padding(self):
self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
loss, _, logging_output = crit(self.model, self.sample)
def get_one_no_padding(idx):
# create a new sample with just a single batch item so that there's
# no padding
sample1 = next(test_utils.dummy_dataloader([self.data[idx]]))
args1 = copy.copy(self.args)
args1.probs = args1.probs[idx, :, :].unsqueeze(0)
model1 = test_utils.TestModel.build_model(args1, self.d, self.d)
loss1, _, _ = crit(model1, sample1)
return loss1
loss1 = get_one_no_padding(0)
loss2 = get_one_no_padding(1)
self.assertAlmostEqual(loss, loss1 + loss2)
def test_reduction(self):
self.args.label_smoothing = 0.1
crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
loss, _, logging_output = crit(self.model, self.sample, reduce=True)
unreduced_loss, _, _ = crit(self.model, self.sample, reduce=False)
self.assertAlmostEqual(loss, unreduced_loss.sum())
def test_zero_eps(self):
self.args.label_smoothing = 0.0
nll_crit = CrossEntropyCriterion(self.args, self.d, self.d)
smooth_crit = LabelSmoothedCrossEntropyCriterion(self.args, self.d, self.d)
nll_loss, nll_sample_size, nll_logging_output = nll_crit(self.model, self.sample)
smooth_loss, smooth_sample_size, smooth_logging_output = smooth_crit(self.model, self.sample)
self.assertAlmostEqual(nll_loss, smooth_loss)
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-6)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -92,7 +92,7 @@ class TestEncoder(FairseqEncoder): ...@@ -92,7 +92,7 @@ class TestEncoder(FairseqEncoder):
class TestIncrementalDecoder(FairseqIncrementalDecoder): class TestIncrementalDecoder(FairseqIncrementalDecoder):
def __init__(self, args, dictionary): def __init__(self, args, dictionary):
super().__init__(dictionary) super().__init__(dictionary)
assert hasattr(args, 'beam_probs') assert hasattr(args, 'beam_probs') or hasattr(args, 'probs')
args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100) args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100)
self.args = args self.args = args
...@@ -116,6 +116,11 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder): ...@@ -116,6 +116,11 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
steps = list(range(tgt_len)) steps = list(range(tgt_len))
# define output in terms of raw probs # define output in terms of raw probs
if hasattr(self.args, 'probs'):
assert self.args.probs.dim() == 3, \
'expected probs to have size bsz*steps*vocab'
probs = self.args.probs.index_select(1, torch.LongTensor(steps))
else:
probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_() probs = torch.FloatTensor(bbsz, len(steps), vocab).zero_()
for i, step in enumerate(steps): for i, step in enumerate(steps):
# args.beam_probs gives the probability for every vocab element, # args.beam_probs gives the probability for every vocab element,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment