test_noising.py 15 KB
Newer Older
1
2
3
4
5
6
7
8
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import unittest
9
from typing import Dict, List
10

11
import tests.utils as test_utils
12
import torch
13
14
15
16
from fairseq import utils
from fairseq.data import (
    AppendEosDataset,
    Dictionary,
17
    LanguagePairDataset,
18
19
20
    data_utils,
    noising,
)
21
22
23


class TestDataNoising(unittest.TestCase):
24
    def _get_test_data(self, append_eos=True, bpe=True):
25
        vocab = Dictionary()
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        if bpe:
            vocab.add_symbol("he@@")
            vocab.add_symbol("llo")
            vocab.add_symbol("how")
            vocab.add_symbol("are")
            vocab.add_symbol("y@@")
            vocab.add_symbol("ou")
            vocab.add_symbol("n@@")
            vocab.add_symbol("ew")
            vocab.add_symbol("or@@")
            vocab.add_symbol("k")

            src_tokens = [
                ["he@@", "llo", "n@@", "ew", "y@@", "or@@", "k"],
                ["how", "are", "y@@", "ou"],
            ]
        else:
            vocab.add_symbol("hello")
            vocab.add_symbol("how")
            vocab.add_symbol("are")
            vocab.add_symbol("you")
            vocab.add_symbol("new")
            vocab.add_symbol("york")
            src_tokens = [
                ["hello", "new", "york", "you"],
                ["how", "are", "you", "new", "york"],
            ]

54
        src_len = [len(x) for x in src_tokens]
55
56
57
58
59
        # If we have to append EOS, we include EOS in counting src length
        if append_eos:
            src_len = [length + 1 for length in src_len]

        x = torch.LongTensor(len(src_tokens), max(src_len)).fill_(vocab.pad())
60
61
62
        for i in range(len(src_tokens)):
            for j in range(len(src_tokens[i])):
                x[i][j] = vocab.index(src_tokens[i][j])
63
64
            if append_eos:
                x[i][j + 1] = vocab.eos()
65
66

        x = x.transpose(1, 0)
67
68
69
        return vocab, x, torch.LongTensor(src_len)

    def assert_eos_at_end(self, x, x_len, eos):
70
        """Asserts last token of every sentence in x is EOS """
71
72
        for i in range(len(x_len)):
            self.assertEqual(
73
                x[x_len[i] - 1][i],
74
                eos,
75
76
77
                (
                    "Expected eos (token id {eos}) at the end of sentence {i} but "
                    "got {other} instead"
78
                ).format(i=i, eos=eos, other=x[i][-1]),
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
            )

    def assert_word_dropout_correct(self, x, x_noised, x_len, l_noised):
        # Expect only the first word (2 bpe tokens) of the first example
        # was dropped out
        self.assertEqual(x_len[0] - 2, l_noised[0])
        for i in range(l_noised[0]):
            self.assertEqual(x_noised[i][0], x[i + 2][0])

    def test_word_dropout_with_eos(self):
        vocab, x, x_len = self._get_test_data(append_eos=True)

        with data_utils.numpy_seed(1234):
            noising_gen = noising.WordDropout(vocab)
            x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
            self.assert_word_dropout_correct(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
            )
            self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())

    def assert_word_blanking_correct(self, x, x_noised, x_len, l_noised, unk):
        # Expect only the first word (2 bpe tokens) of the first example
        # was blanked out
        self.assertEqual(x_len[0], l_noised[0])
        for i in range(l_noised[0]):
            if i < 2:
                self.assertEqual(x_noised[i][0], unk)
            else:
                self.assertEqual(x_noised[i][0], x[i][0])

    def test_word_blank_with_eos(self):
        vocab, x, x_len = self._get_test_data(append_eos=True)

        with data_utils.numpy_seed(1234):
            noising_gen = noising.WordDropout(vocab)
            x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
            self.assert_word_blanking_correct(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
            )
            self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())

120
121
122
123
124
125
126
127
128
129
130
131
    def generate_unchanged_shuffle_map(self, length):
        return {i: i for i in range(length)}

    def assert_word_shuffle_matches_expected(
        self,
        x,
        x_len,
        max_shuffle_distance: int,
        vocab: Dictionary,
        expected_shufle_maps: List[Dict[int, int]],
        expect_eos_at_end: bool,
    ):
132
        """
133
134
135
136
137
138
139
140
141
142
143
144
        This verifies that with a given x, x_len, max_shuffle_distance, and
        vocab, we get the expected shuffle result.

        Args:
            x: Tensor of shape (T x B) = (sequence_length, batch_size)
            x_len: Tensor of length B = batch_size
            max_shuffle_distance: arg to pass to noising
            expected_shuffle_maps: List[mapping] where mapping is a
                Dict[old_index, new_index], mapping x's elements from their
                old positions in x to their new positions in x.
            expect_eos_at_end: if True, check the output to make sure there is
                an EOS at the end.
145
        """
146
147
        with data_utils.numpy_seed(1234):
            word_shuffle = noising.WordShuffle(vocab)
148
149
            x_noised, l_noised = word_shuffle.noising(
                x, x_len, max_shuffle_distance=max_shuffle_distance
150
151
            )

152
153
154
155
156
157
158
159
160
161
162
163
        # For every example, we have a different expected shuffle map. We check
        # that each example is shuffled as expected according to each
        # corresponding shuffle map.
        for i in range(len(expected_shufle_maps)):
            shuffle_map = expected_shufle_maps[i]
            for k, v in shuffle_map.items():
                self.assertEqual(x[k][i], x_noised[v][i])

        # Shuffling should not affect the length of each example
        for pre_shuffle_length, post_shuffle_length in zip(x_len, l_noised):
            self.assertEqual(pre_shuffle_length, post_shuffle_length)
        if expect_eos_at_end:
164
165
            self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    def test_word_shuffle_with_eos(self):
        vocab, x, x_len = self._get_test_data(append_eos=True)

        # Assert word shuffle with max shuffle distance 0 causes input to be
        # unchanged
        self.assert_word_shuffle_matches_expected(
            x=x,
            x_len=x_len,
            max_shuffle_distance=0,
            vocab=vocab,
            expected_shufle_maps=[
                self.generate_unchanged_shuffle_map(example_len)
                for example_len in x_len
            ],
            expect_eos_at_end=True,
        )

        # Assert word shuffle with max shuffle distance 3 matches our expected
        # shuffle order
        self.assert_word_shuffle_matches_expected(
            x=x,
            x_len=x_len,
            vocab=vocab,
            max_shuffle_distance=3,
            expected_shufle_maps=[
                self.generate_unchanged_shuffle_map(x_len[0]),
                {0: 0, 1: 3, 2: 1, 3: 2},
            ],
            expect_eos_at_end=True,
        )

197
198
199
    def test_word_shuffle_with_eos_nonbpe(self):
        vocab, x, x_len = self._get_test_data(append_eos=True, bpe=False)

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        # Assert word shuffle with max shuffle distance 0 causes input to be
        # unchanged
        self.assert_word_shuffle_matches_expected(
            x=x,
            x_len=x_len,
            max_shuffle_distance=0,
            vocab=vocab,
            expected_shufle_maps=[
                self.generate_unchanged_shuffle_map(example_len)
                for example_len in x_len
            ],
            expect_eos_at_end=True,
        )

        # Assert word shuffle with max shuffle distance 3 matches our expected
        # shuffle order
        self.assert_word_shuffle_matches_expected(
            x=x,
            x_len=x_len,
            vocab=vocab,
            max_shuffle_distance=3,
            expected_shufle_maps=[
                {0: 0, 1: 1, 2: 3, 3: 2},
                {0: 0, 1: 2, 2: 1, 3: 3, 4: 4},
            ],
            expect_eos_at_end=True,
        )
227

228
229
230
    def test_word_shuffle_without_eos(self):
        """Same result as word shuffle with eos except no EOS at end"""
        vocab, x, x_len = self._get_test_data(append_eos=False)
231

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        # Assert word shuffle with max shuffle distance 0 causes input to be
        # unchanged
        self.assert_word_shuffle_matches_expected(
            x=x,
            x_len=x_len,
            max_shuffle_distance=0,
            vocab=vocab,
            expected_shufle_maps=[
                self.generate_unchanged_shuffle_map(example_len)
                for example_len in x_len
            ],
            expect_eos_at_end=False,
        )

        # Assert word shuffle with max shuffle distance 3 matches our expected
        # shuffle order
        self.assert_word_shuffle_matches_expected(
            x=x,
            x_len=x_len,
            vocab=vocab,
            max_shuffle_distance=3,
            expected_shufle_maps=[
                self.generate_unchanged_shuffle_map(x_len[0]),
                {0: 0, 1: 3, 2: 1, 3: 2},
            ],
            expect_eos_at_end=False,
        )
259

260
    def assert_no_eos_at_end(self, x, x_len, eos):
261
        """Asserts that the last token of each sentence in x is not EOS """
262
263
        for i in range(len(x_len)):
            self.assertNotEqual(
264
                x[x_len[i] - 1][i],
265
                eos,
266
                "Expected no eos (token id {eos}) at the end of sentence {i}.".format(
267
268
                    eos=eos, i=i
                ),
269
270
271
            )

    def test_word_dropout_without_eos(self):
272
        """Same result as word dropout with eos except no EOS at end"""
273
        vocab, x, x_len = self._get_test_data(append_eos=False)
274
275
276
277

        with data_utils.numpy_seed(1234):
            noising_gen = noising.WordDropout(vocab)
            x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
278
279
280
281
            self.assert_word_dropout_correct(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
            )
            self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
282

283
    def test_word_blank_without_eos(self):
284
        """Same result as word blank with eos except no EOS at end"""
285
        vocab, x, x_len = self._get_test_data(append_eos=False)
286
287
288
289

        with data_utils.numpy_seed(1234):
            noising_gen = noising.WordDropout(vocab)
            x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
290
291
292
293
294
            self.assert_word_blanking_correct(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
            )
            self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    def _get_noising_dataset_batch(
        self, src_tokens_no_pad, src_dict, use_append_eos_dataset=False
    ):
        """
        Constructs a NoisingDataset and the corresponding
        LanguagePairDataset(NoisingDataset(src), src). If we set
        use_append_eos_dataset to True, wrap the source dataset in
        AppendEosDataset to append EOS to the clean source when using it as the
        target. In practice, we should use AppendEosDataset because our models
        usually have source without EOS but target with EOS.
        """
        src_dataset = test_utils.TestDataset(data=src_tokens_no_pad)

        noising_dataset = noising.NoisingDataset(
            src_dataset=src_dataset,
            src_dict=src_dict,
            seed=1234,
            max_word_shuffle_distance=3,
            word_dropout_prob=0.2,
            word_blanking_prob=0.2,
            noising_class=noising.UnsupervisedMTNoising,
        )
        tgt = src_dataset
        if use_append_eos_dataset:
            tgt = AppendEosDataset(src_dataset, src_dict.eos())
        language_pair_dataset = LanguagePairDataset(
321
            src=noising_dataset, tgt=tgt, src_sizes=None, src_dict=src_dict
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        )

        dataloader = torch.utils.data.DataLoader(
            dataset=language_pair_dataset,
            batch_size=2,
            collate_fn=language_pair_dataset.collater,
        )
        denoising_batch_result = next(iter(dataloader))
        return denoising_batch_result

    def test_noising_dataset_with_eos(self):
        src_dict, src_tokens, _ = self._get_test_data(append_eos=True)

        # Format data for src_dataset
        src_tokens = torch.t(src_tokens)
        src_tokens_no_pad = []
        for src_sentence in src_tokens:
            src_tokens_no_pad.append(
                utils.strip_pad(tensor=src_sentence, pad=src_dict.pad())
            )
        denoising_batch_result = self._get_noising_dataset_batch(
            src_tokens_no_pad=src_tokens_no_pad, src_dict=src_dict
        )

        eos, pad = src_dict.eos(), src_dict.pad()

        # Generated noisy source as source
        expected_src = torch.LongTensor(
            [[4, 5, 10, 11, 8, 12, 13, eos], [pad, pad, pad, 6, 8, 9, 7, eos]]
        )
        # Original clean source as target (right-padded)
        expected_tgt = torch.LongTensor(
            [[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]]
        )
        generated_src = denoising_batch_result["net_input"]["src_tokens"]
        tgt_tokens = denoising_batch_result["target"]

        self.assertTensorEqual(expected_src, generated_src)
        self.assertTensorEqual(expected_tgt, tgt_tokens)

    def test_noising_dataset_without_eos(self):
        """
        Similar to test noising dataset with eos except that we have to set
        use_append_eos_dataset=True so that we wrap the source dataset in the
        AppendEosDataset when using it as the target in LanguagePairDataset.
        """

        src_dict, src_tokens, _ = self._get_test_data(append_eos=False)

        # Format data for src_dataset
        src_tokens = torch.t(src_tokens)
        src_tokens_no_pad = []
        for src_sentence in src_tokens:
            src_tokens_no_pad.append(
                utils.strip_pad(tensor=src_sentence, pad=src_dict.pad())
            )
        denoising_batch_result = self._get_noising_dataset_batch(
            src_tokens_no_pad=src_tokens_no_pad,
            src_dict=src_dict,
            use_append_eos_dataset=True,
        )

        eos, pad = src_dict.eos(), src_dict.pad()

        # Generated noisy source as source
        expected_src = torch.LongTensor(
388
            [[4, 5, 10, 11, 8, 12, 13], [pad, pad, pad, 6, 8, 9, 7]]
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        )
        # Original clean source as target (right-padded)
        expected_tgt = torch.LongTensor(
            [[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]]
        )

        generated_src = denoising_batch_result["net_input"]["src_tokens"]
        tgt_tokens = denoising_batch_result["target"]

        self.assertTensorEqual(expected_src, generated_src)
        self.assertTensorEqual(expected_tgt, tgt_tokens)

    def assertTensorEqual(self, t1, t2):
        self.assertEqual(t1.size(), t2.size(), "size mismatch")
        self.assertEqual(t1.ne(t2).long().sum(), 0)

405

406
if __name__ == "__main__":
407
    unittest.main()