test_noising.py 7.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch
import unittest

11
from fairseq.data import Dictionary, data_utils, noising
12
13
14


class TestDataNoising(unittest.TestCase):
15
    def _get_test_data(self, append_eos=True):
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
        vocab = Dictionary()
        vocab.add_symbol("he@@")
        vocab.add_symbol("llo")
        vocab.add_symbol("how")
        vocab.add_symbol("are")
        vocab.add_symbol("y@@")
        vocab.add_symbol("ou")
        vocab.add_symbol("n@@")
        vocab.add_symbol("ew")
        vocab.add_symbol("or@@")
        vocab.add_symbol("k")

        src_tokens = [
            ["he@@", "llo", "n@@", "ew", "y@@", "or@@", "k"],
            ["how", "are", "y@@", "ou"],
        ]
        src_len = [len(x) for x in src_tokens]
33
34
35
36
37
        # If we have to append EOS, we include EOS in counting src length
        if append_eos:
            src_len = [length + 1 for length in src_len]

        x = torch.LongTensor(len(src_tokens), max(src_len)).fill_(vocab.pad())
38
39
40
        for i in range(len(src_tokens)):
            for j in range(len(src_tokens[i])):
                x[i][j] = vocab.index(src_tokens[i][j])
41
42
            if append_eos:
                x[i][j + 1] = vocab.eos()
43
44

        x = x.transpose(1, 0)
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        return vocab, x, torch.LongTensor(src_len)

    def assert_eos_at_end(self, x, x_len, eos):
        """ Asserts last token of every sentence in x is EOS """
        for i in range(len(x_len)):
            self.assertEqual(
                x[x_len[i]-1][i],
                eos,
                f"Expected eos (token id {eos}) at the end of sentence {i} but "
                f"got {x[i][-1]} instead"
            )

    def assert_word_dropout_correct(self, x, x_noised, x_len, l_noised):
        # Expect only the first word (2 bpe tokens) of the first example
        # was dropped out
        self.assertEqual(x_len[0] - 2, l_noised[0])
        for i in range(l_noised[0]):
            self.assertEqual(x_noised[i][0], x[i + 2][0])

    def test_word_dropout_with_eos(self):
        vocab, x, x_len = self._get_test_data(append_eos=True)

        with data_utils.numpy_seed(1234):
            noising_gen = noising.WordDropout(vocab)
            x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
            self.assert_word_dropout_correct(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
            )
            self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())

    def assert_word_blanking_correct(self, x, x_noised, x_len, l_noised, unk):
        # Expect only the first word (2 bpe tokens) of the first example
        # was blanked out
        self.assertEqual(x_len[0], l_noised[0])
        for i in range(l_noised[0]):
            if i < 2:
                self.assertEqual(x_noised[i][0], unk)
            else:
                self.assertEqual(x_noised[i][0], x[i][0])

    def test_word_blank_with_eos(self):
        vocab, x, x_len = self._get_test_data(append_eos=True)

        with data_utils.numpy_seed(1234):
            noising_gen = noising.WordDropout(vocab)
            x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
            self.assert_word_blanking_correct(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
            )
            self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())

    def assert_no_shuffle_with_0_distance(self, x, x_noised, x_len, l_noised):
        """
        Applies word shuffle with 0 max_shuffle_distance and asserts that no
        shuffling happened
        """
        for i in range(len(x_len)):
            for j in range(x_len[i]):
                self.assertEqual(x[j][i], x_noised[j][i])
        self.assertEqual(x_len[0], l_noised[0])

    def assert_word_shuffle_with_distance_3(self, x, x_noised, x_len, l_noised):
        """
        Applies word shuffle with max_shuffle_distance = 3 and asserts that the
        shuffling result is as expected. If test data changes, update this func
        """
        # Expect the second example has the last three tokens shuffled
        # 6, 7, 8, 9 => 6, 8, 9, 7, where (8, 9) is a word
        for i in range(x_len[0]):
            self.assertEqual(x[i][0], x_noised[i][0])
        shuffle_map = {0: 0, 1: 3, 2: 1, 3: 2}
        for k, v in shuffle_map.items():
            self.assertEqual(x[k][1], x_noised[v][1])
        self.assertEqual(x_len[0], l_noised[0])
        self.assertEqual(x_len[1], l_noised[1])

    def test_word_shuffle_with_eos(self):
        vocab, x, x_len = self._get_test_data(append_eos=True)

        with data_utils.numpy_seed(1234):
            word_shuffle = noising.WordShuffle(vocab)
126

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
            x_noised, l_noised = word_shuffle.noising(x, x_len, 0)
            self.assert_no_shuffle_with_0_distance(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
            )
            self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())

            x_noised, l_noised = word_shuffle.noising(x, x_len, 3)
            self.assert_word_shuffle_with_distance_3(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
            )
            self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())

    def assert_no_eos_at_end(self, x, x_len, eos):
        """ Asserts that the last token of each sentence in x is not EOS """
        for i in range(len(x_len)):
            self.assertNotEqual(
                x[x_len[i]-1][i],
                eos,
                f"Expected no eos (token id {eos}) at the end of sentence {i}."
            )

    def test_word_dropout_without_eos(self):
        """ Same result as word dropout with eos except no EOS at end"""
        vocab, x, x_len = self._get_test_data(append_eos=False)
151
152
153
154

        with data_utils.numpy_seed(1234):
            noising_gen = noising.WordDropout(vocab)
            x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
155
156
157
158
            self.assert_word_dropout_correct(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
            )
            self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
159

160
161
162
    def test_word_blank_without_eos(self):
        """ Same result as word blank with eos except no EOS at end"""
        vocab, x, x_len = self._get_test_data(append_eos=False)
163
164
165
166

        with data_utils.numpy_seed(1234):
            noising_gen = noising.WordDropout(vocab)
            x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
167
168
169
170
171
172
173
174
            self.assert_word_blanking_correct(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
            )
            self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())

    def test_word_shuffle_without_eos(self):
        """ Same result as word shuffle with eos except no EOS at end """
        vocab, x, x_len = self._get_test_data(append_eos=False)
175
176
177
178
179

        with data_utils.numpy_seed(1234):
            word_shuffle = noising.WordShuffle(vocab)

            x_noised, l_noised = word_shuffle.noising(x, x_len, 0)
180
181
182
183
            self.assert_no_shuffle_with_0_distance(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
            )
            self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
184
185

            x_noised, l_noised = word_shuffle.noising(x, x_len, 3)
186
187
188
189
            self.assert_word_shuffle_with_distance_3(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
            )
            self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
190
191
192
193


if __name__ == '__main__':
    unittest.main()