test_fused_optimizer.py 11.3 KB
Newer Older
1
2
3
4
5
6
import unittest
import os
import random

import torch
import apex
7
from itertools import product
8

lcskrishna's avatar
lcskrishna committed
9
10
from apex.testing.common_utils import skipIfRocm

11
class TestFusedOptimizer(unittest.TestCase):
12
13
14
15
16
17
18
19
20
    def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
        self.max_abs_diff = max_abs_diff
        self.max_rel_diff = max_rel_diff
        self.iters = iters
        torch.cuda.manual_seed(9876)

    def tearDown(self):
        pass

21
    def gen_param_optim(self, tensors, options, apex_only=False):
22
23
24
        ref_param = []
        tst_param = []
        for tensor in tensors:
rohithkrn's avatar
rohithkrn committed
25
26
27
28
            if apex_only:
                ref_param.append(torch.nn.Parameter(tensor.clone().float()))
            else:
                ref_param.append(torch.nn.Parameter(tensor.clone()))
29
30
            tst_param.append(torch.nn.Parameter(tensor.clone()))

rohithkrn's avatar
rohithkrn committed
31
        if apex_only:
32
            ref_optim = self.fused_optim(ref_param, **options)
rohithkrn's avatar
rohithkrn committed
33
        else:
34
            ref_optim = self.ref_optim(ref_param, **options)
35
        tst_optim = self.fused_optim(tst_param, **options)
36
37
38

        return (ref_param, tst_param, ref_optim, tst_optim)

rohithkrn's avatar
rohithkrn committed
39
    def gen_grad(self, ref_param, tst_param, apex_only=False):
40
        for p_ref, p_tst in zip(ref_param, tst_param):
rohithkrn's avatar
rohithkrn committed
41
42
            p_tst.grad = torch.rand_like(p_tst)
            p_ref.grad = p_tst.grad.detach().float() if apex_only else p_tst.grad
43
44
45
46
47
48
49
50

    def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
        half_grads = []
        for p_ref, p_tst in zip(ref_param, tst_param):
            half_grads.append(torch.rand_like(p_ref).half())
            p_ref.grad = half_grads[-1].float() / scale
        return half_grads

rohithkrn's avatar
rohithkrn committed
51
    def get_max_diff(self, ref_param, tst_param, apex_only=False):
52
53
        max_abs_diff = max_rel_diff = 0
        for p_ref, p_tst in zip(ref_param, tst_param):
rohithkrn's avatar
rohithkrn committed
54
55
            if apex_only:
                p_tst = p_tst.float()
56
57
58
59
60
61
62
63
            max_abs_diff_p = (p_ref - p_tst).abs().max().item()
            max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()

            if max_abs_diff_p > max_abs_diff:  max_abs_diff = max_abs_diff_p
            if max_rel_diff_p > max_rel_diff:  max_rel_diff = max_rel_diff_p

        return max_abs_diff, max_rel_diff

64
    def gen_single_type_test(self, param_type=torch.float, apex_only=False, device='cuda'):
65
66
        nelem = 278011

67
        tensor = torch.rand(nelem, dtype=param_type, device=device)
68
        ref_param, tst_param, ref_optim, tst_optim = \
69
            self.gen_param_optim([tensor], self.options, apex_only=apex_only)
70
71

        for i in range(self.iters):
rohithkrn's avatar
rohithkrn committed
72
            self.gen_grad(ref_param, tst_param, apex_only=apex_only)
73
74
            ref_optim.step()
            tst_optim.step()
rohithkrn's avatar
rohithkrn committed
75
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param, apex_only=apex_only)
76
            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
rohithkrn's avatar
rohithkrn committed
77
78
            if not apex_only:
                self.assertLessEqual(max_rel_diff, self.max_rel_diff)
79

80
81
82
83
84
85
86
87
88
89

class TestFusedAdam(TestFusedOptimizer):

    def __init__(self, *args, **kwargs):
        super(TestFusedAdam, self).__init__(*args, **kwargs)
        self.options = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
            'weight_decay': 0, 'amsgrad': False}
        self.ref_optim = torch.optim.Adam
        self.fused_optim = apex.optimizers.FusedAdam

Jeff Daily's avatar
Jeff Daily committed
90
    @skipIfRocm
91
92
93
94
95
96
    def test_float(self):
        self.gen_single_type_test(param_type=torch.float)

    def test_half(self):
        self.gen_single_type_test(param_type=torch.float16)

rohithkrn's avatar
rohithkrn committed
97
98
99
100
    # Compares bfloat16 computation against float32 as gold standard.
    # Uses apex optimizers(controlled by apex_only flag) for both types.
    # Doesn't use upstream optimizer like other tests as they seem to be
    # numerically unstable for half types
lcskrishna's avatar
lcskrishna committed
101
    @skipIfRocm
rohithkrn's avatar
rohithkrn committed
102
103
104
105
    def test_bfloat16(self):
        self.max_abs_diff = 1e-2
        self.gen_single_type_test(param_type=torch.bfloat16, apex_only=True)

Jeff Daily's avatar
Jeff Daily committed
106
    @skipIfRocm
107
108
109
110
111
112
113
    @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
    def test_multi_device(self):
        devices = ("cuda:0", "cuda:1")
        for current_dev, tensor_dev in product(devices, devices):
            with torch.cuda.device(current_dev):
                self.gen_single_type_test(param_type=torch.float, device=tensor_dev)

114
    @unittest.skip('Disable until 8/1/2019 adam/adamw upstream picked')
115
116
117
118
119
120
121
    def test_multi_params(self):
        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]

        tensors = []
        for size in sizes:
            tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
        ref_param, tst_param, ref_optim, tst_optim = \
122
            self.gen_param_optim(tensors, self.options)
123
124
125
126
127
128
129
130
131
132
133
134
135
136

        for i in range(self.iters):
            self.gen_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step()
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)

    @unittest.skip('No longer support fuse scaling')
    def test_scale(self):
        nelem = 278011
        tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
        ref_param, tst_param, ref_optim, tst_optim = \
137
            self.gen_param_optim([tensor], self.options)
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

        for i in range(self.iters):
            scale = random.random() * 1000
            half_grads = self.gen_mixed_grad(ref_param, tst_param, scale)
            ref_optim.step()
            tst_optim.step(grads=half_grads, scale=scale)
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)

            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)

    @unittest.skip('No longer support output fp16 param')
    def test_fp16_output(self):
        nelem = 278011

        tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
        ref_param, tst_param, ref_optim, tst_optim = \
155
            self.gen_param_optim([tensor], self.options)
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

        fp16_param = torch.nn.Parameter(tensor.clone().half())

        for i in range(self.iters):
            half_grads = self.gen_mixed_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step(grads=half_grads, output_params=[fp16_param])

            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)

            max_abs_diff, max_rel_diff = self.get_max_diff(tst_param, \
                [fp16_param.float()])
            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)

    def test_adam_option(self):
        nelem = 1
        adam_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06,
            'weight_decay':0, 'amsgrad':False}

        tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
        ref_param, tst_param, ref_optim, tst_optim = \
            self.gen_param_optim([tensor], adam_option)

        for i in range(self.iters):
            self.gen_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step()
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)

            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)


192
193
194
195
196
197
198
class TestFusedAdagrad(TestFusedOptimizer):
    def __init__(self, *args, **kwargs):
        super(TestFusedAdagrad, self).__init__(*args, **kwargs)
        self.options = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 1.0e-5}
        self.ref_optim = torch.optim.Adagrad
        self.fused_optim = apex.optimizers.FusedAdagrad

Jeff Daily's avatar
Jeff Daily committed
199
    @skipIfRocm
200
201
202
203
204
205
206
    def test_float(self):
        self.gen_single_type_test(param_type=torch.float)

    @unittest.skip("PyTorch optimizer is not numerically correct for fp16")
    def test_half(self):
        self.gen_single_type_test(param_type=torch.float16)

Jeff Daily's avatar
Jeff Daily committed
207
    @skipIfRocm
208
209
210
211
212
213
214
215
    @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
    def test_multi_device(self):
        devices = ("cuda:0", "cuda:1")
        for current_dev, tensor_dev in product(devices, devices):
            with torch.cuda.device(current_dev):
                self.gen_single_type_test(param_type=torch.float, device=tensor_dev)


Jeff Daily's avatar
Jeff Daily committed
216
    @skipIfRocm
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    def test_multi_params(self):
        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
        adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0}

        tensors = []
        for size in sizes:
            tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
            tensors, adagrad_option
        )

        for _ in range(self.iters):
            self.gen_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step()
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)

    @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
    def test_multi_params_different_devices_throws(self):
        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
        adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0}

        tensors = []
        for i, size in enumerate(sizes):
            tensors.append(torch.rand(size, dtype=torch.float, device="cuda:"+str(i % 2)))
        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
            tensors, adagrad_option
        )
        self.gen_grad(ref_param, tst_param)
        with self.assertRaisesRegex(RuntimeError, "not on the same device"):
            tst_optim.step()

    def test_adagrad_option(self):
        nelem = 1
        adagrad_option = {"lr": 0.01, "eps": 3e-06, "weight_decay": 0}

        tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
        ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
            [tensor], adagrad_option
        )

        for _ in range(self.iters):
            self.gen_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step()
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)

            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)


class TestFusedSGD(TestFusedOptimizer):
    def __init__(self, *args, **kwargs):
        super(TestFusedSGD, self).__init__(*args, **kwargs)
        self.options = {"lr": .25, "momentum": .125}
        self.ref_optim = torch.optim.SGD
        self.fused_optim = apex.optimizers.FusedSGD

    def test_float(self):
        self.gen_single_type_test(param_type=torch.float)

    def test_half(self):
        self.gen_single_type_test(param_type=torch.float16)

    @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
    def test_multi_device(self):
        devices = ("cuda:0", "cuda:1")
        for current_dev, tensor_dev in product(devices, devices):
            with torch.cuda.device(current_dev):
                self.gen_single_type_test(param_type=torch.float, device=tensor_dev)




293
294
if __name__ == '__main__':
    unittest.main()