optimizer.py 29.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron optimizer."""
mohammad's avatar
mohammad committed
17
18
19
20
21

from abc import ABC
from abc import abstractmethod
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
22
23
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
mohammad's avatar
mohammad committed
24

mohammad's avatar
mohammad committed
25
26
from megatron import get_timers
from megatron import mpu
mohammad's avatar
mohammad committed
27
from megatron import print_rank_0
28
29
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
30
31
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
32
from megatron.utils import unwrap_model
33

34
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
35

Lawrence McAfee's avatar
Lawrence McAfee committed
36
37
38
39
# >>>
from lutil import pax, tp, print_seq
# <<<

Lawrence McAfee's avatar
Lawrence McAfee committed
40

mohammad's avatar
mohammad committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def _zero_grad_group_helper(group, set_to_none):
    """Zero out the gradient for a group of parameters.
    Note: copied from torch.optim.optimizer."""
    for param in group:
        if param.grad is not None:
            if set_to_none:
                param.grad = None
            else:
                if param.grad.grad_fn is not None:
                    param.grad.detach_()
                else:
                    param.grad.requires_grad_(False)
                param.grad.zero_()


56
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
57
58
59
60
    """Use multi-tensor-applier to copy values from one list to another.
    We don't have a blfoat16 implementation so for now if the overflow_buf
    is not provided, we default back to simple loop copy to be compatible
    with bfloat16."""
61
62
    if overflow_buf:
        overflow_buf.fill_(0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
63
64
65
66
67
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             overflow_buf,
                             [this, that],
                             1.0)
68
    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
69
70
71
        for this_, that_ in zip(this, that):
            that_.copy_(this_)

72

mohammad's avatar
mohammad committed
73
74
75

class MegatronOptimizer(ABC):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
76
77
78

    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
79
                 params_have_main_grad,
80
81
                 use_contiguous_buffers_in_local_ddp,
                 models):
82

mohammad's avatar
mohammad committed
83
84
85
        """Input optimizer is the base optimizer for example Adam."""
        self.optimizer = optimizer
        assert self.optimizer, 'no optimizer is provided.'
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
86
87
88
89
        # Set gradient clipping and logging params.
        self.clip_grad = clip_grad
        self.log_num_zeros_in_grad = log_num_zeros_in_grad
        self.params_have_main_grad = params_have_main_grad
90
        self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
91

92
93
94
95
        # 'models' are retained for access to the contiguous grad buffers.
        # (see distributed optimizer)
        self.models = models

96
        if self.use_contiguous_buffers_in_local_ddp:
97
98
            assert self.params_have_main_grad, \
                "use of contiguous buffer requires that params have main grad"
mohammad's avatar
mohammad committed
99

100

Rewon Child's avatar
Rewon Child committed
101
    def get_parameters(self):
102
103
104
105
        params = []
        for param_group in self.optimizer.param_groups:
            for param in param_group['params']:
                params.append(param)
Rewon Child's avatar
Rewon Child committed
106
107
        return params

108

109
    def get_main_grads_for_grad_norm(self):
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

        # Filter parameters based on:
        #   - grad should not be none
        #   - parameter should not be shared
        #   - should not be a replica due to tensor model parallelism
        params = self.get_parameters()
        grads_for_norm = []
        for param in params:
            grad = param.grad
            grad_not_none = grad is not None
            is_not_shared = param_is_not_shared(param)
            is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
            if grad_not_none and is_not_shared and is_not_tp_duplicate:
                grads_for_norm.append(grad)

125
126
127
128
129
130
131
132
        # >>>
        # from lutil import pax
        # pax(0, {
        #     "params" : params,
        #     "grads_for_norm" : grads_for_norm,
        # })
        # <<<

133
134
        return grads_for_norm

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
135

136
137
138
139
140
    def get_model_parallel_group(self):
        '''Default returned here, but the distributed optimizer overrides this.'''
        return mpu.get_model_parallel_group()


141
    def clip_grad_norm(self, clip_grad):
Lawrence McAfee's avatar
Lawrence McAfee committed
142
        params = self.get_parameters()
143
        grads_for_norm = self.get_main_grads_for_grad_norm()
144
145
146
147
148
149
150
151
152
153
        # >>>
        from lutil import print_seq
        # print_seq("params %d, ngrads %d." % (len(params), len(grads_for_norm)))
        # print_seq([
        #     "grads_for_norm / %d = %s." % (i, str(tuple(g.shape)))
        #     for i, g in enumerate(grads_for_norm)
        # ])
        print_seq("grads_for_norm = %s." % ", ".join(
            str(tuple(g.shape)) for g in grads_for_norm))
        # <<<
154
        return clip_grad_norm_fp32(
155
            params, grads_for_norm, clip_grad,
156
            model_parallel_group=self.get_model_parallel_group())
157

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
158

Rewon Child's avatar
Rewon Child committed
159
160
    def count_zeros(self):
        params = self.get_parameters()
161
162
        return count_zeros_fp32(params,
                                model_parallel_group=self.get_model_parallel_group())
Rewon Child's avatar
Rewon Child committed
163

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
164

mohammad's avatar
mohammad committed
165
166
167
168
    @abstractmethod
    def zero_grad(self, set_to_none=True):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
169

mohammad's avatar
mohammad committed
170
171
    @abstractmethod
    def get_loss_scale(self):
172
        """The output should be a cuda tensor of size 1."""
mohammad's avatar
mohammad committed
173
174
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
175

mohammad's avatar
mohammad committed
176
177
178
179
    def scale_loss(self, loss):
        """Simple scaling."""
        return self.get_loss_scale() * loss

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
180

181
182
    @abstractmethod
    def reload_model_params(self):
183
184
185
186
187
        """Refreshes any internal state from the current model parameters.
        Call whenever the parameters are changed outside of the optimizer.
        For example, when we load a model from a checkpoint  without loading
        the optimizer, the model parameters are updated but for fp16 optimizer
        with main parameters, the main parameters need to also be updated."""
188
189
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
190

mohammad's avatar
mohammad committed
191
192
193
194
    @abstractmethod
    def state_dict(self):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
195

mohammad's avatar
mohammad committed
196
197
198
199
    @abstractmethod
    def load_state_dict(self, state_dict):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
200

mohammad's avatar
mohammad committed
201
202
203
204
205
206
207
208
209
210
    # Promote state so it can be retrieved or set via
    # "optimizer_instance.state"
    def _get_state(self):
        return self.optimizer.state

    def _set_state(self, value):
        self.optimizer.state = value

    state = property(_get_state, _set_state)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
211

mohammad's avatar
mohammad committed
212
213
214
215
216
217
218
219
220
221
222
223
    # Promote param_groups so it can be retrieved or set via
    # "optimizer_instance.param_groups"
    # (for example, to adjust the learning rate)
    def _get_param_groups(self):
        return self.optimizer.param_groups

    def _set_param_groups(self, value):
        self.optimizer.param_groups = value

    param_groups = property(_get_param_groups, _set_param_groups)


224
    @abstractmethod
225
    def step(self, args, timers):
226
227
        pass

228
    def gather_model_params(self, args, timers):
229
230
        '''For the case of a non-distributed-optimizer, there is nothing to
        do here.'''
231
232
        pass

233
    def allreduce_word_embedding_grads(self, args):
234
235
        '''
        All-reduce word embedding grads.
236

237
238
239
240
        Reduce grads across first and last stages to ensure that word_embeddings
        parameters stay in sync. This should only run for models that support
        pipelined model parallelism (BERT and GPT-2).
        '''
241
242
243
244

        if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
                mpu.get_pipeline_model_parallel_world_size() > 1:
            if mpu.is_pipeline_first_stage(ignore_virtual=True):
245
                unwrapped_model = self.models[0]
246
            elif mpu.is_pipeline_last_stage(ignore_virtual=True):
247
                unwrapped_model = self.models[-1]
248
            else:  # We do not support the interleaved schedule for T5 yet.
249
                unwrapped_model = self.models[0]
250
251
252
253
254
255
256
257
258
259
260
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))

            if unwrapped_model.share_word_embeddings:
                word_embeddings_weight = unwrapped_model.word_embeddings_weight()
                if args.DDP_impl == 'local':
                    grad = word_embeddings_weight.main_grad
                else:
                    grad = word_embeddings_weight.grad
                torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())

261
    def allreduce_position_embedding_grads(self, args):
262
263
264
265
266
267
        '''
        All-reduce position_embeddings grad across first (encoder) and
        split (decoder) stages to ensure that position embeddings parameters
        stay in sync. This should only run for T5 models with pipeline
        parallelism.
        '''
268
269
270
        if mpu.is_rank_in_position_embedding_group() and \
                mpu.get_pipeline_model_parallel_world_size() > 1 and \
                args.pipeline_model_parallel_split_rank is not None:
271
            unwrapped_model = self.models[0]
272
273
274
275
276
277
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))
            assert args.DDP_impl == 'local', \
                'T5 model is only supported with local DDP mode'
            grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
            torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
278

279
280
281
    def allreduce_embedding_grads(self, args):
        self.allreduce_word_embedding_grads(args)
        self.allreduce_position_embedding_grads(args)
282

283
    def reduce_model_grads(self, args, timers):
284
285
286
287

        # All-reduce if needed.
        if args.DDP_impl == 'local':
            timers('backward-params-all-reduce').start()
288
289
            for model in self.models:
                model.allreduce_gradients()
290
291
292
293
            timers('backward-params-all-reduce').stop()

        # All-reduce embedding grads.
        timers('backward-embedding-all-reduce').start()
294
        self.allreduce_embedding_grads(args)
295
296
        timers('backward-embedding-all-reduce').stop()

297

298
class MixedPrecisionOptimizer(MegatronOptimizer):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
299
300

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
301
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
302
                 fp16, bf16, grad_scaler,
303
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
304

Lawrence McAfee's avatar
Lawrence McAfee committed
305
        super().__init__(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
306
            optimizer, clip_grad, log_num_zeros_in_grad,
307
308
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
            models)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
309

310
        self.fp16 = fp16
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
311
        self.bf16 = bf16
mohammad's avatar
mohammad committed
312
        self.grad_scaler = grad_scaler
313

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
314
315
        # None grad scaler is only supported for bf16.
        if self.grad_scaler is None:
316
            assert not self.fp16, 'fp16 expects a grad scaler.'
mohammad's avatar
mohammad committed
317
318
319

        # Tensor used to determine if a nan/if has happend.
        # Any non-zero value indicates inf/nan.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
320
321
322
323
        # Note that we keep this for the cases that grad scaler is none.
        # We still record nan/inf if we have a bfloat16 with a grad scaler.
        if self.grad_scaler:
            self.found_inf = torch.cuda.FloatTensor([0.0])
mohammad's avatar
mohammad committed
324
325

        # Dummy tensor needed for apex multi-apply tensor.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
326
327
328
329
330
331
332
333
334
335
        # For bfloat, we don't have multi-tensor apply and for now
        # we set it to none so the multi-tensor apply gets ignored.
        if bf16:
            self._dummy_overflow_buf = None
        else:
            self._dummy_overflow_buf = torch.cuda.IntTensor([0])

        # In case grad scaler is not passed, define the unity scale.
        if self.grad_scaler is None:
            self._scale_one = torch.cuda.FloatTensor([1.0])
mohammad's avatar
mohammad committed
336

Lawrence McAfee's avatar
Lawrence McAfee committed
337
338
339
340
341
342
343

    def get_loss_scale(self):
        if self.grad_scaler is None:
            return self._scale_one
        return self.grad_scaler.scale


Lawrence McAfee's avatar
Lawrence McAfee committed
344
345
346
347
    def reload_model_params(self):
        self._copy_model_params_to_main_params()


348
    def _unscale_main_grads_and_check_for_nan(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
349
350
351
352
353
354
355
356
357
358
359
360
361

        # Collect main grads.
        main_grads = self._collect_main_grad_data_for_unscaling()

        # Reset found inf.
        self.found_inf.fill_(0.0)

        # Unscale and set found inf/nan
        torch._amp_foreach_non_finite_check_and_unscale_(
            main_grads, self.found_inf, self.grad_scaler.inv_scale)

        # Update across all model parallel instances.
        torch.distributed.all_reduce(self.found_inf,
362
363
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=self.get_model_parallel_group())
Lawrence McAfee's avatar
Lawrence McAfee committed
364
365
366
367
368
369

        # Check for nan.
        found_inf_flag = (self.found_inf.item() > 0)

        return found_inf_flag

Lawrence McAfee's avatar
Lawrence McAfee committed
370
    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
    # @classmethod
    # def debug_base(cls, ITERATION, key, value):
    #     from megatron import get_args
    #     args = get_args()
    #     my_rank = torch.distributed.get_rank()
    #     DEBUG_ITERATION = ITERATION
    #     if ITERATION != DEBUG_ITERATION:
    #         return
    #     for r in range(torch.distributed.get_world_size()):
    #         if my_rank == r:
    #             # prefix = "            + "
    #             prefix = ""
    #             print("%sbr/%s; [r%d, i%d]; %s, %.12e" % (prefix, "fix " if args.use_distributed_optimizer else "main", my_rank, ITERATION, key, value))
    #         torch.distributed.barrier()
    #     torch.distributed.barrier()
    #     # if my_rank == 0:
    #     #     raise Exception("debug.")
    #     # else:
    #     #     exit(0)
    #     exit(0)
    # def debug_model(self, ITERATION, key, use_grad):
    #     use_grad = bool(use_grad)
    #     tensors = [
    #         (p.main_grad.float() if use_grad else p.float())
    #         for m in self.models for p in m.parameters()
    #     ]
    #     count = sum(t.nelement() for t in tensors)
    #     return self.debug_base(
    #         ITERATION,
    #         "model/%s, %s [count %d]" % (
    #             "grad" if use_grad else "param",
    #             key,
    #             count,
    #         ),
    #         # sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
    #         sum(torch.sum(torch.abs(t)) for t in tensors),
    #     )
    # def debug_main(self, ITERATION, key, use_grad):
    #     use_grad = bool(use_grad)
    #     tensors = [
    #         p.grad if use_grad else p
    #         for g in self.optimizer.param_groups
    #         for p in g["params"]
    #     ]
    #     tensors = [ t.float() for t in tensors ]
    #     count = sum(t.nelement() for t in tensors)
    #     return self.debug_base(
    #         ITERATION,
    #         "main/%s, %s [count %d]" % (
    #             "grad" if use_grad else "param",
    #             key,
    #             count,
    #         ),
    #         sum(torch.sum(torch.abs(t)) for t in tensors),
    #     )
Lawrence McAfee's avatar
Lawrence McAfee committed
426
    # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
Lawrence McAfee's avatar
Lawrence McAfee committed
427
428

    @torch.no_grad()
429
    def step(self, args, timers):
430

Lawrence McAfee's avatar
Lawrence McAfee committed
431
432
        # Copy gradients from model params to main params.
        timers('optimizer-copy-to-main-grad').start()
433
        self._copy_model_grads_to_main_grads()
Lawrence McAfee's avatar
Lawrence McAfee committed
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        timers('optimizer-copy-to-main-grad').stop()

        # Do unscale, check for inf, and update grad scaler only for
        # the case that grad scaler is provided.
        if self.grad_scaler:

            # Unscale and check for inf/nan.
            timers('optimizer-unscale-and-check-inf').start()
            found_inf_flag = self._unscale_main_grads_and_check_for_nan()
            timers('optimizer-unscale-and-check-inf').stop()

            # We are done with scaling gradients
            # so we can update the loss scale.
            self.grad_scaler.update(found_inf_flag)

            # If we found inf/nan, skip the update.
            if found_inf_flag:
                return False, None, None

        # Clip the main gradients.
        timers('optimizer-clip-main-grad').start()
        grad_norm = None
        if self.clip_grad > 0.0:
457
            grad_norm = self.clip_grad_norm(self.clip_grad)
Lawrence McAfee's avatar
Lawrence McAfee committed
458
459
460
        timers('optimizer-clip-main-grad').stop()

        # count the zeros in the grads
461
        timers('optimizer-count-zeros').start()
Lawrence McAfee's avatar
Lawrence McAfee committed
462
463
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
464
        timers('optimizer-count-zeros').stop()
Lawrence McAfee's avatar
Lawrence McAfee committed
465

466
        # Step the optimizer.
467
        timers('optimizer-inner-step').start()
468
        self.optimizer.step()
469
        timers('optimizer-inner-step').stop()
470

Lawrence McAfee's avatar
Lawrence McAfee committed
471
472
        # Update params from main params.
        timers('optimizer-copy-main-to-model-params').start()
473
        self._copy_main_params_to_model_params()
Lawrence McAfee's avatar
Lawrence McAfee committed
474
475
476
477
478
479
        timers('optimizer-copy-main-to-model-params').stop()

        # Successful update.
        return True, grad_norm, num_zeros_in_grad


480
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
Lawrence McAfee's avatar
Lawrence McAfee committed
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
    """Float16 optimizer for fp16 and bf16 data types.

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
    """

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
507
                 fp16, bf16, grad_scaler, models):
Lawrence McAfee's avatar
Lawrence McAfee committed
508
509
510
511

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
512
            fp16, bf16, grad_scaler, models)
Lawrence McAfee's avatar
Lawrence McAfee committed
513

mohammad's avatar
mohammad committed
514
        # ======================
515
        # main parameter stuff
mohammad's avatar
mohammad committed
516
517
518
        # ======================

        # Three groups of parameters:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
519
520
        #   float16_groups: original float16 parameters
        #   fp32_from_float16_groups: fp32 copy of float16 parameters
mohammad's avatar
mohammad committed
521
        #   fp32_from_fp32_groups: original fp32 parameters
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
522
523
        self.float16_groups = []
        self.fp32_from_float16_groups = []
mohammad's avatar
mohammad committed
524
525
526
527
        self.fp32_from_fp32_groups = []

        # For all the groups in the original optimizer:
        for param_group in self.optimizer.param_groups:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
528
            float16_params_this_group = []
mohammad's avatar
mohammad committed
529
            fp32_params_this_group = []
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
530
            fp32_from_float16_params_this_group = []
mohammad's avatar
mohammad committed
531
532
533
534
            # For all the parameters in this group:
            for i, param in enumerate(param_group['params']):
                if param.requires_grad:

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
535
536
537
538
                    # float16 params:
                    if param.type() in ['torch.cuda.HalfTensor',
                                        'torch.cuda.BFloat16Tensor']:
                        float16_params_this_group.append(param)
mohammad's avatar
mohammad committed
539
                        # Create a copy
540
                        main_param = param.detach().clone().float()
mohammad's avatar
mohammad committed
541
                        # Copy tensor model parallel attributes.
542
                        mpu.copy_tensor_model_parallel_attributes(main_param,
mohammad's avatar
mohammad committed
543
                                                                  param)
544
                        if hasattr(param, 'shared'):
545
                            main_param.shared = param.shared
mohammad's avatar
mohammad committed
546
                        # Replace the optimizer params with the new fp32 copy.
547
                        param_group['params'][i] = main_param
548

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
549
                        fp32_from_float16_params_this_group.append(main_param)
550
                        # Reset existing state dict key to the new main param.
mohammad's avatar
mohammad committed
551
                        if param in self.optimizer.state:
552
                            self.optimizer.state[main_param] \
mohammad's avatar
mohammad committed
553
554
555
556
557
558
559
560
                                = self.optimizer.state.pop(param)

                    # fp32 params.
                    elif param.type() == 'torch.cuda.FloatTensor':
                        fp32_params_this_group.append(param)
                        param_group['params'][i] = param

                    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
561
562
563
564
565
566
567
568
569
                        raise TypeError('Wrapped parameters must be one of '
                                        'torch.cuda.FloatTensor,  '
                                        'torch.cuda.HalfTensor, or '
                                        'torch.cuda.BFloat16Tensor. '
                                        'Received {}'.format(param.type()))

            self.float16_groups.append(float16_params_this_group)
            self.fp32_from_float16_groups.append(
                fp32_from_float16_params_this_group)
mohammad's avatar
mohammad committed
570
571
572
573
574
575
576
            self.fp32_from_fp32_groups.append(fp32_params_this_group)

        # Leverage state_dict() and load_state_dict() to
        # recast preexisting per-param state tensors
        self.optimizer.load_state_dict(self.optimizer.state_dict())


Lawrence McAfee's avatar
Lawrence McAfee committed
577
578
579
580
581
582
583
584
585
586
587
588
    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
        float16_groups & fp32_from_fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
        for group in self.float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_fp32_groups:
            _zero_grad_group_helper(group, set_to_none)
mohammad's avatar
mohammad committed
589
590


591
    def _collect_main_grad_data_for_unscaling(self):
592

593
        main_grads = []
594

595
596
597
598
599
        # fp32 params from float16 ones.
        for main_group in self.fp32_from_float16_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
600

601
602
603
604
605
606
607
        # Append fp32 parameters.
        for main_group in self.fp32_from_fp32_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
        
        return main_grads
608
609


610
611
612
613
614
615
616
617
618
    def _get_model_and_main_params_data_float16(self):
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
619

Lawrence McAfee's avatar
Lawrence McAfee committed
620

621
    def _copy_model_grads_to_main_grads(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
622
623
624
        # This only needs to be done for the float16 group.
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
625
            for model_param, main_param in zip(model_group, main_group):
626
                if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
627
628
629
630
                    main_param.grad = model_param.main_grad.float()
                else:
                    if model_param.grad is not None:
                        main_param.grad = model_param.grad.float()
631
632
633
634
635

                # Safe to deallocate model's grad/main_grad after copying.
                # (If using contiguous buffers, main_grad's memory should
                # persist and therefore should not be deallocated.)
                model_param.grad = None
636
                if self.params_have_main_grad and \
637
                   not self.use_contiguous_buffers_in_local_ddp:
638
639
                    model_param.main_grad = None

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
640
641
642
643
644
        # For fp32 grads, we need to reset the grads to main grad.
        if self.params_have_main_grad:
            for model_group in self.fp32_from_fp32_groups:
                for model_param in model_group:
                    model_param.grad = model_param.main_grad
mohammad's avatar
mohammad committed
645

646
647
648
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
649
                    if not self.use_contiguous_buffers_in_local_ddp:
650
                        model_param.main_grad = None
mohammad's avatar
mohammad committed
651

652

653
    def _copy_main_params_to_model_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
654
655
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
656
657
658
659
660
        _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
                                        overflow_buf=self._dummy_overflow_buf)


    def _copy_model_params_to_main_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
661
662
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
663
664
        _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
                                        overflow_buf=self._dummy_overflow_buf)
665
666


mohammad's avatar
mohammad committed
667
668
669
    def state_dict(self):
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
670
671
672
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
mohammad's avatar
mohammad committed
673
674
675
676
        return state_dict


    def load_state_dict(self, state_dict):
mohammad's avatar
mohammad committed
677
678
679
680
681
682
683
684
685
686
687
688
689
        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
            print_rank_0('***WARNING*** found an old checkpoint, will not '
                         'load grad scaler ...')
        else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
690
691
692
693
694
695
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')
mohammad's avatar
mohammad committed
696

697
        # Copy data for the main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
698
699
700
        fp32_from_float16_params_key = 'fp32_from_fp16_params'
        if fp32_from_float16_params_key not in state_dict:
            fp32_from_float16_params_key = 'fp32_from_fp16'
mohammad's avatar
mohammad committed
701
        for current_group, saved_group in zip(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
702
703
                self.fp32_from_float16_groups,
                state_dict[fp32_from_float16_params_key]):
mohammad's avatar
mohammad committed
704
705
706
707
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)


mohammad's avatar
mohammad committed
708
709
class FP32Optimizer(MegatronOptimizer):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
710
711
    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
712
                 params_have_main_grad,
713
714
                 use_contiguous_buffers_in_local_ddp,
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
715
716
717

        super(FP32Optimizer, self).__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
718
719
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
            models)
mohammad's avatar
mohammad committed
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735

        self._scale = torch.cuda.FloatTensor([1.0])


    def zero_grad(self, set_to_none=True):
        """Copied from torch.optim.optimizer"""
        for group in self.optimizer.param_groups:
            _zero_grad_group_helper(group['params'], set_to_none)


    def get_loss_scale(self):
        """FP32 optimizer does not do any scaling."""
        return self._scale


    @torch.no_grad()
736
    def step(self, args, timers):
mohammad's avatar
mohammad committed
737
        """Clip gradients (if needed) and step the base optimizer.
mohammad's avatar
mohammad committed
738
        Always return successful since there is no overflow."""
mohammad's avatar
mohammad committed
739

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
740
        # Copy main_grads to grads.
741
        timers('optimizer-copy-to-main-grad').start()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
742
743
744
745
746
        if self.params_have_main_grad:
            for param_group in self.optimizer.param_groups:
                for param in param_group['params']:
                    param.grad = param.main_grad

747
748
749
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
750
                    if not self.use_contiguous_buffers_in_local_ddp:
751
                        param.main_grad = None
752
        timers('optimizer-copy-to-main-grad').stop()
753

mohammad's avatar
mohammad committed
754
        # Clip gradients.
755
        timers('optimizer-clip-main-grad').start()
756
        grad_norm = None
mohammad's avatar
mohammad committed
757
        if self.clip_grad > 0.0:
758
            grad_norm = self.clip_grad_norm(self.clip_grad)
759
        timers('optimizer-clip-main-grad').stop()
mohammad's avatar
mohammad committed
760

Rewon Child's avatar
Rewon Child committed
761
        # count the zeros in the grads
762
        timers('optimizer-count-zeros').start()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
763
764
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
765
        timers('optimizer-count-zeros').stop()
Rewon Child's avatar
Rewon Child committed
766

mohammad's avatar
mohammad committed
767
        # Update parameters.
768
        timers('optimizer-inner-step').start()
mohammad's avatar
mohammad committed
769
        self.optimizer.step()
770
        timers('optimizer-inner-step').stop()
mohammad's avatar
mohammad committed
771
772

        # No overflow for FP32 optimizer.
773
        return True, grad_norm, num_zeros_in_grad
mohammad's avatar
mohammad committed
774
775


776
777
778
779
    def reload_model_params(self):
        pass


mohammad's avatar
mohammad committed
780
781
782
783
784
785
    def state_dict(self):
        return self.optimizer.state_dict()


    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)