Commit 86e93f2b authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Explicitly list out generation args for backtranslation dataset

Summary:
Using argparse Namespace hides the actual args that are expected and makes code harder to read.

Note the difference in style for the args list

    def __init__(
        self,
        tgt_dataset,
        tgt_dict,
        backtranslation_model,
        unkpen,
        sampling,
        beam,
        max_len_a,
        max_len_b,
    ):

instead of

    def __init__(
        self, tgt_dataset, tgt_dict, backtranslation_model, unkpen, sampling,
        beam,  max_len_a, max_len_b,
    ):

Reviewed By: dpacgopinath

Differential Revision: D10152331

fbshipit-source-id: 6539ccba09d48acf23759996b7e32fb329b3e3f6
parent 22e535e2
...@@ -12,18 +12,22 @@ from . import FairseqDataset, language_pair_dataset ...@@ -12,18 +12,22 @@ from . import FairseqDataset, language_pair_dataset
class BacktranslationDataset(FairseqDataset): class BacktranslationDataset(FairseqDataset):
def __init__(self, args, tgt_dataset, tgt_dict, backtranslation_model): def __init__(
self,
tgt_dataset,
tgt_dict,
backtranslation_model,
unkpen,
sampling,
beam,
max_len_a,
max_len_b,
):
""" """
Sets up a backtranslation dataset which takes a tgt batch, generates Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation_model, and returns the a src using a tgt-src backtranslation_model, and returns the
corresponding {generated src, input tgt} batch corresponding {generated src, input tgt} batch
Args: Args:
args: generation args for the backtranslation SequenceGenerator'
Note that there is no equivalent argparse code for these args
anywhere in our top level train scripts yet. Integration is
still in progress. You can still, however, test out this dataset
functionality with the appropriate args as in the corresponding
unittest: test_backtranslation_dataset.
tgt_dataset: dataset which will be used to build self.tgt_dataset -- tgt_dataset: dataset which will be used to build self.tgt_dataset --
a LanguagePairDataset with tgt dataset as the source dataset and a LanguagePairDataset with tgt dataset as the source dataset and
None as the target dataset. None as the target dataset.
...@@ -33,6 +37,8 @@ class BacktranslationDataset(FairseqDataset): ...@@ -33,6 +37,8 @@ class BacktranslationDataset(FairseqDataset):
tgt_dict: tgt dictionary (typically a joint src/tgt BPE dictionary) tgt_dict: tgt dictionary (typically a joint src/tgt BPE dictionary)
backtranslation_model: tgt-src model to use in the SequenceGenerator backtranslation_model: tgt-src model to use in the SequenceGenerator
to generate backtranslations from tgt batches to generate backtranslations from tgt batches
unkpen, sampling, beam, max_len_a, max_len_b: generation args for
the backtranslation SequenceGenerator
""" """
self.tgt_dataset = language_pair_dataset.LanguagePairDataset( self.tgt_dataset = language_pair_dataset.LanguagePairDataset(
src=tgt_dataset, src=tgt_dataset,
...@@ -45,13 +51,13 @@ class BacktranslationDataset(FairseqDataset): ...@@ -45,13 +51,13 @@ class BacktranslationDataset(FairseqDataset):
self.backtranslation_generator = sequence_generator.SequenceGenerator( self.backtranslation_generator = sequence_generator.SequenceGenerator(
[backtranslation_model], [backtranslation_model],
tgt_dict, tgt_dict,
unk_penalty=args.backtranslation_unkpen, unk_penalty=unkpen,
sampling=args.backtranslation_sampling, sampling=sampling,
beam_size=args.backtranslation_beam, beam_size=beam,
) )
self.backtranslation_max_len_a = args.backtranslation_max_len_a self.max_len_a = max_len_a
self.backtranslation_max_len_b = args.backtranslation_max_len_b self.max_len_b = max_len_b
self.backtranslation_beam = args.backtranslation_beam self.beam = beam
def __getitem__(self, index): def __getitem__(self, index):
""" """
...@@ -75,8 +81,10 @@ class BacktranslationDataset(FairseqDataset): ...@@ -75,8 +81,10 @@ class BacktranslationDataset(FairseqDataset):
feed to the backtranslation model. Then take the generated translation feed to the backtranslation model. Then take the generated translation
with best score as the source and the orignal net input as the target. with best score as the source and the orignal net input as the target.
""" """
collated_tgt_only_sample = self.tgt_dataset.collater(samples) collated_tgt_only_sample = self.tgt_dataset.collater(samples=samples)
backtranslation_hypos = self._generate_hypotheses(collated_tgt_only_sample) backtranslation_hypos = self._generate_hypotheses(
sample=collated_tgt_only_sample
)
# Go through each tgt sentence in batch and its corresponding best # Go through each tgt sentence in batch and its corresponding best
# generated hypothesis and create a backtranslation data pair # generated hypothesis and create a backtranslation data pair
...@@ -125,7 +133,7 @@ class BacktranslationDataset(FairseqDataset): ...@@ -125,7 +133,7 @@ class BacktranslationDataset(FairseqDataset):
hypos = self.backtranslation_generator.generate( hypos = self.backtranslation_generator.generate(
input, input,
maxlen=int( maxlen=int(
self.backtranslation_max_len_a * srclen + self.backtranslation_max_len_b self.max_len_a * srclen + self.max_len_b
), ),
) )
return hypos return hypos
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import argparse
import unittest import unittest
import tests.utils as test_utils import tests.utils as test_utils
...@@ -18,18 +17,6 @@ class TestBacktranslationDataset(unittest.TestCase): ...@@ -18,18 +17,6 @@ class TestBacktranslationDataset(unittest.TestCase):
self.tgt_dict, self.w1, self.w2, self.src_tokens, self.src_lengths, self.model = ( self.tgt_dict, self.w1, self.w2, self.src_tokens, self.src_lengths, self.model = (
test_utils.sequence_generator_setup() test_utils.sequence_generator_setup()
) )
backtranslation_args = argparse.Namespace()
"""
Same as defaults from fairseq/options.py
"""
backtranslation_args.backtranslation_unkpen = 0
backtranslation_args.backtranslation_sampling = False
backtranslation_args.backtranslation_max_len_a = 0
backtranslation_args.backtranslation_max_len_b = 200
backtranslation_args.backtranslation_beam = 2
self.backtranslation_args = backtranslation_args
dummy_src_samples = self.src_tokens dummy_src_samples = self.src_tokens
...@@ -37,10 +24,14 @@ class TestBacktranslationDataset(unittest.TestCase): ...@@ -37,10 +24,14 @@ class TestBacktranslationDataset(unittest.TestCase):
def test_backtranslation_dataset(self): def test_backtranslation_dataset(self):
backtranslation_dataset = BacktranslationDataset( backtranslation_dataset = BacktranslationDataset(
args=self.backtranslation_args,
tgt_dataset=self.tgt_dataset, tgt_dataset=self.tgt_dataset,
tgt_dict=self.tgt_dict, tgt_dict=self.tgt_dict,
backtranslation_model=self.model, backtranslation_model=self.model,
unkpen=0,
sampling=False,
max_len_a=0,
max_len_b=200,
beam=2,
) )
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
backtranslation_dataset, backtranslation_dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment