Commit e4335001 authored by Peng-Jen Chen's avatar Peng-Jen Chen Committed by Facebook Github Bot
Browse files

Fix multilingual evaluation bug (#592)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/592

Fix bug reported at
https://github.com/pytorch/fairseq/commit/9c3bb5c6d6c7d6442a28ccb8a81b2fc4e5782ace#r34181600

D15682169 breaks the multilingual translation generation.

Reviewed By: dpacgopinath

Differential Revision: D16147454

fbshipit-source-id: e0cf4d32f362190a0542fa0160f65a2a207ca3fa
parent 6d2e0831
...@@ -100,19 +100,23 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -100,19 +100,23 @@ class MultilingualTranslationTask(FairseqTask):
def __init__(self, args, dicts, training): def __init__(self, args, dicts, training):
super().__init__(args) super().__init__(args)
self.dicts = dicts self.dicts = dicts
self.training = training
if training:
self.lang_pairs = args.lang_pairs self.lang_pairs = args.lang_pairs
args.source_lang, args.target_lang = args.lang_pairs[0].split('-')
else:
self.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
# eval_lang_pairs for multilingual translation is usually all of the # eval_lang_pairs for multilingual translation is usually all of the
# lang_pairs. However for other multitask settings or when we want to # lang_pairs. However for other multitask settings or when we want to
# optimize for certain languages we want to use a different subset. Thus # optimize for certain languages we want to use a different subset. Thus
# the eval_lang_pairs class variable is provided for classes that extend # the eval_lang_pairs class variable is provided for classes that extend
# this class. # this class.
self.eval_lang_pairs = args.lang_pairs self.eval_lang_pairs = self.lang_pairs
# model_lang_pairs will be used to build encoder-decoder model pairs in # model_lang_pairs will be used to build encoder-decoder model pairs in
# models.build_model(). This allows multitask type of sub-class can # models.build_model(). This allows multitask type of sub-class can
# build models other than the input lang_pairs # build models other than the input lang_pairs
self.model_lang_pairs = copy.copy(args.lang_pairs) self.model_lang_pairs = self.lang_pairs
self.langs = list(dicts.keys()) self.langs = list(dicts.keys())
self.training = training
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
...@@ -136,10 +140,8 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -136,10 +140,8 @@ class MultilingualTranslationTask(FairseqTask):
sorted_langs = sorted(list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')})) sorted_langs = sorted(list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')}))
if args.source_lang is not None or args.target_lang is not None: if args.source_lang is not None or args.target_lang is not None:
training = False training = False
args.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
else: else:
training = True training = True
args.source_lang, args.target_lang = args.lang_pairs[0].split('-')
# load dictionaries # load dictionaries
dicts = OrderedDict() dicts = OrderedDict()
...@@ -224,7 +226,7 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -224,7 +226,7 @@ class MultilingualTranslationTask(FairseqTask):
self.datasets[split] = RoundRobinZipDatasets( self.datasets[split] = RoundRobinZipDatasets(
OrderedDict([ OrderedDict([
(lang_pair, language_pair_dataset(lang_pair)) (lang_pair, language_pair_dataset(lang_pair))
for lang_pair in self.args.lang_pairs for lang_pair in self.lang_pairs
]), ]),
eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang), eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang),
) )
...@@ -272,7 +274,7 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -272,7 +274,7 @@ class MultilingualTranslationTask(FairseqTask):
def train_step(self, sample, model, criterion, optimizer, ignore_grad=False): def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
model.train() model.train()
agg_loss, agg_sample_size, agg_logging_output = 0., 0., {} agg_loss, agg_sample_size, agg_logging_output = 0., 0., {}
for lang_pair in self.args.lang_pairs: for lang_pair in self.model_lang_pairs:
if sample[lang_pair] is None or len(sample[lang_pair]) == 0: if sample[lang_pair] is None or len(sample[lang_pair]) == 0:
continue continue
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
......
...@@ -124,7 +124,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -124,7 +124,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
"%s-%s" % (tgt, tgt) "%s-%s" % (tgt, tgt)
for tgt in {lang_pair.split('-')[1] for lang_pair in args.lang_pairs} for tgt in {lang_pair.split('-')[1] for lang_pair in args.lang_pairs}
] ]
self.model_lang_pairs += denoising_lang_pairs self.model_lang_pairs = self.model_lang_pairs + denoising_lang_pairs
self.backtranslate_datasets = {} self.backtranslate_datasets = {}
self.backtranslators = {} self.backtranslators = {}
...@@ -164,7 +164,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -164,7 +164,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
# load parallel datasets # load parallel datasets
src_datasets, tgt_datasets = {}, {} src_datasets, tgt_datasets = {}, {}
if (self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train")): if (self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train")):
for lang_pair in self.args.lang_pairs: for lang_pair in self.lang_pairs:
src, tgt = lang_pair.split('-') src, tgt = lang_pair.split('-')
if split_exists(split, src, tgt, src): if split_exists(split, src, tgt, src):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt)) prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt))
...@@ -181,7 +181,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -181,7 +181,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
# back translation datasets # back translation datasets
backtranslate_datasets = {} backtranslate_datasets = {}
if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and split.startswith("train"): if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and split.startswith("train"):
for lang_pair in self.args.lang_pairs: for lang_pair in self.lang_pairs:
src, tgt = lang_pair.split('-') src, tgt = lang_pair.split('-')
if not split_exists(split, tgt, None, tgt): if not split_exists(split, tgt, None, tgt):
raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path)) raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path))
...@@ -229,7 +229,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -229,7 +229,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
# denoising autoencoder # denoising autoencoder
noising_datasets = {} noising_datasets = {}
if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None) and split.startswith("train"): if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None) and split.startswith("train"):
for lang_pair in self.args.lang_pairs: for lang_pair in self.lang_pairs:
_, tgt = lang_pair.split('-') _, tgt = lang_pair.split('-')
if not split_exists(split, tgt, None, tgt): if not split_exists(split, tgt, None, tgt):
continue continue
...@@ -305,7 +305,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -305,7 +305,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
# create SequenceGenerator for each model that has backtranslation dependency on it # create SequenceGenerator for each model that has backtranslation dependency on it
self.sequence_generators = {} self.sequence_generators = {}
if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and self.training: if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and self.training:
for lang_pair in self.args.lang_pairs: for lang_pair in self.lang_pairs:
src, tgt = lang_pair.split('-') src, tgt = lang_pair.split('-')
key = '{}-{}'.format(tgt, src) key = '{}-{}'.format(tgt, src)
self.sequence_generators[key] = SequenceGenerator( self.sequence_generators[key] = SequenceGenerator(
...@@ -350,16 +350,16 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -350,16 +350,16 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
agg_logging_output[logging_output_key] = logging_output agg_logging_output[logging_output_key] = logging_output
if self.lambda_parallel > 0.0: if self.lambda_parallel > 0.0:
for lang_pair in self.args.lang_pairs: for lang_pair in self.lang_pairs:
forward_backward(model.models[lang_pair], sample[lang_pair], lang_pair, self.lambda_parallel) forward_backward(model.models[lang_pair], sample[lang_pair], lang_pair, self.lambda_parallel)
if self.lambda_otf_bt > 0.0: if self.lambda_otf_bt > 0.0:
for lang_pair in self.args.lang_pairs: for lang_pair in self.lang_pairs:
sample_key = _get_bt_dataset_key(lang_pair) sample_key = _get_bt_dataset_key(lang_pair)
forward_backward(model.models[lang_pair], sample[sample_key], sample_key, self.lambda_otf_bt) forward_backward(model.models[lang_pair], sample[sample_key], sample_key, self.lambda_otf_bt)
if self.lambda_denoising > 0.0: if self.lambda_denoising > 0.0:
for lang_pair in self.args.lang_pairs: for lang_pair in self.lang_pairs:
_, tgt = lang_pair.split('-') _, tgt = lang_pair.split('-')
sample_key = _get_denoising_dataset_key(lang_pair) sample_key = _get_denoising_dataset_key(lang_pair)
forward_backward(model.models['{0}-{0}'.format(tgt)], sample[sample_key], sample_key, self.lambda_denoising) forward_backward(model.models['{0}-{0}'.format(tgt)], sample[sample_key], sample_key, self.lambda_denoising)
...@@ -395,12 +395,12 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -395,12 +395,12 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
for logging_output in logging_outputs for logging_output in logging_outputs
for key in logging_output for key in logging_output
} }
lang_pair_keys = set(self.args.lang_pairs + [ lang_pair_keys = set(self.lang_pairs + [
_get_bt_dataset_key(lang_pair) _get_bt_dataset_key(lang_pair)
for lang_pair in self.args.lang_pairs for lang_pair in self.lang_pairs
] + [ ] + [
_get_denoising_dataset_key(lang_pair) _get_denoising_dataset_key(lang_pair)
for lang_pair in self.args.lang_pairs for lang_pair in self.lang_pairs
]) ])
logging_output_keys = logging_output_keys.intersection(lang_pair_keys) logging_output_keys = logging_output_keys.intersection(lang_pair_keys)
return super().aggregate_logging_outputs(logging_outputs, criterion, logging_output_keys) return super().aggregate_logging_outputs(logging_outputs, criterion, logging_output_keys)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment