Commit 613ffeea authored by Deepak Gopinath's avatar Deepak Gopinath Committed by Facebook Github Bot
Browse files

Add size method to BacktranslationDataset + misc fixes (#325)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/325

RoundRobinZipDataset requires size(index) method implemented in every dataset used. Also added missing return statements in a few methods.

Reviewed By: liezl200

Differential Revision: D10457159

fbshipit-source-id: 01856eb455f2f3a21e7fb723129ff35fbe29e0ae
parent 1aae5f6a
......@@ -55,7 +55,7 @@ class BacktranslationDataset(FairseqDataset):
"""
self.tgt_dataset = language_pair_dataset.LanguagePairDataset(
src=tgt_dataset,
src_sizes=None,
src_sizes=tgt_dataset.sizes,
src_dict=tgt_dict,
tgt=None,
tgt_sizes=None,
......@@ -141,19 +141,19 @@ class BacktranslationDataset(FairseqDataset):
def get_dummy_batch(self, num_tokens, max_positions):
""" Just use the tgt dataset get_dummy_batch """
self.tgt_dataset.get_dummy_batch(num_tokens, max_positions)
return self.tgt_dataset.get_dummy_batch(num_tokens, max_positions)
def num_tokens(self, index):
""" Just use the tgt dataset num_tokens """
self.tgt_dataset.num_tokens(index)
return self.tgt_dataset.num_tokens(index)
def ordered_indices(self):
""" Just use the tgt dataset ordered_indices """
self.tgt_dataset.ordered_indices
return self.tgt_dataset.ordered_indices()
def valid_size(self, index, max_positions):
""" Just use the tgt dataset size """
self.tgt_dataset.valid_size(index, max_positions)
return self.tgt_dataset.valid_size(index, max_positions)
def _generate_hypotheses(self, sample):
"""
......@@ -171,3 +171,11 @@ class BacktranslationDataset(FairseqDataset):
),
)
return hypos
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``.
Here, we return src dataset size as tgt dataset size as an approximation.
We do not know src size until we backtranslate and generate src sentences.
"""
return (self.tgt_dataset.size(index), self.tgt_dataset.size(index))
......@@ -121,6 +121,7 @@ class TestDataset(torch.utils.data.Dataset):
def __init__(self, data):
super().__init__()
self.data = data
self.sizes = None
def __getitem__(self, index):
return self.data[index]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment