"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c6714fc3bfc4b8ccba08ea68cebb095f2af1d75e"
Commit f766c9a0 authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Pass in kwargs and SequenceGenerator class to init BacktranslationDataset

Summary: This generalizes BacktranslationDataset to allow us to use any SequenceGenerator class. For example, if we want to use this model in PyTorch Translate, we can pass the following to BacktraanslationDataset init: (1) a PyTorch Translate SequenceGenerator class as generator_class and (2) the appropriate args for initializing that class as kwargs.

Reviewed By: xianxl

Differential Revision: D10156552

fbshipit-source-id: 0495d825bf4727da96d0d9a40dc434135ff3486c
parent df88ba95
...@@ -17,11 +17,10 @@ class BacktranslationDataset(FairseqDataset): ...@@ -17,11 +17,10 @@ class BacktranslationDataset(FairseqDataset):
tgt_dataset, tgt_dataset,
tgt_dict, tgt_dict,
backtranslation_model, backtranslation_model,
unkpen,
sampling,
beam,
max_len_a, max_len_a,
max_len_b, max_len_b,
generator_class=sequence_generator.SequenceGenerator,
**kwargs,
): ):
""" """
Sets up a backtranslation dataset which takes a tgt batch, generates Sets up a backtranslation dataset which takes a tgt batch, generates
...@@ -37,8 +36,13 @@ class BacktranslationDataset(FairseqDataset): ...@@ -37,8 +36,13 @@ class BacktranslationDataset(FairseqDataset):
tgt_dict: tgt dictionary (typically a joint src/tgt BPE dictionary) tgt_dict: tgt dictionary (typically a joint src/tgt BPE dictionary)
backtranslation_model: tgt-src model to use in the SequenceGenerator backtranslation_model: tgt-src model to use in the SequenceGenerator
to generate backtranslations from tgt batches to generate backtranslations from tgt batches
unkpen, sampling, beam, max_len_a, max_len_b: generation args for max_len_a, max_len_b: args passed into generate() function of
the backtranslation SequenceGenerator the backtranslation SequenceGenerator
generator_class: which SequenceGenerator class to use for
backtranslation. Output of generate() should be the same format
as fairseq's SequenceGenerator
kwargs: generation args to init the backtranslation
SequenceGenerator
""" """
self.tgt_dataset = language_pair_dataset.LanguagePairDataset( self.tgt_dataset = language_pair_dataset.LanguagePairDataset(
src=tgt_dataset, src=tgt_dataset,
...@@ -48,16 +52,14 @@ class BacktranslationDataset(FairseqDataset): ...@@ -48,16 +52,14 @@ class BacktranslationDataset(FairseqDataset):
tgt_sizes=None, tgt_sizes=None,
tgt_dict=None, tgt_dict=None,
) )
self.backtranslation_generator = sequence_generator.SequenceGenerator(
[backtranslation_model],
tgt_dict,
unk_penalty=unkpen,
sampling=sampling,
beam_size=beam,
)
self.max_len_a = max_len_a self.max_len_a = max_len_a
self.max_len_b = max_len_b self.max_len_b = max_len_b
self.beam = beam self.backtranslation_generator = generator_class(
models=[backtranslation_model],
tgt_dict=tgt_dict,
**kwargs,
)
def __getitem__(self, index): def __getitem__(self, index):
""" """
......
...@@ -10,6 +10,7 @@ import unittest ...@@ -10,6 +10,7 @@ import unittest
import tests.utils as test_utils import tests.utils as test_utils
import torch import torch
from fairseq.data.backtranslation_dataset import BacktranslationDataset from fairseq.data.backtranslation_dataset import BacktranslationDataset
from fairseq import sequence_generator
class TestBacktranslationDataset(unittest.TestCase): class TestBacktranslationDataset(unittest.TestCase):
...@@ -23,15 +24,19 @@ class TestBacktranslationDataset(unittest.TestCase): ...@@ -23,15 +24,19 @@ class TestBacktranslationDataset(unittest.TestCase):
self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples) self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)
def test_backtranslation_dataset(self): def test_backtranslation_dataset(self):
"""
SequenceGenerator kwargs are same as defaults from fairseq/options.py
"""
backtranslation_dataset = BacktranslationDataset( backtranslation_dataset = BacktranslationDataset(
tgt_dataset=self.tgt_dataset, tgt_dataset=self.tgt_dataset,
tgt_dict=self.tgt_dict, tgt_dict=self.tgt_dict,
backtranslation_model=self.model, backtranslation_model=self.model,
unkpen=0,
sampling=False,
max_len_a=0, max_len_a=0,
max_len_b=200, max_len_b=200,
beam=2, beam_size=2,
unk_penalty=0,
sampling=False,
generator_class=sequence_generator.SequenceGenerator,
) )
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
backtranslation_dataset, backtranslation_dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment