optimizer.py 21.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron optimizer."""
mohammad's avatar
mohammad committed
17
18
19
20
21
22
23
24
25

from abc import ABC
from abc import abstractmethod

import torch

from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

mohammad's avatar
mohammad committed
26
27
from megatron import get_timers
from megatron import mpu
mohammad's avatar
mohammad committed
28
29
from megatron import print_rank_0

Rewon Child's avatar
Rewon Child committed
30
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
mohammad's avatar
mohammad committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


def _zero_grad_group_helper(group, set_to_none):
    """Zero out the gradient for a group of parameters.
    Note: copied from torch.optim.optimizer."""
    for param in group:
        if param.grad is not None:
            if set_to_none:
                param.grad = None
            else:
                if param.grad.grad_fn is not None:
                    param.grad.detach_()
                else:
                    param.grad.requires_grad_(False)
                param.grad.zero_()


48
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
49
50
51
52
    """Use multi-tensor-applier to copy values from one list to another.
    We don't have a blfoat16 implementation so for now if the overflow_buf
    is not provided, we default back to simple loop copy to be compatible
    with bfloat16."""
53
54
    if overflow_buf:
        overflow_buf.fill_(0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
55
56
57
58
59
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             overflow_buf,
                             [this, that],
                             1.0)
60
    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
61
62
63
        for this_, that_ in zip(this, that):
            that_.copy_(this_)

64

mohammad's avatar
mohammad committed
65
66
67

class MegatronOptimizer(ABC):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
68
69
70

    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
71
                 params_have_main_grad,
72
                 use_contiguous_buffers_in_local_ddp):
73

mohammad's avatar
mohammad committed
74
75
76
        """Input optimizer is the base optimizer for example Adam."""
        self.optimizer = optimizer
        assert self.optimizer, 'no optimizer is provided.'
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
77
78
79
80
        # Set gradient clipping and logging params.
        self.clip_grad = clip_grad
        self.log_num_zeros_in_grad = log_num_zeros_in_grad
        self.params_have_main_grad = params_have_main_grad
81
        self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
82

83
        if self.use_contiguous_buffers_in_local_ddp:
84
85
            assert self.params_have_main_grad, \
                "use of contiguous buffer requires that params have main grad"
mohammad's avatar
mohammad committed
86

Rewon Child's avatar
Rewon Child committed
87
    def get_parameters(self):
88
89
90
91
        params = []
        for param_group in self.optimizer.param_groups:
            for param in param_group['params']:
                params.append(param)
Rewon Child's avatar
Rewon Child committed
92
93
        return params

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
94

Rewon Child's avatar
Rewon Child committed
95
96
    def clip_grad_norm(self, clip_grad):
        params = self.get_parameters()
97
        return clip_grad_norm_fp32(params, clip_grad)
98

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
99

Rewon Child's avatar
Rewon Child committed
100
101
102
103
    def count_zeros(self):
        params = self.get_parameters()
        return count_zeros_fp32(params)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
104

mohammad's avatar
mohammad committed
105
106
107
108
    @abstractmethod
    def zero_grad(self, set_to_none=True):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
109

mohammad's avatar
mohammad committed
110
111
    @abstractmethod
    def get_loss_scale(self):
112
        """The output should be a cuda tensor of size 1."""
mohammad's avatar
mohammad committed
113
114
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
115

mohammad's avatar
mohammad committed
116
117
118
119
    def scale_loss(self, loss):
        """Simple scaling."""
        return self.get_loss_scale() * loss

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
120

mohammad's avatar
mohammad committed
121
122
123
124
    @abstractmethod
    def step(self):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
125

126
127
    @abstractmethod
    def reload_model_params(self):
128
129
130
131
132
        """Refreshes any internal state from the current model parameters.
        Call whenever the parameters are changed outside of the optimizer.
        For example, when we load a model from a checkpoint  without loading
        the optimizer, the model parameters are updated but for fp16 optimizer
        with main parameters, the main parameters need to also be updated."""
133
134
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
135

mohammad's avatar
mohammad committed
136
137
138
139
    @abstractmethod
    def state_dict(self):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
140

mohammad's avatar
mohammad committed
141
142
143
144
    @abstractmethod
    def load_state_dict(self, state_dict):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
145

mohammad's avatar
mohammad committed
146
147
148
149
150
151
152
153
154
155
    # Promote state so it can be retrieved or set via
    # "optimizer_instance.state"
    def _get_state(self):
        return self.optimizer.state

    def _set_state(self, value):
        self.optimizer.state = value

    state = property(_get_state, _set_state)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
156

mohammad's avatar
mohammad committed
157
158
159
160
161
162
163
164
165
166
167
168
169
    # Promote param_groups so it can be retrieved or set via
    # "optimizer_instance.param_groups"
    # (for example, to adjust the learning rate)
    def _get_param_groups(self):
        return self.optimizer.param_groups

    def _set_param_groups(self, value):
        self.optimizer.param_groups = value

    param_groups = property(_get_param_groups, _set_param_groups)



Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
170
171
172
173
174
175
176
177
178
179
180
181
class Float16OptimizerWithFloat16Params(MegatronOptimizer):
    """Float16 optimizer for fp16 and bf16 data types.

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
182
            for the DDP cases where there is a continuous buffer
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
183
184
185
186
187
188
189
190
191
192
193
194
195
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
    """

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
196
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
197
                 bf16, grad_scaler):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
198
199
200

        super(Float16OptimizerWithFloat16Params, self).__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
201
            params_have_main_grad, use_contiguous_buffers_in_local_ddp)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
202
203

        self.bf16 = bf16
mohammad's avatar
mohammad committed
204
        self.grad_scaler = grad_scaler
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
205
206
207
        # None grad scaler is only supported for bf16.
        if self.grad_scaler is None:
            assert self.bf16, 'fp16 expects a grad scaler.'
mohammad's avatar
mohammad committed
208
209
210

        # Tensor used to determine if a nan/if has happend.
        # Any non-zero value indicates inf/nan.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
211
212
213
214
        # Note that we keep this for the cases that grad scaler is none.
        # We still record nan/inf if we have a bfloat16 with a grad scaler.
        if self.grad_scaler:
            self.found_inf = torch.cuda.FloatTensor([0.0])
mohammad's avatar
mohammad committed
215
216

        # Dummy tensor needed for apex multi-apply tensor.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
217
218
219
220
221
222
223
224
225
226
        # For bfloat, we don't have multi-tensor apply and for now
        # we set it to none so the multi-tensor apply gets ignored.
        if bf16:
            self._dummy_overflow_buf = None
        else:
            self._dummy_overflow_buf = torch.cuda.IntTensor([0])

        # In case grad scaler is not passed, define the unity scale.
        if self.grad_scaler is None:
            self._scale_one = torch.cuda.FloatTensor([1.0])
mohammad's avatar
mohammad committed
227
228

        # ======================
229
        # main parameter stuff
mohammad's avatar
mohammad committed
230
231
232
        # ======================

        # Three groups of parameters:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
233
234
        #   float16_groups: original float16 parameters
        #   fp32_from_float16_groups: fp32 copy of float16 parameters
mohammad's avatar
mohammad committed
235
        #   fp32_from_fp32_groups: original fp32 parameters
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
236
237
        self.float16_groups = []
        self.fp32_from_float16_groups = []
mohammad's avatar
mohammad committed
238
239
240
241
        self.fp32_from_fp32_groups = []

        # For all the groups in the original optimizer:
        for param_group in self.optimizer.param_groups:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
242
            float16_params_this_group = []
mohammad's avatar
mohammad committed
243
            fp32_params_this_group = []
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
244
            fp32_from_float16_params_this_group = []
mohammad's avatar
mohammad committed
245
246
247
248
            # For all the parameters in this group:
            for i, param in enumerate(param_group['params']):
                if param.requires_grad:

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
249
250
251
252
                    # float16 params:
                    if param.type() in ['torch.cuda.HalfTensor',
                                        'torch.cuda.BFloat16Tensor']:
                        float16_params_this_group.append(param)
mohammad's avatar
mohammad committed
253
                        # Create a copy
254
                        main_param = param.detach().clone().float()
mohammad's avatar
mohammad committed
255
                        # Copy tensor model parallel attributes.
256
                        mpu.copy_tensor_model_parallel_attributes(main_param,
mohammad's avatar
mohammad committed
257
                                                                  param)
258
                        if hasattr(param, 'shared'):
259
                            main_param.shared = param.shared
mohammad's avatar
mohammad committed
260
                        # Replace the optimizer params with the new fp32 copy.
261
                        param_group['params'][i] = main_param
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
262
                        fp32_from_float16_params_this_group.append(main_param)
263
                        # Reset existing state dict key to the new main param.
mohammad's avatar
mohammad committed
264
                        if param in self.optimizer.state:
265
                            self.optimizer.state[main_param] \
mohammad's avatar
mohammad committed
266
                                = self.optimizer.state.pop(param)
267
268
269
270
271
272
273
                    
                        #state = self.optimizer.state[main_param]
                        #if len(state) == 0:
                        #    # Exponential moving average of gradient values
                        #    state['exp_avg'] = torch.zeros_like(main_param.data)
                        #    # Exponential moving average of squared gradient values
                        #    state['exp_avg_sq'] = torch.zeros_like(main_param.data)
mohammad's avatar
mohammad committed
274
275
276
277
278
279
280

                    # fp32 params.
                    elif param.type() == 'torch.cuda.FloatTensor':
                        fp32_params_this_group.append(param)
                        param_group['params'][i] = param

                    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
281
282
283
284
285
286
287
288
289
                        raise TypeError('Wrapped parameters must be one of '
                                        'torch.cuda.FloatTensor,  '
                                        'torch.cuda.HalfTensor, or '
                                        'torch.cuda.BFloat16Tensor. '
                                        'Received {}'.format(param.type()))

            self.float16_groups.append(float16_params_this_group)
            self.fp32_from_float16_groups.append(
                fp32_from_float16_params_this_group)
mohammad's avatar
mohammad committed
290
291
292
293
            self.fp32_from_fp32_groups.append(fp32_params_this_group)

        # Leverage state_dict() and load_state_dict() to
        # recast preexisting per-param state tensors
294
295
        # self.optimizer.load_state_dict(self.optimizer.state_dict())
        
mohammad's avatar
mohammad committed
296
297
298

    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
299
300
301
302
        float16_groups & fp32_from_fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
303
        for group in self.float16_groups:
mohammad's avatar
mohammad committed
304
            _zero_grad_group_helper(group, set_to_none)
305
306
        for group in self.fp32_from_float16_groups:
            _zero_grad_group_helper(group, set_to_none)
mohammad's avatar
mohammad committed
307
308
309
310
311
        for group in self.fp32_from_fp32_groups:
            _zero_grad_group_helper(group, set_to_none)


    def get_loss_scale(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
312
313
        if self.grad_scaler is None:
            return self._scale_one
mohammad's avatar
mohammad committed
314
315
316
        return self.grad_scaler.scale


317
    def _copy_model_grads_to_main_grads(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
318
319
320
        # This only needs to be done for the float16 group.
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
321
            for model_param, main_param in zip(model_group, main_group):
322
                if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
323
324
325
326
                    main_param.grad = model_param.main_grad.float()
                else:
                    if model_param.grad is not None:
                        main_param.grad = model_param.grad.float()
327
328
329
330
331

                # Safe to deallocate model's grad/main_grad after copying.
                # (If using contiguous buffers, main_grad's memory should
                # persist and therefore should not be deallocated.)
                model_param.grad = None
332
                if self.params_have_main_grad and \
333
                   not self.use_contiguous_buffers_in_local_ddp:
334
335
                    model_param.main_grad = None

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
336
337
338
339
340
        # For fp32 grads, we need to reset the grads to main grad.
        if self.params_have_main_grad:
            for model_group in self.fp32_from_fp32_groups:
                for model_param in model_group:
                    model_param.grad = model_param.main_grad
mohammad's avatar
mohammad committed
341

342
343
344
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
345
                    if not self.use_contiguous_buffers_in_local_ddp:
346
                        model_param.main_grad = None
mohammad's avatar
mohammad committed
347

348
349
    def _unscale_main_grads_and_check_for_nan(self):
        main_grads = []
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
350
351
        # fp32 params fromm float16 ones.
        for main_group in self.fp32_from_float16_groups:
352
353
354
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
mohammad's avatar
mohammad committed
355
        # Append fp32 parameters.
356
357
358
359
        for main_group in self.fp32_from_fp32_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
mohammad's avatar
mohammad committed
360
361
362
363
        # Reset found inf.
        self.found_inf.fill_(0.0)
        # Unscale and set found inf/nan
        torch._amp_foreach_non_finite_check_and_unscale_(
364
            main_grads, self.found_inf, self.grad_scaler.inv_scale)
mohammad's avatar
mohammad committed
365
366
367
368
        # Update across all model parallel instances.
        torch.distributed.all_reduce(self.found_inf,
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=mpu.get_model_parallel_group())
mohammad's avatar
mohammad committed
369
370
371
372
373
374

        # Check for nan.
        found_inf_flag = (self.found_inf.item() > 0)
        return found_inf_flag


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
375
    def _get_model_and_main_params_data_float16(self):
mohammad's avatar
mohammad committed
376
        model_data = []
377
        main_data = []
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
378
379
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
380
            for model_param, main_param in zip(model_group, main_group):
mohammad's avatar
mohammad committed
381
                model_data.append(model_param.data)
382
383
                main_data.append(main_param.data)
        return model_data, main_data
384
385


386
    def _copy_main_params_to_model_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
387
388
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
389
390
391
392
393
        _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
                                        overflow_buf=self._dummy_overflow_buf)


    def _copy_model_params_to_main_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
394
395
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
396
397
        _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
                                        overflow_buf=self._dummy_overflow_buf)
398
399
400


    def reload_model_params(self):
401
        self._copy_model_params_to_main_params()
mohammad's avatar
mohammad committed
402

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
403

mohammad's avatar
mohammad committed
404
405
406
407
408
    @torch.no_grad()
    def step(self):

        timers = get_timers()

409
410
411
412
        # Copy gradients from model params to main params.
        timers('optimizer-copy-to-main-grad').start()
        self._copy_model_grads_to_main_grads()
        timers('optimizer-copy-to-main-grad').stop()
mohammad's avatar
mohammad committed
413

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
414
415
416
        # Do unscale, check for inf, and update grad scaler only for
        # the case that grad scaler is provided.
        if self.grad_scaler:
mohammad's avatar
mohammad committed
417

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
418
419
420
421
            # Unscale and check for inf/nan.
            timers('optimizer-unscale-and-check-inf').start()
            found_inf_flag = self._unscale_main_grads_and_check_for_nan()
            timers('optimizer-unscale-and-check-inf').stop()
mohammad's avatar
mohammad committed
422

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
423
424
425
426
427
428
429
            # We are done with scaling gradients
            # so we can update the loss scale.
            self.grad_scaler.update(found_inf_flag)

            # If we found inf/nan, skip the update.
            if found_inf_flag:
                return False, None, None
mohammad's avatar
mohammad committed
430

431
432
        # Clip the main gradients.
        timers('optimizer-clip-main-grad').start()
433
434
435
        grad_norm = None
        if self.clip_grad > 0.0:
            grad_norm = self.clip_grad_norm(self.clip_grad)
436
        timers('optimizer-clip-main-grad').stop()
mohammad's avatar
mohammad committed
437

Rewon Child's avatar
Rewon Child committed
438
        # count the zeros in the grads
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
439
440
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
Rewon Child's avatar
Rewon Child committed
441

mohammad's avatar
mohammad committed
442
443
444
        # Step the optimizer.
        self.optimizer.step()

445
446
447
448
        # Update params from main params.
        timers('optimizer-copy-main-to-model-params').start()
        self._copy_main_params_to_model_params()
        timers('optimizer-copy-main-to-model-params').stop()
mohammad's avatar
mohammad committed
449

mohammad's avatar
mohammad committed
450
        # Successful update.
451
        return True, grad_norm, num_zeros_in_grad
mohammad's avatar
mohammad committed
452
453


mohammad's avatar
mohammad committed
454
455
456
    def state_dict(self):
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
457
458
459
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
mohammad's avatar
mohammad committed
460
461
462
463
        return state_dict


    def load_state_dict(self, state_dict):
mohammad's avatar
mohammad committed
464
465
466
467
468
469
470
471
472
473
474
475
476
        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
            print_rank_0('***WARNING*** found an old checkpoint, will not '
                         'load grad scaler ...')
        else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
477
478
479
480
481
482
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')
mohammad's avatar
mohammad committed
483

484
        # Copy data for the main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
485
486
487
        fp32_from_float16_params_key = 'fp32_from_fp16_params'
        if fp32_from_float16_params_key not in state_dict:
            fp32_from_float16_params_key = 'fp32_from_fp16'
mohammad's avatar
mohammad committed
488
        for current_group, saved_group in zip(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
489
490
                self.fp32_from_float16_groups,
                state_dict[fp32_from_float16_params_key]):
mohammad's avatar
mohammad committed
491
492
493
494
495
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)



mohammad's avatar
mohammad committed
496
497
class FP32Optimizer(MegatronOptimizer):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
498
499
    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
500
                 params_have_main_grad,
501
                 use_contiguous_buffers_in_local_ddp):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
502
503
504

        super(FP32Optimizer, self).__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
505
            params_have_main_grad, use_contiguous_buffers_in_local_ddp)
mohammad's avatar
mohammad committed
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523

        self._scale = torch.cuda.FloatTensor([1.0])


    def zero_grad(self, set_to_none=True):
        """Copied from torch.optim.optimizer"""
        for group in self.optimizer.param_groups:
            _zero_grad_group_helper(group['params'], set_to_none)


    def get_loss_scale(self):
        """FP32 optimizer does not do any scaling."""
        return self._scale


    @torch.no_grad()
    def step(self):
        """Clip gradients (if needed) and step the base optimizer.
mohammad's avatar
mohammad committed
524
        Always return successful since there is no overflow."""
mohammad's avatar
mohammad committed
525

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
526
527
528
529
530
531
        # Copy main_grads to grads.
        if self.params_have_main_grad:
            for param_group in self.optimizer.param_groups:
                for param in param_group['params']:
                    param.grad = param.main_grad

532
533
534
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
535
                    if not self.use_contiguous_buffers_in_local_ddp:
536
537
                        param.main_grad = None

mohammad's avatar
mohammad committed
538
        # Clip gradients.
539
        grad_norm = None
mohammad's avatar
mohammad committed
540
        if self.clip_grad > 0.0:
541
            grad_norm = self.clip_grad_norm(self.clip_grad)
mohammad's avatar
mohammad committed
542

Rewon Child's avatar
Rewon Child committed
543
        # count the zeros in the grads
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
544
545
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
Rewon Child's avatar
Rewon Child committed
546

mohammad's avatar
mohammad committed
547
548
549
550
        # Update parameters.
        self.optimizer.step()

        # No overflow for FP32 optimizer.
551
        return True, grad_norm, num_zeros_in_grad
mohammad's avatar
mohammad committed
552
553


554
555
556
557
    def reload_model_params(self):
        pass


mohammad's avatar
mohammad committed
558
559
560
561
562
563
    def state_dict(self):
        return self.optimizer.state_dict()


    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)