You need to sign in or sign up before continuing.
Commit 613ffeea authored by Deepak Gopinath's avatar Deepak Gopinath Committed by Facebook Github Bot
Browse files

Add size method to BacktranslationDataset + misc fixes (#325)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/325

RoundRobinZipDataset requires size(index) method implemented in every dataset used. Also added missing return statements in a few methods.

Reviewed By: liezl200

Differential Revision: D10457159

fbshipit-source-id: 01856eb455f2f3a21e7fb723129ff35fbe29e0ae
parent 1aae5f6a
...@@ -55,7 +55,7 @@ class BacktranslationDataset(FairseqDataset): ...@@ -55,7 +55,7 @@ class BacktranslationDataset(FairseqDataset):
""" """
self.tgt_dataset = language_pair_dataset.LanguagePairDataset( self.tgt_dataset = language_pair_dataset.LanguagePairDataset(
src=tgt_dataset, src=tgt_dataset,
src_sizes=None, src_sizes=tgt_dataset.sizes,
src_dict=tgt_dict, src_dict=tgt_dict,
tgt=None, tgt=None,
tgt_sizes=None, tgt_sizes=None,
...@@ -141,19 +141,19 @@ class BacktranslationDataset(FairseqDataset): ...@@ -141,19 +141,19 @@ class BacktranslationDataset(FairseqDataset):
def get_dummy_batch(self, num_tokens, max_positions): def get_dummy_batch(self, num_tokens, max_positions):
""" Just use the tgt dataset get_dummy_batch """ """ Just use the tgt dataset get_dummy_batch """
self.tgt_dataset.get_dummy_batch(num_tokens, max_positions) return self.tgt_dataset.get_dummy_batch(num_tokens, max_positions)
def num_tokens(self, index): def num_tokens(self, index):
""" Just use the tgt dataset num_tokens """ """ Just use the tgt dataset num_tokens """
self.tgt_dataset.num_tokens(index) return self.tgt_dataset.num_tokens(index)
def ordered_indices(self): def ordered_indices(self):
""" Just use the tgt dataset ordered_indices """ """ Just use the tgt dataset ordered_indices """
self.tgt_dataset.ordered_indices return self.tgt_dataset.ordered_indices()
def valid_size(self, index, max_positions): def valid_size(self, index, max_positions):
""" Just use the tgt dataset size """ """ Just use the tgt dataset size """
self.tgt_dataset.valid_size(index, max_positions) return self.tgt_dataset.valid_size(index, max_positions)
def _generate_hypotheses(self, sample): def _generate_hypotheses(self, sample):
""" """
...@@ -171,3 +171,11 @@ class BacktranslationDataset(FairseqDataset): ...@@ -171,3 +171,11 @@ class BacktranslationDataset(FairseqDataset):
), ),
) )
return hypos return hypos
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``.
Here, we return src dataset size as tgt dataset size as an approximation.
We do not know src size until we backtranslate and generate src sentences.
"""
return (self.tgt_dataset.size(index), self.tgt_dataset.size(index))
...@@ -121,6 +121,7 @@ class TestDataset(torch.utils.data.Dataset): ...@@ -121,6 +121,7 @@ class TestDataset(torch.utils.data.Dataset):
def __init__(self, data): def __init__(self, data):
super().__init__() super().__init__()
self.data = data self.data = data
self.sizes = None
def __getitem__(self, index): def __getitem__(self, index):
return self.data[index] return self.data[index]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment