Commit 5bbd148e authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix tests + style nits + Python 3.5 compat

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/336

Differential Revision: D12876709

Pulled By: myleott

fbshipit-source-id: a31536e2eb93f752600b9940c28e9b9fcefc8b86
parent f3a0939e
...@@ -7,10 +7,10 @@ ...@@ -7,10 +7,10 @@
from .dictionary import Dictionary, TruncatedDictionary from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset from .fairseq_dataset import FairseqDataset
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .append_eos_dataset import AppendEosDataset from .append_eos_dataset import AppendEosDataset
from .backtranslation_dataset import BacktranslationDataset from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets from .round_robin_zip_datasets import RoundRobinZipDatasets
...@@ -25,6 +25,7 @@ from .iterators import ( ...@@ -25,6 +25,7 @@ from .iterators import (
__all__ = [ __all__ = [
'AppendEosDataset', 'AppendEosDataset',
'BacktranslationDataset',
'ConcatDataset', 'ConcatDataset',
'CountingIterator', 'CountingIterator',
'Dictionary', 'Dictionary',
...@@ -40,5 +41,4 @@ __all__ = [ ...@@ -40,5 +41,4 @@ __all__ = [
'RoundRobinZipDatasets', 'RoundRobinZipDatasets',
'ShardedIterator', 'ShardedIterator',
'TokenBlockDataset', 'TokenBlockDataset',
'BacktranslationDataset',
] ]
...@@ -17,7 +17,6 @@ class AppendEosDataset(torch.utils.data.Dataset): ...@@ -17,7 +17,6 @@ class AppendEosDataset(torch.utils.data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
item = torch.cat([self.dataset[index], torch.LongTensor([self.eos])]) item = torch.cat([self.dataset[index], torch.LongTensor([self.eos])])
print(item)
return item return item
def __len__(self): def __len__(self):
......
...@@ -23,6 +23,7 @@ class BacktranslationDataset(FairseqDataset): ...@@ -23,6 +23,7 @@ class BacktranslationDataset(FairseqDataset):
max_len_b, max_len_b,
remove_eos_at_src=False, remove_eos_at_src=False,
generator_class=sequence_generator.SequenceGenerator, generator_class=sequence_generator.SequenceGenerator,
cuda=True,
**kwargs, **kwargs,
): ):
""" """
...@@ -51,6 +52,7 @@ class BacktranslationDataset(FairseqDataset): ...@@ -51,6 +52,7 @@ class BacktranslationDataset(FairseqDataset):
generator_class: which SequenceGenerator class to use for generator_class: which SequenceGenerator class to use for
backtranslation. Output of generate() should be the same format backtranslation. Output of generate() should be the same format
as fairseq's SequenceGenerator as fairseq's SequenceGenerator
cuda: use GPU for generation
kwargs: generation args to init the backtranslation kwargs: generation args to init the backtranslation
SequenceGenerator SequenceGenerator
""" """
...@@ -73,6 +75,10 @@ class BacktranslationDataset(FairseqDataset): ...@@ -73,6 +75,10 @@ class BacktranslationDataset(FairseqDataset):
**kwargs, **kwargs,
) )
self.cuda = cuda if torch.cuda.is_available() else False
if self.cuda:
self.backtranslation_generator.cuda()
def __getitem__(self, index): def __getitem__(self, index):
""" """
Returns a single sample. Multiple samples are fed to the collater to Returns a single sample. Multiple samples are fed to the collater to
...@@ -105,32 +111,32 @@ class BacktranslationDataset(FairseqDataset): ...@@ -105,32 +111,32 @@ class BacktranslationDataset(FairseqDataset):
# {id: id, source: generated backtranslation, target: original tgt} # {id: id, source: generated backtranslation, target: original tgt}
generated_samples = [] generated_samples = []
for input_sample, hypos in zip(samples, backtranslation_hypos): for input_sample, hypos in zip(samples, backtranslation_hypos):
eos = self.tgt_dataset.src_dict.eos() original_tgt = input_sample["source"].cpu()
generated_source = hypos[0]["tokens"].cpu() # first hypo is best hypo
# Append EOS to the tgt sentence if it does not have an EOS # Append EOS to the tgt sentence if it does not have an EOS
# This is the case if the samples in monolingual tgt_dataset don't # This is the case if the samples in monolingual tgt_dataset don't
# have an EOS appended to the end of each sentence. # have an EOS appended to the end of each sentence.
original_tgt = input_sample["source"] eos = self.tgt_dataset.src_dict.eos()
if original_tgt[-1] != eos: if original_tgt[-1] != eos:
original_tgt = torch.cat([original_tgt, torch.LongTensor([eos])]) original_tgt = torch.cat([original_tgt, torch.LongTensor([eos])])
# The generated source dialect backtranslation will have an EOS. # The generated source dialect backtranslation will have an EOS.
# If we want our parallel data source to not have an EOS, we will # If we want our parallel data source to not have an EOS, we will
# have to remove it. # have to remove it.
generated_source = hypos[0]["tokens"] # first hypo is best hypo
if self.remove_eos_at_src: if self.remove_eos_at_src:
assert generated_source[-1] == eos, ( assert generated_source[-1] == eos, (
f"Expected generated backtranslation to have eos (id: " "Expected generated backtranslation to have eos (id: "
f"{eos}) at end, but instead found token id " "{eos}) at end, but instead found token id "
f"{generated_source[-1]} at end." "{generated_source[-1]} at end."
) ).format(eos=eos, generated_source=generated_source)
generated_source = generated_source[:-1] generated_source = generated_source[:-1]
generated_samples.append( generated_samples.append(
{ {
"id": input_sample["id"], "id": input_sample["id"],
"source": generated_source.cpu(), "source": generated_source,
"target": original_tgt.cpu(), "target": original_tgt,
} }
) )
...@@ -162,11 +168,7 @@ class BacktranslationDataset(FairseqDataset): ...@@ -162,11 +168,7 @@ class BacktranslationDataset(FairseqDataset):
sample. Note in this case, sample["target"] is None, and sample. Note in this case, sample["target"] is None, and
sample["net_input"]["src_tokens"] is really in tgt language. sample["net_input"]["src_tokens"] is really in tgt language.
""" """
if torch.cuda.is_available(): s = utils.move_to_cuda(sample) if self.cuda else sample
s = utils.move_to_cuda(sample)
else:
s = sample
self.backtranslation_generator.cuda()
input = s["net_input"] input = s["net_input"]
srclen = input["src_tokens"].size(1) srclen = input["src_tokens"].size(1)
hypos = self.backtranslation_generator.generate( hypos = self.backtranslation_generator.generate(
......
...@@ -66,13 +66,15 @@ class TestDataNoising(unittest.TestCase): ...@@ -66,13 +66,15 @@ class TestDataNoising(unittest.TestCase):
return vocab, x, torch.LongTensor(src_len) return vocab, x, torch.LongTensor(src_len)
def assert_eos_at_end(self, x, x_len, eos): def assert_eos_at_end(self, x, x_len, eos):
""" Asserts last token of every sentence in x is EOS """ """Asserts last token of every sentence in x is EOS """
for i in range(len(x_len)): for i in range(len(x_len)):
self.assertEqual( self.assertEqual(
x[x_len[i]-1][i], x[x_len[i]-1][i],
eos, eos,
f"Expected eos (token id {eos}) at the end of sentence {i} but " (
f"got {x[i][-1]} instead" "Expected eos (token id {eos}) at the end of sentence {i} but "
"got {other} instead"
).format(i=i, eos=eos, other=x[i][-1])
) )
def assert_word_dropout_correct(self, x, x_noised, x_len, l_noised): def assert_word_dropout_correct(self, x, x_noised, x_len, l_noised):
...@@ -192,16 +194,18 @@ class TestDataNoising(unittest.TestCase): ...@@ -192,16 +194,18 @@ class TestDataNoising(unittest.TestCase):
self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def assert_no_eos_at_end(self, x, x_len, eos): def assert_no_eos_at_end(self, x, x_len, eos):
""" Asserts that the last token of each sentence in x is not EOS """ """Asserts that the last token of each sentence in x is not EOS """
for i in range(len(x_len)): for i in range(len(x_len)):
self.assertNotEqual( self.assertNotEqual(
x[x_len[i]-1][i], x[x_len[i]-1][i],
eos, eos,
f"Expected no eos (token id {eos}) at the end of sentence {i}." "Expected no eos (token id {eos}) at the end of sentence {i}.".format(
eos=eos, i=i,
)
) )
def test_word_dropout_without_eos(self): def test_word_dropout_without_eos(self):
""" Same result as word dropout with eos except no EOS at end""" """Same result as word dropout with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data(append_eos=False) vocab, x, x_len = self._get_test_data(append_eos=False)
with data_utils.numpy_seed(1234): with data_utils.numpy_seed(1234):
...@@ -213,7 +217,7 @@ class TestDataNoising(unittest.TestCase): ...@@ -213,7 +217,7 @@ class TestDataNoising(unittest.TestCase):
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def test_word_blank_without_eos(self): def test_word_blank_without_eos(self):
""" Same result as word blank with eos except no EOS at end""" """Same result as word blank with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data(append_eos=False) vocab, x, x_len = self._get_test_data(append_eos=False)
with data_utils.numpy_seed(1234): with data_utils.numpy_seed(1234):
...@@ -225,7 +229,7 @@ class TestDataNoising(unittest.TestCase): ...@@ -225,7 +229,7 @@ class TestDataNoising(unittest.TestCase):
self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos()) self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
def test_word_shuffle_without_eos(self): def test_word_shuffle_without_eos(self):
""" Same result as word shuffle with eos except no EOS at end """ """Same result as word shuffle with eos except no EOS at end"""
vocab, x, x_len = self._get_test_data(append_eos=False) vocab, x, x_len = self._get_test_data(append_eos=False)
with data_utils.numpy_seed(1234): with data_utils.numpy_seed(1234):
......
...@@ -221,7 +221,8 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder): ...@@ -221,7 +221,8 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
# random attention # random attention
attn = torch.rand(bbsz, tgt_len, src_len) attn = torch.rand(bbsz, tgt_len, src_len)
return probs, attn dev = prev_output_tokens.device
return probs.to(dev), attn.to(dev)
def get_normalized_probs(self, net_output, log_probs, _): def get_normalized_probs(self, net_output, log_probs, _):
# the decoder returns probabilities directly # the decoder returns probabilities directly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment