Commit c11aaf14 authored by Matt Le's avatar Matt Le Committed by Facebook Github Bot
Browse files

Fix semisupervised translation

Summary: Fixes semisupervised translation task to deal with change in order of data loading and model creation (D15428242).  When we build the model, we create the backtranslation function, which we can then pass in to the constructor of BacktranslationDataset

Reviewed By: myleott

Differential Revision: D15455420

fbshipit-source-id: 95101ca92f8af33702be3416147edd98da135a20
parent 886ef6bc
...@@ -347,5 +347,6 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -347,5 +347,6 @@ class MultilingualTranslationTask(FairseqTask):
(self.args.max_source_positions, self.args.max_target_positions)} (self.args.max_source_positions, self.args.max_target_positions)}
return OrderedDict([ return OrderedDict([
(key, (self.args.max_source_positions, self.args.max_target_positions)) (key, (self.args.max_source_positions, self.args.max_target_positions))
for key in next(iter(self.datasets.values())).datasets.keys() for split in self.datasets.keys()
for key in self.datasets[split].datasets.keys()
]) ])
...@@ -126,6 +126,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -126,6 +126,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
] ]
self.model_lang_pairs += denoising_lang_pairs self.model_lang_pairs += denoising_lang_pairs
self.backtranslate_datasets = {} self.backtranslate_datasets = {}
self.backtranslators = {}
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
...@@ -210,6 +211,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -210,6 +211,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
src_lang=tgt, src_lang=tgt,
tgt_lang=src, tgt_lang=src,
), ),
backtranslation_fn=self.backtranslators[lang_pair],
src_dict=self.dicts[src], tgt_dict=self.dicts[tgt], src_dict=self.dicts[src], tgt_dict=self.dicts[tgt],
output_collater=self.alter_dataset_langtok( output_collater=self.alter_dataset_langtok(
lang_pair_dataset=lang_pair_dataset, lang_pair_dataset=lang_pair_dataset,
...@@ -324,7 +326,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask): ...@@ -324,7 +326,7 @@ class SemisupervisedTranslationTask(MultilingualTranslationTask):
sample, sample,
bos_token=bos_token, bos_token=bos_token,
) )
self.backtranslate_datasets[lang_pair].set_backtranslation_fn(backtranslate_fn) self.backtranslators[lang_pair] = backtranslate_fn
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment