optimizer.py 28.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron optimizer."""
mohammad's avatar
mohammad committed
17
18
19
20
21

from abc import ABC
from abc import abstractmethod
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
22
23
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
mohammad's avatar
mohammad committed
24

mohammad's avatar
mohammad committed
25
26
from megatron import get_timers
from megatron import mpu
mohammad's avatar
mohammad committed
27
from megatron import print_rank_0
28
29
30
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import unwrap_model
mohammad's avatar
mohammad committed
31

Rewon Child's avatar
Rewon Child committed
32
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
mohammad's avatar
mohammad committed
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47
# >>>
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate

from lutil import pax
        
get_clippy = lambda params : [ "%d, %d, %d ... %s" % (
    p.grad is not None,
    param_is_not_shared(p),
    param_is_not_tensor_parallel_duplicate(p),
    str(tuple(p.shape)),
) for p in params ]
# <<<

Lawrence McAfee's avatar
Lawrence McAfee committed
48

mohammad's avatar
mohammad committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def _zero_grad_group_helper(group, set_to_none):
    """Zero out the gradient for a group of parameters.
    Note: copied from torch.optim.optimizer."""
    for param in group:
        if param.grad is not None:
            if set_to_none:
                param.grad = None
            else:
                if param.grad.grad_fn is not None:
                    param.grad.detach_()
                else:
                    param.grad.requires_grad_(False)
                param.grad.zero_()


64
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
65
66
67
68
    """Use multi-tensor-applier to copy values from one list to another.
    We don't have a blfoat16 implementation so for now if the overflow_buf
    is not provided, we default back to simple loop copy to be compatible
    with bfloat16."""
69
70
    if overflow_buf:
        overflow_buf.fill_(0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
71
72
73
74
75
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             overflow_buf,
                             [this, that],
                             1.0)
76
    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
77
78
79
        for this_, that_ in zip(this, that):
            that_.copy_(this_)

80

mohammad's avatar
mohammad committed
81
82
83

class MegatronOptimizer(ABC):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
84
85
86

    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
87
                 params_have_main_grad,
88
89
                 use_contiguous_buffers_in_local_ddp,
                 models):
90

mohammad's avatar
mohammad committed
91
92
93
        """Input optimizer is the base optimizer for example Adam."""
        self.optimizer = optimizer
        assert self.optimizer, 'no optimizer is provided.'
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
94
95
96
97
        # Set gradient clipping and logging params.
        self.clip_grad = clip_grad
        self.log_num_zeros_in_grad = log_num_zeros_in_grad
        self.params_have_main_grad = params_have_main_grad
98
        self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
99

100
101
102
103
        # 'models' are retained for access to the contiguous grad buffers.
        # (see distributed optimizer)
        self.models = models

104
        if self.use_contiguous_buffers_in_local_ddp:
105
106
            assert self.params_have_main_grad, \
                "use of contiguous buffer requires that params have main grad"
mohammad's avatar
mohammad committed
107

Rewon Child's avatar
Rewon Child committed
108
    def get_parameters(self):
109
110
111
112
        params = []
        for param_group in self.optimizer.param_groups:
            for param in param_group['params']:
                params.append(param)
Rewon Child's avatar
Rewon Child committed
113
114
        return params

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
115

116
117
118
119
120
    def get_model_parallel_group(self):
        '''Default returned here, but the distributed optimizer overrides this.'''
        return mpu.get_model_parallel_group()


121
    def clip_grad_norm(self, clip_grad):
122
123
124
125
126
127
128
129
130
131
132

        # >>>
        # model_params = [ p for m in self.models for p in m.parameters() ]
        # optim_params = self.get_parameters()
        # from lutil import pax
        # pax(1, {
        #     "model_params" : get_clippy(model_params),
        #     "optim_params" : get_clippy(optim_params),
        # })
        # <<<

Lawrence McAfee's avatar
Lawrence McAfee committed
133
        params = self.get_parameters()
134
135
        return clip_grad_norm_fp32(
            params, clip_grad,
136
            model_parallel_group=self.get_model_parallel_group())
137

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
138

Rewon Child's avatar
Rewon Child committed
139
140
    def count_zeros(self):
        params = self.get_parameters()
141
142
        return count_zeros_fp32(params,
                                model_parallel_group=self.get_model_parallel_group())
Rewon Child's avatar
Rewon Child committed
143

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
144

mohammad's avatar
mohammad committed
145
146
147
148
    @abstractmethod
    def zero_grad(self, set_to_none=True):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
149

mohammad's avatar
mohammad committed
150
151
    @abstractmethod
    def get_loss_scale(self):
152
        """The output should be a cuda tensor of size 1."""
mohammad's avatar
mohammad committed
153
154
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
155

mohammad's avatar
mohammad committed
156
157
158
159
    def scale_loss(self, loss):
        """Simple scaling."""
        return self.get_loss_scale() * loss

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
160

161
162
    @abstractmethod
    def reload_model_params(self):
163
164
165
166
167
        """Refreshes any internal state from the current model parameters.
        Call whenever the parameters are changed outside of the optimizer.
        For example, when we load a model from a checkpoint  without loading
        the optimizer, the model parameters are updated but for fp16 optimizer
        with main parameters, the main parameters need to also be updated."""
168
169
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
170

mohammad's avatar
mohammad committed
171
172
173
174
    @abstractmethod
    def state_dict(self):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
175

mohammad's avatar
mohammad committed
176
177
178
179
    @abstractmethod
    def load_state_dict(self, state_dict):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
180

mohammad's avatar
mohammad committed
181
182
183
184
185
186
187
188
189
190
    # Promote state so it can be retrieved or set via
    # "optimizer_instance.state"
    def _get_state(self):
        return self.optimizer.state

    def _set_state(self, value):
        self.optimizer.state = value

    state = property(_get_state, _set_state)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
191

mohammad's avatar
mohammad committed
192
193
194
195
196
197
198
199
200
201
202
203
    # Promote param_groups so it can be retrieved or set via
    # "optimizer_instance.param_groups"
    # (for example, to adjust the learning rate)
    def _get_param_groups(self):
        return self.optimizer.param_groups

    def _set_param_groups(self, value):
        self.optimizer.param_groups = value

    param_groups = property(_get_param_groups, _set_param_groups)


204
    @abstractmethod
205
    def step(self, args, timers):
206
207
        pass

208
    def gather_model_params(self, args, timers):
209
210
        '''For the case of a non-distributed-optimizer, there is nothing to
        do here.'''
211
212
        pass

213
    def allreduce_word_embedding_grads(self, args):
214
215
        '''
        All-reduce word embedding grads.
216

217
218
219
220
        Reduce grads across first and last stages to ensure that word_embeddings
        parameters stay in sync. This should only run for models that support
        pipelined model parallelism (BERT and GPT-2).
        '''
221
222
223
224

        if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
                mpu.get_pipeline_model_parallel_world_size() > 1:
            if mpu.is_pipeline_first_stage(ignore_virtual=True):
225
                unwrapped_model = self.models[0]
226
            elif mpu.is_pipeline_last_stage(ignore_virtual=True):
227
                unwrapped_model = self.models[-1]
228
            else:  # We do not support the interleaved schedule for T5 yet.
229
                unwrapped_model = self.models[0]
230
231
232
233
234
235
236
237
238
239
240
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))

            if unwrapped_model.share_word_embeddings:
                word_embeddings_weight = unwrapped_model.word_embeddings_weight()
                if args.DDP_impl == 'local':
                    grad = word_embeddings_weight.main_grad
                else:
                    grad = word_embeddings_weight.grad
                torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())

241
    def allreduce_position_embedding_grads(self, args):
242
243
244
245
246
247
        '''
        All-reduce position_embeddings grad across first (encoder) and
        split (decoder) stages to ensure that position embeddings parameters
        stay in sync. This should only run for T5 models with pipeline
        parallelism.
        '''
248
249
250
        if mpu.is_rank_in_position_embedding_group() and \
                mpu.get_pipeline_model_parallel_world_size() > 1 and \
                args.pipeline_model_parallel_split_rank is not None:
251
            unwrapped_model = self.models[0]
252
253
254
255
256
257
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))
            assert args.DDP_impl == 'local', \
                'T5 model is only supported with local DDP mode'
            grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
            torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
258

259
260
261
    def allreduce_embedding_grads(self, args):
        self.allreduce_word_embedding_grads(args)
        self.allreduce_position_embedding_grads(args)
262

263
    def reduce_model_grads(self, args, timers):
264
265
266
267

        # All-reduce if needed.
        if args.DDP_impl == 'local':
            timers('backward-params-all-reduce').start()
268
269
            for model in self.models:
                model.allreduce_gradients()
270
271
272
273
            timers('backward-params-all-reduce').stop()

        # All-reduce embedding grads.
        timers('backward-embedding-all-reduce').start()
274
        self.allreduce_embedding_grads(args)
275
276
        timers('backward-embedding-all-reduce').stop()

277

278
class MixedPrecisionOptimizer(MegatronOptimizer):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
279
280

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
281
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
282
283
                 bf16, grad_scaler,
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
284

Lawrence McAfee's avatar
Lawrence McAfee committed
285
        super().__init__(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
286
            optimizer, clip_grad, log_num_zeros_in_grad,
287
288
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
            models)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
289
290

        self.bf16 = bf16
mohammad's avatar
mohammad committed
291
        self.grad_scaler = grad_scaler
292

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
293
294
295
        # None grad scaler is only supported for bf16.
        if self.grad_scaler is None:
            assert self.bf16, 'fp16 expects a grad scaler.'
mohammad's avatar
mohammad committed
296
297
298

        # Tensor used to determine if a nan/if has happend.
        # Any non-zero value indicates inf/nan.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
299
300
301
302
        # Note that we keep this for the cases that grad scaler is none.
        # We still record nan/inf if we have a bfloat16 with a grad scaler.
        if self.grad_scaler:
            self.found_inf = torch.cuda.FloatTensor([0.0])
mohammad's avatar
mohammad committed
303
304

        # Dummy tensor needed for apex multi-apply tensor.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
305
306
307
308
309
310
311
312
313
314
        # For bfloat, we don't have multi-tensor apply and for now
        # we set it to none so the multi-tensor apply gets ignored.
        if bf16:
            self._dummy_overflow_buf = None
        else:
            self._dummy_overflow_buf = torch.cuda.IntTensor([0])

        # In case grad scaler is not passed, define the unity scale.
        if self.grad_scaler is None:
            self._scale_one = torch.cuda.FloatTensor([1.0])
mohammad's avatar
mohammad committed
315

Lawrence McAfee's avatar
Lawrence McAfee committed
316
317
318
319
320
321
322

    def get_loss_scale(self):
        if self.grad_scaler is None:
            return self._scale_one
        return self.grad_scaler.scale


Lawrence McAfee's avatar
Lawrence McAfee committed
323
324
325
326
    def reload_model_params(self):
        self._copy_model_params_to_main_params()


327
    def _unscale_main_grads_and_check_for_nan(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
328
329
330
331
332
333
334
335
336
337
338
339
340

        # Collect main grads.
        main_grads = self._collect_main_grad_data_for_unscaling()

        # Reset found inf.
        self.found_inf.fill_(0.0)

        # Unscale and set found inf/nan
        torch._amp_foreach_non_finite_check_and_unscale_(
            main_grads, self.found_inf, self.grad_scaler.inv_scale)

        # Update across all model parallel instances.
        torch.distributed.all_reduce(self.found_inf,
341
342
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=self.get_model_parallel_group())
Lawrence McAfee's avatar
Lawrence McAfee committed
343
344
345
346
347
348

        # Check for nan.
        found_inf_flag = (self.found_inf.item() > 0)

        return found_inf_flag

Lawrence McAfee's avatar
Lawrence McAfee committed
349
    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    # @classmethod
    # def debug_base(cls, ITERATION, key, value):
    #     from megatron import get_args
    #     args = get_args()
    #     my_rank = torch.distributed.get_rank()
    #     DEBUG_ITERATION = ITERATION
    #     if ITERATION != DEBUG_ITERATION:
    #         return
    #     for r in range(torch.distributed.get_world_size()):
    #         if my_rank == r:
    #             # prefix = "            + "
    #             prefix = ""
    #             print("%sbr/%s; [r%d, i%d]; %s, %.12e" % (prefix, "fix " if args.use_distributed_optimizer else "main", my_rank, ITERATION, key, value))
    #         torch.distributed.barrier()
    #     torch.distributed.barrier()
    #     # if my_rank == 0:
    #     #     raise Exception("debug.")
    #     # else:
    #     #     exit(0)
    #     exit(0)
    # def debug_model(self, ITERATION, key, use_grad):
    #     use_grad = bool(use_grad)
    #     tensors = [
    #         (p.main_grad.float() if use_grad else p.float())
    #         for m in self.models for p in m.parameters()
    #     ]
    #     count = sum(t.nelement() for t in tensors)
    #     return self.debug_base(
    #         ITERATION,
    #         "model/%s, %s [count %d]" % (
    #             "grad" if use_grad else "param",
    #             key,
    #             count,
    #         ),
    #         # sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
    #         sum(torch.sum(torch.abs(t)) for t in tensors),
    #     )
    # def debug_main(self, ITERATION, key, use_grad):
    #     use_grad = bool(use_grad)
    #     tensors = [
    #         p.grad if use_grad else p
    #         for g in self.optimizer.param_groups
    #         for p in g["params"]
    #     ]
    #     tensors = [ t.float() for t in tensors ]
    #     count = sum(t.nelement() for t in tensors)
    #     return self.debug_base(
    #         ITERATION,
    #         "main/%s, %s [count %d]" % (
    #             "grad" if use_grad else "param",
    #             key,
    #             count,
    #         ),
    #         sum(torch.sum(torch.abs(t)) for t in tensors),
    #     )
Lawrence McAfee's avatar
Lawrence McAfee committed
405
    # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
Lawrence McAfee's avatar
Lawrence McAfee committed
406
407

    @torch.no_grad()
408
    def step(self, args, timers):
409

Lawrence McAfee's avatar
Lawrence McAfee committed
410
411
        # Copy gradients from model params to main params.
        timers('optimizer-copy-to-main-grad').start()
412
        self._copy_model_grads_to_main_grads()
Lawrence McAfee's avatar
Lawrence McAfee committed
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        timers('optimizer-copy-to-main-grad').stop()

        # Do unscale, check for inf, and update grad scaler only for
        # the case that grad scaler is provided.
        if self.grad_scaler:

            # Unscale and check for inf/nan.
            timers('optimizer-unscale-and-check-inf').start()
            found_inf_flag = self._unscale_main_grads_and_check_for_nan()
            timers('optimizer-unscale-and-check-inf').stop()

            # We are done with scaling gradients
            # so we can update the loss scale.
            self.grad_scaler.update(found_inf_flag)

            # If we found inf/nan, skip the update.
            if found_inf_flag:
                return False, None, None

        # Clip the main gradients.
        timers('optimizer-clip-main-grad').start()
        grad_norm = None
        if self.clip_grad > 0.0:
436
            grad_norm = self.clip_grad_norm(self.clip_grad)
Lawrence McAfee's avatar
Lawrence McAfee committed
437
438
439
440
441
442
        timers('optimizer-clip-main-grad').stop()

        # count the zeros in the grads
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None

443
444
445
        # Step the optimizer.
        self.optimizer.step()

Lawrence McAfee's avatar
Lawrence McAfee committed
446
447
        # Update params from main params.
        timers('optimizer-copy-main-to-model-params').start()
448
        self._copy_main_params_to_model_params()
Lawrence McAfee's avatar
Lawrence McAfee committed
449
450
451
452
453
454
        timers('optimizer-copy-main-to-model-params').stop()

        # Successful update.
        return True, grad_norm, num_zeros_in_grad


455
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
Lawrence McAfee's avatar
Lawrence McAfee committed
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    """Float16 optimizer for fp16 and bf16 data types.

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
    """

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
Lawrence McAfee's avatar
Lawrence McAfee committed
482
                 bf16, grad_scaler, models):
Lawrence McAfee's avatar
Lawrence McAfee committed
483
484
485
486

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
Lawrence McAfee's avatar
Lawrence McAfee committed
487
            bf16, grad_scaler, models)
Lawrence McAfee's avatar
Lawrence McAfee committed
488

mohammad's avatar
mohammad committed
489
        # ======================
490
        # main parameter stuff
mohammad's avatar
mohammad committed
491
492
493
        # ======================

        # Three groups of parameters:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
494
495
        #   float16_groups: original float16 parameters
        #   fp32_from_float16_groups: fp32 copy of float16 parameters
mohammad's avatar
mohammad committed
496
        #   fp32_from_fp32_groups: original fp32 parameters
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
497
498
        self.float16_groups = []
        self.fp32_from_float16_groups = []
mohammad's avatar
mohammad committed
499
500
501
502
        self.fp32_from_fp32_groups = []

        # For all the groups in the original optimizer:
        for param_group in self.optimizer.param_groups:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
503
            float16_params_this_group = []
mohammad's avatar
mohammad committed
504
            fp32_params_this_group = []
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
505
            fp32_from_float16_params_this_group = []
mohammad's avatar
mohammad committed
506
507
508
509
            # For all the parameters in this group:
            for i, param in enumerate(param_group['params']):
                if param.requires_grad:

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
510
511
512
513
                    # float16 params:
                    if param.type() in ['torch.cuda.HalfTensor',
                                        'torch.cuda.BFloat16Tensor']:
                        float16_params_this_group.append(param)
mohammad's avatar
mohammad committed
514
                        # Create a copy
515
                        main_param = param.detach().clone().float()
mohammad's avatar
mohammad committed
516
                        # Copy tensor model parallel attributes.
517
                        mpu.copy_tensor_model_parallel_attributes(main_param,
mohammad's avatar
mohammad committed
518
                                                                  param)
519
                        if hasattr(param, 'shared'):
520
                            main_param.shared = param.shared
mohammad's avatar
mohammad committed
521
                        # Replace the optimizer params with the new fp32 copy.
522
                        param_group['params'][i] = main_param
523

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
524
                        fp32_from_float16_params_this_group.append(main_param)
525
                        # Reset existing state dict key to the new main param.
mohammad's avatar
mohammad committed
526
                        if param in self.optimizer.state:
527
                            self.optimizer.state[main_param] \
mohammad's avatar
mohammad committed
528
529
530
531
532
533
534
535
                                = self.optimizer.state.pop(param)

                    # fp32 params.
                    elif param.type() == 'torch.cuda.FloatTensor':
                        fp32_params_this_group.append(param)
                        param_group['params'][i] = param

                    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
536
537
538
539
540
541
542
543
544
                        raise TypeError('Wrapped parameters must be one of '
                                        'torch.cuda.FloatTensor,  '
                                        'torch.cuda.HalfTensor, or '
                                        'torch.cuda.BFloat16Tensor. '
                                        'Received {}'.format(param.type()))

            self.float16_groups.append(float16_params_this_group)
            self.fp32_from_float16_groups.append(
                fp32_from_float16_params_this_group)
mohammad's avatar
mohammad committed
545
546
547
548
549
550
            self.fp32_from_fp32_groups.append(fp32_params_this_group)

        # Leverage state_dict() and load_state_dict() to
        # recast preexisting per-param state tensors
        self.optimizer.load_state_dict(self.optimizer.state_dict())

551
552
553
554
555
556
557
558
559
560
561
        # >>>
        # model_params = [ p for m in self.models for p in m.parameters() ]
        # optim_params = self.get_parameters()
        # model_params.sort(key = lambda p : p.nelement(), reverse = True)
        # optim_params.sort(key = lambda p : p.nelement(), reverse = True)
        # # assert len(model_params) == len(optim_params
        # pax(7, {
        #     "model_params" : get_clippy(model_params),
        #     "optim_params" : get_clippy(optim_params),
        # })
        # <<<
mohammad's avatar
mohammad committed
562
563
564

    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
565
566
567
568
        float16_groups & fp32_from_fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
569
        for group in self.float16_groups:
mohammad's avatar
mohammad committed
570
            _zero_grad_group_helper(group, set_to_none)
571
572
        for group in self.fp32_from_float16_groups:
            _zero_grad_group_helper(group, set_to_none)
mohammad's avatar
mohammad committed
573
574
575
576
        for group in self.fp32_from_fp32_groups:
            _zero_grad_group_helper(group, set_to_none)


577
    def _collect_main_grad_data_for_unscaling(self):
578

579
        main_grads = []
580

581
582
583
584
585
        # fp32 params from float16 ones.
        for main_group in self.fp32_from_float16_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
586

587
588
589
590
591
592
593
        # Append fp32 parameters.
        for main_group in self.fp32_from_fp32_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
        
        return main_grads
594
595


596
597
598
599
600
601
602
603
604
    def _get_model_and_main_params_data_float16(self):
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
605

Lawrence McAfee's avatar
Lawrence McAfee committed
606

607
    def _copy_model_grads_to_main_grads(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
608
609
610
        # This only needs to be done for the float16 group.
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
611
            for model_param, main_param in zip(model_group, main_group):
612
                if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
613
614
615
616
                    main_param.grad = model_param.main_grad.float()
                else:
                    if model_param.grad is not None:
                        main_param.grad = model_param.grad.float()
617
618
619
620
621

                # Safe to deallocate model's grad/main_grad after copying.
                # (If using contiguous buffers, main_grad's memory should
                # persist and therefore should not be deallocated.)
                model_param.grad = None
622
                if self.params_have_main_grad and \
623
                   not self.use_contiguous_buffers_in_local_ddp:
624
625
                    model_param.main_grad = None

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
626
627
628
629
630
        # For fp32 grads, we need to reset the grads to main grad.
        if self.params_have_main_grad:
            for model_group in self.fp32_from_fp32_groups:
                for model_param in model_group:
                    model_param.grad = model_param.main_grad
mohammad's avatar
mohammad committed
631

632
633
634
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
635
                    if not self.use_contiguous_buffers_in_local_ddp:
636
                        model_param.main_grad = None
mohammad's avatar
mohammad committed
637

638

639
    def _copy_main_params_to_model_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
640
641
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
642
643
644
645
646
        _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
                                        overflow_buf=self._dummy_overflow_buf)


    def _copy_model_params_to_main_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
647
648
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
649
650
        _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
                                        overflow_buf=self._dummy_overflow_buf)
651
652


mohammad's avatar
mohammad committed
653
654
655
    def state_dict(self):
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
656
657
658
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
mohammad's avatar
mohammad committed
659
660
661
662
        return state_dict


    def load_state_dict(self, state_dict):
mohammad's avatar
mohammad committed
663
664
665
666
667
668
669
670
671
672
673
674
675
        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
            print_rank_0('***WARNING*** found an old checkpoint, will not '
                         'load grad scaler ...')
        else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
676
677
678
679
680
681
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')
mohammad's avatar
mohammad committed
682

683
        # Copy data for the main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
684
685
686
        fp32_from_float16_params_key = 'fp32_from_fp16_params'
        if fp32_from_float16_params_key not in state_dict:
            fp32_from_float16_params_key = 'fp32_from_fp16'
mohammad's avatar
mohammad committed
687
        for current_group, saved_group in zip(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
688
689
                self.fp32_from_float16_groups,
                state_dict[fp32_from_float16_params_key]):
mohammad's avatar
mohammad committed
690
691
692
693
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)


mohammad's avatar
mohammad committed
694
695
class FP32Optimizer(MegatronOptimizer):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
696
697
    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
698
                 params_have_main_grad,
699
700
                 use_contiguous_buffers_in_local_ddp,
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
701
702
703

        super(FP32Optimizer, self).__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
704
705
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
            models)
mohammad's avatar
mohammad committed
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721

        self._scale = torch.cuda.FloatTensor([1.0])


    def zero_grad(self, set_to_none=True):
        """Copied from torch.optim.optimizer"""
        for group in self.optimizer.param_groups:
            _zero_grad_group_helper(group['params'], set_to_none)


    def get_loss_scale(self):
        """FP32 optimizer does not do any scaling."""
        return self._scale


    @torch.no_grad()
722
    def step(self, args, timers):
mohammad's avatar
mohammad committed
723
        """Clip gradients (if needed) and step the base optimizer.
mohammad's avatar
mohammad committed
724
        Always return successful since there is no overflow."""
mohammad's avatar
mohammad committed
725

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
726
727
728
729
730
731
        # Copy main_grads to grads.
        if self.params_have_main_grad:
            for param_group in self.optimizer.param_groups:
                for param in param_group['params']:
                    param.grad = param.main_grad

732
733
734
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
735
                    if not self.use_contiguous_buffers_in_local_ddp:
736
737
                        param.main_grad = None

mohammad's avatar
mohammad committed
738
        # Clip gradients.
739
        grad_norm = None
mohammad's avatar
mohammad committed
740
        if self.clip_grad > 0.0:
741
            grad_norm = self.clip_grad_norm(self.clip_grad)
mohammad's avatar
mohammad committed
742

Rewon Child's avatar
Rewon Child committed
743
        # count the zeros in the grads
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
744
745
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
Rewon Child's avatar
Rewon Child committed
746

mohammad's avatar
mohammad committed
747
748
749
750
        # Update parameters.
        self.optimizer.step()

        # No overflow for FP32 optimizer.
751
        return True, grad_norm, num_zeros_in_grad
mohammad's avatar
mohammad committed
752
753


754
755
756
757
    def reload_model_params(self):
        pass


mohammad's avatar
mohammad committed
758
759
760
761
762
763
    def state_dict(self):
        return self.optimizer.state_dict()


    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)