"docs/README_Windows_CUDA_Acceleration_zh_CN.md" did not exist on "c479245e7a2e9dbed0654d6c60b49a1fc6464199"
test_adam.py 7.2 KB
Newer Older
1
2
3
4
5
6
7
import unittest
import os
import random

import torch
import apex

lcskrishna's avatar
lcskrishna committed
8
9
from apex.testing.common_utils import skipIfRocm

10
11
12
13
14
15
16
17
18
19
class TestFusedAdam(unittest.TestCase):
    def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
        self.max_abs_diff = max_abs_diff
        self.max_rel_diff = max_rel_diff
        self.iters = iters
        torch.cuda.manual_seed(9876)

    def tearDown(self):
        pass

rohithkrn's avatar
rohithkrn committed
20
    def gen_param_optim(self, tensors, adam_option, apex_only=False):
21
22
23
        ref_param = []
        tst_param = []
        for tensor in tensors:
rohithkrn's avatar
rohithkrn committed
24
25
26
27
            if apex_only:
                ref_param.append(torch.nn.Parameter(tensor.clone().float()))
            else:
                ref_param.append(torch.nn.Parameter(tensor.clone()))
28
29
            tst_param.append(torch.nn.Parameter(tensor.clone()))

rohithkrn's avatar
rohithkrn committed
30
31
32
33
        if apex_only:
            ref_optim = apex.optimizers.FusedAdam(ref_param, **adam_option)
        else:
            ref_optim = torch.optim.Adam(ref_param, **adam_option)
34
        tst_optim = apex.optimizers.FusedAdam(tst_param, **adam_option)
35
36
37

        return (ref_param, tst_param, ref_optim, tst_optim)

rohithkrn's avatar
rohithkrn committed
38
    def gen_grad(self, ref_param, tst_param, apex_only=False):
39
        for p_ref, p_tst in zip(ref_param, tst_param):
rohithkrn's avatar
rohithkrn committed
40
41
            p_tst.grad = torch.rand_like(p_tst)
            p_ref.grad = p_tst.grad.detach().float() if apex_only else p_tst.grad
42
43
44
45
46
47
48
49

    def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
        half_grads = []
        for p_ref, p_tst in zip(ref_param, tst_param):
            half_grads.append(torch.rand_like(p_ref).half())
            p_ref.grad = half_grads[-1].float() / scale
        return half_grads

rohithkrn's avatar
rohithkrn committed
50
    def get_max_diff(self, ref_param, tst_param, apex_only=False):
51
52
        max_abs_diff = max_rel_diff = 0
        for p_ref, p_tst in zip(ref_param, tst_param):
rohithkrn's avatar
rohithkrn committed
53
54
            if apex_only:
                p_tst = p_tst.float()
55
56
57
58
59
60
61
62
            max_abs_diff_p = (p_ref - p_tst).abs().max().item()
            max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()

            if max_abs_diff_p > max_abs_diff:  max_abs_diff = max_abs_diff_p
            if max_rel_diff_p > max_rel_diff:  max_rel_diff = max_rel_diff_p

        return max_abs_diff, max_rel_diff

rohithkrn's avatar
rohithkrn committed
63
    def gen_single_type_test(self, param_type=torch.float, apex_only=False):
64
65
66
67
68
69
        nelem = 278011
        adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
            'weight_decay':0, 'amsgrad':False}

        tensor = torch.rand(nelem, dtype=param_type, device='cuda')
        ref_param, tst_param, ref_optim, tst_optim = \
rohithkrn's avatar
rohithkrn committed
70
            self.gen_param_optim([tensor], adam_option, apex_only=apex_only)
71
72

        for i in range(self.iters):
rohithkrn's avatar
rohithkrn committed
73
            self.gen_grad(ref_param, tst_param, apex_only=apex_only)
74
75
            ref_optim.step()
            tst_optim.step()
rohithkrn's avatar
rohithkrn committed
76
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param, apex_only=apex_only)
77
78

            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
rohithkrn's avatar
rohithkrn committed
79
80
            if not apex_only:
                self.assertLessEqual(max_rel_diff, self.max_rel_diff)
81

lcskrishna's avatar
lcskrishna committed
82
    @skipIfRocm
83
84
85
86
87
88
    def test_float(self):
        self.gen_single_type_test(param_type=torch.float)

    def test_half(self):
        self.gen_single_type_test(param_type=torch.float16)

rohithkrn's avatar
rohithkrn committed
89
90
91
92
    # Compares bfloat16 computation against float32 as gold standard.
    # Uses apex optimizers(controlled by apex_only flag) for both types.
    # Doesn't use upstream optimizer like other tests as they seem to be
    # numerically unstable for half types
lcskrishna's avatar
lcskrishna committed
93
    @skipIfRocm
rohithkrn's avatar
rohithkrn committed
94
95
96
97
    def test_bfloat16(self):
        self.max_abs_diff = 1e-2
        self.gen_single_type_test(param_type=torch.bfloat16, apex_only=True)

98
    @unittest.skip('Disable until 8/1/2019 adam/adamw upstream picked')
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    def test_multi_params(self):
        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
        adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
            'weight_decay':0, 'amsgrad':False}

        tensors = []
        for size in sizes:
            tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
        ref_param, tst_param, ref_optim, tst_optim = \
            self.gen_param_optim(tensors, adam_option)

        for i in range(self.iters):
            self.gen_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step()
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)

    @unittest.skip('No longer support fuse scaling')
    def test_scale(self):
        nelem = 278011
        adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
            'weight_decay':0, 'amsgrad':False}

        tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
        ref_param, tst_param, ref_optim, tst_optim = \
            self.gen_param_optim([tensor], adam_option)

        for i in range(self.iters):
            scale = random.random() * 1000
            half_grads = self.gen_mixed_grad(ref_param, tst_param, scale)
            ref_optim.step()
            tst_optim.step(grads=half_grads, scale=scale)
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)

            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)

    @unittest.skip('No longer support output fp16 param')
    def test_fp16_output(self):
        nelem = 278011
        adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
            'weight_decay':0, 'amsgrad':False}

        tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
        ref_param, tst_param, ref_optim, tst_optim = \
            self.gen_param_optim([tensor], adam_option)

        fp16_param = torch.nn.Parameter(tensor.clone().half())

        for i in range(self.iters):
            half_grads = self.gen_mixed_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step(grads=half_grads, output_params=[fp16_param])

            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)

            max_abs_diff, max_rel_diff = self.get_max_diff(tst_param, \
                [fp16_param.float()])
            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)

    def test_adam_option(self):
        nelem = 1
        adam_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06,
            'weight_decay':0, 'amsgrad':False}

        tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
        ref_param, tst_param, ref_optim, tst_optim = \
            self.gen_param_optim([tensor], adam_option)

        for i in range(self.iters):
            self.gen_grad(ref_param, tst_param)
            ref_optim.step()
            tst_optim.step()
            max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)

            self.assertLessEqual(max_abs_diff, self.max_abs_diff)
            self.assertLessEqual(max_rel_diff, self.max_rel_diff)


if __name__ == '__main__':
    script_path = os.path.dirname(os.path.realpath(__file__))
    unittest.main()