Commit 864b89d0 authored by Myle Ott's avatar Myle Ott
Browse files

Online backtranslation module


Co-authored-by: default avatarliezl200 <lie@fb.com>
parent a4fe8c99
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import sequence_generator
from . import FairseqDataset, language_pair_dataset
class BacktranslationDataset(FairseqDataset):
def __init__(self, args, tgt_dataset, tgt_dict, backtranslation_model):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation_model, and returns the
corresponding {generated src, input tgt} batch
Args:
args: generation args for the backtranslation SequenceGenerator'
Note that there is no equivalent argparse code for these args
anywhere in our top level train scripts yet. Integration is
still in progress. You can still, however, test out this dataset
functionality with the appropriate args as in the corresponding
unittest: test_backtranslation_dataset.
tgt_dataset: dataset which will be used to build self.tgt_dataset --
a LanguagePairDataset with tgt dataset as the source dataset and
None as the target dataset.
We use language_pair_dataset here to encapsulate the tgt_dataset
so we can re-use the LanguagePairDataset collater to format the
batches in the structure that SequenceGenerator expects.
tgt_dict: tgt dictionary (typically a joint src/tgt BPE dictionary)
backtranslation_model: tgt-src model to use in the SequenceGenerator
to generate backtranslations from tgt batches
"""
self.tgt_dataset = language_pair_dataset.LanguagePairDataset(
src=tgt_dataset,
src_sizes=None,
src_dict=tgt_dict,
tgt=None,
tgt_sizes=None,
tgt_dict=None,
)
self.backtranslation_generator = sequence_generator.SequenceGenerator(
[backtranslation_model],
tgt_dict,
unk_penalty=args.backtranslation_unkpen,
sampling=args.backtranslation_sampling,
beam_size=args.backtranslation_beam,
)
self.backtranslation_max_len_a = args.backtranslation_max_len_a
self.backtranslation_max_len_b = args.backtranslation_max_len_b
self.backtranslation_beam = args.backtranslation_beam
def __getitem__(self, index):
"""
Returns a single sample. Multiple samples are fed to the collater to
create a backtranslation batch. Note you should always use collate_fn
BacktranslationDataset.collater() below if given the option to
specify which collate_fn to use (e.g. in a dataloader which uses this
BacktranslationDataset -- see corresponding unittest for an example).
"""
return self.tgt_dataset[index]
def __len__(self):
"""
The length of the backtranslation dataset is the length of tgt.
"""
return len(self.tgt_dataset)
def collater(self, samples):
"""
Using the samples from the tgt dataset, load a collated tgt sample to
feed to the backtranslation model. Then take the generated translation
with best score as the source and the orignal net input as the target.
"""
collated_tgt_only_sample = self.tgt_dataset.collater(samples)
backtranslation_hypos = self._generate_hypotheses(collated_tgt_only_sample)
# Go through each tgt sentence in batch and its corresponding best
# generated hypothesis and create a backtranslation data pair
# {id: id, source: generated backtranslation, target: original tgt}
generated_samples = []
for input_sample, hypos in zip(samples, backtranslation_hypos):
generated_samples.append(
{
"id": input_sample["id"],
"source": hypos[0]["tokens"], # first hypo is best hypo
"target": input_sample["source"],
}
)
return language_pair_dataset.collate(
samples=generated_samples,
pad_idx=self.tgt_dataset.src_dict.pad(),
eos_idx=self.tgt_dataset.src_dict.eos(),
)
def get_dummy_batch(self, num_tokens, max_positions):
""" Just use the tgt dataset get_dummy_batch """
self.tgt_dataset.get_dummy_batch(num_tokens, max_positions)
def num_tokens(self, index):
""" Just use the tgt dataset num_tokens """
self.tgt_dataset.num_tokens(index)
def ordered_indices(self):
""" Just use the tgt dataset ordered_indices """
self.tgt_dataset.ordered_indices
def valid_size(self, index, max_positions):
""" Just use the tgt dataset size """
self.tgt_dataset.valid_size(index, max_positions)
def _generate_hypotheses(self, sample):
"""
Generates hypotheses from a LanguagePairDataset collated / batched
sample. Note in this case, sample["target"] is None, and
sample["net_input"]["src_tokens"] is really in tgt language.
"""
self.backtranslation_generator.cuda()
input = sample["net_input"]
srclen = input["src_tokens"].size(1)
hypos = self.backtranslation_generator.generate(
input,
maxlen=int(
self.backtranslation_max_len_a * srclen + self.backtranslation_max_len_b
),
)
return hypos
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import argparse
import unittest
import tests.utils as test_utils
import torch
from fairseq.data.backtranslation_dataset import BacktranslationDataset
class TestBacktranslationDataset(unittest.TestCase):
def setUp(self):
self.tgt_dict, self.w1, self.w2, self.src_tokens, self.src_lengths, self.model = (
test_utils.sequence_generator_setup()
)
backtranslation_args = argparse.Namespace()
"""
Same as defaults from fairseq/options.py
"""
backtranslation_args.backtranslation_unkpen = 0
backtranslation_args.backtranslation_sampling = False
backtranslation_args.backtranslation_max_len_a = 0
backtranslation_args.backtranslation_max_len_b = 200
backtranslation_args.backtranslation_beam = 2
self.backtranslation_args = backtranslation_args
dummy_src_samples = self.src_tokens
self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)
def test_backtranslation_dataset(self):
backtranslation_dataset = BacktranslationDataset(
args=self.backtranslation_args,
tgt_dataset=self.tgt_dataset,
tgt_dict=self.tgt_dict,
backtranslation_model=self.model,
)
dataloader = torch.utils.data.DataLoader(
backtranslation_dataset,
batch_size=2,
collate_fn=backtranslation_dataset.collater,
)
backtranslation_batch_result = next(iter(dataloader))
eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2
# Note that we sort by src_lengths and add left padding, so actually
# ids will look like: [1, 0]
expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
tgt_tokens = backtranslation_batch_result["target"]
self.assertTensorEqual(expected_src, generated_src)
self.assertTensorEqual(expected_tgt, tgt_tokens)
def assertTensorEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertEqual(t1.ne(t2).long().sum(), 0)
if __name__ == "__main__":
unittest.main()
...@@ -18,79 +18,17 @@ import tests.utils as test_utils ...@@ -18,79 +18,17 @@ import tests.utils as test_utils
class TestSequenceGenerator(unittest.TestCase): class TestSequenceGenerator(unittest.TestCase):
def setUp(self): def setUp(self):
# construct dummy dictionary self.tgt_dict, self.w1, self.w2, src_tokens, src_lengths, self.model = (
d = test_utils.dummy_dictionary(vocab_size=2) test_utils.sequence_generator_setup()
self.assertEqual(d.pad(), 1) )
self.assertEqual(d.eos(), 2)
self.assertEqual(d.unk(), 3)
self.eos = d.eos()
self.w1 = 4
self.w2 = 5
# construct source data
self.src_tokens = torch.LongTensor([
[self.w1, self.w2, self.eos],
[self.w1, self.w2, self.eos],
])
self.src_lengths = torch.LongTensor([2, 2])
self.encoder_input = { self.encoder_input = {
'src_tokens': self.src_tokens, 'src_tokens': src_tokens, 'src_lengths': src_lengths,
'src_lengths': self.src_lengths,
} }
args = argparse.Namespace()
unk = 0.
args.beam_probs = [
# step 0:
torch.FloatTensor([
# eos w1 w2
# sentence 1:
[0.0, unk, 0.9, 0.1], # beam 1
[0.0, unk, 0.9, 0.1], # beam 2
# sentence 2:
[0.0, unk, 0.7, 0.3],
[0.0, unk, 0.7, 0.3],
]),
# step 1:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[1.0, unk, 0.0, 0.0], # w1: 0.9 (emit: w1 <eos>: 0.9*1.0)
[0.0, unk, 0.9, 0.1], # w2: 0.1
# sentence 2:
[0.25, unk, 0.35, 0.4], # w1: 0.7 (don't emit: w1 <eos>: 0.7*0.25)
[0.00, unk, 0.10, 0.9], # w2: 0.3
]),
# step 2:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[0.0, unk, 0.1, 0.9], # w2 w1: 0.1*0.9
[0.6, unk, 0.2, 0.2], # w2 w2: 0.1*0.1 (emit: w2 w2 <eos>: 0.1*0.1*0.6)
# sentence 2:
[0.60, unk, 0.4, 0.00], # w1 w2: 0.7*0.4 (emit: w1 w2 <eos>: 0.7*0.4*0.6)
[0.01, unk, 0.0, 0.99], # w2 w2: 0.3*0.9
]),
# step 3:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[1.0, unk, 0.0, 0.0], # w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 <eos>: 0.1*0.9*0.9*1.0)
[1.0, unk, 0.0, 0.0], # w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 <eos>: 0.1*0.9*0.1*1.0)
# sentence 2:
[0.1, unk, 0.5, 0.4], # w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 <eos>: 0.3*0.9*0.99*0.1)
[1.0, unk, 0.0, 0.0], # w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 <eos>: 0.7*0.4*0.4*1.0)
]),
]
task = test_utils.TestTranslationTask.setup_task(args, d, d)
self.model = task.build_model(args)
self.tgt_dict = task.target_dictionary
def test_with_normalization(self): def test_with_normalization(self):
generator = SequenceGenerator([self.model], self.tgt_dict) generator = SequenceGenerator([self.model], self.tgt_dict)
hypos = generator.generate(self.encoder_input, beam_size=2) hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0]) self.assertHypoScore(hypos[0][0], [0.9, 1.0])
...@@ -109,7 +47,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -109,7 +47,7 @@ class TestSequenceGenerator(unittest.TestCase):
# Sentence 2: beams swap order # Sentence 2: beams swap order
generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False) generator = SequenceGenerator([self.model], self.tgt_dict, normalize_scores=False)
hypos = generator.generate(self.encoder_input, beam_size=2) hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0], normalized=False) self.assertHypoScore(hypos[0][0], [0.9, 1.0], normalized=False)
...@@ -127,7 +65,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -127,7 +65,7 @@ class TestSequenceGenerator(unittest.TestCase):
lenpen = 0.6 lenpen = 0.6
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen) generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.encoder_input, beam_size=2) hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0], lenpen=lenpen) self.assertHypoScore(hypos[0][0], [0.9, 1.0], lenpen=lenpen)
...@@ -145,7 +83,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -145,7 +83,7 @@ class TestSequenceGenerator(unittest.TestCase):
lenpen = 5.0 lenpen = 5.0
generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen) generator = SequenceGenerator([self.model], self.tgt_dict, len_penalty=lenpen)
hypos = generator.generate(self.encoder_input, beam_size=2) hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos]) self.assertHypoTokens(hypos[0][0], [w2, w1, w2, eos])
self.assertHypoScore(hypos[0][0], [0.1, 0.9, 0.9, 1.0], lenpen=lenpen) self.assertHypoScore(hypos[0][0], [0.1, 0.9, 0.9, 1.0], lenpen=lenpen)
...@@ -162,7 +100,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -162,7 +100,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_maxlen(self): def test_maxlen(self):
generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2) generator = SequenceGenerator([self.model], self.tgt_dict, maxlen=2)
hypos = generator.generate(self.encoder_input, beam_size=2) hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0]) self.assertHypoScore(hypos[0][0], [0.9, 1.0])
...@@ -179,7 +117,7 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -179,7 +117,7 @@ class TestSequenceGenerator(unittest.TestCase):
def test_no_stop_early(self): def test_no_stop_early(self):
generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False) generator = SequenceGenerator([self.model], self.tgt_dict, stop_early=False)
hypos = generator.generate(self.encoder_input, beam_size=2) hypos = generator.generate(self.encoder_input, beam_size=2)
eos, w1, w2 = self.eos, self.w1, self.w2 eos, w1, w2 = self.tgt_dict.eos(), self.w1, self.w2
# sentence 1, beam 1 # sentence 1, beam 1
self.assertHypoTokens(hypos[0][0], [w1, eos]) self.assertHypoTokens(hypos[0][0], [w1, eos])
self.assertHypoScore(hypos[0][0], [0.9, 1.0]) self.assertHypoScore(hypos[0][0], [0.9, 1.0])
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import argparse
import torch import torch
from fairseq import utils from fairseq import utils
...@@ -51,6 +52,70 @@ def dummy_dataloader( ...@@ -51,6 +52,70 @@ def dummy_dataloader(
return iter(dataloader) return iter(dataloader)
def sequence_generator_setup():
# construct dummy dictionary
d = dummy_dictionary(vocab_size=2)
eos = d.eos()
w1 = 4
w2 = 5
# construct source data
src_tokens = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
src_lengths = torch.LongTensor([2, 2])
args = argparse.Namespace()
unk = 0.
args.beam_probs = [
# step 0:
torch.FloatTensor([
# eos w1 w2
# sentence 1:
[0.0, unk, 0.9, 0.1], # beam 1
[0.0, unk, 0.9, 0.1], # beam 2
# sentence 2:
[0.0, unk, 0.7, 0.3],
[0.0, unk, 0.7, 0.3],
]),
# step 1:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[1.0, unk, 0.0, 0.0], # w1: 0.9 (emit: w1 <eos>: 0.9*1.0)
[0.0, unk, 0.9, 0.1], # w2: 0.1
# sentence 2:
[0.25, unk, 0.35, 0.4], # w1: 0.7 (don't emit: w1 <eos>: 0.7*0.25)
[0.00, unk, 0.10, 0.9], # w2: 0.3
]),
# step 2:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[0.0, unk, 0.1, 0.9], # w2 w1: 0.1*0.9
[0.6, unk, 0.2, 0.2], # w2 w2: 0.1*0.1 (emit: w2 w2 <eos>: 0.1*0.1*0.6)
# sentence 2:
[0.60, unk, 0.4, 0.00], # w1 w2: 0.7*0.4 (emit: w1 w2 <eos>: 0.7*0.4*0.6)
[0.01, unk, 0.0, 0.99], # w2 w2: 0.3*0.9
]),
# step 3:
torch.FloatTensor([
# eos w1 w2 prefix
# sentence 1:
[1.0, unk, 0.0, 0.0], # w2 w1 w2: 0.1*0.9*0.9 (emit: w2 w1 w2 <eos>: 0.1*0.9*0.9*1.0)
[1.0, unk, 0.0, 0.0], # w2 w1 w1: 0.1*0.9*0.1 (emit: w2 w1 w1 <eos>: 0.1*0.9*0.1*1.0)
# sentence 2:
[0.1, unk, 0.5, 0.4], # w2 w2 w2: 0.3*0.9*0.99 (emit: w2 w2 w2 <eos>: 0.3*0.9*0.99*0.1)
[1.0, unk, 0.0, 0.0], # w1 w2 w1: 0.7*0.4*0.4 (emit: w1 w2 w1 <eos>: 0.7*0.4*0.4*1.0)
]),
]
task = TestTranslationTask.setup_task(args, d, d)
model = task.build_model(args)
tgt_dict = task.target_dictionary
return tgt_dict, w1, w2, src_tokens, src_lengths, model
class TestDataset(torch.utils.data.Dataset): class TestDataset(torch.utils.data.Dataset):
def __init__(self, data): def __init__(self, data):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment