Commit f766c9a0 authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Pass in kwargs and SequenceGenerator class to init BacktranslationDataset

Summary: This generalizes BacktranslationDataset to allow us to use any SequenceGenerator class. For example, if we want to use this model in PyTorch Translate, we can pass the following to BacktraanslationDataset init: (1) a PyTorch Translate SequenceGenerator class as generator_class and (2) the appropriate args for initializing that class as kwargs.

Reviewed By: xianxl

Differential Revision: D10156552

fbshipit-source-id: 0495d825bf4727da96d0d9a40dc434135ff3486c
parent df88ba95
......@@ -17,11 +17,10 @@ class BacktranslationDataset(FairseqDataset):
tgt_dataset,
tgt_dict,
backtranslation_model,
unkpen,
sampling,
beam,
max_len_a,
max_len_b,
generator_class=sequence_generator.SequenceGenerator,
**kwargs,
):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
......@@ -37,8 +36,13 @@ class BacktranslationDataset(FairseqDataset):
tgt_dict: tgt dictionary (typically a joint src/tgt BPE dictionary)
backtranslation_model: tgt-src model to use in the SequenceGenerator
to generate backtranslations from tgt batches
unkpen, sampling, beam, max_len_a, max_len_b: generation args for
max_len_a, max_len_b: args passed into generate() function of
the backtranslation SequenceGenerator
generator_class: which SequenceGenerator class to use for
backtranslation. Output of generate() should be the same format
as fairseq's SequenceGenerator
kwargs: generation args to init the backtranslation
SequenceGenerator
"""
self.tgt_dataset = language_pair_dataset.LanguagePairDataset(
src=tgt_dataset,
......@@ -48,16 +52,14 @@ class BacktranslationDataset(FairseqDataset):
tgt_sizes=None,
tgt_dict=None,
)
self.backtranslation_generator = sequence_generator.SequenceGenerator(
[backtranslation_model],
tgt_dict,
unk_penalty=unkpen,
sampling=sampling,
beam_size=beam,
)
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.beam = beam
self.backtranslation_generator = generator_class(
models=[backtranslation_model],
tgt_dict=tgt_dict,
**kwargs,
)
def __getitem__(self, index):
"""
......
......@@ -10,6 +10,7 @@ import unittest
import tests.utils as test_utils
import torch
from fairseq.data.backtranslation_dataset import BacktranslationDataset
from fairseq import sequence_generator
class TestBacktranslationDataset(unittest.TestCase):
......@@ -23,15 +24,19 @@ class TestBacktranslationDataset(unittest.TestCase):
self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)
def test_backtranslation_dataset(self):
"""
SequenceGenerator kwargs are same as defaults from fairseq/options.py
"""
backtranslation_dataset = BacktranslationDataset(
tgt_dataset=self.tgt_dataset,
tgt_dict=self.tgt_dict,
backtranslation_model=self.model,
unkpen=0,
sampling=False,
max_len_a=0,
max_len_b=200,
beam=2,
beam_size=2,
unk_penalty=0,
sampling=False,
generator_class=sequence_generator.SequenceGenerator,
)
dataloader = torch.utils.data.DataLoader(
backtranslation_dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment