test_sequence_scorer.py 3.87 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import argparse
import unittest

import torch

from fairseq.sequence_scorer import SequenceScorer

import tests.utils as test_utils


class TestSequenceScorer(unittest.TestCase):

    def test_sequence_scorer(self):
        # construct dummy dictionary
        d = test_utils.dummy_dictionary(vocab_size=2)
        self.assertEqual(d.pad(), 1)
        self.assertEqual(d.eos(), 2)
        self.assertEqual(d.unk(), 3)
        eos = d.eos()
        w1 = 4
        w2 = 5

        # construct dataloader
        data = [
            {
                'source': torch.LongTensor([w1, w2, eos]),
                'target': torch.LongTensor([w1, w2, w1, eos]),
            },
            {
                'source': torch.LongTensor([w2, eos]),
                'target': torch.LongTensor([w2, w1, eos]),
            },
            {
                'source': torch.LongTensor([w2, eos]),
                'target': torch.LongTensor([w2, eos]),
            },
        ]
        data_itr = test_utils.dummy_dataloader(data)

        # specify expected output probabilities
        args = argparse.Namespace()
        unk = 0.
        args.beam_probs = [
            # step 0:
            torch.FloatTensor([
                # eos      w1   w2
                [0.0, unk, 0.6, 0.4],  # sentence 1
                [0.0, unk, 0.4, 0.6],  # sentence 2
                [0.0, unk, 0.7, 0.3],  # sentence 3
            ]),
            # step 1:
            torch.FloatTensor([
                # eos      w1   w2
                [0.0, unk, 0.2, 0.7],  # sentence 1
                [0.0, unk, 0.8, 0.2],  # sentence 2
                [0.7, unk, 0.1, 0.2],  # sentence 3
            ]),
            # step 2:
            torch.FloatTensor([
                # eos       w1    w2
                [0.10, unk, 0.50, 0.4],  # sentence 1
                [0.15, unk, 0.15, 0.7],  # sentence 2
                [0.00, unk, 0.00, 0.0],  # sentence 3
            ]),
            # step 3:
            torch.FloatTensor([
                # eos      w1    w2
                [0.9, unk, 0.05, 0.05],  # sentence 1
                [0.0, unk, 0.00, 0.0],  # sentence 2
                [0.0, unk, 0.00, 0.0],  # sentence 3
            ]),
        ]
        expected_scores = [
            [0.6, 0.7, 0.5, 0.9],  # sentence 1
            [0.6, 0.8, 0.15],  # sentence 2
            [0.3, 0.7],  # sentence 3
        ]

Myle Ott's avatar
Myle Ott committed
86
87
88
        task = test_utils.TestTranslationTask.setup_task(args, d, d)
        model = task.build_model(args)
        scorer = SequenceScorer([model], task.target_dictionary)
Myle Ott's avatar
Myle Ott committed
89
        for id, _src, _ref, hypos in scorer.score_batched_itr(data_itr):
Myle Ott's avatar
Myle Ott committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            self.assertHypoTokens(hypos[0], data[id]['target'])
            self.assertHypoScore(hypos[0], expected_scores[id])

    def assertHypoTokens(self, hypo, tokens):
        self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))

    def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
        pos_scores = torch.FloatTensor(pos_probs).log()
        self.assertAlmostEqual(hypo['positional_scores'], pos_scores)
        self.assertEqual(pos_scores.numel(), hypo['tokens'].numel())
        score = pos_scores.sum()
        if normalized:
            score /= pos_scores.numel()**lenpen
        self.assertLess(abs(score - hypo['score']), 1e-6)

    def assertAlmostEqual(self, t1, t2):
        self.assertEqual(t1.size(), t2.size(), "size mismatch")
        self.assertLess((t1 - t2).abs().max(), 1e-4)

    def assertTensorEqual(self, t1, t2):
        self.assertEqual(t1.size(), t2.size(), "size mismatch")
        self.assertEqual(t1.ne(t2).long().sum(), 0)


if __name__ == '__main__':
    unittest.main()