test_basic_casts.py 8.88 KB
Newer Older
Carl Case's avatar
Carl Case committed
1
2
3
4
5
6
7
8
9
10
import unittest

import functools as ft
import itertools as it

from apex import amp
import torch
from torch import nn
import torch.nn.functional as F

Michael Carilli's avatar
Michael Carilli committed
11
from utils import common_init, HALF, FLOAT,\
12
    ALWAYS_HALF, ALWAYS_BFLOAT16, ALWAYS_FLOAT, MATCH_INPUT
Carl Case's avatar
Carl Case committed
13

14
15
from apex.testing.common_utils import skipIfRocm

Carl Case's avatar
Carl Case committed
16
17
18
19
20
21
22
23
def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
    for fn, typ in it.product(fns, expected.keys()):
        x = torch.randn(input_shape, dtype=typ).requires_grad_()
        y = fn(x)
        test_case.assertEqual(y.type(), expected[typ])
        if test_backward:
            y.float().sum().backward()
            test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ])
Carl Case's avatar
Carl Case committed
24

25
26
class _TestBasicCasts(unittest.TestCase):
    def _test_linear(self, expected):
Carl Case's avatar
Carl Case committed
27
28
        m = nn.Linear(self.h, self.h)
        f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
29
        run_layer_test(self, [m, f], expected, (self.b, self.h))
Carl Case's avatar
Carl Case committed
30

31
    def _test_conv2d(self, expected):
Carl Case's avatar
Carl Case committed
32
33
        m = nn.Conv2d(self.c, self.c, self.k)
        f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias)
34
        run_layer_test(self, [m, f], expected, (self.b, self.c, self.h, self.h))
Carl Case's avatar
Carl Case committed
35

36
    def _test_softmax(self, expected):
Carl Case's avatar
Carl Case committed
37
38
        m = nn.Softmax(dim=1)
        f = ft.partial(F.softmax, dim=1)
39
        run_layer_test(self, [m, f], expected, (self.b, self.h))
Carl Case's avatar
Carl Case committed
40

41
    def _test_group_norm(self, expected):
Carl Case's avatar
Carl Case committed
42
        m = nn.GroupNorm(num_groups=4, num_channels=self.c)
43
        run_layer_test(self, [m], expected, (self.b, self.c, self.h, self.h))
Carl Case's avatar
Carl Case committed
44

45
    def _test_mse_loss(self, expected):
Carl Case's avatar
Carl Case committed
46
47
48
49
50
        shape = (self.b, self.h)
        target = torch.randn(shape)
        mod = nn.MSELoss()
        m = lambda x: mod(x, target)
        f = ft.partial(F.mse_loss, target=target)
51
        run_layer_test(self, [m], expected, shape)
Carl Case's avatar
Carl Case committed
52

53
54
    def _test_relu(self, expected):
        run_layer_test(self, [nn.ReLU(), F.relu], expected, (self.b, self.h))
Carl Case's avatar
Carl Case committed
55

56
    def _test_batch_norm(self, expected):
Carl Case's avatar
Carl Case committed
57
58
59
        m = nn.BatchNorm2d(num_features=self.c)
        f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
                       weight=m.weight, bias=m.bias, training=True)
60
        run_layer_test(self, [m], expected, (self.b, self.c, self.h, self.h))
Carl Case's avatar
Carl Case committed
61
62
63
64
65

        # Test forward-only for BN inference
        m.eval()
        f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
                       weight=m.weight, bias=m.bias, training=False)
66
        run_layer_test(self, [m, f], expected, (self.b, self.c, self.h, self.h),
Carl Case's avatar
Carl Case committed
67
68
                            test_backward=False)

69
70
71
72
73
74
75
class TestBasicCastsHalf(_TestBasicCasts):
    def setUp(self):
        self.handle = amp.init(enabled=True, patch_type=torch.half)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()
Hubert Lu's avatar
Hubert Lu committed
76
77
    
    @unittest.skip("The failing unit test is introduced by a PyTorch commit sometime in between rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 and 2021/12/01. Same error is also observed on CUDA. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/62")
78
79
80
    def test_linear_is_half(self):
        self._test_linear(ALWAYS_HALF)

Hubert Lu's avatar
Hubert Lu committed
81
    @unittest.skip("The failing unit test is introduced by a PyTorch commit sometime in between rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 and 2021/12/01. Same error is also observed on CUDA. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/62")
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    def test_conv2d_is_half(self):
        self._test_conv2d(ALWAYS_HALF)

    def test_softmax_is_float(self):
        self._test_softmax(ALWAYS_FLOAT)

    def test_group_norm_is_float(self):
        self._test_group_norm(ALWAYS_FLOAT)

    def test_mse_loss_is_float(self):
        self._test_mse_loss(ALWAYS_FLOAT)

    def test_relu_is_match(self):
        self._test_relu(MATCH_INPUT)

    def test_batch_norm_is_match(self):
        self._test_batch_norm(MATCH_INPUT)

class TestBasicCastsBFloat16(_TestBasicCasts):
    def setUp(self):
        self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

108
    @skipIfRocm
109
110
111
    def test_linear_is_bfloat16(self):
        self._test_linear(ALWAYS_BFLOAT16)

112
    @skipIfRocm
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    def test_conv2d_is_bfloat16(self):
        self._test_conv2d(ALWAYS_BFLOAT16)

    def test_softmax_is_float(self):
        self._test_softmax(ALWAYS_FLOAT)

    def test_group_norm_is_float(self):
        self._test_group_norm(ALWAYS_FLOAT)

    def test_mse_loss_is_float(self):
        self._test_mse_loss(ALWAYS_FLOAT)

    def test_relu_is_match(self):
        self._test_relu(MATCH_INPUT)

    def test_batch_norm_is_match(self):
        self._test_batch_norm(MATCH_INPUT)

Carl Case's avatar
Carl Case committed
131
132
class TestBannedMethods(unittest.TestCase):
    def setUp(self):
133
        self.handle = amp.init(enabled=True, patch_type=torch.half)
Carl Case's avatar
Carl Case committed
134
135
136
137
138
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

139
    def bce_common(self, assertion, dtype=torch.half):
Carl Case's avatar
Carl Case committed
140
        shape = (self.b, self.h)
Carl Case's avatar
Carl Case committed
141
        target = torch.rand(shape)
Carl Case's avatar
Carl Case committed
142
143
144
145
        mod = nn.BCELoss()
        m = lambda x: mod(x, target)
        f = ft.partial(F.binary_cross_entropy, target=target)
        for fn in [m, f]:
146
            x = torch.rand(shape, dtype=dtype)
Carl Case's avatar
Carl Case committed
147
148
149
150
            assertion(fn, x)

    def test_bce_raises_by_default(self):
        assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x)
151
152
153
154
155
156
        self.bce_common(assertion, dtype=torch.half)

        # handle with bfloat16 as patch_type
        self.handle._deactivate()
        self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
        self.bce_common(assertion, dtype=torch.bfloat16)
Carl Case's avatar
Carl Case committed
157
158
159

    def test_bce_is_float_with_allow_banned(self):
        self.handle._deactivate()
160
        self.handle = amp.init(enabled=True, allow_banned=True, patch_type=torch.half)
Carl Case's avatar
Carl Case committed
161
        assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
162
        self.bce_common(assertion, dtype=torch.half)
Carl Case's avatar
Carl Case committed
163

164
        # handle with bfloat16 as patch_type
Carl Case's avatar
Carl Case committed
165
        self.handle._deactivate()
166
167
        self.handle = amp.init(enabled=True, allow_banned=True, patch_type=torch.bfloat16)
        self.bce_common(assertion, dtype=torch.bfloat16)
Carl Case's avatar
Carl Case committed
168

169
170
class _TestTensorCasts(unittest.TestCase):
    def _test_matmul_method(self, expected):
Carl Case's avatar
Carl Case committed
171
172
173
        other = torch.randn(self.h, self.h)
        lhs = lambda x: x.matmul(other)
        rhs = lambda x: other.matmul(x)
174
        run_layer_test(self, [lhs, rhs], expected, (self.h, self.h))
Carl Case's avatar
Carl Case committed
175

176
    def _test_matmul_op(self, expected):
Carl Case's avatar
Carl Case committed
177
178
179
        other = torch.randn(self.h, self.h)
        lhs = lambda x: x @ other
        rhs = lambda x: other @ x
180
        run_layer_test(self, [lhs, rhs], expected, (self.h, self.h))
Carl Case's avatar
Carl Case committed
181

182
    def _test_pow_method(self, expected):
Carl Case's avatar
Carl Case committed
183
        fn = lambda x: x.pow(2.)
184
        run_layer_test(self, [fn], expected, (self.b, self.h))
Carl Case's avatar
Carl Case committed
185

186
    def _test_pow_op(self, expected):
Carl Case's avatar
Carl Case committed
187
        fn = lambda x: x ** 2.
188
        run_layer_test(self, [fn], expected, (self.b, self.h))
Carl Case's avatar
Carl Case committed
189

190
    def _test_cpu(self, expected):
Carl Case's avatar
Carl Case committed
191
        fn = lambda x: x.cpu()
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        run_layer_test(self, [fn], expected, (self.b, self.h))

    def _test_sum(self, expected):
        fn = lambda x: x.sum()
        run_layer_test(self, [fn], expected, (self.b, self.h))

    # TODO: maybe more tests on disabled casting?

class TestTensorCastsHalf(_TestTensorCasts):
    def setUp(self):
        self.handle = amp.init(enabled=True, patch_type=torch.half)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

    def test_matmul_method_is_half(self):
        self._test_matmul_method(ALWAYS_HALF)

    def test_matmul_op_is_half(self):
        self._test_matmul_op(ALWAYS_HALF)

    def test_pow_method_is_float(self):
        self._test_pow_method(ALWAYS_FLOAT)

    def test_pow_op_is_float(self):
        self._test_pow_op(ALWAYS_FLOAT)

    def test_cpu_is_float(self):
Carl Case's avatar
Carl Case committed
221
222
        always_cpu_float = {torch.float: 'torch.FloatTensor',
                            torch.half: 'torch.FloatTensor'}
223
        self._test_cpu(always_cpu_float)
Carl Case's avatar
Carl Case committed
224
225

    def test_sum_is_float(self):
226
227
228
229
230
231
232
233
234
235
        self._test_sum(ALWAYS_FLOAT)

class TestTensorCastsBFloat16(_TestTensorCasts):
    def setUp(self):
        self.handle = amp.init(enabled=True, patch_type=torch.bfloat16)
        common_init(self)

    def tearDown(self):
        self.handle._deactivate()

236
    @skipIfRocm
237
238
239
    def test_matmul_method_is_bfloat16(self):
        self._test_matmul_method(ALWAYS_BFLOAT16)

240
    @skipIfRocm
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    def test_matmul_op_is_bfloat16(self):
        self._test_matmul_op(ALWAYS_BFLOAT16)

    def test_pow_method_is_float(self):
        self._test_pow_method(ALWAYS_FLOAT)

    def test_pow_op_is_float(self):
        self._test_pow_op(ALWAYS_FLOAT)

    def test_cpu_is_float(self):
        always_cpu_float = {torch.float: 'torch.FloatTensor',
                            torch.bfloat16: 'torch.FloatTensor'}
        self._test_cpu(always_cpu_float)

    def test_sum_is_float(self):
        self._test_sum(ALWAYS_FLOAT)
Carl Case's avatar
Carl Case committed
257
258
259
260


if __name__ == '__main__':
    unittest.main()