"vscode:/vscode.git/clone" did not exist on "aa3003f1165f1b547130967ba5f132fecd10a9c2"
optimizer.py 15.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron optimizer."""
mohammad's avatar
mohammad committed
17
18
19
20
21

from abc import ABC
from abc import abstractmethod

import torch
22
from torch._six import inf
mohammad's avatar
mohammad committed
23
24
25
26

from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

mohammad's avatar
mohammad committed
27
28
from megatron import get_timers
from megatron import mpu
mohammad's avatar
mohammad committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45


def _zero_grad_group_helper(group, set_to_none):
    """Zero out the gradient for a group of parameters.
    Note: copied from torch.optim.optimizer."""
    for param in group:
        if param.grad is not None:
            if set_to_none:
                param.grad = None
            else:
                if param.grad.grad_fn is not None:
                    param.grad.detach_()
                else:
                    param.grad.requires_grad_(False)
                param.grad.zero_()


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def _clip_grad_norm(parameters, max_norm, norm_type=2):
    """Clips gradient norm of an iterable of parameters.

    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
    added functionality to handle model parallel parameters. Note that
    the gradients are modified in place.

    Arguments:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.

    Returns:
        Total norm of the parameters (viewed as a single vector).
    """
mohammad's avatar
mohammad committed
63

64
65
66
67
68
69
70
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    # Filter parameters based on:
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
71
72
    grads = []
    grads_for_norm = []
73
    for param in parameters:
mohammad's avatar
mohammad committed
74
75
        # Make sure the grads are in fp32
        assert param.grad.type() == 'torch.cuda.FloatTensor'
76
77
78
79
        grad_not_none = param.grad is not None
        is_not_shared = not hasattr(param, 'shared') or not param.shared
        is_not_tp_duplicate = param.tensor_model_parallel or \
                              (mpu.get_tensor_model_parallel_rank() == 0)
80
        grad = param.grad.detach()
mohammad's avatar
mohammad committed
81
        if grad_not_none:
82
            grads.append(grad)
83
        if grad_not_none and is_not_shared and is_not_tp_duplicate:
84
            grads_for_norm.append(grad)
85
86
87
88
89
90
91
92

    # Norm parameters.
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    total_norm = 0.0

    # Calculate norm.
    if norm_type == inf:
93
        total_norm = max(grad.abs().max() for grad in grads_for_norm)
94
95
96
97
98
99
100
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        # Take max across all model-parallel GPUs.
        torch.distributed.all_reduce(total_norm_cuda,
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=mpu.get_model_parallel_group())
        total_norm = total_norm_cuda[0].item()

mohammad's avatar
mohammad committed
101
    else:
102
103
104
        for grad in grads_for_norm:
            grad_norm = torch.norm(grad, norm_type)
            total_norm += grad_norm.item() ** norm_type
105
106
107
108
109
110
111
112
        # Sum across all model-parallel GPUs.
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        torch.distributed.all_reduce(total_norm_cuda,
                                     op=torch.distributed.ReduceOp.SUM,
                                     group=mpu.get_model_parallel_group())
        total_norm = total_norm_cuda[0].item() ** (1. / norm_type)

    # Scale.
mohammad's avatar
mohammad committed
113
    clip_coef = max_norm / (total_norm + 1.0e-6)
Mohammad's avatar
Mohammad committed
114
    if clip_coef < 1.0:
115
116
        for grad in grads:
            grad.mul_(clip_coef)
117
118
119
120

    return total_norm


mohammad's avatar
mohammad committed
121
122
123
124
125
126
127
128

class MegatronOptimizer(ABC):

    def __init__(self, optimizer):
        """Input optimizer is the base optimizer for example Adam."""
        self.optimizer = optimizer
        assert self.optimizer, 'no optimizer is provided.'

129
130
131
132
133
134
135
    def clip_grad_norm(self, clip_grad):
        params = []
        for param_group in self.optimizer.param_groups:
            for param in param_group['params']:
                params.append(param)
        _clip_grad_norm(params, clip_grad)

mohammad's avatar
mohammad committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    @abstractmethod
    def zero_grad(self, set_to_none=True):
        pass

    @abstractmethod
    def get_loss_scale(self):
        pass

    def scale_loss(self, loss):
        """Simple scaling."""
        return self.get_loss_scale() * loss

    @abstractmethod
    def step(self):
        pass

    @abstractmethod
    def state_dict(self):
        pass

    @abstractmethod
    def load_state_dict(self, state_dict):
        pass

    # Promote state so it can be retrieved or set via
    # "optimizer_instance.state"
    def _get_state(self):
        return self.optimizer.state

    def _set_state(self, value):
        self.optimizer.state = value

    state = property(_get_state, _set_state)

    # Promote param_groups so it can be retrieved or set via
    # "optimizer_instance.param_groups"
    # (for example, to adjust the learning rate)
    def _get_param_groups(self):
        return self.optimizer.param_groups

    def _set_param_groups(self, value):
        self.optimizer.param_groups = value

    param_groups = property(_get_param_groups, _set_param_groups)



class FP16OptimizerWithFP16Params(MegatronOptimizer):

    def __init__(self, optimizer, grad_scaler, clip_grad):
        super(FP16OptimizerWithFP16Params, self).__init__(optimizer)

        self.grad_scaler = grad_scaler
        self.clip_grad = clip_grad

        # Tensor used to determine if a nan/if has happend.
        # Any non-zero value indicates inf/nan.
        self.found_inf = torch.cuda.FloatTensor([0.0])

        # Dummy tensor needed for apex multi-apply tensor.
        self._dummy_overflow_buf = torch.cuda.IntTensor([0])

        # ======================
        # master parameter stuff
        # ======================

        # Three groups of parameters:
        #   fp16_groups: original fp16 parameters
        #   fp32_from_fp16_groups: fp32 copy of fp16 parameters
        #   fp32_from_fp32_groups: original fp32 parameters
        self.fp16_groups = []
        self.fp32_from_fp16_groups = []
        self.fp32_from_fp32_groups = []

        # For all the groups in the original optimizer:
        for param_group in self.optimizer.param_groups:
            fp16_params_this_group = []
            fp32_params_this_group = []
            fp32_from_fp16_params_this_group = []
            # For all the parameters in this group:
            for i, param in enumerate(param_group['params']):
                if param.requires_grad:

                    # fp16 params:
                    if param.type() == 'torch.cuda.HalfTensor':
                        fp16_params_this_group.append(param)
                        # Create a copy
                        master_param = param.detach().clone().float()
                        # Store grads
                        master_param.requires_grad = True
                        # Copy tensor model parallel attributes.
mohammad's avatar
mohammad committed
227
228
                        mpu.copy_tensor_model_parallel_attributes(master_param,
                                                                  param)
229
230
                        if hasattr(param, 'shared'):
                            master_param.shared = param.shared
mohammad's avatar
mohammad committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
                        # Replace the optimizer params with the new fp32 copy.
                        param_group['params'][i] = master_param
                        fp32_from_fp16_params_this_group.append(master_param)
                        # Reset existing state dict key to the new master param.
                        if param in self.optimizer.state:
                            self.optimizer.state[master_param] \
                                = self.optimizer.state.pop(param)

                    # fp32 params.
                    elif param.type() == 'torch.cuda.FloatTensor':
                        fp32_params_this_group.append(param)
                        param_group['params'][i] = param

                    else:
                        raise TypeError("Wrapped parameters must be either "
                                        "torch.cuda.FloatTensor or "
                                        "torch.cuda.HalfTensor. "
                                        "Received {}".format(param.type()))

            self.fp16_groups.append(fp16_params_this_group)
            self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
            self.fp32_from_fp32_groups.append(fp32_params_this_group)

        # Leverage state_dict() and load_state_dict() to
        # recast preexisting per-param state tensors
        self.optimizer.load_state_dict(self.optimizer.state_dict())


    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
                fp16_groups & fp32_from_fp32_groups."""
        for group in self.fp16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_fp32_groups:
            _zero_grad_group_helper(group, set_to_none)


    def get_loss_scale(self):
        return self.grad_scaler.scale


    @torch.no_grad()
    def step(self):

mohammad's avatar
mohammad committed
275
276
        timers = get_timers()

mohammad's avatar
mohammad committed
277
278
279
280
        # ==================================================
        # Copy gradients from model params to master params.
        # ==================================================

mohammad's avatar
mohammad committed
281
        timers('optimizer-copy-to-master-grad').start()
mohammad's avatar
mohammad committed
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        # This only needs to be done for the fp16 group.
        model_grads = []
        master_grads = []
        for model_group, master_group in zip(self.fp16_groups,
                                             self.fp32_from_fp16_groups):
            for model_param, master_param in zip(model_group, master_group):
                if model_param.grad is not None:
                    if master_param.grad is None:
                        master_param.grad = torch.empty_like(master_param)
                    model_grads.append(model_param.grad)
                    master_grads.append(master_param.grad)
        self._dummy_overflow_buf.fill_(0)
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             self._dummy_overflow_buf,
                             [model_grads, master_grads],
                             1.0)
mohammad's avatar
mohammad committed
299
        timers('optimizer-copy-to-master-grad').stop()
mohammad's avatar
mohammad committed
300
301
302
303
304

        # ==============================
        # Unscale and check for inf/nan.
        # ==============================

mohammad's avatar
mohammad committed
305
        timers('optimizer-unscale-and-check-inf').start()
mohammad's avatar
mohammad committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        # Append fp32 parameters.
        for master_group in self.fp32_from_fp32_groups:
            for master_param in master_group:
                if master_param.grad is not None:
                    master_grads.append(master_param.grad)
        # Reset found inf.
        self.found_inf.fill_(0.0)
        # Unscale and set found inf/nan
        torch._amp_foreach_non_finite_check_and_unscale_(
            master_grads, self.found_inf, self.grad_scaler.inv_scale)
        # Update across all model parallel instances.
        torch.distributed.all_reduce(self.found_inf,
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=mpu.get_model_parallel_group())
mohammad's avatar
mohammad committed
320
        timers('optimizer-unscale-and-check-inf').stop()
mohammad's avatar
mohammad committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338

        # ==================================
        # We are done with scaling gradients
        # so we can update the loss scale.
        # ==================================
        found_inf_flag = (self.found_inf.item() > 0)
        self.grad_scaler.update(found_inf_flag)

        # =====================================
        # If we found inf/nan, skip the update.
        # =====================================
        if found_inf_flag:
            return False

        # ==========================
        # Clip the master gradients.
        # ==========================

mohammad's avatar
mohammad committed
339
        timers('optimizer-clip-master-grad').start()
340
        self.clip_grad_norm(self.clip_grad)
mohammad's avatar
mohammad committed
341
        timers('optimizer-clip-master-grad').stop()
mohammad's avatar
mohammad committed
342
343
344
345
346
347
348
349
350
351
352

        # ===================
        # Step the optimizer.
        # ===================

        self.optimizer.step()

        # =================================
        # Update params from master params.
        # =================================

mohammad's avatar
mohammad committed
353
        timers('optimizer-copy-master-to-model-params').start()
mohammad's avatar
mohammad committed
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        # Only needed for the fp16 params.
        model_data = []
        master_data = []
        for model_group, master_group in zip(self.fp16_groups,
                                             self.fp32_from_fp16_groups):
            for model_param, master_param in zip(model_group, master_group):
                model_data.append(model_param.data)
                master_data.append(master_param.data)
        self._dummy_overflow_buf.fill_(0)
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             self._dummy_overflow_buf,
                             [master_data, model_data],
                             1.0)
mohammad's avatar
mohammad committed
368
        timers('optimizer-copy-master-to-model-params').stop()
mohammad's avatar
mohammad committed
369
370

        return True
mohammad's avatar
mohammad committed
371
372


mohammad's avatar
mohammad committed
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    def state_dict(self):
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
        state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['fp32_from_fp16_params'] = self.fp32_from_fp16_groups
        return state_dict


    def load_state_dict(self, state_dict):
        # Defer to the class to load.
        self.optimizer.load_state_dict(state_dict['optimizer'])
        self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
        # Copy data for the master params.
        for current_group, saved_group in zip(
                self.fp32_from_fp16_groups,
                state_dict['fp32_from_fp16_params']):
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)



mohammad's avatar
mohammad committed
394
395
class FP32Optimizer(MegatronOptimizer):

mohammad's avatar
mohammad committed
396
    def __init__(self, optimizer, clip_grad):
mohammad's avatar
mohammad committed
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420

        super(FP32Optimizer, self).__init__(optimizer)
        self.clip_grad = clip_grad
        self._scale = torch.cuda.FloatTensor([1.0])


    def zero_grad(self, set_to_none=True):
        """Copied from torch.optim.optimizer"""
        for group in self.optimizer.param_groups:
            _zero_grad_group_helper(group['params'], set_to_none)


    def get_loss_scale(self):
        """FP32 optimizer does not do any scaling."""
        return self._scale


    @torch.no_grad()
    def step(self):
        """Clip gradients (if needed) and step the base optimizer.
        Always return auccessful since there is no overflow."""

        # Clip gradients.
        if self.clip_grad > 0.0:
421
            self.clip_grad_norm(self.clip_grad)
mohammad's avatar
mohammad committed
422
423
424
425
426
427
428
429
430
431
432
433
434
435

        # Update parameters.
        self.optimizer.step()

        # No overflow for FP32 optimizer.
        return True


    def state_dict(self):
        return self.optimizer.state_dict()


    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)