test_lamb.py 10.3 KB
Newer Older
Kexin Yu's avatar
Kexin Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import unittest
import os

import torch
from torch.optim import Optimizer
import apex
from apex.multi_tensor_apply import multi_tensor_applier

class RefLAMB(Optimizer):
    r"""Implements Lamb algorithm.

    It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.

    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-6)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01)

    .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
        https://arxiv.org/abs/1904.00962
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(RefLAMB, self).__init__(params, defaults)
        if multi_tensor_applier.available:
            import amp_C
            self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
            # Skip buffer
            self._dummy_overflow_buf = torch.cuda.IntTensor([0])
            self.multi_tensor_lamb = amp_C.multi_tensor_lamb
        else:
            raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        # create separate grad lists for fp32 and fp16 params
        g_all_32, g_all_16 = [], []
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                if p.dtype == torch.float32:
                    g_all_32.append(p.grad.data)
                elif p.dtype == torch.float16:
                    g_all_16.append(p.grad.data)
                else:
                    raise RuntimeError('FusedLAMB only support fp16 and fp32.')

        g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(1, device='cuda')
        # compute grad norm for two lists
        if len(g_all_32) > 0:
            g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
                                             self._dummy_overflow_buf,
                                             [g_all_32], False)[0]
        if len(g_all_16) > 0:
            g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
                                             self._dummy_overflow_buf,
                                             [g_all_16], False)[0]

        # blend two grad norms to get global grad norm
        global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
                                                self._dummy_overflow_buf,
                                                [[g_norm_32, g_norm_16]],
                                                False)[0]
        
        max_grad_norm = 1.0
        clipped_ratio = max_grad_norm / max(global_grad_norm, max_grad_norm)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                p.grad.data *= clipped_ratio
                grad = p.grad.data 
                if grad.is_sparse:
                    raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['m'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['v'] = torch.zeros_like(p.data)

                m_t, v_t = state['m'], state['v']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # m_t = beta1 * m + (1 - beta1) * g_t
                m_t.mul_(beta1).add_(grad, alpha=1-beta1)
                # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
                v_t.mul_(beta2).addcmul_(grad, grad, value=1-beta2)

                # Debiasing
                m_t_hat = m_t / (1.0 - beta1 ** state['step'])
                v_t_hat = v_t / (1.0 - beta2 ** state['step'])

                update = m_t_hat / v_t_hat.sqrt().add(group['eps'])

                if group['weight_decay'] != 0:
                    update.add_(p.data, alpha=group['weight_decay'])

                trust_ratio = 1.0
                w_norm = p.data.pow(2).sum().sqrt()
                g_norm = update.pow(2).sum().sqrt()
                if w_norm > 0 and g_norm > 0:
                    trust_ratio = w_norm / g_norm

                state['w_norm'] = w_norm
                state['g_norm'] = g_norm
                state['trust_ratio'] = trust_ratio

                step_size = group['lr'] 

                p.data.add_(update, alpha=-step_size*trust_ratio)

        return loss


class TestFusedLAMB(unittest.TestCase):
147
    def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7):
Kexin Yu's avatar
Kexin Yu committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        self.max_abs_diff = max_abs_diff
        self.max_rel_diff = max_rel_diff
        self.iters = iters
        torch.cuda.manual_seed(9876)

    def tearDown(self):
        pass

    def gen_param_optim(self, tensors, lamb_option):
        ref_param = []
        tst_param = []
        for tensor in tensors:
            ref_param.append(torch.nn.Parameter(tensor.clone()))
            tst_param.append(torch.nn.Parameter(tensor.clone()))

        ref_optim = RefLAMB(ref_param, **lamb_option)
164
        tst_optim = apex.optimizers.FusedLAMB(tst_param, use_nvlamb=True, **lamb_option)
Kexin Yu's avatar
Kexin Yu committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

        return (ref_param, tst_param, ref_optim, tst_optim)

    def gen_grad(self, ref_param, tst_param):
        for p_ref, p_tst in zip(ref_param, tst_param):
            p_ref.grad = torch.rand_like(p_ref)
            p_tst.grad = p_ref.grad

    def gen_mixed_grad(self, ref_param, tst_param, scale=1.0):
        half_grads = []
        for p_ref, _ in zip(ref_param, tst_param):
            half_grads.append(torch.rand_like(p_ref).half())
            p_ref.grad = half_grads[-1].float() / scale
        return half_grads

    def get_max_diff(self, ref_param, tst_param):
        max_abs_diff = max_rel_diff = 0
        for p_ref, p_tst in zip(ref_param, tst_param):
            max_abs_diff_p = (p_ref - p_tst).abs().max().item()
            max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()

            if max_abs_diff_p > max_abs_diff:  max_abs_diff = max_abs_diff_p
            if max_rel_diff_p > max_rel_diff:  max_rel_diff = max_rel_diff_p

        return max_abs_diff, max_rel_diff

    def gen_single_type_test(self, param_type=torch.float):
        nelem = 278011
        tensor = torch.rand(nelem, dtype=param_type, device='cuda')
194
195
196
197
198
199
        weight_decay = [0, 0.01]
        
        for wd in weight_decay:
            lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
            ref_param, tst_param, ref_optim, tst_optim = \
                self.gen_param_optim([tensor], lamb_option)
Kexin Yu's avatar
Kexin Yu committed
200

201
202
203
204
205
            for i in range(self.iters):
                self.gen_grad(ref_param, tst_param)
                ref_optim.step()
                tst_optim.step()
                max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
Kexin Yu's avatar
Kexin Yu committed
206

207
208
                self.assertLessEqual(max_abs_diff, self.max_abs_diff)
                self.assertLessEqual(max_rel_diff, self.max_rel_diff)
Kexin Yu's avatar
Kexin Yu committed
209
210
211
212
213
214
215
216
217
218

    def test_float(self):
        self.gen_single_type_test(param_type=torch.float)

    @unittest.skip("PyTorch optimizer is not numerically correct for fp16")
    def test_half(self):
        self.gen_single_type_test(param_type=torch.float16)

    def test_multi_params(self):
        sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        weight_decay = [0, 0.01]
        
        for wd in weight_decay:
            lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd}
            tensors = []
            for size in sizes:
                tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
            ref_param, tst_param, ref_optim, tst_optim = \
                self.gen_param_optim(tensors, lamb_option)

            for i in range(self.iters):
                self.gen_grad(ref_param, tst_param)
                ref_optim.step()
                tst_optim.step()
                max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
                self.assertLessEqual(max_abs_diff, self.max_abs_diff)
                self.assertLessEqual(max_rel_diff, self.max_rel_diff)
Kexin Yu's avatar
Kexin Yu committed
236
237
238
239

    def test_lamb_option(self):
        nelem = 1
        tensor = torch.rand(nelem, dtype=torch.float, device='cuda')
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        weight_decay = [0, 0.01]
        
        for wd in weight_decay:
            lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd}
            ref_param, tst_param, ref_optim, tst_optim = \
                self.gen_param_optim([tensor], lamb_option)

            for i in range(self.iters):
                self.gen_grad(ref_param, tst_param)
                ref_optim.step()
                tst_optim.step()
                max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)

                self.assertLessEqual(max_abs_diff, self.max_abs_diff)
                self.assertLessEqual(max_rel_diff, self.max_rel_diff)
Kexin Yu's avatar
Kexin Yu committed
255
256
257
258
259


if __name__ == '__main__':
    script_path = os.path.dirname(os.path.realpath(__file__))
    unittest.main()