"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "4ff8a691bcef296aa976e19d0ba9c7b74ae9f27c"
Commit 40ac340b authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Eval and log on a subset of directions for multimodel training (#605)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/605

Eval and log on a subset of directions for multimodel training

This reduces code duplication in PyTorch Translate's semi_supervised task and will enable clean multitask setups in the future.

Reviewed By: pipibjc, dpacgopinath

Differential Revision: D14672779

fbshipit-source-id: 1342c71781f0824cc56a38ad1c1822e34eaef337
parent f492db25
...@@ -77,6 +77,12 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -77,6 +77,12 @@ class MultilingualTranslationTask(FairseqTask):
super().__init__(args) super().__init__(args)
self.dicts = dicts self.dicts = dicts
self.lang_pairs = args.lang_pairs self.lang_pairs = args.lang_pairs
# eval_lang_pairs for multilingual translation is usually all of the
# lang_pairs. However for other multitask settings or when we want to
# optimize for certain languages we want to use a different subset. Thus
# the eval_lang_pairs class variable is provided for classes that extend
# this class.
self.eval_lang_pairs = args.lang_pairs
self.langs = list(dicts.keys()) self.langs = list(dicts.keys())
self.training = training self.training = training
...@@ -205,7 +211,7 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -205,7 +211,7 @@ class MultilingualTranslationTask(FairseqTask):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
agg_loss, agg_sample_size, agg_logging_output = 0., 0., {} agg_loss, agg_sample_size, agg_logging_output = 0., 0., {}
for lang_pair in self.args.lang_pairs: for lang_pair in self.eval_lang_pairs:
if sample[lang_pair] is None or len(sample[lang_pair]) == 0: if sample[lang_pair] is None or len(sample[lang_pair]) == 0:
continue continue
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
...@@ -236,7 +242,7 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -236,7 +242,7 @@ class MultilingualTranslationTask(FairseqTask):
lang_pair: criterion.__class__.aggregate_logging_outputs([ lang_pair: criterion.__class__.aggregate_logging_outputs([
logging_output.get(lang_pair, {}) for logging_output in logging_outputs logging_output.get(lang_pair, {}) for logging_output in logging_outputs
]) ])
for lang_pair in self.args.lang_pairs for lang_pair in self.eval_lang_pairs
} }
def sum_over_languages(key): def sum_over_languages(key):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment