test_backtranslation_dataset.py 2.54 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import argparse
import unittest

import tests.utils as test_utils
import torch
from fairseq.data.backtranslation_dataset import BacktranslationDataset


class TestBacktranslationDataset(unittest.TestCase):
    def setUp(self):
        self.tgt_dict, self.w1, self.w2, self.src_tokens, self.src_lengths, self.model = (
            test_utils.sequence_generator_setup()
        )
        backtranslation_args = argparse.Namespace()

        """
        Same as defaults from fairseq/options.py
        """
        backtranslation_args.backtranslation_unkpen = 0
        backtranslation_args.backtranslation_sampling = False
        backtranslation_args.backtranslation_max_len_a = 0
        backtranslation_args.backtranslation_max_len_b = 200
        backtranslation_args.backtranslation_beam = 2

        self.backtranslation_args = backtranslation_args

        dummy_src_samples = self.src_tokens

        self.tgt_dataset = test_utils.TestDataset(data=dummy_src_samples)

    def test_backtranslation_dataset(self):
        backtranslation_dataset = BacktranslationDataset(
            args=self.backtranslation_args,
            tgt_dataset=self.tgt_dataset,
            tgt_dict=self.tgt_dict,
            backtranslation_model=self.model,
        )
        dataloader = torch.utils.data.DataLoader(
            backtranslation_dataset,
            batch_size=2,
            collate_fn=backtranslation_dataset.collater,
        )
        backtranslation_batch_result = next(iter(dataloader))

        eos, pad, w1, w2 = self.tgt_dict.eos(), self.tgt_dict.pad(), self.w1, self.w2

        # Note that we sort by src_lengths and add left padding, so actually
        # ids will look like: [1, 0]
        expected_src = torch.LongTensor([[w1, w2, w1, eos], [pad, pad, w1, eos]])
        expected_tgt = torch.LongTensor([[w1, w2, eos], [w1, w2, eos]])
        generated_src = backtranslation_batch_result["net_input"]["src_tokens"]
        tgt_tokens = backtranslation_batch_result["target"]

        self.assertTensorEqual(expected_src, generated_src)
        self.assertTensorEqual(expected_tgt, tgt_tokens)

    def assertTensorEqual(self, t1, t2):
        self.assertEqual(t1.size(), t2.size(), "size mismatch")
        self.assertEqual(t1.ne(t2).long().sum(), 0)


if __name__ == "__main__":
    unittest.main()