test_basic_casts.py 4.97 KB
Newer Older
Carl Case's avatar
Carl Case committed
1
2
3
4
5
6
7
8
9
10
import unittest

import functools as ft
import itertools as it

from apex import amp
import torch
from torch import nn
import torch.nn.functional as F

Michael Carilli's avatar
Michael Carilli committed
11
from utils import common_init, HALF, FLOAT,\
Carl Case's avatar
Carl Case committed
12
13
14
15
16
17
18
19
20
21
    ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT

def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
    for fn, typ in it.product(fns, expected.keys()):
        x = torch.randn(input_shape, dtype=typ).requires_grad_()
        y = fn(x)
        test_case.assertEqual(y.type(), expected[typ])
        if test_backward:
            y.float().sum().backward()
            test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ])
Carl Case's avatar
Carl Case committed
22
23
24
25

class TestBasicCasts(unittest.TestCase):
    def setUp(self):
        self.handle = amp.init(enabled=True)
Carl Case's avatar
Carl Case committed
26
        common_init(self)
Carl Case's avatar
Carl Case committed
27
28
29
30
31
32
33

    def tearDown(self):
        self.handle._deactivate()

    def test_linear_is_half(self):
        m = nn.Linear(self.h, self.h)
        f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
Carl Case's avatar
Carl Case committed
34
        run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.h))
Carl Case's avatar
Carl Case committed
35
36
37
38

    def test_conv2d_is_half(self):
        m = nn.Conv2d(self.c, self.c, self.k)
        f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias)
Carl Case's avatar
Carl Case committed
39
        run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h))
Carl Case's avatar
Carl Case committed
40
41
42
43

    def test_softmax_is_float(self):
        m = nn.Softmax(dim=1)
        f = ft.partial(F.softmax, dim=1)
Carl Case's avatar
Carl Case committed
44
        run_layer_test(self, [m, f], ALWAYS_FLOAT, (self.b, self.h))
Carl Case's avatar
Carl Case committed
45
46
47

    def test_group_norm_is_float(self):
        m = nn.GroupNorm(num_groups=4, num_channels=self.c)
Carl Case's avatar
Carl Case committed
48
        run_layer_test(self, [m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h))
Carl Case's avatar
Carl Case committed
49
50
51
52
53
54
55

    def test_mse_loss_is_float(self):
        shape = (self.b, self.h)
        target = torch.randn(shape)
        mod = nn.MSELoss()
        m = lambda x: mod(x, target)
        f = ft.partial(F.mse_loss, target=target)
Carl Case's avatar
Carl Case committed
56
        run_layer_test(self, [m], ALWAYS_FLOAT, shape)
Carl Case's avatar
Carl Case committed
57
58

    def test_relu_is_match(self):
Carl Case's avatar
Carl Case committed
59
        run_layer_test(self, [nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h))
Carl Case's avatar
Carl Case committed
60
61
62
63
64

    def test_batch_norm_is_match(self):
        m = nn.BatchNorm2d(num_features=self.c)
        f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
                       weight=m.weight, bias=m.bias, training=True)
Carl Case's avatar
Carl Case committed
65
        run_layer_test(self, [m], MATCH_INPUT, (self.b, self.c, self.h, self.h))
Carl Case's avatar
Carl Case committed
66
67
68
69
70

        # Test forward-only for BN inference
        m.eval()
        f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
                       weight=m.weight, bias=m.bias, training=False)
Carl Case's avatar
Carl Case committed
71
72
73
        run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h),
                            test_backward=False)

Carl Case's avatar
Carl Case committed
74
75
76
77
78
79
80
81
82
class TestBannedMethods(unittest.TestCase):
    def setUp(self):
        self.handle = amp.init(enabled=True)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

    def bce_common(self, assertion):
Carl Case's avatar
Carl Case committed
83
        shape = (self.b, self.h)
Carl Case's avatar
Carl Case committed
84
        target = torch.rand(shape)
Carl Case's avatar
Carl Case committed
85
86
87
88
        mod = nn.BCELoss()
        m = lambda x: mod(x, target)
        f = ft.partial(F.binary_cross_entropy, target=target)
        for fn in [m, f]:
Carl Case's avatar
Carl Case committed
89
90
91
92
93
94
95
96
97
98
99
100
            x = torch.rand(shape, dtype=torch.half)
            assertion(fn, x)

    def test_bce_raises_by_default(self):
        assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x)
        self.bce_common(assertion)

    def test_bce_is_float_with_allow_banned(self):
        self.handle._deactivate()
        self.handle = amp.init(enabled=True, allow_banned=True)
        assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
        self.bce_common(assertion)
Carl Case's avatar
Carl Case committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

class TestTensorCasts(unittest.TestCase):
    def setUp(self):
        self.handle = amp.init(enabled=True)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

    def test_matmul_method_is_half(self):
        other = torch.randn(self.h, self.h)
        lhs = lambda x: x.matmul(other)
        rhs = lambda x: other.matmul(x)
        run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))

    def test_matmul_op_is_half(self):
        other = torch.randn(self.h, self.h)
        lhs = lambda x: x @ other
        rhs = lambda x: other @ x
        run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))

    def test_pow_method_is_float(self):
        fn = lambda x: x.pow(2.)
        run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))

    def test_pow_op_is_float(self):
        fn = lambda x: x ** 2.
        run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))

    def test_cpu_is_float(self):
        fn = lambda x: x.cpu()
        always_cpu_float = {torch.float: 'torch.FloatTensor',
                            torch.half: 'torch.FloatTensor'}
        run_layer_test(self, [fn], always_cpu_float, (self.b, self.h))

    def test_sum_is_float(self):
        fn = lambda x: x.sum()
        run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
Carl Case's avatar
Carl Case committed
139
140
141
142
143

    # TODO: maybe more tests on disabled casting?

if __name__ == '__main__':
    unittest.main()