test_basic_casts.py 8.34 KB
Newer Older
Carl Case's avatar
Carl Case committed
1
2
3
4
5
6
7
8
9
10
import unittest

import functools as ft
import itertools as it

from apex import amp
import torch
from torch import nn
import torch.nn.functional as F

Michael Carilli's avatar
Michael Carilli committed
11
from utils import common_init, HALF, FLOAT,\
12
    ALWAYS_HALF, ALWAYS_BFLOAT16, ALWAYS_FLOAT, MATCH_INPUT
Carl Case's avatar
Carl Case committed
13

14
15
from apex.testing.common_utils import skipIfRocm

Carl Case's avatar
Carl Case committed
16
17
18
19
20
21
22
23
def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
    for fn, typ in it.product(fns, expected.keys()):
        x = torch.randn(input_shape, dtype=typ).requires_grad_()
        y = fn(x)
        test_case.assertEqual(y.type(), expected[typ])
        if test_backward:
            y.float().sum().backward()
            test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ])
Carl Case's avatar
Carl Case committed
24

25
26
class _TestBasicCasts(unittest.TestCase):
    def _test_linear(self, expected):
Carl Case's avatar
Carl Case committed
27
28
        m = nn.Linear(self.h, self.h)
        f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
29
        run_layer_test(self, [m, f], expected, (self.b, self.h))
Carl Case's avatar
Carl Case committed
30

31
    def _test_conv2d(self, expected):
Carl Case's avatar
Carl Case committed
32
33
        m = nn.Conv2d(self.c, self.c, self.k)
        f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias)
34
        run_layer_test(self, [m, f], expected, (self.b, self.c, self.h, self.h))
Carl Case's avatar
Carl Case committed
35

36
    def _test_softmax(self, expected):
Carl Case's avatar
Carl Case committed
37
38
        m = nn.Softmax(dim=1)
        f = ft.partial(F.softmax, dim=1)
39
        run_layer_test(self, [m, f], expected, (self.b, self.h))
Carl Case's avatar
Carl Case committed
40

41
    def _test_group_norm(self, expected):
Carl Case's avatar
Carl Case committed
42
        m = nn.GroupNorm(num_groups=4, num_channels=self.c)
43
        run_layer_test(self, [m], expected, (self.b, self.c, self.h, self.h))
Carl Case's avatar
Carl Case committed
44

45
    def _test_mse_loss(self, expected):
Carl Case's avatar
Carl Case committed
46
47
48
49
50
        shape = (self.b, self.h)
        target = torch.randn(shape)
        mod = nn.MSELoss()
        m = lambda x: mod(x, target)
        f = ft.partial(F.mse_loss, target=target)
51
        run_layer_test(self, [m], expected, shape)
Carl Case's avatar
Carl Case committed
52

53
54
    def _test_relu(self, expected):
        run_layer_test(self, [nn.ReLU(), F.relu], expected, (self.b, self.h))
Carl Case's avatar
Carl Case committed
55

56
    def _test_batch_norm(self, expected):
Carl Case's avatar
Carl Case committed
57
58
59
        m = nn.BatchNorm2d(num_features=self.c)
        f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
                       weight=m.weight, bias=m.bias, training=True)
60
        run_layer_test(self, [m], expected, (self.b, self.c, self.h, self.h))
Carl Case's avatar
Carl Case committed
61
62
63
64
65

        # Test forward-only for BN inference
        m.eval()
        f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
                       weight=m.weight, bias=m.bias, training=False)
66
        run_layer_test(self, [m, f], expected, (self.b, self.c, self.h, self.h),
Carl Case's avatar
Carl Case committed
67
68
                            test_backward=False)

69
70
71
72
73
74
75
class TestBasicCastsHalf(_TestBasicCasts):
    def setUp(self):
        self.handle = amp.init(enabled=True, patch_type=torch.half)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()
Hubert Lu's avatar
Hubert Lu committed
76
    
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    def test_linear_is_half(self):
        self._test_linear(ALWAYS_HALF)

    def test_conv2d_is_half(self):
        self._test_conv2d(ALWAYS_HALF)

    def test_softmax_is_float(self):
        self._test_softmax(ALWAYS_FLOAT)

    def test_group_norm_is_float(self):
        self._test_group_norm(ALWAYS_FLOAT)

    def test_mse_loss_is_float(self):
        self._test_mse_loss(ALWAYS_FLOAT)

    def test_relu_is_match(self):
        self._test_relu(MATCH_INPUT)

    def test_batch_norm_is_match(self):
        self._test_batch_norm(MATCH_INPUT)

class TestBasicCastsBFloat16(_TestBasicCasts):
    def setUp(self):
        self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

106
    @skipIfRocm
107
108
109
    def test_linear_is_bfloat16(self):
        self._test_linear(ALWAYS_BFLOAT16)

110
    @skipIfRocm
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    def test_conv2d_is_bfloat16(self):
        self._test_conv2d(ALWAYS_BFLOAT16)

    def test_softmax_is_float(self):
        self._test_softmax(ALWAYS_FLOAT)

    def test_group_norm_is_float(self):
        self._test_group_norm(ALWAYS_FLOAT)

    def test_mse_loss_is_float(self):
        self._test_mse_loss(ALWAYS_FLOAT)

    def test_relu_is_match(self):
        self._test_relu(MATCH_INPUT)

    def test_batch_norm_is_match(self):
        self._test_batch_norm(MATCH_INPUT)

Carl Case's avatar
Carl Case committed
129
130
class TestBannedMethods(unittest.TestCase):
    def setUp(self):
131
        self.handle = amp.init(enabled=True, patch_type=torch.half)
Carl Case's avatar
Carl Case committed
132
133
134
135
136
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

137
    def bce_common(self, assertion, dtype=torch.half):
Carl Case's avatar
Carl Case committed
138
        shape = (self.b, self.h)
Carl Case's avatar
Carl Case committed
139
        target = torch.rand(shape)
Carl Case's avatar
Carl Case committed
140
141
142
143
        mod = nn.BCELoss()
        m = lambda x: mod(x, target)
        f = ft.partial(F.binary_cross_entropy, target=target)
        for fn in [m, f]:
144
            x = torch.rand(shape, dtype=dtype)
Carl Case's avatar
Carl Case committed
145
146
147
148
            assertion(fn, x)

    def test_bce_raises_by_default(self):
        assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x)
149
150
151
152
153
154
        self.bce_common(assertion, dtype=torch.half)

        # handle with bfloat16 as patch_type
        self.handle._deactivate()
        self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
        self.bce_common(assertion, dtype=torch.bfloat16)
Carl Case's avatar
Carl Case committed
155
156
157

    def test_bce_is_float_with_allow_banned(self):
        self.handle._deactivate()
158
        self.handle = amp.init(enabled=True, allow_banned=True, patch_type=torch.half)
Carl Case's avatar
Carl Case committed
159
        assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
160
        self.bce_common(assertion, dtype=torch.half)
Carl Case's avatar
Carl Case committed
161

162
        # handle with bfloat16 as patch_type
Carl Case's avatar
Carl Case committed
163
        self.handle._deactivate()
164
165
        self.handle = amp.init(enabled=True, allow_banned=True, patch_type=torch.bfloat16)
        self.bce_common(assertion, dtype=torch.bfloat16)
Carl Case's avatar
Carl Case committed
166

167
168
class _TestTensorCasts(unittest.TestCase):
    def _test_matmul_method(self, expected):
Carl Case's avatar
Carl Case committed
169
170
171
        other = torch.randn(self.h, self.h)
        lhs = lambda x: x.matmul(other)
        rhs = lambda x: other.matmul(x)
172
        run_layer_test(self, [lhs, rhs], expected, (self.h, self.h))
Carl Case's avatar
Carl Case committed
173

174
    def _test_matmul_op(self, expected):
Carl Case's avatar
Carl Case committed
175
176
177
        other = torch.randn(self.h, self.h)
        lhs = lambda x: x @ other
        rhs = lambda x: other @ x
178
        run_layer_test(self, [lhs, rhs], expected, (self.h, self.h))
Carl Case's avatar
Carl Case committed
179

180
    def _test_pow_method(self, expected):
Carl Case's avatar
Carl Case committed
181
        fn = lambda x: x.pow(2.)
182
        run_layer_test(self, [fn], expected, (self.b, self.h))
Carl Case's avatar
Carl Case committed
183

184
    def _test_pow_op(self, expected):
Carl Case's avatar
Carl Case committed
185
        fn = lambda x: x ** 2.
186
        run_layer_test(self, [fn], expected, (self.b, self.h))
Carl Case's avatar
Carl Case committed
187

188
    def _test_cpu(self, expected):
Carl Case's avatar
Carl Case committed
189
        fn = lambda x: x.cpu()
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        run_layer_test(self, [fn], expected, (self.b, self.h))

    def _test_sum(self, expected):
        fn = lambda x: x.sum()
        run_layer_test(self, [fn], expected, (self.b, self.h))

    # TODO: maybe more tests on disabled casting?

class TestTensorCastsHalf(_TestTensorCasts):
    def setUp(self):
        self.handle = amp.init(enabled=True, patch_type=torch.half)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

    def test_matmul_method_is_half(self):
        self._test_matmul_method(ALWAYS_HALF)

    def test_matmul_op_is_half(self):
        self._test_matmul_op(ALWAYS_HALF)

    def test_pow_method_is_float(self):
        self._test_pow_method(ALWAYS_FLOAT)

    def test_pow_op_is_float(self):
        self._test_pow_op(ALWAYS_FLOAT)

    def test_cpu_is_float(self):
Carl Case's avatar
Carl Case committed
219
220
        always_cpu_float = {torch.float: 'torch.FloatTensor',
                            torch.half: 'torch.FloatTensor'}
221
        self._test_cpu(always_cpu_float)
Carl Case's avatar
Carl Case committed
222
223

    def test_sum_is_float(self):
224
225
226
227
228
229
230
231
232
233
        self._test_sum(ALWAYS_FLOAT)

class TestTensorCastsBFloat16(_TestTensorCasts):
    def setUp(self):
        self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

234
    @skipIfRocm
235
236
237
    def test_matmul_method_is_bfloat16(self):
        self._test_matmul_method(ALWAYS_BFLOAT16)

238
    @skipIfRocm
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    def test_matmul_op_is_bfloat16(self):
        self._test_matmul_op(ALWAYS_BFLOAT16)

    def test_pow_method_is_float(self):
        self._test_pow_method(ALWAYS_FLOAT)

    def test_pow_op_is_float(self):
        self._test_pow_op(ALWAYS_FLOAT)

    def test_cpu_is_float(self):
        always_cpu_float = {torch.float: 'torch.FloatTensor',
                            torch.bfloat16: 'torch.FloatTensor'}
        self._test_cpu(always_cpu_float)

    def test_sum_is_float(self):
        self._test_sum(ALWAYS_FLOAT)
Carl Case's avatar
Carl Case committed
255
256
257
258


if __name__ == '__main__':
    unittest.main()