"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "1b0bd0e32cef4d4c52abf599dd25583ca7859997"
Commit 3c19878f authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Refactor BacktranslationDataset to be more reusable (#354)

Summary:
- generalize AppendEosDataset -> TransformEosDataset
- remove EOS logic from BacktranslationDataset (use TransformEosDataset instead)
- BacktranslationDataset takes a backtranslation_fn instead of building the SequenceGenerator itself
Pull Request resolved: https://github.com/pytorch/fairseq/pull/354

Reviewed By: liezl200

Differential Revision: D12970233

Pulled By: myleott

fbshipit-source-id: d5c5b0e0a75eca1bd3a50382ac24621f35c32f36
parent a442244d
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
from .dictionary import Dictionary, TruncatedDictionary from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset from .fairseq_dataset import FairseqDataset
from .append_eos_dataset import AppendEosDataset
from .backtranslation_dataset import BacktranslationDataset from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
...@@ -15,6 +14,7 @@ from .language_pair_dataset import LanguagePairDataset ...@@ -15,6 +14,7 @@ from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets from .round_robin_zip_datasets import RoundRobinZipDatasets
from .token_block_dataset import TokenBlockDataset from .token_block_dataset import TokenBlockDataset
from .transform_eos_dataset import TransformEosDataset
from .iterators import ( from .iterators import (
CountingIterator, CountingIterator,
...@@ -24,7 +24,6 @@ from .iterators import ( ...@@ -24,7 +24,6 @@ from .iterators import (
) )
__all__ = [ __all__ = [
'AppendEosDataset',
'BacktranslationDataset', 'BacktranslationDataset',
'ConcatDataset', 'ConcatDataset',
'CountingIterator', 'CountingIterator',
...@@ -41,4 +40,5 @@ __all__ = [ ...@@ -41,4 +40,5 @@ __all__ = [
'RoundRobinZipDatasets', 'RoundRobinZipDatasets',
'ShardedIterator', 'ShardedIterator',
'TokenBlockDataset', 'TokenBlockDataset',
'TransformEosDataset',
] ]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
class AppendEosDataset(torch.utils.data.Dataset):
"""A dataset wrapper that appends EOS to each item."""
def __init__(self, dataset, eos):
self.dataset = dataset
self.eos = eos
def __getitem__(self, index):
item = torch.cat([self.dataset[index], torch.LongTensor([self.eos])])
return item
def __len__(self):
return len(self.dataset)
...@@ -7,182 +7,162 @@ ...@@ -7,182 +7,162 @@
import torch import torch
from fairseq import sequence_generator
from fairseq import utils from fairseq import utils
from . import FairseqDataset, language_pair_dataset from . import FairseqDataset
def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
"""Backtranslate a list of samples.
Given an input (*samples*) of the form:
[{'id': 1, 'source': 'hallo welt'}]
this will return:
[{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
Args:
samples (List[dict]): samples to backtranslate. Individual samples are
expected to have a 'source' key, which will become the 'target'
after backtranslation.
collate_fn (callable): function to collate samples into a mini-batch
generate_fn (callable): function to generate backtranslations
cuda (bool): use GPU for generation (default: ``True``)
Returns:
List[dict]: an updated list of samples with a backtranslated source
"""
collated_samples = collate_fn(samples)
s = utils.move_to_cuda(collated_samples) if cuda else collated_samples
generated_sources = generate_fn(s['net_input'])
def update_sample(sample, generated_source):
sample['target'] = sample['source'] # the original source becomes the target
sample['source'] = generated_source
return sample
# Go through each tgt sentence in batch and its corresponding best
# generated hypothesis and create a backtranslation data pair
# {id: id, source: generated backtranslation, target: original tgt}
return [
update_sample(
sample=input_sample,
generated_source=hypos[0]['tokens'].cpu(), # highest scoring hypo is first
)
for input_sample, hypos in zip(samples, generated_sources)
]
class BacktranslationDataset(FairseqDataset): class BacktranslationDataset(FairseqDataset):
def __init__( def __init__(
self, self,
tgt_dataset, tgt_dataset,
tgt_dict, backtranslation_fn,
backtranslation_model,
max_len_a, max_len_a,
max_len_b, max_len_b,
remove_eos_at_src=False, output_collater=None,
generator_class=sequence_generator.SequenceGenerator,
cuda=True, cuda=True,
**kwargs **kwargs
): ):
""" """
Sets up a backtranslation dataset which takes a tgt batch, generates Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation_model, and returns the a src using a tgt-src backtranslation function (*backtranslation_fn*),
corresponding {generated src, input tgt} batch and returns the corresponding `{generated src, input tgt}` batch.
Args: Args:
tgt_dataset: dataset which will be used to build self.tgt_dataset -- tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
a LanguagePairDataset with tgt dataset as the source dataset and backtranslated. Only the source side of this dataset will be
None as the target dataset. Should NOT have padding so that used. After backtranslation, the source sentences in this
src_lengths are accurately calculated by language_pair_dataset dataset will be returned as the targets.
collate function. backtranslation_fn (callable): function to call to generate
We use language_pair_dataset here to encapsulate the tgt_dataset backtranslations. This is typically the `generate` method of a
so we can re-use the LanguagePairDataset collater to format the :class:`~fairseq.sequence_generator.SequenceGenerator` object.
batches in the structure that SequenceGenerator expects. max_len_a, max_len_b (int, int): will be used to compute
Note: tgt_dataset samples should not have EOS at end if `maxlen = max_len_a * src_len + max_len_b`, which will be
the tgt-src model expects an input without EOS. This dataset passed into *backtranslation_fn*.
does not enforce this, you should enforce that in preprocessing. output_collater (callable, optional): function to call on the
tgt_dict: tgt dictionary (typically a joint src/tgt BPE dictionary) backtranslated samples to create the final batch (default:
backtranslation_model: tgt-src model to use in the SequenceGenerator ``tgt_dataset.collater``)
to generate backtranslations from tgt batches
max_len_a, max_len_b: args passed into generate() function of
the backtranslation SequenceGenerator
remove_eos_at_src: whether we should remove EOS from the source
dialect text generated by the backtranslation model.
generator_class: which SequenceGenerator class to use for
backtranslation. Output of generate() should be the same format
as fairseq's SequenceGenerator
cuda: use GPU for generation cuda: use GPU for generation
kwargs: generation args to init the backtranslation
SequenceGenerator
""" """
self.tgt_dataset = language_pair_dataset.LanguagePairDataset( self.tgt_dataset = tgt_dataset
src=tgt_dataset, self.backtranslation_fn = backtranslation_fn
src_sizes=tgt_dataset.sizes,
src_dict=tgt_dict,
tgt=None,
tgt_sizes=None,
tgt_dict=None,
)
self.max_len_a = max_len_a self.max_len_a = max_len_a
self.max_len_b = max_len_b self.max_len_b = max_len_b
self.remove_eos_at_src = remove_eos_at_src self.output_collater = output_collater if output_collater is not None \
else tgt_dataset.collater
self.backtranslation_generator = generator_class(
models=[backtranslation_model],
tgt_dict=tgt_dict,
**kwargs,
)
self.cuda = cuda if torch.cuda.is_available() else False self.cuda = cuda if torch.cuda.is_available() else False
if self.cuda:
self.backtranslation_generator.cuda()
def __getitem__(self, index): def __getitem__(self, index):
""" """
Returns a single sample. Multiple samples are fed to the collater to Returns a single sample from *tgt_dataset*. Note that backtranslation is
create a backtranslation batch. Note you should always use collate_fn not applied in this step; use :func:`collater` instead to backtranslate
BacktranslationDataset.collater() below if given the option to a batch of samples.
specify which collate_fn to use (e.g. in a dataloader which uses this
BacktranslationDataset -- see corresponding unittest for an example).
""" """
return self.tgt_dataset[index] return self.tgt_dataset[index]
def __len__(self): def __len__(self):
"""
The length of the backtranslation dataset is the length of tgt.
"""
return len(self.tgt_dataset) return len(self.tgt_dataset)
def collater(self, samples): def collater(self, samples):
""" """Merge and backtranslate a list of samples to form a mini-batch.
Using the samples from the tgt dataset, load a collated tgt sample to
feed to the backtranslation model. Then take the generated translation Using the samples from *tgt_dataset*, load a collated target sample to
with best score as the source and the orignal net input as the target. feed to the backtranslation model. Then take the backtranslation with
""" the best score as the source and the original input as the target.
collated_tgt_only_sample = self.tgt_dataset.collater(samples=samples)
backtranslation_hypos = self._generate_hypotheses( Note: we expect *tgt_dataset* to provide a function `collater()` that
sample=collated_tgt_only_sample will collate samples into the format expected by *backtranslation_fn*.
) After backtranslation, we will feed the new list of samples (i.e., the
`(backtranslated source, original source)` pairs) to *output_collater*
and return the result.
Args:
samples (List[dict]): samples to backtranslate and collate
# Go through each tgt sentence in batch and its corresponding best Returns:
# generated hypothesis and create a backtranslation data pair dict: a mini-batch with keys coming from *output_collater*
# {id: id, source: generated backtranslation, target: original tgt} """
generated_samples = [] samples = backtranslate_samples(
for input_sample, hypos in zip(samples, backtranslation_hypos): samples=samples,
original_tgt = input_sample["source"].cpu() collate_fn=self.tgt_dataset.collater,
generated_source = hypos[0]["tokens"].cpu() # first hypo is best hypo generate_fn=(
lambda net_input: self.backtranslation_fn(
# Append EOS to the tgt sentence if it does not have an EOS net_input,
# This is the case if the samples in monolingual tgt_dataset don't maxlen=int(
# have an EOS appended to the end of each sentence. self.max_len_a * net_input['src_tokens'].size(1) + self.max_len_b
eos = self.tgt_dataset.src_dict.eos() ),
if original_tgt[-1] != eos: )
original_tgt = torch.cat([original_tgt, torch.LongTensor([eos])]) ),
cuda=self.cuda,
# The generated source dialect backtranslation will have an EOS.
# If we want our parallel data source to not have an EOS, we will
# have to remove it.
if self.remove_eos_at_src:
assert generated_source[-1] == eos, (
"Expected generated backtranslation to have eos (id: "
"{eos}) at end, but instead found token id "
"{generated_source[-1]} at end."
).format(eos=eos, generated_source=generated_source)
generated_source = generated_source[:-1]
generated_samples.append(
{
"id": input_sample["id"],
"source": generated_source,
"target": original_tgt,
}
)
return language_pair_dataset.collate(
samples=generated_samples,
pad_idx=self.tgt_dataset.src_dict.pad(),
eos_idx=self.tgt_dataset.src_dict.eos(),
) )
return self.output_collater(samples)
def get_dummy_batch(self, num_tokens, max_positions): def get_dummy_batch(self, num_tokens, max_positions):
""" Just use the tgt dataset get_dummy_batch """ """Just use the tgt dataset get_dummy_batch"""
return self.tgt_dataset.get_dummy_batch(num_tokens, max_positions) return self.tgt_dataset.get_dummy_batch(num_tokens, max_positions)
def num_tokens(self, index): def num_tokens(self, index):
""" Just use the tgt dataset num_tokens """ """Just use the tgt dataset num_tokens"""
return self.tgt_dataset.num_tokens(index) return self.tgt_dataset.num_tokens(index)
def ordered_indices(self): def ordered_indices(self):
""" Just use the tgt dataset ordered_indices """ """Just use the tgt dataset ordered_indices"""
return self.tgt_dataset.ordered_indices() return self.tgt_dataset.ordered_indices()
def valid_size(self, index, max_positions): def valid_size(self, index, max_positions):
""" Just use the tgt dataset size """ """Just use the tgt dataset size"""
return self.tgt_dataset.valid_size(index, max_positions) return self.tgt_dataset.valid_size(index, max_positions)
def _generate_hypotheses(self, sample):
"""
Generates hypotheses from a LanguagePairDataset collated / batched
sample. Note in this case, sample["target"] is None, and
sample["net_input"]["src_tokens"] is really in tgt language.
"""
s = utils.move_to_cuda(sample) if self.cuda else sample
input = s["net_input"]
srclen = input["src_tokens"].size(1)
hypos = self.backtranslation_generator.generate(
input,
maxlen=int(
self.max_len_a * srclen + self.max_len_b
),
)
return hypos
def size(self, index): def size(self, index):
"""Return an example's size as a float or tuple. This value is used when """Return an example's size as a float or tuple. This value is used
filtering a dataset with ``--max-positions``. when filtering a dataset with ``--max-positions``.
Here, we return src dataset size as tgt dataset size as an approximation.
We do not know src size until we backtranslate and generate src sentences. Note: we use *tgt_dataset* to approximate the length of the source
sentence, since we do not know the actual length until after
backtranslation.
""" """
return (self.tgt_dataset.size(index)[0], self.tgt_dataset.size(index)[0]) tgt_size = self.tgt_dataset.size(index)[0]
return (tgt_size, tgt_size)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import bisect import bisect
import numpy as np import numpy as np
......
...@@ -55,13 +55,13 @@ def collate( ...@@ -55,13 +55,13 @@ def collate(
batch = { batch = {
'id': id, 'id': id,
'nsentences': len(samples),
'ntokens': ntokens, 'ntokens': ntokens,
'net_input': { 'net_input': {
'src_tokens': src_tokens, 'src_tokens': src_tokens,
'src_lengths': src_lengths, 'src_lengths': src_lengths,
}, },
'target': target, 'target': target,
'nsentences': samples[0]['source'].size(0),
} }
if prev_output_tokens is not None: if prev_output_tokens is not None:
batch['net_input']['prev_output_tokens'] = prev_output_tokens batch['net_input']['prev_output_tokens'] = prev_output_tokens
......
...@@ -33,6 +33,7 @@ def collate(samples, pad_idx, eos_idx): ...@@ -33,6 +33,7 @@ def collate(samples, pad_idx, eos_idx):
return { return {
'id': torch.LongTensor([s['id'] for s in samples]), 'id': torch.LongTensor([s['id'] for s in samples]),
'nsentences': len(samples),
'ntokens': sum(len(s['source']) for s in samples), 'ntokens': sum(len(s['source']) for s in samples),
'net_input': { 'net_input': {
'src_tokens': merge('source'), 'src_tokens': merge('source'),
...@@ -41,7 +42,6 @@ def collate(samples, pad_idx, eos_idx): ...@@ -41,7 +42,6 @@ def collate(samples, pad_idx, eos_idx):
]), ]),
}, },
'target': merge('target', is_target_list), 'target': merge('target', is_target_list),
'nsentences': samples[0]['source'].size(0),
} }
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from . import FairseqDataset
class TransformEosDataset(FairseqDataset):
"""A dataset wrapper that appends/prepends/strips EOS.
Note that the transformation is applied in :func:`collater`.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to wrap
eos (int): index of the end-of-sentence symbol
append_eos_to_src (bool, optional): append EOS to the end of src
remove_eos_from_src (bool, optional): remove EOS from the end of src
append_eos_to_tgt (bool, optional): append EOS to the end of tgt
remove_eos_from_tgt (bool, optional): remove EOS from the end of tgt
"""
def __init__(
self,
dataset,
eos,
append_eos_to_src=False,
remove_eos_from_src=False,
append_eos_to_tgt=False,
remove_eos_from_tgt=False,
):
if not isinstance(dataset, FairseqDataset):
raise ValueError('dataset must be an instance of FairseqDataset')
if append_eos_to_src and remove_eos_from_src:
raise ValueError('cannot combine append_eos_to_src and remove_eos_from_src')
if append_eos_to_tgt and remove_eos_from_tgt:
raise ValueError('cannot combine append_eos_to_tgt and remove_eos_from_tgt')
self.dataset = dataset
self.eos = torch.LongTensor([eos])
self.append_eos_to_src = append_eos_to_src
self.remove_eos_from_src = remove_eos_from_src
self.append_eos_to_tgt = append_eos_to_tgt
self.remove_eos_from_tgt = remove_eos_from_tgt
# precompute how we should adjust the reported sizes
self._src_delta = 0
self._src_delta += 1 if append_eos_to_src else 0
self._src_delta -= 1 if remove_eos_from_src else 0
self._tgt_delta = 0
self._tgt_delta += 1 if append_eos_to_tgt else 0
self._tgt_delta -= 1 if remove_eos_from_tgt else 0
self._checked_src = False
self._checked_tgt = False
def _check_src(self, src, expect_eos):
if not self._checked_src:
assert (src[-1] == self.eos[0]) == expect_eos
self._checked_src = True
def _check_tgt(self, tgt, expect_eos):
if not self._checked_tgt:
assert (tgt[-1] == self.eos[0]) == expect_eos
self._checked_tgt = True
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
def transform(item):
if self.append_eos_to_src:
self._check_src(item['source'], expect_eos=False)
item['source'] = torch.cat([item['source'], self.eos])
if self.remove_eos_from_src:
self._check_src(item['source'], expect_eos=True)
item['source'] = item['source'][:-1]
if self.append_eos_to_tgt:
self._check_tgt(item['target'], expect_eos=False)
item['target'] = torch.cat([item['target'], self.eos])
if self.remove_eos_from_tgt:
self._check_tgt(item['target'], expect_eos=True)
item['target'] = item['target'][:-1]
return item
samples = list(map(transform, samples))
return self.dataset.collater(samples)
def get_dummy_batch(self, *args, **kwargs):
return self.dataset.get_dummy_batch(*args, **kwargs)
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
src_len, tgt_len = self.dataset.size(index)
return (src_len + self._src_delta, tgt_len + self._tgt_delta)
def ordered_indices(self):
# NOTE: we assume that the ordering does not change based on the
# addition or removal of eos
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return self.dataset.supports_prefetch()
def prefetch(self, indices):
return self.dataset.prefetch(indices)
...@@ -25,14 +25,15 @@ class LanguageModelingTask(FairseqTask): ...@@ -25,14 +25,15 @@ class LanguageModelingTask(FairseqTask):
Train a language model. Train a language model.
Args: Args:
dictionary (Dictionary): the dictionary for the input of the language model dictionary (~fairseq.data.Dictionary): the dictionary for the input of
the language model
output_dictionary (Dictionary): the dictionary for the output of the language model. output_dictionary (~fairseq.data.Dictionary): the dictionary for the
In most cases it will be the same as dictionary, but could possibly be a more limited output of the language model. In most cases it will be the same as
version of the dictionary (if --output-dictionary-size is used). *dictionary*, but could possibly be a more limited version of the
dictionary (if ``--output-dictionary-size`` is used).
targets (List[str]): list of the target types that the language model should predict. targets (List[str]): list of the target types that the language model
Can be one of "self", "future", and "past". Defaults to "future". should predict. Can be one of "self", "future", and "past".
Defaults to "future".
.. note:: .. note::
......
...@@ -7,13 +7,20 @@ ...@@ -7,13 +7,20 @@
import unittest import unittest
import tests.utils as test_utils
import torch import torch
from fairseq.data.backtranslation_dataset import BacktranslationDataset
from fairseq import sequence_generator from fairseq.data import (
BacktranslationDataset,
LanguagePairDataset,
TransformEosDataset,
)
from fairseq.sequence_generator import SequenceGenerator
import tests.utils as test_utils
class TestBacktranslationDataset(unittest.TestCase): class TestBacktranslationDataset(unittest.TestCase):
def setUp(self): def setUp(self):
self.tgt_dict, self.w1, self.w2, self.src_tokens, self.src_lengths, self.model = ( self.tgt_dict, self.w1, self.w2, self.src_tokens, self.src_lengths, self.model = (
test_utils.sequence_generator_setup() test_utils.sequence_generator_setup()
...@@ -22,22 +29,49 @@ class TestBacktranslationDataset(unittest.TestCase): ...@@ -22,22 +29,49 @@ class TestBacktranslationDataset(unittest.TestCase):
dummy_src_samples = self.src_tokens dummy_src_samples = self.src_tokens
self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples) self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)
self.cuda = torch.cuda.is_available()
def _backtranslation_dataset_helper(self, remove_eos_at_src): def _backtranslation_dataset_helper(
""" self, remove_eos_from_input_src, remove_eos_from_output_src,
SequenceGenerator kwargs are same as defaults from fairseq/options.py ):
""" tgt_dataset = LanguagePairDataset(
backtranslation_dataset = BacktranslationDataset( src=self.tgt_dataset,
tgt_dataset=self.tgt_dataset, src_sizes=self.tgt_dataset.sizes,
src_dict=self.tgt_dict,
tgt=None,
tgt_sizes=None,
tgt_dict=None,
)
generator = SequenceGenerator(
models=[self.model],
tgt_dict=self.tgt_dict, tgt_dict=self.tgt_dict,
backtranslation_model=self.model,
max_len_a=0,
max_len_b=200,
beam_size=2, beam_size=2,
unk_penalty=0, unk_penalty=0,
sampling=False, sampling=False,
remove_eos_at_src=remove_eos_at_src, )
generator_class=sequence_generator.SequenceGenerator, if self.cuda:
generator.cuda()
backtranslation_dataset = BacktranslationDataset(
tgt_dataset=TransformEosDataset(
dataset=tgt_dataset,
eos=self.tgt_dict.eos(),
# remove eos from the input src
remove_eos_from_src=remove_eos_from_input_src,
),
backtranslation_fn=generator.generate,
max_len_a=0,
max_len_b=200,
output_collater=TransformEosDataset(
dataset=tgt_dataset,
eos=self.tgt_dict.eos(),
# if we remove eos from the input src, then we need to add it
# back to the output tgt
append_eos_to_tgt=remove_eos_from_input_src,
remove_eos_from_src=remove_eos_from_output_src,
).collater,
cuda=self.cuda,
) )
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
backtranslation_dataset, backtranslation_dataset,
...@@ -51,7 +85,7 @@ class TestBacktranslationDataset(unittest.TestCase): ...@@ -51,7 +85,7 @@ class TestBacktranslationDataset(unittest.TestCase):
# Note that we sort by src_lengths and add left padding, so actually # Note that we sort by src_lengths and add left padding, so actually
# ids will look like: [1, 0] # ids will look like: [1, 0]
expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]]) expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
if remove_eos_at_src: if remove_eos_from_output_src:
expected_src = expected_src[:, :-1] expected_src = expected_src[:, :-1]
expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]]) expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
generated_src = backtranslation_batch_result["net_input"]["src_tokens"] generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
...@@ -60,11 +94,20 @@ class TestBacktranslationDataset(unittest.TestCase): ...@@ -60,11 +94,20 @@ class TestBacktranslationDataset(unittest.TestCase):
self.assertTensorEqual(expected_src, generated_src) self.assertTensorEqual(expected_src, generated_src)
self.assertTensorEqual(expected_tgt, tgt_tokens) self.assertTensorEqual(expected_tgt, tgt_tokens)
def test_backtranslation_dataset_no_eos_at_src(self): def test_backtranslation_dataset_no_eos_in_output_src(self):
self._backtranslation_dataset_helper(remove_eos_at_src=True) self._backtranslation_dataset_helper(
remove_eos_from_input_src=False, remove_eos_from_output_src=True,
)
def test_backtranslation_dataset_with_eos_in_output_src(self):
self._backtranslation_dataset_helper(
remove_eos_from_input_src=False, remove_eos_from_output_src=False,
)
def test_backtranslation_dataset_with_eos_at_src(self): def test_backtranslation_dataset_no_eos_in_input_src(self):
self._backtranslation_dataset_helper(remove_eos_at_src=False) self._backtranslation_dataset_helper(
remove_eos_from_input_src=True, remove_eos_from_output_src=False,
)
def assertTensorEqual(self, t1, t2): def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch") self.assertEqual(t1.size(), t2.size(), "size mismatch")
......
...@@ -12,9 +12,9 @@ import tests.utils as test_utils ...@@ -12,9 +12,9 @@ import tests.utils as test_utils
import torch import torch
from fairseq import utils from fairseq import utils
from fairseq.data import ( from fairseq.data import (
AppendEosDataset,
Dictionary, Dictionary,
LanguagePairDataset, LanguagePairDataset,
TransformEosDataset,
data_utils, data_utils,
noising, noising,
) )
...@@ -410,15 +410,14 @@ class TestDataNoising(unittest.TestCase): ...@@ -410,15 +410,14 @@ class TestDataNoising(unittest.TestCase):
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def _get_noising_dataset_batch( def _get_noising_dataset_batch(
self, src_tokens_no_pad, src_dict, use_append_eos_dataset=False self, src_tokens_no_pad, src_dict, append_eos_to_tgt=False,
): ):
""" """
Constructs a NoisingDataset and the corresponding Constructs a NoisingDataset and the corresponding
LanguagePairDataset(NoisingDataset(src), src). If we set ``LanguagePairDataset(NoisingDataset(src), src)``. If
use_append_eos_dataset to True, wrap the source dataset in *append_eos_to_tgt* is True, wrap the source dataset in
AppendEosDataset to append EOS to the clean source when using it as the :class:`TransformEosDataset` to append EOS to the clean source when
target. In practice, we should use AppendEosDataset because our models using it as the target.
usually have source without EOS but target with EOS.
""" """
src_dataset = test_utils.TestDataset(data=src_tokens_no_pad) src_dataset = test_utils.TestDataset(data=src_tokens_no_pad)
...@@ -432,11 +431,13 @@ class TestDataNoising(unittest.TestCase): ...@@ -432,11 +431,13 @@ class TestDataNoising(unittest.TestCase):
noising_class=noising.UnsupervisedMTNoising, noising_class=noising.UnsupervisedMTNoising,
) )
tgt = src_dataset tgt = src_dataset
if use_append_eos_dataset:
tgt = AppendEosDataset(src_dataset, src_dict.eos())
language_pair_dataset = LanguagePairDataset( language_pair_dataset = LanguagePairDataset(
src=noising_dataset, tgt=tgt, src_sizes=None, src_dict=src_dict src=noising_dataset, tgt=tgt, src_sizes=None, src_dict=src_dict
) )
language_pair_dataset = TransformEosDataset(
language_pair_dataset, src_dict.eos(),
append_eos_to_tgt=append_eos_to_tgt,
)
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset=language_pair_dataset, dataset=language_pair_dataset,
...@@ -481,8 +482,7 @@ class TestDataNoising(unittest.TestCase): ...@@ -481,8 +482,7 @@ class TestDataNoising(unittest.TestCase):
def test_noising_dataset_without_eos(self): def test_noising_dataset_without_eos(self):
""" """
Similar to test noising dataset with eos except that we have to set Similar to test noising dataset with eos except that we have to set
use_append_eos_dataset=True so that we wrap the source dataset in the *append_eos_to_tgt* to ``True``.
AppendEosDataset when using it as the target in LanguagePairDataset.
""" """
src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker( src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker(
...@@ -499,7 +499,7 @@ class TestDataNoising(unittest.TestCase): ...@@ -499,7 +499,7 @@ class TestDataNoising(unittest.TestCase):
denoising_batch_result = self._get_noising_dataset_batch( denoising_batch_result = self._get_noising_dataset_batch(
src_tokens_no_pad=src_tokens_no_pad, src_tokens_no_pad=src_tokens_no_pad,
src_dict=src_dict, src_dict=src_dict,
use_append_eos_dataset=True, append_eos_to_tgt=True,
) )
eos, pad = src_dict.eos(), src_dict.pad() eos, pad = src_dict.eos(), src_dict.pad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment