test_noising.py 19.3 KB
Newer Older
1
# Copyright (c) Facebook, Inc. and its affiliates.
2
#
3
4
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
5
6

import unittest
7
from typing import Dict, List
8

9
import tests.utils as test_utils
10
import torch
11
12
13
from fairseq import utils
from fairseq.data import (
    Dictionary,
14
    LanguagePairDataset,
15
    TransformEosDataset,
16
17
18
    data_utils,
    noising,
)
19
20
21


class TestDataNoising(unittest.TestCase):
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    def _get_test_data_with_bpe_cont_marker(self, append_eos=True):
        """
        Args:
            append_eos: if True, each input sentence in the source tokens tensor
                will have an EOS appended to the end.

        Returns:
            vocabs: BPE vocab with continuation markers as suffixes to denote
                non-end of word tokens. This is the standard BPE format used in
                fairseq's preprocessing.
            x: input tensor containing numberized source tokens, with EOS at the
                end if append_eos is true
            src_lengths: and source lengths.
        """
        vocab = Dictionary()
        vocab.add_symbol("he@@")
        vocab.add_symbol("llo")
        vocab.add_symbol("how")
        vocab.add_symbol("are")
        vocab.add_symbol("y@@")
        vocab.add_symbol("ou")
        vocab.add_symbol("n@@")
        vocab.add_symbol("ew")
        vocab.add_symbol("or@@")
        vocab.add_symbol("k")

        src_tokens = [
            ["he@@", "llo", "n@@", "ew", "y@@", "or@@", "k"],
            ["how", "are", "y@@", "ou"],
        ]
        x, src_lengths = x, src_lengths = self._convert_src_tokens_to_tensor(
            vocab=vocab, src_tokens=src_tokens, append_eos=append_eos
        )
        return vocab, x, src_lengths

    def _get_test_data_with_bpe_end_marker(self, append_eos=True):
        """
        Args:
            append_eos: if True, each input sentence in the source tokens tensor
                will have an EOS appended to the end.

        Returns:
            vocabs: BPE vocab with end-of-word markers as suffixes to denote
                tokens at the end of a word. This is an alternative to fairseq's
                standard preprocessing framework and is not generally supported
                within fairseq.
            x: input tensor containing numberized source tokens, with EOS at the
                end if append_eos is true
            src_lengths: and source lengths.
        """
        vocab = Dictionary()
        vocab.add_symbol("he")
        vocab.add_symbol("llo_EOW")
        vocab.add_symbol("how_EOW")
        vocab.add_symbol("are_EOW")
        vocab.add_symbol("y")
        vocab.add_symbol("ou_EOW")
        vocab.add_symbol("n")
        vocab.add_symbol("ew_EOW")
        vocab.add_symbol("or")
        vocab.add_symbol("k_EOW")

        src_tokens = [
            ["he", "llo_EOW", "n", "ew_EOW", "y", "or", "k_EOW"],
            ["how_EOW", "are_EOW", "y", "ou_EOW"],
        ]
        x, src_lengths = x, src_lengths = self._convert_src_tokens_to_tensor(
            vocab=vocab, src_tokens=src_tokens, append_eos=append_eos
        )
        return vocab, x, src_lengths

    def _get_test_data_with_word_vocab(self, append_eos=True):
        """
        Args:
            append_eos: if True, each input sentence in the source tokens tensor
                will have an EOS appended to the end.

        Returns:
            vocabs: word vocab
            x: input tensor containing numberized source tokens, with EOS at the
                end if append_eos is true
            src_lengths: and source lengths.
        """
105
        vocab = Dictionary()
106

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        vocab.add_symbol("hello")
        vocab.add_symbol("how")
        vocab.add_symbol("are")
        vocab.add_symbol("you")
        vocab.add_symbol("new")
        vocab.add_symbol("york")
        src_tokens = [
            ["hello", "new", "york", "you"],
            ["how", "are", "you", "new", "york"],
        ]
        x, src_lengths = self._convert_src_tokens_to_tensor(
            vocab=vocab, src_tokens=src_tokens, append_eos=append_eos
        )
        return vocab, x, src_lengths

    def _convert_src_tokens_to_tensor(
        self, vocab: Dictionary, src_tokens: List[List[str]], append_eos: bool
    ):
125
        src_len = [len(x) for x in src_tokens]
126
127
128
129
130
        # If we have to append EOS, we include EOS in counting src length
        if append_eos:
            src_len = [length + 1 for length in src_len]

        x = torch.LongTensor(len(src_tokens), max(src_len)).fill_(vocab.pad())
131
132
133
        for i in range(len(src_tokens)):
            for j in range(len(src_tokens[i])):
                x[i][j] = vocab.index(src_tokens[i][j])
134
135
            if append_eos:
                x[i][j + 1] = vocab.eos()
136
137

        x = x.transpose(1, 0)
138
        return x, torch.LongTensor(src_len)
139
140

    def assert_eos_at_end(self, x, x_len, eos):
141
        """Asserts last token of every sentence in x is EOS """
142
143
        for i in range(len(x_len)):
            self.assertEqual(
144
                x[x_len[i] - 1][i],
145
                eos,
146
                (
Myle Ott's avatar
Myle Ott committed
147
148
                    "Expected eos (token id {eos}) at the end of sentence {i} "
                    "but got {other} instead"
149
                ).format(i=i, eos=eos, other=x[i][-1]),
150
151
152
153
154
155
156
157
158
159
            )

    def assert_word_dropout_correct(self, x, x_noised, x_len, l_noised):
        # Expect only the first word (2 bpe tokens) of the first example
        # was dropped out
        self.assertEqual(x_len[0] - 2, l_noised[0])
        for i in range(l_noised[0]):
            self.assertEqual(x_noised[i][0], x[i + 2][0])

    def test_word_dropout_with_eos(self):
160
        vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True)
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

        with data_utils.numpy_seed(1234):
            noising_gen = noising.WordDropout(vocab)
            x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
            self.assert_word_dropout_correct(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
            )
            self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())

    def assert_word_blanking_correct(self, x, x_noised, x_len, l_noised, unk):
        # Expect only the first word (2 bpe tokens) of the first example
        # was blanked out
        self.assertEqual(x_len[0], l_noised[0])
        for i in range(l_noised[0]):
            if i < 2:
                self.assertEqual(x_noised[i][0], unk)
            else:
                self.assertEqual(x_noised[i][0], x[i][0])

    def test_word_blank_with_eos(self):
181
        vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True)
182
183
184
185
186
187
188
189
190

        with data_utils.numpy_seed(1234):
            noising_gen = noising.WordDropout(vocab)
            x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
            self.assert_word_blanking_correct(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
            )
            self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())

191
192
193
194
195
196
197
198
199
200
201
    def generate_unchanged_shuffle_map(self, length):
        return {i: i for i in range(length)}

    def assert_word_shuffle_matches_expected(
        self,
        x,
        x_len,
        max_shuffle_distance: int,
        vocab: Dictionary,
        expected_shufle_maps: List[Dict[int, int]],
        expect_eos_at_end: bool,
202
        bpe_end_marker=None,
203
    ):
204
        """
205
206
207
208
209
210
211
212
213
214
215
216
        This verifies that with a given x, x_len, max_shuffle_distance, and
        vocab, we get the expected shuffle result.

        Args:
            x: Tensor of shape (T x B) = (sequence_length, batch_size)
            x_len: Tensor of length B = batch_size
            max_shuffle_distance: arg to pass to noising
            expected_shuffle_maps: List[mapping] where mapping is a
                Dict[old_index, new_index], mapping x's elements from their
                old positions in x to their new positions in x.
            expect_eos_at_end: if True, check the output to make sure there is
                an EOS at the end.
217
218
            bpe_end_marker: str denoting the BPE end token. If this is not None, we
                set the BPE cont token to None in the noising classes.
219
        """
220
221
222
223
        bpe_cont_marker = None
        if bpe_end_marker is None:
            bpe_cont_marker = "@@"

224
        with data_utils.numpy_seed(1234):
225
226
227
            word_shuffle = noising.WordShuffle(
                vocab, bpe_cont_marker=bpe_cont_marker, bpe_end_marker=bpe_end_marker
            )
228
229
            x_noised, l_noised = word_shuffle.noising(
                x, x_len, max_shuffle_distance=max_shuffle_distance
230
231
            )

232
233
234
235
236
237
238
239
240
241
242
243
        # For every example, we have a different expected shuffle map. We check
        # that each example is shuffled as expected according to each
        # corresponding shuffle map.
        for i in range(len(expected_shufle_maps)):
            shuffle_map = expected_shufle_maps[i]
            for k, v in shuffle_map.items():
                self.assertEqual(x[k][i], x_noised[v][i])

        # Shuffling should not affect the length of each example
        for pre_shuffle_length, post_shuffle_length in zip(x_len, l_noised):
            self.assertEqual(pre_shuffle_length, post_shuffle_length)
        if expect_eos_at_end:
244
245
            self.assert_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())

246
    def test_word_shuffle_with_eos(self):
247
        vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=True)
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

        # Assert word shuffle with max shuffle distance 0 causes input to be
        # unchanged
        self.assert_word_shuffle_matches_expected(
            x=x,
            x_len=x_len,
            max_shuffle_distance=0,
            vocab=vocab,
            expected_shufle_maps=[
                self.generate_unchanged_shuffle_map(example_len)
                for example_len in x_len
            ],
            expect_eos_at_end=True,
        )

        # Assert word shuffle with max shuffle distance 3 matches our expected
        # shuffle order
        self.assert_word_shuffle_matches_expected(
            x=x,
            x_len=x_len,
            vocab=vocab,
            max_shuffle_distance=3,
            expected_shufle_maps=[
                self.generate_unchanged_shuffle_map(x_len[0]),
                {0: 0, 1: 3, 2: 1, 3: 2},
            ],
            expect_eos_at_end=True,
        )

277
    def test_word_shuffle_with_eos_nonbpe(self):
278
279
        """The purpose of this is to test shuffling logic with word vocabs"""
        vocab, x, x_len = self._get_test_data_with_word_vocab(append_eos=True)
280

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        # Assert word shuffle with max shuffle distance 0 causes input to be
        # unchanged
        self.assert_word_shuffle_matches_expected(
            x=x,
            x_len=x_len,
            max_shuffle_distance=0,
            vocab=vocab,
            expected_shufle_maps=[
                self.generate_unchanged_shuffle_map(example_len)
                for example_len in x_len
            ],
            expect_eos_at_end=True,
        )

        # Assert word shuffle with max shuffle distance 3 matches our expected
        # shuffle order
        self.assert_word_shuffle_matches_expected(
            x=x,
            x_len=x_len,
            vocab=vocab,
            max_shuffle_distance=3,
            expected_shufle_maps=[
                {0: 0, 1: 1, 2: 3, 3: 2},
                {0: 0, 1: 2, 2: 1, 3: 3, 4: 4},
            ],
            expect_eos_at_end=True,
        )
308

309
310
    def test_word_shuffle_without_eos(self):
        """Same result as word shuffle with eos except no EOS at end"""
311
        vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False)
312

313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        # Assert word shuffle with max shuffle distance 0 causes input to be
        # unchanged
        self.assert_word_shuffle_matches_expected(
            x=x,
            x_len=x_len,
            max_shuffle_distance=0,
            vocab=vocab,
            expected_shufle_maps=[
                self.generate_unchanged_shuffle_map(example_len)
                for example_len in x_len
            ],
            expect_eos_at_end=False,
        )

        # Assert word shuffle with max shuffle distance 3 matches our expected
        # shuffle order
        self.assert_word_shuffle_matches_expected(
            x=x,
            x_len=x_len,
            vocab=vocab,
            max_shuffle_distance=3,
            expected_shufle_maps=[
                self.generate_unchanged_shuffle_map(x_len[0]),
                {0: 0, 1: 3, 2: 1, 3: 2},
            ],
            expect_eos_at_end=False,
        )
340

341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    def test_word_shuffle_without_eos_with_bpe_end_marker(self):
        """Same result as word shuffle without eos except using BPE end token"""
        vocab, x, x_len = self._get_test_data_with_bpe_end_marker(append_eos=False)

        # Assert word shuffle with max shuffle distance 0 causes input to be
        # unchanged
        self.assert_word_shuffle_matches_expected(
            x=x,
            x_len=x_len,
            max_shuffle_distance=0,
            vocab=vocab,
            expected_shufle_maps=[
                self.generate_unchanged_shuffle_map(example_len)
                for example_len in x_len
            ],
            expect_eos_at_end=False,
            bpe_end_marker="_EOW",
        )

        # Assert word shuffle with max shuffle distance 3 matches our expected
        # shuffle order
        self.assert_word_shuffle_matches_expected(
            x=x,
            x_len=x_len,
            vocab=vocab,
            max_shuffle_distance=3,
            expected_shufle_maps=[
                self.generate_unchanged_shuffle_map(x_len[0]),
                {0: 0, 1: 3, 2: 1, 3: 2},
            ],
            expect_eos_at_end=False,
            bpe_end_marker="_EOW",
        )

375
    def assert_no_eos_at_end(self, x, x_len, eos):
376
        """Asserts that the last token of each sentence in x is not EOS """
377
378
        for i in range(len(x_len)):
            self.assertNotEqual(
379
                x[x_len[i] - 1][i],
380
                eos,
381
                "Expected no eos (token id {eos}) at the end of sentence {i}.".format(
382
383
                    eos=eos, i=i
                ),
384
385
386
            )

    def test_word_dropout_without_eos(self):
387
        """Same result as word dropout with eos except no EOS at end"""
388
        vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False)
389
390
391
392

        with data_utils.numpy_seed(1234):
            noising_gen = noising.WordDropout(vocab)
            x_noised, l_noised = noising_gen.noising(x, x_len, 0.2)
393
394
395
396
            self.assert_word_dropout_correct(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised
            )
            self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())
397

398
    def test_word_blank_without_eos(self):
399
        """Same result as word blank with eos except no EOS at end"""
400
        vocab, x, x_len = self._get_test_data_with_bpe_cont_marker(append_eos=False)
401
402
403
404

        with data_utils.numpy_seed(1234):
            noising_gen = noising.WordDropout(vocab)
            x_noised, l_noised = noising_gen.noising(x, x_len, 0.2, vocab.unk())
405
406
407
408
409
            self.assert_word_blanking_correct(
                x=x, x_noised=x_noised, x_len=x_len, l_noised=l_noised, unk=vocab.unk()
            )
            self.assert_no_eos_at_end(x=x_noised, x_len=l_noised, eos=vocab.eos())

410
    def _get_noising_dataset_batch(
411
        self, src_tokens_no_pad, src_dict, append_eos_to_tgt=False,
412
413
414
    ):
        """
        Constructs a NoisingDataset and the corresponding
415
416
417
418
        ``LanguagePairDataset(NoisingDataset(src), src)``. If
        *append_eos_to_tgt* is True, wrap the source dataset in
        :class:`TransformEosDataset` to append EOS to the clean source when
        using it as the target.
419
420
421
422
423
424
425
426
427
428
429
430
431
432
        """
        src_dataset = test_utils.TestDataset(data=src_tokens_no_pad)

        noising_dataset = noising.NoisingDataset(
            src_dataset=src_dataset,
            src_dict=src_dict,
            seed=1234,
            max_word_shuffle_distance=3,
            word_dropout_prob=0.2,
            word_blanking_prob=0.2,
            noising_class=noising.UnsupervisedMTNoising,
        )
        tgt = src_dataset
        language_pair_dataset = LanguagePairDataset(
433
            src=noising_dataset, tgt=tgt, src_sizes=None, src_dict=src_dict
434
        )
435
436
437
438
        language_pair_dataset = TransformEosDataset(
            language_pair_dataset, src_dict.eos(),
            append_eos_to_tgt=append_eos_to_tgt,
        )
439
440
441
442
443
444
445
446
447
448

        dataloader = torch.utils.data.DataLoader(
            dataset=language_pair_dataset,
            batch_size=2,
            collate_fn=language_pair_dataset.collater,
        )
        denoising_batch_result = next(iter(dataloader))
        return denoising_batch_result

    def test_noising_dataset_with_eos(self):
449
450
451
        src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker(
            append_eos=True
        )
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482

        # Format data for src_dataset
        src_tokens = torch.t(src_tokens)
        src_tokens_no_pad = []
        for src_sentence in src_tokens:
            src_tokens_no_pad.append(
                utils.strip_pad(tensor=src_sentence, pad=src_dict.pad())
            )
        denoising_batch_result = self._get_noising_dataset_batch(
            src_tokens_no_pad=src_tokens_no_pad, src_dict=src_dict
        )

        eos, pad = src_dict.eos(), src_dict.pad()

        # Generated noisy source as source
        expected_src = torch.LongTensor(
            [[4, 5, 10, 11, 8, 12, 13, eos], [pad, pad, pad, 6, 8, 9, 7, eos]]
        )
        # Original clean source as target (right-padded)
        expected_tgt = torch.LongTensor(
            [[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]]
        )
        generated_src = denoising_batch_result["net_input"]["src_tokens"]
        tgt_tokens = denoising_batch_result["target"]

        self.assertTensorEqual(expected_src, generated_src)
        self.assertTensorEqual(expected_tgt, tgt_tokens)

    def test_noising_dataset_without_eos(self):
        """
        Similar to test noising dataset with eos except that we have to set
483
        *append_eos_to_tgt* to ``True``.
484
485
        """

486
487
488
        src_dict, src_tokens, _ = self._get_test_data_with_bpe_cont_marker(
            append_eos=False
        )
489
490
491
492
493
494
495
496
497
498
499

        # Format data for src_dataset
        src_tokens = torch.t(src_tokens)
        src_tokens_no_pad = []
        for src_sentence in src_tokens:
            src_tokens_no_pad.append(
                utils.strip_pad(tensor=src_sentence, pad=src_dict.pad())
            )
        denoising_batch_result = self._get_noising_dataset_batch(
            src_tokens_no_pad=src_tokens_no_pad,
            src_dict=src_dict,
500
            append_eos_to_tgt=True,
501
502
503
504
505
506
        )

        eos, pad = src_dict.eos(), src_dict.pad()

        # Generated noisy source as source
        expected_src = torch.LongTensor(
507
            [[4, 5, 10, 11, 8, 12, 13], [pad, pad, pad, 6, 8, 9, 7]]
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
        )
        # Original clean source as target (right-padded)
        expected_tgt = torch.LongTensor(
            [[4, 5, 10, 11, 8, 12, 13, eos], [6, 7, 8, 9, eos, pad, pad, pad]]
        )

        generated_src = denoising_batch_result["net_input"]["src_tokens"]
        tgt_tokens = denoising_batch_result["target"]

        self.assertTensorEqual(expected_src, generated_src)
        self.assertTensorEqual(expected_tgt, tgt_tokens)

    def assertTensorEqual(self, t1, t2):
        self.assertEqual(t1.size(), t2.size(), "size mismatch")
        self.assertEqual(t1.ne(t2).long().sum(), 0)

524

525
if __name__ == "__main__":
526
    unittest.main()