test_fused_optimizer.py 12.5 KB
Newer Older
1
from itertools import product
2
import random
3
import unittest
4
5

import torch
6

7
import apex
8

9

10
class TestFusedOptimizer(unittest.TestCase):
11
12
13
14
    def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
        self.max_abs_diff = max_abs_diff
        self.max_rel_diff = max_rel_diff
        self.iters = iters
15
        torch.manual_seed(9876)
16
17
18
19

    def tearDown(self):
        pass

20
    def gen_param_optim(self, tensors, options, tst_options=None):
21

22
23
24
25
26
        # Adding this to make backward compatible with existing tests. Just in
        # case "tst_options" are not provided, it gets a copy of options
        # which contains the parameters for the reference optimizer
        if tst_options == None:
            tst_options = options
27

28
29
30
        ref_param = []
        tst_param = []
        for tensor in tensors:
31
            ref_param.append(torch.nn.Parameter(tensor.clone()))
32
33
            tst_param.append(torch.nn.Parameter(tensor.clone()))

34
        ref_optim = self.ref_optim(ref_param, **options)
35
        tst_optim = self.fused_optim(tst_param, **tst_options)
36
37
38

        return (ref_param, tst_param, ref_optim, tst_optim)

39
    def gen_grad(self, ref_param, tst_param):
40
        for p_ref, p_tst in zip(ref_param, tst_param):
41
42
            p_ref.grad = torch.rand_like(p_ref)
            p_tst.grad = p_ref.grad
43
44
45
46
47
48
49
50

    def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
        half_grads = []
        for p_ref, p_tst in zip(ref_param, tst_param):
            half_grads.append(torch.rand_like(p_ref).half())
            p_ref.grad = half_grads[-1].float() / scale
        return half_grads

51
    def get_max_diff(self, ref_param, tst_param):
52
53
54
55
56
57
58
59
60
61
        max_abs_diff = max_rel_diff = 0
        for p_ref, p_tst in zip(ref_param, tst_param):
            max_abs_diff_p = (p_ref - p_tst).abs().max().item()
            max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()

            if max_abs_diff_p > max_abs_diff:  max_abs_diff = max_abs_diff_p
            if max_rel_diff_p > max_rel_diff:  max_rel_diff = max_rel_diff_p

        return max_abs_diff, max_rel_diff

62
    def gen_single_type_test(self, param_type=torch.float, device='cuda', *, skip_assert: bool = False):
63
64
        nelem = 278011

65
        # Some ref and test optimizers may require different set of options.
66
        # This is a quick workaround to add that functionality while making
67
68
69
70
71
72
        # minimum changes in existing code.
        # If there is no "tst_options" field provided, safe to initialize
        # the test optimizer with the parameters of reference optimizer.
        if not hasattr(self, 'tst_options'):
            self.tst_options = self.options

73
        tensor = torch.rand(nelem, dtype=param_type, device=device)
74

75
        ref_param, tst_param, ref_optim, tst_optim = \
76
            self.gen_param_optim([tensor], self.options, self.tst_options)
77
78

        for i in range(self.iters):
79
            self.gen_grad(ref_param, tst_param)
80
81
            ref_optim.step()
            tst_optim.step()
82
83
            if skip_assert:
                return
84
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
85
            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
86
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)
87

88
89
90

class TestFusedAdam(TestFusedOptimizer):

91
92
    def setUp(self):
        super().setUp()
93
94
95
96
97
        self.options = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
            'weight_decay': 0, 'amsgrad': False}
        self.ref_optim = torch.optim.Adam
        self.fused_optim = apex.optimizers.FusedAdam

98
    @unittest.skip("Skipped the test since a regression introduced from PyTorch upstream: due to https://github.com/pytorch/pytorch/issues/80809#issuecomment-1175211598. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/82")
99
100
    def test_float(self):
        self.gen_single_type_test(param_type=torch.float)
Hubert Lu's avatar
Hubert Lu committed
101
    
102
103
    # NOTE(mkozuki): Current threshold values look too small for BFloat16.
    # TODO(mkozuki): Refactor `TestFusedOptimizer`
Hubert Lu's avatar
Hubert Lu committed
104
    @unittest.skip("NaN issue observed on ROCm as of 12/1/2021. The failing unit test is introduced by a PyTorch commit sometime in between rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 and 2021/12/01. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/63")
105
    def test_half(self):
106
107
        self.gen_single_type_test(param_type=torch.float16, skip_assert=True)

108
    @unittest.skip("Skipped the test since a regression introduced from PyTorch upstream: due to https://github.com/pytorch/pytorch/issues/80809#issuecomment-1175211598. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/82")
109
110
    def test_bfloat16(self):
        self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True)
111

112
    @unittest.skip("Skipped the test since a regression introduced from PyTorch upstream: due to https://github.com/pytorch/pytorch/issues/80809#issuecomment-1175211598. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/82")
113
114
115
116
117
118
119
    @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
    def test_multi_device(self):
        devices = ("cuda:0", "cuda:1")
        for current_dev, tensor_dev in product(devices, devices):
            with torch.cuda.device(current_dev):
                self.gen_single_type_test(param_type=torch.float, device=tensor_dev)

120
    @unittest.skip('Disable until 8/1/2019 adam/adamw upstream picked')
121
122
123
124
125
126
127
    def test_multi_params(self):
        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]

        tensors = []
        for size in sizes:
            tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
        ref_param, tst_param, ref_optim, tst_optim = \
128
            self.gen_param_optim(tensors, self.options)
129
130
131
132
133
134
135
136
137
138
139
140
141
142

        for i in range(self.iters):
            self.gen_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step()
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)

    @unittest.skip('No longer support fuse scaling')
    def test_scale(self):
        nelem = 278011
        tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
        ref_param, tst_param, ref_optim, tst_optim = \
143
            self.gen_param_optim([tensor], self.options)
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160

        for i in range(self.iters):
            scale = random.random() * 1000
            half_grads = self.gen_mixed_grad(ref_param, tst_param, scale)
            ref_optim.step()
            tst_optim.step(grads=half_grads, scale=scale)
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)

            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)

    @unittest.skip('No longer support output fp16 param')
    def test_fp16_output(self):
        nelem = 278011

        tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
        ref_param, tst_param, ref_optim, tst_optim = \
161
            self.gen_param_optim([tensor], self.options)
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

        fp16_param = torch.nn.Parameter(tensor.clone().half())

        for i in range(self.iters):
            half_grads = self.gen_mixed_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step(grads=half_grads, output_params=[fp16_param])

            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)

            max_abs_diff, max_rel_diff = self.get_max_diff(tst_param, \
                [fp16_param.float()])
            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)

179
    @unittest.skip("Skipped the test since a regression introduced from PyTorch upstream: due to https://github.com/pytorch/pytorch/issues/80809#issuecomment-1175211598. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/82")
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    def test_adam_option(self):
        nelem = 1
        adam_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06,
            'weight_decay':0, 'amsgrad':False}

        tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
        ref_param, tst_param, ref_optim, tst_optim = \
            self.gen_param_optim([tensor], adam_option)

        for i in range(self.iters):
            self.gen_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step()
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)

            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)


199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
class TestFusedAdagrad(TestFusedOptimizer):
    def __init__(self, *args, **kwargs):
        super(TestFusedAdagrad, self).__init__(*args, **kwargs)
        self.options = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 1.0e-5}
        self.ref_optim = torch.optim.Adagrad
        self.fused_optim = apex.optimizers.FusedAdagrad

    def test_float(self):
        self.gen_single_type_test(param_type=torch.float)

    @unittest.skip("PyTorch optimizer is not numerically correct for fp16")
    def test_half(self):
        self.gen_single_type_test(param_type=torch.float16)

    @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
    def test_multi_device(self):
        devices = ("cuda:0", "cuda:1")
        for current_dev, tensor_dev in product(devices, devices):
            with torch.cuda.device(current_dev):
                self.gen_single_type_test(param_type=torch.float, device=tensor_dev)


    def test_multi_params(self):
        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
        adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0}

        tensors = []
        for size in sizes:
            tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
            tensors, adagrad_option
        )

        for _ in range(self.iters):
            self.gen_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step()
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)

    @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
    def test_multi_params_different_devices_throws(self):
        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
        adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0}

        tensors = []
        for i, size in enumerate(sizes):
            tensors.append(torch.rand(size, dtype=torch.float, device="cuda:"+str(i % 2)))
        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
            tensors, adagrad_option
        )
        self.gen_grad(ref_param, tst_param)
        with self.assertRaisesRegex(RuntimeError, "not on the same device"):
            tst_optim.step()

    def test_adagrad_option(self):
        nelem = 1
        adagrad_option = {"lr": 0.01, "eps": 3e-06, "weight_decay": 0}

        tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
            [tensor], adagrad_option
        )

        for _ in range(self.iters):
            self.gen_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step()
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)

            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)


class TestFusedSGD(TestFusedOptimizer):
    def __init__(self, *args, **kwargs):
        super(TestFusedSGD, self).__init__(*args, **kwargs)
        self.options = {"lr": .25, "momentum": .125}
        self.ref_optim = torch.optim.SGD
        self.fused_optim = apex.optimizers.FusedSGD

    def test_float(self):
        self.gen_single_type_test(param_type=torch.float)

    def test_half(self):
        self.gen_single_type_test(param_type=torch.float16)

    @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
    def test_multi_device(self):
        devices = ("cuda:0", "cuda:1")
        for current_dev, tensor_dev in product(devices, devices):
            with torch.cuda.device(current_dev):
                self.gen_single_type_test(param_type=torch.float, device=tensor_dev)

294
295
if __name__ == '__main__':
    unittest.main()