"examples/vscode:/vscode.git/clone" did not exist on "8b451eb63b0f101e7fcc72365fe0d683808b22cd"
Commit 613ffeea authored by Deepak Gopinath's avatar Deepak Gopinath Committed by Facebook Github Bot
Browse files

Add size method to BacktranslationDataset + misc fixes (#325)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/325

RoundRobinZipDataset requires size(index) method implemented in every dataset used. Also added missing return statements in a few methods.

Reviewed By: liezl200

Differential Revision: D10457159

fbshipit-source-id: 01856eb455f2f3a21e7fb723129ff35fbe29e0ae
parent 1aae5f6a
...@@ -55,7 +55,7 @@ class BacktranslationDataset(FairseqDataset): ...@@ -55,7 +55,7 @@ class BacktranslationDataset(FairseqDataset):
""" """
self.tgt_dataset = language_pair_dataset.LanguagePairDataset( self.tgt_dataset = language_pair_dataset.LanguagePairDataset(
src=tgt_dataset, src=tgt_dataset,
src_sizes=None, src_sizes=tgt_dataset.sizes,
src_dict=tgt_dict, src_dict=tgt_dict,
tgt=None, tgt=None,
tgt_sizes=None, tgt_sizes=None,
...@@ -141,19 +141,19 @@ class BacktranslationDataset(FairseqDataset): ...@@ -141,19 +141,19 @@ class BacktranslationDataset(FairseqDataset):
def get_dummy_batch(self, num_tokens, max_positions): def get_dummy_batch(self, num_tokens, max_positions):
""" Just use the tgt dataset get_dummy_batch """ """ Just use the tgt dataset get_dummy_batch """
self.tgt_dataset.get_dummy_batch(num_tokens, max_positions) return self.tgt_dataset.get_dummy_batch(num_tokens, max_positions)
def num_tokens(self, index): def num_tokens(self, index):
""" Just use the tgt dataset num_tokens """ """ Just use the tgt dataset num_tokens """
self.tgt_dataset.num_tokens(index) return self.tgt_dataset.num_tokens(index)
def ordered_indices(self): def ordered_indices(self):
""" Just use the tgt dataset ordered_indices """ """ Just use the tgt dataset ordered_indices """
self.tgt_dataset.ordered_indices return self.tgt_dataset.ordered_indices()
def valid_size(self, index, max_positions): def valid_size(self, index, max_positions):
""" Just use the tgt dataset size """ """ Just use the tgt dataset size """
self.tgt_dataset.valid_size(index, max_positions) return self.tgt_dataset.valid_size(index, max_positions)
def _generate_hypotheses(self, sample): def _generate_hypotheses(self, sample):
""" """
...@@ -171,3 +171,11 @@ class BacktranslationDataset(FairseqDataset): ...@@ -171,3 +171,11 @@ class BacktranslationDataset(FairseqDataset):
), ),
) )
return hypos return hypos
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``.
Here, we return src dataset size as tgt dataset size as an approximation.
We do not know src size until we backtranslate and generate src sentences.
"""
return (self.tgt_dataset.size(index), self.tgt_dataset.size(index))
...@@ -121,6 +121,7 @@ class TestDataset(torch.utils.data.Dataset): ...@@ -121,6 +121,7 @@ class TestDataset(torch.utils.data.Dataset):
def __init__(self, data): def __init__(self, data):
super().__init__() super().__init__()
self.data = data self.data = data
self.sizes = None
def __getitem__(self, index): def __getitem__(self, index):
return self.data[index] return self.data[index]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment