test_label_smoothing.py 1.18 KB
Newer Older
Sergey Edunov's avatar
Sergey Edunov committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#

import torch
import unittest
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropy
from torch.autograd import Variable, gradcheck


torch.set_default_tensor_type('torch.DoubleTensor')


class TestLabelSmoothing(unittest.TestCase):

    def test_label_smoothing(self):
        input = Variable(torch.randn(3, 5), requires_grad=True)
        idx = torch.rand(3) * 4
        target = Variable(idx.long())
        criterion = LabelSmoothedCrossEntropy()
        self.assertTrue(gradcheck(
            lambda x, y: criterion.apply(x, y, 0.1, 2, None), (input, target)
        ))
        weights = torch.ones(5)
        weights[2] = 0
        self.assertTrue(gradcheck(lambda x, y: criterion.apply(x, y, 0.1, None, weights), (input, target)))
        self.assertTrue(gradcheck(lambda x, y: criterion.apply(x, y, 0.1, None, None), (input, target)))


if __name__ == '__main__':
    unittest.main()