Commit d7e19573 authored by Peng-Jen Chen's avatar Peng-Jen Chen Committed by Facebook Github Bot
Browse files

Back translation + denoising in MultilingualTranslation task (#620)

Summary:
- Add language token to MultilingualTranslation task
- Add back translation and denoising loss to MultilingualTranslation task
Pull Request resolved: https://github.com/pytorch/fairseq/pull/620

Reviewed By: liezl200

Differential Revision: D14756873

Pulled By: pipibjc

fbshipit-source-id: 89d668db26848fd95f446edf5923bab2113636f7
parent c2820af0
...@@ -13,9 +13,11 @@ from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTex ...@@ -13,9 +13,11 @@ from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTex
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .lm_context_window_dataset import LMContextWindowDataset from .lm_context_window_dataset import LMContextWindowDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .noising import NoisingDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets from .round_robin_zip_datasets import RoundRobinZipDatasets
from .token_block_dataset import TokenBlockDataset from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset from .transform_eos_dataset import TransformEosDataset
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
from .iterators import ( from .iterators import (
CountingIterator, CountingIterator,
...@@ -38,8 +40,10 @@ __all__ = [ ...@@ -38,8 +40,10 @@ __all__ = [
'LanguagePairDataset', 'LanguagePairDataset',
'LMContextWindowDataset', 'LMContextWindowDataset',
'MonolingualDataset', 'MonolingualDataset',
'NoisingDataset',
'RoundRobinZipDatasets', 'RoundRobinZipDatasets',
'ShardedIterator', 'ShardedIterator',
'TokenBlockDataset', 'TokenBlockDataset',
'TransformEosDataset', 'TransformEosDataset',
'TransformEosLangPairDataset',
] ]
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
from fairseq import utils from fairseq import utils
from . import FairseqDataset from . import FairseqDataset
from .language_pair_dataset import collate as language_pair_collate, generate_dummy_batch
def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True): def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
...@@ -36,22 +37,18 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True): ...@@ -36,22 +37,18 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
""" """
collated_samples = collate_fn(samples) collated_samples = collate_fn(samples)
s = utils.move_to_cuda(collated_samples) if cuda else collated_samples s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
generated_sources = generate_fn(s['net_input']) generated_sources = generate_fn(s)
def update_sample(sample, generated_source): id_to_src = {
sample['target'] = sample['source'] # the original source becomes the target sample['id']: sample['source'] for sample in samples
sample['source'] = generated_source }
return sample
# Go through each tgt sentence in batch and its corresponding best # Go through each tgt sentence in batch and its corresponding best
# generated hypothesis and create a backtranslation data pair # generated hypothesis and create a backtranslation data pair
# {id: id, source: generated backtranslation, target: original tgt} # {id: id, source: generated backtranslation, target: original tgt}
return [ return [
update_sample( {'id': id.item(), 'target': id_to_src[id.item()], 'source': hypos[0]['tokens'].cpu()}
sample=input_sample, for id, hypos in zip(collated_samples['id'], generated_sources)
generated_source=hypos[0]['tokens'].cpu(), # highest scoring hypo is first
)
for input_sample, hypos in zip(samples, generated_sources)
] ]
...@@ -66,9 +63,15 @@ class BacktranslationDataset(FairseqDataset): ...@@ -66,9 +63,15 @@ class BacktranslationDataset(FairseqDataset):
backtranslated. Only the source side of this dataset will be used. backtranslated. Only the source side of this dataset will be used.
After backtranslation, the source sentences in this dataset will be After backtranslation, the source sentences in this dataset will be
returned as the targets. returned as the targets.
backtranslation_fn (callable): function to call to generate src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
sentences.
tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
sentences to be backtranslated.
backtranslation_fn (callable, optional): function to call to generate
backtranslations. This is typically the `generate` method of a backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object. :class:`~fairseq.sequence_generator.SequenceGenerator` object.
Pass in None when it is not available at initialization time, and
use set_backtranslation_fn function to set it when available.
output_collater (callable, optional): function to call on the output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``). (default: ``tgt_dataset.collater``).
...@@ -78,7 +81,9 @@ class BacktranslationDataset(FairseqDataset): ...@@ -78,7 +81,9 @@ class BacktranslationDataset(FairseqDataset):
def __init__( def __init__(
self, self,
tgt_dataset, tgt_dataset,
backtranslation_fn, src_dict,
tgt_dict=None,
backtranslation_fn=None,
output_collater=None, output_collater=None,
cuda=True, cuda=True,
**kwargs **kwargs
...@@ -88,6 +93,8 @@ class BacktranslationDataset(FairseqDataset): ...@@ -88,6 +93,8 @@ class BacktranslationDataset(FairseqDataset):
self.output_collater = output_collater if output_collater is not None \ self.output_collater = output_collater if output_collater is not None \
else tgt_dataset.collater else tgt_dataset.collater
self.cuda = cuda if torch.cuda.is_available() else False self.cuda = cuda if torch.cuda.is_available() else False
self.src_dict = src_dict
self.tgt_dict = tgt_dict
def __getitem__(self, index): def __getitem__(self, index):
""" """
...@@ -100,6 +107,9 @@ class BacktranslationDataset(FairseqDataset): ...@@ -100,6 +107,9 @@ class BacktranslationDataset(FairseqDataset):
def __len__(self): def __len__(self):
return len(self.tgt_dataset) return len(self.tgt_dataset)
def set_backtranslation_fn(self, backtranslation_fn):
self.backtranslation_fn = backtranslation_fn
def collater(self, samples): def collater(self, samples):
"""Merge and backtranslate a list of samples to form a mini-batch. """Merge and backtranslate a list of samples to form a mini-batch.
...@@ -119,6 +129,8 @@ class BacktranslationDataset(FairseqDataset): ...@@ -119,6 +129,8 @@ class BacktranslationDataset(FairseqDataset):
Returns: Returns:
dict: a mini-batch with keys coming from *output_collater* dict: a mini-batch with keys coming from *output_collater*
""" """
if samples[0].get('is_dummy', False):
return samples
samples = backtranslate_samples( samples = backtranslate_samples(
samples=samples, samples=samples,
collate_fn=self.tgt_dataset.collater, collate_fn=self.tgt_dataset.collater,
...@@ -131,7 +143,16 @@ class BacktranslationDataset(FairseqDataset): ...@@ -131,7 +143,16 @@ class BacktranslationDataset(FairseqDataset):
def get_dummy_batch(self, num_tokens, max_positions): def get_dummy_batch(self, num_tokens, max_positions):
"""Just use the tgt dataset get_dummy_batch""" """Just use the tgt dataset get_dummy_batch"""
return self.tgt_dataset.get_dummy_batch(num_tokens, max_positions) def collate_fn(samples):
return language_pair_collate(
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
input_feeding=True,
)
dummy_batch = generate_dummy_batch(
num_tokens, collate_fn,
self.src_dict, tgt_dict=self.tgt_dict)
dummy_batch['is_dummy'] = True
return dummy_batch
def num_tokens(self, index): def num_tokens(self, index):
"""Just use the tgt dataset num_tokens""" """Just use the tgt dataset num_tokens"""
......
...@@ -68,6 +68,19 @@ def collate( ...@@ -68,6 +68,19 @@ def collate(
return batch return batch
def generate_dummy_batch(num_tokens, collate_fn, src_dict, src_len=128, tgt_dict=None, tgt_len=128):
"""Return a dummy batch with a given number of tokens."""
bsz = num_tokens // max(src_len, tgt_len)
return collate_fn([
{
'id': i,
'source': src_dict.dummy_sentence(src_len),
'target': tgt_dict.dummy_sentence(tgt_len) if tgt_dict is not None else None,
}
for i in range(bsz)
])
class LanguagePairDataset(FairseqDataset): class LanguagePairDataset(FairseqDataset):
""" """
A pair of torch.utils.data.Datasets. A pair of torch.utils.data.Datasets.
...@@ -192,15 +205,7 @@ class LanguagePairDataset(FairseqDataset): ...@@ -192,15 +205,7 @@ class LanguagePairDataset(FairseqDataset):
max_positions, max_positions,
(self.max_source_positions, self.max_target_positions), (self.max_source_positions, self.max_target_positions),
) )
bsz = max(num_tokens // max(src_len, tgt_len), 1) return generate_dummy_batch(num_tokens, self.collater, self.src_dict, src_len, self.tgt_dict, tgt_len)
return self.collater([
{
'id': i,
'source': self.src_dict.dummy_sentence(src_len),
'target': self.tgt_dict.dummy_sentence(tgt_len) if self.tgt_dict is not None else None,
}
for i in range(bsz)
])
def num_tokens(self, index): def num_tokens(self, index):
"""Return the number of tokens in a sample. This value is used to """Return the number of tokens in a sample. This value is used to
...@@ -227,9 +232,10 @@ class LanguagePairDataset(FairseqDataset): ...@@ -227,9 +232,10 @@ class LanguagePairDataset(FairseqDataset):
def supports_prefetch(self): def supports_prefetch(self):
return ( return (
getattr(self.src, 'supports_prefetch', False) getattr(self.src, 'supports_prefetch', False)
and getattr(self.tgt, 'supports_prefetch', False) and (getattr(self.tgt, 'supports_prefetch', False) or self.tgt is None)
) )
def prefetch(self, indices): def prefetch(self, indices):
self.src.prefetch(indices) self.src.prefetch(indices)
self.tgt.prefetch(indices) if self.tgt is not None:
self.tgt.prefetch(indices)
...@@ -301,3 +301,11 @@ class NoisingDataset(torch.utils.data.Dataset): ...@@ -301,3 +301,11 @@ class NoisingDataset(torch.utils.data.Dataset):
The length of the noising dataset is the length of src. The length of the noising dataset is the length of src.
""" """
return len(self.src_dataset) return len(self.src_dataset)
@property
def supports_prefetch(self):
return self.src_dataset.supports_prefetch
def prefetch(self, indices):
if self.src_dataset.supports_prefetch:
self.src_dataset.prefetch(indices)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from . import FairseqDataset
from typing import Optional
class TransformEosLangPairDataset(FairseqDataset):
"""A :class:`~fairseq.data.FairseqDataset` wrapper that transform bos on
collated samples of language pair dataset.
Note that the transformation is applied in :func:`collater`.
Args:
dataset (~fairseq.data.FairseqDataset): dataset that collates sample into
LanguagePairDataset schema
src_eos (int): original source end-of-sentence symbol index to be replaced
new_src_eos (int, optional): new end-of-sentence symbol index to replace source eos symbol
tgt_bos (int, optional): original target beginning-of-sentence symbol index to be replaced
new_tgt_bos (int, optional): new beginning-of-sentence symbol index to replace at the
beginning of 'prev_output_tokens'
"""
def __init__(
self,
dataset: FairseqDataset,
src_eos: int,
new_src_eos: Optional[int] = None,
tgt_bos: Optional[int] = None,
new_tgt_bos: Optional[int] = None,
):
self.dataset = dataset
self.src_eos = src_eos
self.new_src_eos = new_src_eos
self.tgt_bos = tgt_bos
self.new_tgt_bos = new_tgt_bos
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
samples = self.dataset.collater(samples)
# TODO: support different padding direction
if self.new_src_eos is not None:
assert(samples['net_input']['src_tokens'][:, -1] != self.src_eos).sum() == 0
samples['net_input']['src_tokens'][:, -1] = self.new_src_eos
if self.new_tgt_bos is not None:
assert (samples['net_input']['prev_output_tokens'][:, 0] != self.tgt_bos).sum() == 0
samples['net_input']['prev_output_tokens'][:, 0] = self.new_tgt_bos
return samples
def get_dummy_batch(self, *args, **kwargs):
return self.dataset.get_dummy_batch(*args, **kwargs)
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
def ordered_indices(self):
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, 'supports_prefetch', False)
def prefetch(self, indices):
return self.dataset.prefetch(indices)
...@@ -67,8 +67,8 @@ class MultilingualTransformerModel(FairseqMultiModel): ...@@ -67,8 +67,8 @@ class MultilingualTransformerModel(FairseqMultiModel):
if not hasattr(args, 'max_target_positions'): if not hasattr(args, 'max_target_positions'):
args.max_target_positions = 1024 args.max_target_positions = 1024
src_langs = [lang_pair.split('-')[0] for lang_pair in task.lang_pairs] src_langs = [lang_pair.split('-')[0] for lang_pair in task.model_lang_pairs]
tgt_langs = [lang_pair.split('-')[1] for lang_pair in task.lang_pairs] tgt_langs = [lang_pair.split('-')[1] for lang_pair in task.model_lang_pairs]
if args.share_encoders: if args.share_encoders:
args.share_encoder_embeddings = True args.share_encoder_embeddings = True
...@@ -158,7 +158,7 @@ class MultilingualTransformerModel(FairseqMultiModel): ...@@ -158,7 +158,7 @@ class MultilingualTransformerModel(FairseqMultiModel):
shared_decoder = get_decoder(tgt_langs[0]) shared_decoder = get_decoder(tgt_langs[0])
encoders, decoders = OrderedDict(), OrderedDict() encoders, decoders = OrderedDict(), OrderedDict()
for lang_pair, src, tgt in zip(task.lang_pairs, src_langs, tgt_langs): for lang_pair, src, tgt in zip(task.model_lang_pairs, src_langs, tgt_langs):
encoders[lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(src) encoders[lang_pair] = shared_encoder if shared_encoder is not None else get_encoder(src)
decoders[lang_pair] = shared_decoder if shared_decoder is not None else get_decoder(tgt) decoders[lang_pair] = shared_decoder if shared_decoder is not None else get_decoder(tgt)
...@@ -166,7 +166,7 @@ class MultilingualTransformerModel(FairseqMultiModel): ...@@ -166,7 +166,7 @@ class MultilingualTransformerModel(FairseqMultiModel):
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
state_dict_subset = state_dict.copy() state_dict_subset = state_dict.copy()
for k, v in state_dict.items(): for k, _ in state_dict.items():
assert k.startswith('models.') assert k.startswith('models.')
lang_pair = k.split('.')[1] lang_pair = k.split('.')[1]
if lang_pair not in self.models: if lang_pair not in self.models:
......
...@@ -241,6 +241,11 @@ class FairseqTask(object): ...@@ -241,6 +241,11 @@ class FairseqTask(object):
with torch.no_grad(): with torch.no_grad():
return generator.generate(models, sample, prefix_tokens=prefix_tokens) return generator.generate(models, sample, prefix_tokens=prefix_tokens)
def update_step(self, num_updates):
"""Task level update when number of update increases. This is called after optimization step and
learning rate update of each step"""
pass
def grad_denom(self, sample_sizes, criterion): def grad_denom(self, sample_sizes, criterion):
return criterion.__class__.grad_denom(sample_sizes) return criterion.__class__.grad_denom(sample_sizes)
......
...@@ -6,24 +6,41 @@ ...@@ -6,24 +6,41 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import OrderedDict from collections import OrderedDict
import copy
import os import os
import torch import torch
from fairseq import options from fairseq import options
from fairseq.data import ( from fairseq.data import (
BacktranslationDataset,
Dictionary, Dictionary,
IndexedCachedDataset, IndexedCachedDataset,
IndexedDataset, IndexedDataset,
IndexedRawTextDataset, IndexedRawTextDataset,
LanguagePairDataset, LanguagePairDataset,
NoisingDataset,
RoundRobinZipDatasets, RoundRobinZipDatasets,
TransformEosLangPairDataset,
) )
from fairseq.models import FairseqMultiModel from fairseq.models import FairseqMultiModel
from . import FairseqTask, register_task from . import FairseqTask, register_task
def _lang_token(lang: str):
return f'__{lang}__'
def _lang_token_index(dic: Dictionary, lang: str):
"""Return language token index."""
idx = dic.index(_lang_token(lang))
assert idx != dic.unk_index, \
f'cannot find language token for lang {lang}'
return idx
@register_task('multilingual_translation') @register_task('multilingual_translation')
class MultilingualTranslationTask(FairseqTask): class MultilingualTranslationTask(FairseqTask):
"""A task for training multiple translation models simultaneously. """A task for training multiple translation models simultaneously.
...@@ -71,6 +88,12 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -71,6 +88,12 @@ class MultilingualTranslationTask(FairseqTask):
help='max number of tokens in the source sequence') help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence') help='max number of tokens in the target sequence')
parser.add_argument('--encoder-langtok', default=None, type=str, choices=['src', 'tgt'],
metavar='SRCTGT',
help='replace beginning-of-sentence in source sentence with source or target '
'language token. (src/tgt)')
parser.add_argument('--decoder-langtok', action='store_true',
help='replace beginning-of-sentence in target sentence with target language token')
# fmt: on # fmt: on
def __init__(self, args, dicts, training): def __init__(self, args, dicts, training):
...@@ -83,40 +106,84 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -83,40 +106,84 @@ class MultilingualTranslationTask(FairseqTask):
# the eval_lang_pairs class variable is provided for classes that extend # the eval_lang_pairs class variable is provided for classes that extend
# this class. # this class.
self.eval_lang_pairs = args.lang_pairs self.eval_lang_pairs = args.lang_pairs
# model_lang_pairs will be used to build encoder-decoder model pairs in
# models.build_model(). This allows multitask type of sub-class can
# build models other than the input lang_pairs
self.model_lang_pairs = copy.copy(args.lang_pairs)
self.langs = list(dicts.keys()) self.langs = list(dicts.keys())
self.training = training self.training = training
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
dicts, training = cls.prepare(args, **kwargs)
return cls(args, dicts, training)
@classmethod
def prepare(cls, args, **kargs):
args.left_pad_source = options.eval_bool(args.left_pad_source) args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target) args.left_pad_target = options.eval_bool(args.left_pad_target)
args.lang_pairs = args.lang_pairs.split(',')
sorted_langs = sorted(list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')}))
if args.source_lang is not None or args.target_lang is not None: if args.source_lang is not None or args.target_lang is not None:
if args.lang_pairs is not None:
raise ValueError(
'--source-lang/--target-lang implies generation, which is '
'incompatible with --lang-pairs'
)
training = False training = False
args.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)] args.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
else: else:
training = True training = True
args.lang_pairs = args.lang_pairs.split(',')
args.source_lang, args.target_lang = args.lang_pairs[0].split('-') args.source_lang, args.target_lang = args.lang_pairs[0].split('-')
langs = list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')})
# load dictionaries # load dictionaries
dicts = OrderedDict() dicts = OrderedDict()
for lang in langs: for lang in sorted_langs:
dicts[lang] = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(lang))) dicts[lang] = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(lang)))
if len(dicts) > 0: if len(dicts) > 0:
assert dicts[lang].pad() == dicts[langs[0]].pad() assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
assert dicts[lang].eos() == dicts[langs[0]].eos() assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
assert dicts[lang].unk() == dicts[langs[0]].unk() assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
if args.encoder_langtok is not None or args.decoder_langtok:
for lang_to_add in sorted_langs:
dicts[lang].add_symbol(_lang_token(lang_to_add))
print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang]))) print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang])))
return dicts, training
return cls(args, dicts, training) def get_encoder_langtok(self, src_lang, tgt_lang):
if self.args.encoder_langtok is None:
return self.dicts[src_lang].eos()
if self.args.encoder_langtok == 'src':
return _lang_token_index(self.dicts[src_lang], src_lang)
else:
return _lang_token_index(self.dicts[src_lang], tgt_lang)
def get_decoder_langtok(self, tgt_lang):
if not self.args.decoder_langtok:
return self.dicts[tgt_lang].eos()
return _lang_token_index(self.dicts[tgt_lang], tgt_lang)
def alter_dataset_langtok(self, lang_pair_dataset,
src_eos=None, src_lang=None, tgt_eos=None, tgt_lang=None):
if self.args.encoder_langtok is None and not self.args.decoder_langtok:
return lang_pair_dataset
new_src_eos = None
if self.args.encoder_langtok is not None and src_eos is not None \
and src_lang is not None and tgt_lang is not None:
new_src_eos = self.get_encoder_langtok(src_lang, tgt_lang)
else:
src_eos = None
new_tgt_bos = None
if self.args.decoder_langtok and tgt_eos is not None and tgt_lang is not None:
new_tgt_bos = self.get_decoder_langtok(tgt_lang)
else:
tgt_eos = None
return TransformEosLangPairDataset(
lang_pair_dataset,
src_eos=src_eos,
new_src_eos=new_src_eos,
tgt_bos=tgt_eos,
new_tgt_bos=new_tgt_bos,
)
def load_dataset(self, split, **kwargs): def load_dataset(self, split, **kwargs):
"""Load a dataset split.""" """Load a dataset split."""
...@@ -158,13 +225,18 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -158,13 +225,18 @@ class MultilingualTranslationTask(FairseqTask):
def language_pair_dataset(lang_pair): def language_pair_dataset(lang_pair):
src, tgt = lang_pair.split('-') src, tgt = lang_pair.split('-')
src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair] src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
return LanguagePairDataset( return self.alter_dataset_langtok(
src_dataset, src_dataset.sizes, self.dicts[src], LanguagePairDataset(
tgt_dataset, tgt_dataset.sizes, self.dicts[tgt], src_dataset, src_dataset.sizes, self.dicts[src],
left_pad_source=self.args.left_pad_source, tgt_dataset, tgt_dataset.sizes, self.dicts[tgt],
left_pad_target=self.args.left_pad_target, left_pad_source=self.args.left_pad_source,
max_source_positions=self.args.max_source_positions, left_pad_target=self.args.left_pad_target,
max_target_positions=self.args.max_target_positions, max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
),
src_eos=self.dicts[tgt].eos(),
src_lang=src,
tgt_lang=tgt,
) )
self.datasets[split] = RoundRobinZipDatasets( self.datasets[split] = RoundRobinZipDatasets(
...@@ -178,9 +250,18 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -178,9 +250,18 @@ class MultilingualTranslationTask(FairseqTask):
def build_dataset_for_inference(self, src_tokens, src_lengths): def build_dataset_for_inference(self, src_tokens, src_lengths):
lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang) lang_pair = "%s-%s" % (self.args.source_lang, self.args.target_lang)
return RoundRobinZipDatasets( return RoundRobinZipDatasets(
OrderedDict([ OrderedDict([(
(lang_pair, LanguagePairDataset(src_tokens, src_lengths, self.source_dictionary)) lang_pair,
]), self.alter_dataset_langtok(
LanguagePairDataset(
src_tokens, src_lengths,
self.source_dictionary
),
src_eos=self.source_dictionary.eos(),
src_lang=self.args.source_lang,
tgt_lang=self.args.target_lang,
),
)]),
eval_key=lang_pair, eval_key=lang_pair,
) )
...@@ -212,7 +293,7 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -212,7 +293,7 @@ class MultilingualTranslationTask(FairseqTask):
with torch.no_grad(): with torch.no_grad():
agg_loss, agg_sample_size, agg_logging_output = 0., 0., {} agg_loss, agg_sample_size, agg_logging_output = 0., 0., {}
for lang_pair in self.eval_lang_pairs: for lang_pair in self.eval_lang_pairs:
if sample[lang_pair] is None or len(sample[lang_pair]) == 0: if lang_pair not in sample or sample[lang_pair] is None or len(sample[lang_pair]) == 0:
continue continue
loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair]) loss, sample_size, logging_output = criterion(model.models[lang_pair], sample[lang_pair])
agg_loss += loss.data.item() agg_loss += loss.data.item()
...@@ -221,6 +302,16 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -221,6 +302,16 @@ class MultilingualTranslationTask(FairseqTask):
agg_logging_output[lang_pair] = logging_output agg_logging_output[lang_pair] = logging_output
return agg_loss, agg_sample_size, agg_logging_output return agg_loss, agg_sample_size, agg_logging_output
def inference_step(self, generator, models, sample, prefix_tokens=None):
with torch.no_grad():
return generator.generate(
models,
sample,
prefix_tokens=prefix_tokens,
bos_token=_lang_token_index(self.target_dictionary, self.args.target_lang)
if self.args.decoder_langtok else self.target_dictionary.eos(),
)
def init_logging_output(self, sample): def init_logging_output(self, sample):
return { return {
'ntokens': sum( 'ntokens': sum(
...@@ -236,13 +327,14 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -236,13 +327,14 @@ class MultilingualTranslationTask(FairseqTask):
def grad_denom(self, sample_sizes, criterion): def grad_denom(self, sample_sizes, criterion):
return criterion.__class__.grad_denom(sample_sizes) return criterion.__class__.grad_denom(sample_sizes)
def aggregate_logging_outputs(self, logging_outputs, criterion): def aggregate_logging_outputs(self, logging_outputs, criterion, logging_output_keys=None):
logging_output_keys = logging_output_keys or self.eval_lang_pairs
# aggregate logging outputs for each language pair # aggregate logging outputs for each language pair
agg_logging_outputs = { agg_logging_outputs = {
lang_pair: criterion.__class__.aggregate_logging_outputs([ key: criterion.__class__.aggregate_logging_outputs([
logging_output.get(lang_pair, {}) for logging_output in logging_outputs logging_output.get(key, {}) for logging_output in logging_outputs
]) ])
for lang_pair in self.eval_lang_pairs for key in logging_output_keys
} }
def sum_over_languages(key): def sum_over_languages(key):
...@@ -269,3 +361,13 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -269,3 +361,13 @@ class MultilingualTranslationTask(FairseqTask):
@property @property
def target_dictionary(self): def target_dictionary(self):
return self.dicts[self.args.target_lang] return self.dicts[self.args.target_lang]
def max_positions(self):
"""Return the max sentence length allowed by the task."""
if len(self.datasets.values()) == 0:
return {'%s-%s' % (self.args.source_lang, self.args.target_lang):
(self.args.max_source_positions, self.args.max_target_positions)}
return OrderedDict([
(key, (self.args.max_source_positions, self.args.max_target_positions))
for key in next(iter(self.datasets.values())).datasets.keys()
])
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import OrderedDict
import os
from fairseq.data import (
BacktranslationDataset,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
NoisingDataset,
RoundRobinZipDatasets,
)
from fairseq.models import FairseqMultiModel
from fairseq.sequence_generator import SequenceGenerator
from .multilingual_translation import MultilingualTranslationTask
from . import register_task
def _get_bt_dataset_key(lang_pair):
return "bt:" + lang_pair
def _get_denoising_dataset_key(lang_pair):
return "denoising:" + lang_pair
# ported from UnsupervisedMT
def parse_lambda_config(x):
"""
Parse the configuration of lambda coefficient (for scheduling).
x = "3" # lambda will be a constant equal to x
x = "0:1,1000:0" # lambda will start from 1 and linearly decrease
# to 0 during the first 1000 iterations
x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000
# iterations, then will linearly increase to 1 until iteration 2000
"""
split = x.split(',')
if len(split) == 1:
return float(x), None
else:
split = [s.split(':') for s in split]
assert all(len(s) == 2 for s in split)
assert all(k.isdigit() for k, _ in split)
assert all(int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1))
return float(split[0][1]), [(int(k), float(v)) for k, v in split]
@register_task('semisupervised_translation')
class SemisupervisedTranslationTask(MultilingualTranslationTask):
"""A task for training multiple translation models simultaneously.
We iterate round-robin over batches from multiple language pairs, ordered
according to the `--lang-pairs` argument.
The training loop is roughly:
for i in range(len(epoch)):
for lang_pair in args.lang_pairs:
batch = next_batch_for_lang_pair(lang_pair)
loss = criterion(model_for_lang_pair(lang_pair), batch)
loss.backward()
optimizer.step()
In practice, `next_batch_for_lang_pair` is abstracted in a FairseqDataset
(e.g., `RoundRobinZipDatasets`) and `model_for_lang_pair` is a model that
implements the `FairseqMultiModel` interface.
During inference it is required to specify a single `--source-lang` and
`--target-lang`, instead of `--lang-pairs`.
"""
@staticmethod
def add_args(parser):
"""Add task-specific arguments to the parser."""
# fmt: off
MultilingualTranslationTask.add_args(parser)
parser.add_argument('--lambda-parallel-config', default="1.0", type=str, metavar='CONFIG',
help='cross-entropy reconstruction coefficient (parallel data). '
'use fixed weight during training if set to floating point number. '
'use piecewise linear function over number of updates to schedule the '
'weight with the format: w0:step0,w1:step1,...')
parser.add_argument('--lambda-denoising-config', default="0.0", type=str, metavar='CONFIG',
help='Cross-entropy reconstruction coefficient (denoising autoencoding)'
'use fixed weight during training if set to floating point number. '
'use piecewise linear function over number of updates to schedule the '
'weight with the format: w0:step0,w1:step1,...')
parser.add_argument('--lambda-otf-bt-config', default="0.0", type=str, metavar='CONFIG',
help='cross-entropy reconstruction coefficient (on-the-fly back-translation parallel data)'
'use fixed weight during training if set to floating point number. '
'use piecewise linear function over number of updates to schedule the '
'weight with the format: w0:step0,w1:step1,...')
parser.add_argument('--bt-max-len-a', default=1.1, type=float, metavar='N',
help='generate back-translated sequences of maximum length ax + b, where x is the '
'source length')
parser.add_argument('--bt-max-len-b', default=10.0, type=float, metavar='N',
help='generate back-translated sequences of maximum length ax + b, where x is the '
'source length')
parser.add_argument('--bt-beam-size', default=1, type=int, metavar='N',
help='beam size used in beam search of online back-translation')
parser.add_argument('--max-word-shuffle-distance', default=3.0, type=float, metavar='N',
help='maximum word shuffle distance for denoising autoencoding data generation')
parser.add_argument('--word-dropout-prob', default=0.1, type=float, metavar='N',
help='word dropout probability for denoising autoencoding data generation')
parser.add_argument('--word-blanking-prob', default=0.2, type=float, metavar='N',
help='word blanking probability for denoising autoencoding data generation')
# fmt: on
def __init__(self, args, dicts, training):
super().__init__(args, dicts, training)
self.lambda_parallel, self.lambda_parallel_steps = parse_lambda_config(args.lambda_parallel_config)
self.lambda_otf_bt, self.lambda_otf_bt_steps = parse_lambda_config(args.lambda_otf_bt_config)
self.lambda_denoising, self.lambda_denoising_steps = parse_lambda_config(args.lambda_denoising_config)
if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None):
denoising_lang_pairs = [
"%s-%s" % (tgt, tgt)
for tgt in {lang_pair.split('-')[1] for lang_pair in args.lang_pairs}
]
self.model_lang_pairs += denoising_lang_pairs
self.backtranslate_datasets = {}
@classmethod
def setup_task(cls, args, **kwargs):
dicts, training = MultilingualTranslationTask.prepare(args, **kwargs)
return cls(args, dicts, training)
def load_dataset(self, split, **kwargs):
"""Load a dataset split."""
def split_exists(split, src, tgt, lang):
if src is not None:
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
else:
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, src, tgt))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedDataset.exists(filename):
return True
return False
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedDataset.exists(path):
if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None
# load parallel datasets
src_datasets, tgt_datasets = {}, {}
if (self.lambda_parallel > 0.0 or self.lambda_parallel_steps is not None or not split.startswith("train")):
for lang_pair in self.args.lang_pairs:
src, tgt = lang_pair.split('-')
if split_exists(split, src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(split, tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
else:
continue
src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src])
tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt])
print('| parallel-{} {} {} examples'.format(self.args.data, split, len(src_datasets[lang_pair])))
if len(src_datasets) == 0:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
# back translation datasets
backtranslate_datasets = {}
if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and split.startswith("train"):
for lang_pair in self.args.lang_pairs:
src, tgt = lang_pair.split('-')
if not split_exists(split, tgt, None, tgt):
raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, self.args.data))
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, tgt, tgt))
dataset = indexed_dataset(filename, self.dicts[tgt])
lang_pair_dataset_tgt = LanguagePairDataset(
dataset,
dataset.sizes,
self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
)
lang_pair_dataset = LanguagePairDataset(
dataset,
dataset.sizes,
src_dict=self.dicts[src],
tgt=dataset,
tgt_sizes=dataset.sizes,
tgt_dict=self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
)
backtranslate_datasets[lang_pair] = BacktranslationDataset(
tgt_dataset=self.alter_dataset_langtok(
lang_pair_dataset_tgt,
src_eos=self.dicts[tgt].eos(),
src_lang=tgt,
tgt_lang=src,
),
src_dict=self.dicts[src], tgt_dict=self.dicts[tgt],
output_collater=self.alter_dataset_langtok(
lang_pair_dataset=lang_pair_dataset,
src_eos=self.dicts[src].eos(),
src_lang=src,
tgt_eos=self.dicts[tgt].eos(),
tgt_lang=tgt,
).collater,
)
print('| backtranslate-{}: {} {} {} examples'.format(
tgt, self.args.data, split, len(backtranslate_datasets[lang_pair]),
))
self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair]
# denoising autoencoder
noising_datasets = {}
if (self.lambda_denoising > 0.0 or self.lambda_denoising_steps is not None) and split.startswith("train"):
for lang_pair in self.args.lang_pairs:
_, tgt = lang_pair.split('-')
if not split_exists(split, tgt, None, tgt):
continue
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, tgt, tgt))
tgt_dataset1 = indexed_dataset(filename, self.dicts[tgt])
tgt_dataset2 = indexed_dataset(filename, self.dicts[tgt])
noising_dataset = NoisingDataset(
tgt_dataset1,
self.dicts[tgt],
seed=1,
max_word_shuffle_distance=self.args.max_word_shuffle_distance,
word_dropout_prob=self.args.word_dropout_prob,
word_blanking_prob=self.args.word_blanking_prob,
)
noising_datasets[lang_pair] = self.alter_dataset_langtok(
LanguagePairDataset(
noising_dataset,
tgt_dataset1.sizes,
self.dicts[tgt],
tgt_dataset2,
tgt_dataset2.sizes,
self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
),
src_eos=self.dicts[tgt].eos(),
src_lang=tgt,
tgt_eos=self.dicts[tgt].eos(),
tgt_lang=tgt,
)
print('| denoising-{}: {} {} {} examples'.format(
tgt, self.args.data, split, len(noising_datasets[lang_pair]),
))
def language_pair_dataset(lang_pair):
src, tgt = lang_pair.split('-')
src_dataset, tgt_dataset = src_datasets[lang_pair], tgt_datasets[lang_pair]
return self.alter_dataset_langtok(
LanguagePairDataset(
src_dataset, src_dataset.sizes, self.dicts[src],
tgt_dataset, tgt_dataset.sizes, self.dicts[tgt],
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
max_target_positions=self.args.max_target_positions,
),
self.dicts[src].eos(),
src,
self.dicts[tgt].eos(),
tgt,
)
self.datasets[split] = RoundRobinZipDatasets(
OrderedDict([
(lang_pair, language_pair_dataset(lang_pair))
for lang_pair in src_datasets.keys()
] + [
(_get_bt_dataset_key(lang_pair), dataset)
for lang_pair, dataset in backtranslate_datasets.items()
] + [
(_get_denoising_dataset_key(lang_pair), dataset)
for lang_pair, dataset in noising_datasets.items()
]),
eval_key=None if self.training else "%s-%s" % (self.args.source_lang, self.args.target_lang),
)
def build_model(self, args):
from fairseq import models
model = models.build_model(args, self)
if not isinstance(model, FairseqMultiModel):
raise ValueError('SemisupervisedTranslationTask requires a FairseqMultiModel architecture')
# create SequenceGenerator for each model that has backtranslation dependency on it
self.sequence_generators = {}
if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and self.training:
for lang_pair in self.args.lang_pairs:
src, tgt = lang_pair.split('-')
key = '{}-{}'.format(tgt, src)
self.sequence_generators[key] = SequenceGenerator(
tgt_dict=self.dicts[src],
beam_size=args.bt_beam_size,
max_len_a=args.bt_max_len_a,
max_len_b=args.bt_max_len_b,
)
decoder_lang_tok_idx = self.get_decoder_langtok(src)
def backtranslate_fn(
sample, model=model.models[key],
bos_token=decoder_lang_tok_idx,
sequence_generator=self.sequence_generators[key],
):
return sequence_generator.generate(
[model],
sample,
bos_token=bos_token,
)
self.backtranslate_datasets[lang_pair].set_backtranslation_fn(backtranslate_fn)
return model
def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
model.train()
agg_loss, agg_sample_size, agg_logging_output = 0., 0., {}
def forward_backward(model, samples, logging_output_key, weight):
nonlocal agg_loss, agg_sample_size, agg_logging_output
if samples is None or len(samples) == 0:
return
loss, sample_size, logging_output = criterion(model, samples)
if ignore_grad:
loss *= 0
else:
loss *= weight
optimizer.backward(loss)
agg_loss += loss.detach().item()
# TODO make summing of the sample sizes configurable
agg_sample_size += sample_size
agg_logging_output[logging_output_key] = logging_output
if self.lambda_parallel > 0.0:
for lang_pair in self.args.lang_pairs:
forward_backward(model.models[lang_pair], sample[lang_pair], self.lambda_parallel)
if self.lambda_otf_bt > 0.0:
for lang_pair in self.args.lang_pairs:
sample_key = _get_bt_dataset_key(lang_pair)
forward_backward(model.models[lang_pair], sample[sample_key], sample_key, self.lambda_otf_bt)
if self.lambda_denoising > 0.0:
for lang_pair in self.args.lang_pairs:
_, tgt = lang_pair.split('-')
sample_key = _get_denoising_dataset_key(lang_pair)
forward_backward(model.models[f'{tgt}-{tgt}'], sample[sample_key], sample_key, self.lambda_denoising)
return agg_loss, agg_sample_size, agg_logging_output
def update_step(self, num_updates):
def lambda_step_func(config, n_iter):
"""
Update a lambda value according to its schedule configuration.
"""
ranges = [i for i in range(len(config) - 1) if config[i][0] <= n_iter < config[i + 1][0]]
if len(ranges) == 0:
assert n_iter >= config[-1][0]
return config[-1][1]
assert len(ranges) == 1
i = ranges[0]
x_a, y_a = config[i]
x_b, y_b = config[i + 1]
return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a)
if self.lambda_parallel_steps is not None:
self.lambda_parallel = lambda_step_func(self.lambda_parallel_steps, num_updates)
if self.lambda_denoising_steps is not None:
self.lambda_denoising = lambda_step_func(self.lambda_denoising_steps, num_updates)
if self.lambda_otf_bt_steps is not None:
self.lambda_otf_bt = lambda_step_func(self.lambda_otf_bt_steps, num_updates)
def aggregate_logging_outputs(self, logging_outputs, criterion):
# aggregate logging outputs for each language pair
logging_output_keys = {
key
for logging_output in logging_outputs
for key in logging_output
}
lang_pair_keys = set(self.args.lang_pairs + [
_get_bt_dataset_key(lang_pair)
for lang_pair in self.args.lang_pairs
] + [
_get_denoising_dataset_key(lang_pair)
for lang_pair in self.args.lang_pairs
])
logging_output_keys = logging_output_keys.intersection(lang_pair_keys)
return super().aggregate_logging_outputs(logging_outputs, criterion, logging_output_keys)
...@@ -259,6 +259,9 @@ class Trainer(object): ...@@ -259,6 +259,9 @@ class Trainer(object):
# update learning rate # update learning rate
self.lr_scheduler.step_update(self._num_updates) self.lr_scheduler.step_update(self._num_updates)
# task specific update per step
self.task.update_step(self._num_updates)
# update meters # update meters
ntokens = logging_output.get('ntokens', 0) ntokens = logging_output.get('ntokens', 0)
nsentences = logging_output.get('nsentences', 0) nsentences = logging_output.get('nsentences', 0)
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
import copy
import importlib.util import importlib.util
import logging import logging
import os import os
...@@ -417,6 +418,15 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): ...@@ -417,6 +418,15 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
def resolve_max_positions(*args): def resolve_max_positions(*args):
"""Resolve max position constraints from multiple sources.""" """Resolve max position constraints from multiple sources."""
def map_value_update(d1, d2):
updated_value = copy.deepcopy(d1)
for key in d2:
if key not in updated_value:
updated_value[key] = d2[key]
else:
updated_value[key] = min(d1[key], d2[key])
return updated_value
def nullsafe_min(l): def nullsafe_min(l):
minim = None minim = None
for item in l: for item in l:
...@@ -433,10 +443,13 @@ def resolve_max_positions(*args): ...@@ -433,10 +443,13 @@ def resolve_max_positions(*args):
elif arg is not None: elif arg is not None:
if isinstance(arg, float) or isinstance(arg, int): if isinstance(arg, float) or isinstance(arg, int):
max_positions = min(max_positions, arg) max_positions = min(max_positions, arg)
elif isinstance(arg, dict):
max_positions = map_value_update(max_positions, arg)
else: else:
max_positions = tuple( max_positions = tuple(
map(nullsafe_min, zip(max_positions, arg)) map(nullsafe_min, zip(max_positions, arg))
) )
return max_positions return max_positions
......
...@@ -59,8 +59,9 @@ class TestBacktranslationDataset(unittest.TestCase): ...@@ -59,8 +59,9 @@ class TestBacktranslationDataset(unittest.TestCase):
# remove eos from the input src # remove eos from the input src
remove_eos_from_src=remove_eos_from_input_src, remove_eos_from_src=remove_eos_from_input_src,
), ),
src_dict=self.tgt_dict,
backtranslation_fn=( backtranslation_fn=(
lambda net_input: generator.generate([self.model], {'net_input': net_input}) lambda sample: generator.generate([self.model], sample)
), ),
output_collater=TransformEosDataset( output_collater=TransformEosDataset(
dataset=tgt_dataset, dataset=tgt_dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment