Commit d7e19573 authored by Peng-Jen Chen's avatar Peng-Jen Chen Committed by Facebook Github Bot
Browse files

Back translation + denoising in MultilingualTranslation task (#620)

Summary:
- Add language token to MultilingualTranslation task
- Add back translation and denoising loss to MultilingualTranslation task
Pull Request resolved: https://github.com/pytorch/fairseq/pull/620

Reviewed By: liezl200

Differential Revision: D14756873

Pulled By: pipibjc

fbshipit-source-id: 89d668db26848fd95f446edf5923bab2113636f7
parent c2820af0
......@@ -13,9 +13,11 @@ from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTex
from .language_pair_dataset import LanguagePairDataset
from .lm_context_window_dataset import LMContextWindowDataset
from .monolingual_dataset import MonolingualDataset
from .noising import NoisingDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
from .iterators import (
CountingIterator,
......@@ -38,8 +40,10 @@ __all__ = [
'LanguagePairDataset',
'LMContextWindowDataset',
'MonolingualDataset',
'NoisingDataset',
'RoundRobinZipDatasets',
'ShardedIterator',
'TokenBlockDataset',
'TransformEosDataset',
'TransformEosLangPairDataset',
]
......@@ -10,6 +10,7 @@ import torch
from fairseq import utils
from . import FairseqDataset
from .language_pair_dataset import collate as language_pair_collate, generate_dummy_batch
def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
......@@ -36,22 +37,18 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
"""
collated_samples = collate_fn(samples)
s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
generated_sources = generate_fn(s['net_input'])
generated_sources = generate_fn(s)
def update_sample(sample, generated_source):
sample['target'] = sample['source'] # the original source becomes the target
sample['source'] = generated_source
return sample
id_to_src = {
sample['id']: sample['source'] for sample in samples
}
# Go through each tgt sentence in batch and its corresponding best
# generated hypothesis and create a backtranslation data pair
# {id: id, source: generated backtranslation, target: original tgt}
return [
update_sample(
sample=input_sample,
generated_source=hypos[0]['tokens'].cpu(), # highest scoring hypo is first
)
for input_sample, hypos in zip(samples, generated_sources)
{'id': id.item(), 'target': id_to_src[id.item()], 'source': hypos[0]['tokens'].cpu()}
for id, hypos in zip(collated_samples['id'], generated_sources)
]
......@@ -66,9 +63,15 @@ class BacktranslationDataset(FairseqDataset):
backtranslated. Only the source side of this dataset will be used.
After backtranslation, the source sentences in this dataset will be
returned as the targets.
backtranslation_fn (callable): function to call to generate
src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
sentences.
tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
sentences to be backtranslated.
backtranslation_fn (callable, optional): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
Pass in None when it is not available at initialization time, and
use set_backtranslation_fn function to set it when available.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``).
......@@ -78,7 +81,9 @@ class BacktranslationDataset(FairseqDataset):
def __init__(
self,
tgt_dataset,
backtranslation_fn,
src_dict,
tgt_dict=None,
backtranslation_fn=None,
output_collater=None,
cuda=True,
**kwargs
......@@ -88,6 +93,8 @@ class BacktranslationDataset(FairseqDataset):
self.output_collater = output_collater if output_collater is not None \
else tgt_dataset.collater
self.cuda = cuda if torch.cuda.is_available() else False
self.src_dict = src_dict
self.tgt_dict = tgt_dict
def __getitem__(self, index):
"""
......@@ -100,6 +107,9 @@ class BacktranslationDataset(FairseqDataset):
def __len__(self):
return len(self.tgt_dataset)
def set_backtranslation_fn(self, backtranslation_fn):
self.backtranslation_fn = backtranslation_fn
def collater(self, samples):
"""Merge and backtranslate a list of samples to form a mini-batch.
......@@ -119,6 +129,8 @@ class BacktranslationDataset(FairseqDataset):
Returns:
dict: a mini-batch with keys coming from *output_collater*
"""
if samples[0].get('is_dummy', False):
return samples
samples = backtranslate_samples(
samples=samples,
collate_fn=self.tgt_dataset.collater,
......@@ -131,7 +143,16 @@ class BacktranslationDataset(FairseqDataset):
def get_dummy_batch(self, num_tokens, max_positions):
"""Just use the tgt dataset get_dummy_batch"""
return self.tgt_dataset.get_dummy_batch(num_tokens, max_positions)
def collate_fn(samples):
return language_pair_collate(
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
input_feeding=True,
)
dummy_batch = generate_dummy_batch(
num_tokens, collate_fn,
self.src_dict, tgt_dict=self.tgt_dict)
dummy_batch['is_dummy'] = True
return dummy_batch
def num_tokens(self, index):
"""Just use the tgt dataset num_tokens"""
......
......@@ -68,6 +68,19 @@ def collate(
return batch
def generate_dummy_batch(num_tokens, collate_fn, src_dict, src_len=128, tgt_dict=None, tgt_len=128):
"""Return a dummy batch with a given number of tokens."""
bsz = num_tokens // max(src_len, tgt_len)
return collate_fn([
{
'id': i,
'source': src_dict.dummy_sentence(src_len),
'target': tgt_dict.dummy_sentence(tgt_len) if tgt_dict is not None else None,
}
for i in range(bsz)
])
class LanguagePairDataset(FairseqDataset):
"""
A pair of torch.utils.data.Datasets.
......@@ -192,15 +205,7 @@ class LanguagePairDataset(FairseqDataset):
max_positions,
(self.max_source_positions, self.max_target_positions),
)
bsz = max(num_tokens // max(src_len, tgt_len), 1)
return self.collater([
{
'id': i,
'source': self.src_dict.dummy_sentence(src_len),
'target': self.tgt_dict.dummy_sentence(tgt_len) if self.tgt_dict is not None else None,
}
for i in range(bsz)
])
return generate_dummy_batch(num_tokens, self.collater, self.src_dict, src_len, self.tgt_dict, tgt_len)
def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to
......@@ -227,9 +232,10 @@ class LanguagePairDataset(FairseqDataset):
def supports_prefetch(self):
return (
getattr(self.src, 'supports_prefetch', False)
and getattr(self.tgt, 'supports_prefetch', False)
and (getattr(self.tgt, 'supports_prefetch', False) or self.tgt is None)
)
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
if self.tgt is not None:
self.tgt.prefetch(indices)
......@@ -301,3 +301,11 @@ class NoisingDataset(torch.utils.data.Dataset):
The length of the noising dataset is the length of src.
"""
return len(self.src_dataset)
@property
def supports_prefetch(self):
return self.src_dataset.supports_prefetch
def prefetch(self, indices):
if self.src_dataset.supports_prefetch:
self.src_dataset.prefetch(indices)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from . import FairseqDataset
from typing import Optional
class TransformEosLangPairDataset(FairseqDataset):
"""A :class:`~fairseq.data.FairseqDataset` wrapper that transform bos on
collated samples of language pair dataset.
Note that the transformation is applied in :func:`collater`.
Args:
dataset (~fairseq.data.FairseqDataset): dataset that collates sample into
LanguagePairDataset schema
src_eos (int): original source end-of-sentence symbol index to be replaced
new_src_eos (int, optional): new end-of-sentence symbol index to replace source eos symbol
tgt_bos (int, optional): original target beginning-of-sentence symbol index to be replaced
new_tgt_bos (int, optional): new beginning-of-sentence symbol index to replace at the
beginning of 'prev_output_tokens'
"""
def __init__(
self,
dataset: FairseqDataset,
src_eos: int,
new_src_eos: Optional[int] = None,
tgt_bos: Optional[int] = None,
new_tgt_bos: Optional[int] = None,
):
self.dataset = dataset
self.src_eos = src_eos
self.new_src_eos = new_src_eos
self.tgt_bos = tgt_bos
self.new_tgt_bos = new_tgt_bos
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
samples = self.dataset.collater(samples)
# TODO: support different padding direction
if self.new_src_eos is not None:
assert(samples['net_input']['src_tokens'][:, -1] != self.src_eos).sum() == 0
samples['net_input']['src_tokens'][:, -1] = self.new_src_eos
if self.new_tgt_bos is not None:
assert (samples['net_input']['prev_output_tokens'][:, 0] != self.tgt_bos).sum() == 0
samples['net_input']['prev_output_tokens'][:, 0] = self.new_tgt_bos
return samples
def get_dummy_batch(self, *args, **kwargs):
return self.dataset.get_dummy_batch(*args, **kwargs)
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
def ordered_indices(self):
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
......@@ -67,8 +67,8 @@ class MultilingualTransformerModel(FairseqMultiModel):
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = 1024
src_langs = [lang_pair.split('-')[0] for lang_pair in task.lang_pairs]
tgt_langs = [lang_pair.split('-')[1] for lang_pair in task.lang_pairs]
src_langs = [lang_pair.split('-')[0] for lang_pair in task.model_lang_pairs]
tgt_langs = [lang_pair.split('-')[1] for lang_pair in task.model_lang_pairs]
if args.share_encoders:
args.share_encoder_embeddings = True
......@@ -158,7 +158,7 @@ class MultilingualTransformerModel(FairseqMultiModel):
shared_decoder = get_decoder(tgt_langs[0])
encoders, decoders = OrderedDict(), OrderedDict()
for lang_pair, src, tgt in zip(task.lang_pairs, src_langs, tgt_langs):
for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs):
encoders[lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(src)
decoders[lang_pair] = shared_decoder if shared_decoder is not None else get_decoder(tgt)
......@@ -166,7 +166,7 @@ class MultilingualTransformerModel(FairseqMultiModel):
def load_state_dict(self, state_dict, strict=True):
state_dict_subset = state_dict.copy()
for k, v in state_dict.items():
for k, _ in state_dict.items():
assert k.startswith('models.')
lang_pair = k.split('.')[1]
if lang_pair not in self.models:
......
......@@ -241,6 +241,11 @@ class FairseqTask(object):
with torch.no_grad():
return generator.generate(models, sample, prefix_tokens=prefix_tokens)
def update_step(self, num_updates):
"""Task level update when number of update increases. This is called after optimization step and
learning rate update of each step"""
pass
def grad_denom(self, sample_sizes, criterion):
return criterion.__class__.grad_denom(sample_sizes)
......
......@@ -6,24 +6,41 @@
# can be found in the PATENTS file in the same directory.
from collections import OrderedDict
import copy
import os
import torch
from fairseq import options
from fairseq.data import (
BacktranslationDataset,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
NoisingDataset,
RoundRobinZipDatasets,
TransformEosLangPairDataset,
)
from fairseq.models import FairseqMultiModel
from . import FairseqTask, register_task
def _lang_token(lang: str):
return f'__{lang}__'
def _lang_token_index(dic: Dictionary, lang: str):
"""Return language token index."""
idx = dic.index(_lang_token(lang))
assert idx != dic.unk_index, \
f'cannot find language token for lang {lang}'
return idx
@register_task('multilingual_translation')
class MultilingualTranslationTask(FairseqTask):
"""A task for training multiple translation models simultaneously.
......@@ -71,6 +88,12 @@ class MultilingualTranslationTask(FairseqTask):
help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
parser.add_argument('--encoder-langtok', default=None, type=str, choices=['src', 'tgt'],
metavar='SRCTGT',
help='replace beginning-of-sentence in source sentence with source or target '
'language token. (src/tgt)')
parser.add_argument('--decoder-langtok', action='store_true',
help='replace beginning-of-sentence in target sentence with target language token')
# fmt: on
def __init__(self, args, dicts, training):
......@@ -83,40 +106,84 @@ class MultilingualTranslationTask(FairseqTask):
# the eval_lang_pairs class variable is provided for classes that extend
# this class.
self.eval_lang_pairs = args.lang_pairs
# model_lang_pairs will be used to build encoder-decoder model pairs in
# models.build_model(). This allows multitask type of sub-class can
# build models other than the input lang_pairs
self.model_lang_pairs = copy.copy(args.lang_pairs)
self.langs = list(dicts.keys())
self.training = training
@classmethod
def setup_task(cls, args, **kwargs):
dicts, training = cls.prepare(args, **kwargs)
return cls(args, dicts, training)
@classmethod
def prepare(cls, args, **kargs):
args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target)
args.lang_pairs = args.lang_pairs.split(',')
sorted_langs = sorted(list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')}))
if args.source_lang is not None or args.target_lang is not None:
if args.lang_pairs is not None:
raise ValueError(
'--source-lang/--target-lang implies generation, which is '
'incompatible with --lang-pairs'
)
training = False
args.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
else:
training = True
args.lang_pairs = args.lang_pairs.split(',')
args.source_lang, args.target_lang = args.lang_pairs[0].split('-')
langs = list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')})
# load dictionaries
dicts = OrderedDict()
for lang in langs:
for lang in sorted_langs:
dicts[lang] = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(lang)))
if len(dicts) > 0:
assert dicts[lang].pad() == dicts[langs[0]].pad()
assert dicts[lang].eos() == dicts[langs[0]].eos()
assert dicts[lang].unk() == dicts[langs[0]].unk()
assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
if args.encoder_langtok is not None or args.decoder_langtok:
for lang_to_add in sorted_langs:
dicts[lang].add_symbol(_lang_token(lang_to_add))
print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang])))
return dicts, training
return cls(args, dicts, training)
def get_encoder_langtok(self, src_lang, tgt_lang):
if self.args.encoder_langtok is None:
return self.dicts[src_lang].eos()
if self.args.encoder_langtok == 'src':
return _lang_token_index(self.dicts[src_lang], src_lang)
else:
return _lang_token_index(self.dicts[src_lang], tgt_lang)
def get_decoder_langtok(self, tgt_lang):
if not self.args.decoder_langtok:
return self.dicts[tgt_lang].eos()
return _lang_token_index(self.dicts[tgt_lang], tgt_lang)
def alter_dataset_langtok(self, lang_pair_dataset,
src_eos=None, src_lang=None, tgt_eos=None, tgt_lang=None):
if self.args.encoder_langtok is None and not self.args.decoder_langtok:
return lang_pair_dataset
new_src_eos = None
if self.args.encoder_langtok is not None and src_eos is not None \
and src_lang is not None and tgt_lang is not None:
new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang)
else:
src_eos = None
new_tgt_bos = None
if self.args.decoder_langtok and tgt_eos is not None and tgt_lang is not None:
new_tgt_bos = self.get_decoder_langtok(tgt_lang)
else:
tgt_eos = None
return TransformEosLangPairDataset(
lang_pair_dataset,
src_eos=src_eos,
new_src_eos=new_src_eos,
tgt_bos=tgt_eos,
new_tgt_bos=new_tgt_bos,
)
def load_dataset(self, split, **kwargs):
"""Load a dataset split."""
......@@ -158,13 +225,18 @@ class MultilingualTranslationTask(FairseqTask):
def language_pair_dataset(lang_pair):
src, tgt = lang_pair.split('-')
src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
return LanguagePairDataset(
src_dataset, src_dataset.sizes, self.dicts[src],
tgt_dataset, tgt_dataset.sizes, self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
return self.alter_dataset_langtok(
LanguagePairDataset(
src_dataset, src_dataset.sizes, self.dicts[src],
tgt_dataset, tgt_dataset.sizes, self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
),
src_eos=self.dicts[tgt].eos(),
src_lang=src,
tgt_lang=tgt,
)
self.datasets[split] = RoundRobinZipDatasets(
......@@ -178,9 +250,18 @@ class MultilingualTranslationTask(FairseqTask):
def build_dataset_for_inference(self, src_tokens, src_lengths):
lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang)
return RoundRobinZipDatasets(
OrderedDict([
(lang_pair, LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary))
]),
OrderedDict([(
lang_pair,
self.alter_dataset_langtok(
LanguagePairDataset(
src_tokens, src_lengths,
self.source_dictionary
),
src_eos=self.source_dictionary.eos(),
src_lang=self.args.source_lang,
tgt_lang=self.args.target_lang,
),
)]),
eval_key=lang_pair,
)
......@@ -212,7 +293,7 @@ class MultilingualTranslationTask(FairseqTask):
with torch.no_grad():
agg_loss, agg_sample_size, agg_logging_output = 0., 0., {}
for lang_pair in self.eval_lang_pairs:
if sample[lang_pair] is None or len(sample[lang_pair]) == 0:
if lang_pair not in sample or sample[lang_pair] is None or len(sample[lang_pair]) == 0:
continue
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
agg_loss += loss.data.item()
......@@ -221,6 +302,16 @@ class MultilingualTranslationTask(FairseqTask):
agg_logging_output[lang_pair] = logging_output
return agg_loss, agg_sample_size, agg_logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None):
with torch.no_grad():
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
bos_token=_lang_token_index(self.target_dictionary, self.args.target_lang)
if self.args.decoder_langtok else self.target_dictionary.eos(),
)
def init_logging_output(self, sample):
return {
'ntokens': sum(
......@@ -236,13 +327,14 @@ class MultilingualTranslationTask(FairseqTask):
def grad_denom(self, sample_sizes, criterion):
return criterion.__class__.grad_denom(sample_sizes)
def aggregate_logging_outputs(self, logging_outputs, criterion):
def aggregate_logging_outputs(self, logging_outputs, criterion, logging_output_keys=None):
logging_output_keys = logging_output_keys or self.eval_lang_pairs
# aggregate logging outputs for each language pair
agg_logging_outputs = {
lang_pair: criterion.__class__.aggregate_logging_outputs([
logging_output.get(lang_pair, {}) for logging_output in logging_outputs
key: criterion.__class__.aggregate_logging_outputs([
logging_output.get(key, {}) for logging_output in logging_outputs
])
for lang_pair in self.eval_lang_pairs
for key in logging_output_keys
}
def sum_over_languages(key):
......@@ -269,3 +361,13 @@ class MultilingualTranslationTask(FairseqTask):
@property
def target_dictionary(self):
return self.dicts[self.args.target_lang]
def max_positions(self):
"""Return the max sentence length allowed by the task."""
if len(self.datasets.values()) == 0:
return {'%s-%s' % (self.args.source_lang, self.args.target_lang):
(self.args.max_source_positions, self.args.max_target_positions)}
return OrderedDict([
(key, (self.args.max_source_positions, self.args.max_target_positions))
for key in next(iter(self.datasets.values())).datasets.keys()
])
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import OrderedDict
import os
from fairseq.data import (
BacktranslationDataset,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
NoisingDataset,
RoundRobinZipDatasets,
)
from fairseq.models import FairseqMultiModel
from fairseq.sequence_generator import SequenceGenerator
from .multilingual_translation import MultilingualTranslationTask
from . import register_task
def _get_bt_dataset_key(lang_pair):
return "bt:" + lang_pair
def _get_denoising_dataset_key(lang_pair):
return "denoising:" + lang_pair
# ported from UnsupervisedMT
def parse_lambda_config(x):
"""
Parse the configuration of lambda coefficient (for scheduling).
x = "3" # lambda will be a constant equal to x
x = "0:1,1000:0" # lambda will start from 1 and linearly decrease
# to 0 during the first 1000 iterations
x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000
# iterations, then will linearly increase to 1 until iteration 2000
"""
split = x.split(',')
if len(split) == 1:
return float(x), None
else:
split = [s.split(':') for s in split]
assert all(len(s) == 2 for s in split)
assert all(k.isdigit() for k, _ in split)
assert all(int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1))
return float(split[0][1]), [(int(k), float(v)) for k, v in split]
@register_task('semisupervised_translation')
class SemisupervisedTranslationTask(MultilingualTranslationTask):
"""A task for training multiple translation models simultaneously.
We iterate round-robin over batches from multiple language pairs, ordered
according to the `--lang-pairs` argument.
The training loop is roughly:
for i in range(len(epoch)):
for lang_pair in args.lang_pairs:
batch = next_batch_for_lang_pair(lang_pair)
loss = criterion(model_for_lang_pair(lang_pair), batch)
loss.backward()
optimizer.step()
In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset
(e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that
implements the `FairseqMultiModel` interface.
During inference it is required to specify a single `--source-lang` and
`--target-lang`, instead of `--lang-pairs`.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
MultilingualTranslationTask.add_args(parser)
parser.add_argument('--lambda-parallel-config', default="1.0", type=str, metavar='CONFIG',
help='cross-entropy reconstruction coefficient (parallel data). '
'use fixed weight during training if set to floating point number. '
'use piecewise linear function over number of updates to schedule the '
'weight with the format: w0:step0,w1:step1,...')
parser.add_argument('--lambda-denoising-config', default="0.0", type=str, metavar='CONFIG',
help='Cross-entropy reconstruction coefficient (denoising autoencoding)'
'use fixed weight during training if set to floating point number. '
'use piecewise linear function over number of updates to schedule the '
'weight with the format: w0:step0,w1:step1,...')
parser.add_argument('--lambda-otf-bt-config', default="0.0", type=str, metavar='CONFIG',
help='cross-entropy reconstruction coefficient (on-the-fly back-translation parallel data)'
'use fixed weight during training if set to floating point number. '
'use piecewise linear function over number of updates to schedule the '
'weight with the format: w0:step0,w1:step1,...')
parser.add_argument('--bt-max-len-a', default=1.1, type=float, metavar='N',
help='generate back-translated sequences of maximum length ax + b, where x is the '
'source length')
parser.add_argument('--bt-max-len-b', default=10.0, type=float, metavar='N',
help='generate back-translated sequences of maximum length ax + b, where x is the '
'source length')
parser.add_argument('--bt-beam-size', default=1, type=int, metavar='N',
help='beam size used in beam search of online back-translation')
parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N',
help='maximum word shuffle distance for denoising autoencoding data generation')
parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N',
help='word dropout probability for denoising autoencoding data generation')
parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N',
help='word blanking probability for denoising autoencoding data generation')
# fmt: on
def __init__(self, args, dicts, training):
super().__init__(args, dicts, training)
self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config(args.lambda_parallel_config)
self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config(args.lambda_otf_bt_config)
self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config(args.lambda_denoising_config)
if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None):
denoising_lang_pairs = [
"%s-%s" % (tgt, tgt)
for tgt in {lang_pair.split('-')[1] for lang_pair in args.lang_pairs}
]
self.model_lang_pairs += denoising_lang_pairs
self.backtranslate_datasets = {}
@classmethod
def setup_task(cls, args, **kwargs):
dicts, training = MultilingualTranslationTask.prepare(args, **kwargs)
return cls(args, dicts, training)
def load_dataset(self, split, **kwargs):
"""Load a dataset split."""
def split_exists(split, src, tgt, lang):
if src is not None:
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
else:
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, src, tgt))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedDataset.exists(filename):
return True
return False
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedDataset.exists(path):
if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None
# load parallel datasets
src_datasets, tgt_datasets = {}, {}
if (self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train")):
for lang_pair in self.args.lang_pairs:
src, tgt = lang_pair.split('-')
if split_exists(split, src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(split, tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
else:
continue
src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src])
tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt])
print('| parallel-{} {} {} examples'.format(self.args.data, split, len(src_datasets[lang_pair])))
if len(src_datasets) == 0:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
# back translation datasets
backtranslate_datasets = {}
if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and split.startswith("train"):
for lang_pair in self.args.lang_pairs:
src, tgt = lang_pair.split('-')
if not split_exists(split, tgt, None, tgt):
raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, self.args.data))
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, tgt, tgt))
dataset = indexed_dataset(filename, self.dicts[tgt])
lang_pair_dataset_tgt = LanguagePairDataset(
dataset,
dataset.sizes,
self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
)
lang_pair_dataset = LanguagePairDataset(
dataset,
dataset.sizes,
src_dict=self.dicts[src],
tgt=dataset,
tgt_sizes=dataset.sizes,
tgt_dict=self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
)
backtranslate_datasets[lang_pair] = BacktranslationDataset(
tgt_dataset=self.alter_dataset_langtok(
lang_pair_dataset_tgt,
src_eos=self.dicts[tgt].eos(),
src_lang=tgt,
tgt_lang=src,
),
src_dict=self.dicts[src], tgt_dict=self.dicts[tgt],
output_collater=self.alter_dataset_langtok(
lang_pair_dataset=lang_pair_dataset,
src_eos=self.dicts[src].eos(),
src_lang=src,
tgt_eos=self.dicts[tgt].eos(),
tgt_lang=tgt,
).collater,
)
print('| backtranslate-{}: {} {} {} examples'.format(
tgt, self.args.data, split, len(backtranslate_datasets[lang_pair]),
))
self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair]
# denoising autoencoder
noising_datasets = {}
if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None) and split.startswith("train"):
for lang_pair in self.args.lang_pairs:
_, tgt = lang_pair.split('-')
if not split_exists(split, tgt, None, tgt):
continue
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, tgt, tgt))
tgt_dataset1 = indexed_dataset(filename, self.dicts[tgt])
tgt_dataset2 = indexed_dataset(filename, self.dicts[tgt])
noising_dataset = NoisingDataset(
tgt_dataset1,
self.dicts[tgt],
seed=1,
max_word_shuffle_distance=self.args.max_word_shuffle_distance,
word_dropout_prob=self.args.word_dropout_prob,
word_blanking_prob=self.args.word_blanking_prob,
)
noising_datasets[lang_pair] = self.alter_dataset_langtok(
LanguagePairDataset(
noising_dataset,
tgt_dataset1.sizes,
self.dicts[tgt],
tgt_dataset2,
tgt_dataset2.sizes,
self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
),
src_eos=self.dicts[tgt].eos(),
src_lang=tgt,
tgt_eos=self.dicts[tgt].eos(),
tgt_lang=tgt,
)
print('| denoising-{}: {} {} {} examples'.format(
tgt, self.args.data, split, len(noising_datasets[lang_pair]),
))
def language_pair_dataset(lang_pair):
src, tgt = lang_pair.split('-')
src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
return self.alter_dataset_langtok(
LanguagePairDataset(
src_dataset, src_dataset.sizes, self.dicts[src],
tgt_dataset, tgt_dataset.sizes, self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
),
self.dicts[src].eos(),
src,
self.dicts[tgt].eos(),
tgt,
)
self.datasets[split] = RoundRobinZipDatasets(
OrderedDict([
(lang_pair, language_pair_dataset(lang_pair))
for lang_pair in src_datasets.keys()
] + [
(_get_bt_dataset_key(lang_pair), dataset)
for lang_pair, dataset in backtranslate_datasets.items()
] + [
(_get_denoising_dataset_key(lang_pair), dataset)
for lang_pair, dataset in noising_datasets.items()
]),
eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang),
)
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
if not isinstance(model, FairseqMultiModel):
raise ValueError('SemisupervisedTranslationTask requires a FairseqMultiModel architecture')
# create SequenceGenerator for each model that has backtranslation dependency on it
self.sequence_generators = {}
if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and self.training:
for lang_pair in self.args.lang_pairs:
src, tgt = lang_pair.split('-')
key = '{}-{}'.format(tgt, src)
self.sequence_generators[key] = SequenceGenerator(
tgt_dict=self.dicts[src],
beam_size=args.bt_beam_size,
max_len_a=args.bt_max_len_a,
max_len_b=args.bt_max_len_b,
)
decoder_lang_tok_idx = self.get_decoder_langtok(src)
def backtranslate_fn(
sample, model=model.models[key],
bos_token=decoder_lang_tok_idx,
sequence_generator=self.sequence_generators[key],
):
return sequence_generator.generate(
[model],
sample,
bos_token=bos_token,
)
self.backtranslate_datasets[lang_pair].set_backtranslation_fn(backtranslate_fn)
return model
def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
model.train()
agg_loss, agg_sample_size, agg_logging_output = 0., 0., {}
def forward_backward(model, samples, logging_output_key, weight):
nonlocal agg_loss, agg_sample_size, agg_logging_output
if samples is None or len(samples) == 0:
return
loss, sample_size, logging_output = criterion(model, samples)
if ignore_grad:
loss *= 0
else:
loss *= weight
optimizer.backward(loss)
agg_loss += loss.detach().item()
# TODO make summing of the sample sizes configurable
agg_sample_size += sample_size
agg_logging_output[logging_output_key] = logging_output
if self.lambda_parallel > 0.0:
for lang_pair in self.args.lang_pairs:
forward_backward(model.models[lang_pair], sample[lang_pair], self.lambda_parallel)
if self.lambda_otf_bt > 0.0:
for lang_pair in self.args.lang_pairs:
sample_key = _get_bt_dataset_key(lang_pair)
forward_backward(model.models[lang_pair], sample[sample_key], sample_key, self.lambda_otf_bt)
if self.lambda_denoising > 0.0:
for lang_pair in self.args.lang_pairs:
_, tgt = lang_pair.split('-')
sample_key = _get_denoising_dataset_key(lang_pair)
forward_backward(model.models[f'{tgt}-{tgt}'], sample[sample_key], sample_key, self.lambda_denoising)
return agg_loss, agg_sample_size, agg_logging_output
def update_step(self, num_updates):
def lambda_step_func(config, n_iter):
"""
Update a lambda value according to its schedule configuration.
"""
ranges = [i for i in range(len(config) - 1) if config[i][0] <= n_iter < config[i + 1][0]]
if len(ranges) == 0:
assert n_iter >= config[-1][0]
return config[-1][1]
assert len(ranges) == 1
i = ranges[0]
x_a, y_a = config[i]
x_b, y_b = config[i + 1]
return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a)
if self.lambda_parallel_steps is not None:
self.lambda_parallel = lambda_step_func(self.lambda_parallel_steps, num_updates)
if self.lambda_denoising_steps is not None:
self.lambda_denoising = lambda_step_func(self.lambda_denoising_steps, num_updates)
if self.lambda_otf_bt_steps is not None:
self.lambda_otf_bt = lambda_step_func(self.lambda_otf_bt_steps, num_updates)
def aggregate_logging_outputs(self, logging_outputs, criterion):
# aggregate logging outputs for each language pair
logging_output_keys = {
key
for logging_output in logging_outputs
for key in logging_output
}
lang_pair_keys = set(self.args.lang_pairs + [
_get_bt_dataset_key(lang_pair)
for lang_pair in self.args.lang_pairs
] + [
_get_denoising_dataset_key(lang_pair)
for lang_pair in self.args.lang_pairs
])
logging_output_keys = logging_output_keys.intersection(lang_pair_keys)
return super().aggregate_logging_outputs(logging_outputs, criterion, logging_output_keys)
......@@ -259,6 +259,9 @@ class Trainer(object):
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
# task specific update per step
self.task.update_step(self._num_updates)
# update meters
ntokens = logging_output.get('ntokens', 0)
nsentences = logging_output.get('nsentences', 0)
......
......@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
from collections import defaultdict, OrderedDict
import copy
import importlib.util
import logging
import os
......@@ -417,6 +418,15 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
def resolve_max_positions(*args):
"""Resolve max position constraints from multiple sources."""
def map_value_update(d1, d2):
updated_value = copy.deepcopy(d1)
for key in d2:
if key not in updated_value:
updated_value[key] = d2[key]
else:
updated_value[key] = min(d1[key], d2[key])
return updated_value
def nullsafe_min(l):
minim = None
for item in l:
......@@ -433,10 +443,13 @@ def resolve_max_positions(*args):
elif arg is not None:
if isinstance(arg, float) or isinstance(arg, int):
max_positions = min(max_positions, arg)
elif isinstance(arg, dict):
max_positions = map_value_update(max_positions, arg)
else:
max_positions = tuple(
map(nullsafe_min, zip(max_positions, arg))
)
return max_positions
......
......@@ -59,8 +59,9 @@ class TestBacktranslationDataset(unittest.TestCase):
# remove eos from the input src
remove_eos_from_src=remove_eos_from_input_src,
),
src_dict=self.tgt_dict,
backtranslation_fn=(
lambda net_input: generator.generate([self.model], {'net_input': net_input})
lambda sample: generator.generate([self.model], sample)
),
output_collater=TransformEosDataset(
dataset=tgt_dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment