test_utils.py 1.78 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#

import unittest

import torch
from torch.autograd import Variable

from fairseq import utils


class TestUtils(unittest.TestCase):

    def test_convert_padding_direction(self):
        pad = 1
        left_pad = torch.LongTensor([
            [2, 3, 4, 5, 6],
            [1, 7, 8, 9, 10],
            [1, 1, 1, 11, 12],
        ])
        right_pad = torch.LongTensor([
            [2, 3, 4, 5, 6],
            [7, 8, 9, 10, 1],
            [11, 12, 1, 1, 1],
        ])
        lengths = torch.LongTensor([5, 4, 2])

        self.assertAlmostEqual(
            right_pad,
            utils.convert_padding_direction(
                left_pad,
                lengths,
                pad,
                left_to_right=True,
            ),
        )
        self.assertAlmostEqual(
            left_pad,
            utils.convert_padding_direction(
                right_pad,
                lengths,
                pad,
                right_to_left=True,
            ),
        )

    def test_make_variable(self):
        t = [{'k': torch.rand(5, 5)}]

        v = utils.make_variable(t)[0]['k']
        self.assertTrue(isinstance(v, Variable))
        self.assertFalse(v.data.is_cuda)

        v = utils.make_variable(t, cuda=True)[0]['k']
        self.assertEqual(v.data.is_cuda, torch.cuda.is_available())

    def assertAlmostEqual(self, t1, t2):
        self.assertEqual(t1.size(), t2.size(), "size mismatch")
Myle Ott's avatar
Myle Ott committed
64
        self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4)
Myle Ott's avatar
Myle Ott committed
65
66
67
68


if __name__ == '__main__':
    unittest.main()