optimizer.py 30.3 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3

"""Megatron optimizer."""
mohammad's avatar
mohammad committed
4
5
6
7
8

from abc import ABC
from abc import abstractmethod
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
9
10
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
Lawrence McAfee's avatar
Lawrence McAfee committed
11
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
mohammad's avatar
mohammad committed
12

mohammad's avatar
mohammad committed
13
from megatron import get_timers
mohammad's avatar
mohammad committed
14
from megatron import print_rank_0
15
from megatron.core import mpu, tensor_parallel
16
17
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
18
from megatron.model.module import param_is_not_shared
19
from megatron.utils import unwrap_model
20

21
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
22

Lawrence McAfee's avatar
Lawrence McAfee committed
23

mohammad's avatar
mohammad committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def _zero_grad_group_helper(group, set_to_none):
    """Zero out the gradient for a group of parameters.
    Note: copied from torch.optim.optimizer."""
    for param in group:
        if param.grad is not None:
            if set_to_none:
                param.grad = None
            else:
                if param.grad.grad_fn is not None:
                    param.grad.detach_()
                else:
                    param.grad.requires_grad_(False)
                param.grad.zero_()


39
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
40
41
42
43
    """Use multi-tensor-applier to copy values from one list to another.
    We don't have a blfoat16 implementation so for now if the overflow_buf
    is not provided, we default back to simple loop copy to be compatible
    with bfloat16."""
44
45
    if overflow_buf:
        overflow_buf.fill_(0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
46
47
48
49
50
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             overflow_buf,
                             [this, that],
                             1.0)
51
    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
52
53
54
        for this_, that_ in zip(this, that):
            that_.copy_(this_)

55

mohammad's avatar
mohammad committed
56
57
58

class MegatronOptimizer(ABC):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
59
60
61

    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
62
                 params_have_main_grad,
63
64
                 use_contiguous_buffers_in_local_ddp,
                 models):
65

mohammad's avatar
mohammad committed
66
67
68
        """Input optimizer is the base optimizer for example Adam."""
        self.optimizer = optimizer
        assert self.optimizer, 'no optimizer is provided.'
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
69
70
71
72
        # Set gradient clipping and logging params.
        self.clip_grad = clip_grad
        self.log_num_zeros_in_grad = log_num_zeros_in_grad
        self.params_have_main_grad = params_have_main_grad
73
        self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
74

75
76
77
78
        # 'models' are retained for access to the contiguous grad buffers.
        # (see distributed optimizer)
        self.models = models

79
        if self.use_contiguous_buffers_in_local_ddp:
80
81
            assert self.params_have_main_grad, \
                "use of contiguous buffer requires that params have main grad"
mohammad's avatar
mohammad committed
82

83

Rewon Child's avatar
Rewon Child committed
84
    def get_parameters(self):
85
86
87
88
        params = []
        for param_group in self.optimizer.param_groups:
            for param in param_group['params']:
                params.append(param)
Rewon Child's avatar
Rewon Child committed
89
90
        return params

91

92
    def get_main_grads_for_grad_norm(self):
93
94
95
96
97
98
99
100
101
102
103

        # Filter parameters based on:
        #   - grad should not be none
        #   - parameter should not be shared
        #   - should not be a replica due to tensor model parallelism
        params = self.get_parameters()
        grads_for_norm = []
        for param in params:
            grad = param.grad
            grad_not_none = grad is not None
            is_not_shared = param_is_not_shared(param)
104
            is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param)
105
106
107
108
109
            if grad_not_none and is_not_shared and is_not_tp_duplicate:
                grads_for_norm.append(grad)

        return grads_for_norm

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
110

111
    def get_model_parallel_group(self):
112
        """Default returned here, but the distributed optimizer overrides this."""
113
114
115
        return mpu.get_model_parallel_group()


116
    def clip_grad_norm(self, clip_grad):
Lawrence McAfee's avatar
Lawrence McAfee committed
117
        params = self.get_parameters()
118
        grads_for_norm = self.get_main_grads_for_grad_norm()
119
        return clip_grad_norm_fp32(
120
            params, grads_for_norm, clip_grad,
121
            model_parallel_group=self.get_model_parallel_group())
122

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
123

Rewon Child's avatar
Rewon Child committed
124
125
    def count_zeros(self):
        params = self.get_parameters()
126
127
        return count_zeros_fp32(params,
                                model_parallel_group=self.get_model_parallel_group())
Rewon Child's avatar
Rewon Child committed
128

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
129

mohammad's avatar
mohammad committed
130
131
132
133
    @abstractmethod
    def zero_grad(self, set_to_none=True):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
134

mohammad's avatar
mohammad committed
135
136
    @abstractmethod
    def get_loss_scale(self):
137
        """The output should be a cuda tensor of size 1."""
mohammad's avatar
mohammad committed
138
139
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
140

mohammad's avatar
mohammad committed
141
142
143
144
    def scale_loss(self, loss):
        """Simple scaling."""
        return self.get_loss_scale() * loss

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
145

146
147
    @abstractmethod
    def reload_model_params(self):
148
149
150
151
152
        """Refreshes any internal state from the current model parameters.
        Call whenever the parameters are changed outside of the optimizer.
        For example, when we load a model from a checkpoint  without loading
        the optimizer, the model parameters are updated but for fp16 optimizer
        with main parameters, the main parameters need to also be updated."""
153
154
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
155

mohammad's avatar
mohammad committed
156
157
158
159
    @abstractmethod
    def state_dict(self):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
160

mohammad's avatar
mohammad committed
161
162
163
164
    @abstractmethod
    def load_state_dict(self, state_dict):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
165

mohammad's avatar
mohammad committed
166
167
168
169
170
171
172
173
174
175
    # Promote state so it can be retrieved or set via
    # "optimizer_instance.state"
    def _get_state(self):
        return self.optimizer.state

    def _set_state(self, value):
        self.optimizer.state = value

    state = property(_get_state, _set_state)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
176

mohammad's avatar
mohammad committed
177
178
179
180
181
182
183
184
185
186
187
188
    # Promote param_groups so it can be retrieved or set via
    # "optimizer_instance.param_groups"
    # (for example, to adjust the learning rate)
    def _get_param_groups(self):
        return self.optimizer.param_groups

    def _set_param_groups(self, value):
        self.optimizer.param_groups = value

    param_groups = property(_get_param_groups, _set_param_groups)


189
    @abstractmethod
190
    def step(self, args, timers):
191
192
        pass

Lawrence McAfee's avatar
Lawrence McAfee committed
193

194
    def gather_model_params(self, args, timers):
195
196
197
198
        """
        For the case of a non-distributed-optimizer, there is nothing to
        do here.
        """
199
200
        pass

Lawrence McAfee's avatar
Lawrence McAfee committed
201

202
    def allreduce_word_embedding_grads(self, args):
203
        """
204
        All-reduce word embedding grads.
205

206
207
208
        Reduce grads across first and last stages to ensure that word_embeddings
        parameters stay in sync. This should only run for models that support
        pipelined model parallelism (BERT and GPT-2).
209
        """
210
211
212
213

        if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
                mpu.get_pipeline_model_parallel_world_size() > 1:
            if mpu.is_pipeline_first_stage(ignore_virtual=True):
214
                unwrapped_model = self.models[0]
215
            elif mpu.is_pipeline_last_stage(ignore_virtual=True):
216
                unwrapped_model = self.models[-1]
217
            else:  # We do not support the interleaved schedule for T5 yet.
218
                unwrapped_model = self.models[0]
219
220
221
222
223
224
225
226
227
228
229
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))

            if unwrapped_model.share_word_embeddings:
                word_embeddings_weight = unwrapped_model.word_embeddings_weight()
                if args.DDP_impl == 'local':
                    grad = word_embeddings_weight.main_grad
                else:
                    grad = word_embeddings_weight.grad
                torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())

Lawrence McAfee's avatar
Lawrence McAfee committed
230

231
    def allreduce_position_embedding_grads(self, args):
232
        """
233
234
235
236
        All-reduce position_embeddings grad across first (encoder) and
        split (decoder) stages to ensure that position embeddings parameters
        stay in sync. This should only run for T5 models with pipeline
        parallelism.
237
        """
238
239
240
        if mpu.is_rank_in_position_embedding_group() and \
                mpu.get_pipeline_model_parallel_world_size() > 1 and \
                args.pipeline_model_parallel_split_rank is not None:
241
            unwrapped_model = self.models[0]
242
243
244
245
246
247
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))
            assert args.DDP_impl == 'local', \
                'T5 model is only supported with local DDP mode'
            grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
            torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
248

Lawrence McAfee's avatar
Lawrence McAfee committed
249

250
    def allreduce_embedding_grads(self, args):
251
        """All-reduce both word and position embeddings."""
252
253
        self.allreduce_word_embedding_grads(args)
        self.allreduce_position_embedding_grads(args)
254

Lawrence McAfee's avatar
Lawrence McAfee committed
255

256
257
258
259
260
261
262
263
    def allreduce_layernorm_grads(self, args):
        """All-reduce layernorm grads (for sequence parallelism)."""

        # All-reduce layernorm parameters across model parallel nodes
        # when sequence parallelism is used
        if mpu.get_tensor_model_parallel_world_size() > 1 and \
                args.sequence_parallel:
            grads = []
Lawrence McAfee's avatar
Lawrence McAfee committed
264
            for model_module in self.models:
265
266
267
268
269
270
271
272
273
274
275
276
277
278
                unwrapped_model = unwrap_model( 
                    model_module, (torchDDP, LocalDDP, Float16Module))
                for param in unwrapped_model.parameters():
                    if getattr(param, 'sequence_parallel', False):
                        grad = param.main_grad if args.DDP_impl == 'local' else param.grad
                        grads.append(grad.data)
            coalesced = _flatten_dense_tensors(grads)
            torch.distributed.all_reduce(
                coalesced, group=mpu.get_tensor_model_parallel_group())
            for buf, synced in zip(grads, _unflatten_dense_tensors(
                    coalesced, grads)):
                buf.copy_(synced)


279
    def reduce_model_grads(self, args, timers):
280
        """All-reduce all grads, and all-reduce embeddings."""
281

282
        # All-reduce layer-norm grads (for sequence parallelism).
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
283
284
        timers('layernorm-grads-all-reduce', log_level=1).start(
            barrier=args.barrier_with_L1_time)
285
        self.allreduce_layernorm_grads(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
286
        timers('layernorm-grads-all-reduce').stop()
287

288
289
        # All-reduce if needed.
        if args.DDP_impl == 'local':
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
290
291
            timers('grads-all-reduce', log_level=1).start(
                barrier=args.barrier_with_L1_time)
292
293
            for model in self.models:
                model.allreduce_gradients()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
294
            timers('grads-all-reduce').stop()
295
296

        # All-reduce embedding grads.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
297
298
        timers('embedding-grads-all-reduce', log_level=1).start(
            barrier=args.barrier_with_L1_time)
299
        self.allreduce_embedding_grads(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
300
        timers('embedding-grads-all-reduce').stop()
301

302

303
class MixedPrecisionOptimizer(MegatronOptimizer):
Lawrence McAfee's avatar
Lawrence McAfee committed
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    """Base class for both the float-16 and the distributed optimizer.

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
    """
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
332
333

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
334
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
335
                 fp16, bf16, grad_scaler,
336
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
337

Lawrence McAfee's avatar
Lawrence McAfee committed
338
        super().__init__(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
339
            optimizer, clip_grad, log_num_zeros_in_grad,
340
341
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
            models)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
342

343
        self.fp16 = fp16
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
344
        self.bf16 = bf16
mohammad's avatar
mohammad committed
345
        self.grad_scaler = grad_scaler
346

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
347
348
        # None grad scaler is only supported for bf16.
        if self.grad_scaler is None:
349
            assert not self.fp16, 'fp16 expects a grad scaler.'
mohammad's avatar
mohammad committed
350
351
352

        # Tensor used to determine if a nan/if has happend.
        # Any non-zero value indicates inf/nan.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
353
354
355
356
        # Note that we keep this for the cases that grad scaler is none.
        # We still record nan/inf if we have a bfloat16 with a grad scaler.
        if self.grad_scaler:
            self.found_inf = torch.cuda.FloatTensor([0.0])
mohammad's avatar
mohammad committed
357
358

        # Dummy tensor needed for apex multi-apply tensor.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
359
360
361
362
363
364
365
366
367
368
        # For bfloat, we don't have multi-tensor apply and for now
        # we set it to none so the multi-tensor apply gets ignored.
        if bf16:
            self._dummy_overflow_buf = None
        else:
            self._dummy_overflow_buf = torch.cuda.IntTensor([0])

        # In case grad scaler is not passed, define the unity scale.
        if self.grad_scaler is None:
            self._scale_one = torch.cuda.FloatTensor([1.0])
mohammad's avatar
mohammad committed
369

Lawrence McAfee's avatar
Lawrence McAfee committed
370
371
372
373
374
375
376

    def get_loss_scale(self):
        if self.grad_scaler is None:
            return self._scale_one
        return self.grad_scaler.scale


Lawrence McAfee's avatar
Lawrence McAfee committed
377
378
379
380
    def reload_model_params(self):
        self._copy_model_params_to_main_params()


381
    def _unscale_main_grads_and_check_for_nan(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
382
383
384
385
386
387
388
389
390
391
392
393
394

        # Collect main grads.
        main_grads = self._collect_main_grad_data_for_unscaling()

        # Reset found inf.
        self.found_inf.fill_(0.0)

        # Unscale and set found inf/nan
        torch._amp_foreach_non_finite_check_and_unscale_(
            main_grads, self.found_inf, self.grad_scaler.inv_scale)

        # Update across all model parallel instances.
        torch.distributed.all_reduce(self.found_inf,
395
396
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=self.get_model_parallel_group())
Lawrence McAfee's avatar
Lawrence McAfee committed
397
398
399
400
401
402
403
404

        # Check for nan.
        found_inf_flag = (self.found_inf.item() > 0)

        return found_inf_flag


    @torch.no_grad()
405
    def step(self, args, timers):
406

Lawrence McAfee's avatar
Lawrence McAfee committed
407
        # Copy gradients from model params to main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
408
409
        timers('optimizer-copy-to-main-grad', log_level=1).start(
            barrier=args.barrier_with_L1_time)
410
        self._copy_model_grads_to_main_grads()
Lawrence McAfee's avatar
Lawrence McAfee committed
411
412
413
414
415
416
417
        timers('optimizer-copy-to-main-grad').stop()

        # Do unscale, check for inf, and update grad scaler only for
        # the case that grad scaler is provided.
        if self.grad_scaler:

            # Unscale and check for inf/nan.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
418
419
            timers('optimizer-unscale-and-check-inf', log_level=1).start(
                barrier=args.barrier_with_L1_time)
Lawrence McAfee's avatar
Lawrence McAfee committed
420
421
422
423
424
425
426
427
428
429
430
431
            found_inf_flag = self._unscale_main_grads_and_check_for_nan()
            timers('optimizer-unscale-and-check-inf').stop()

            # We are done with scaling gradients
            # so we can update the loss scale.
            self.grad_scaler.update(found_inf_flag)

            # If we found inf/nan, skip the update.
            if found_inf_flag:
                return False, None, None

        # Clip the main gradients.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
432
433
        timers('optimizer-clip-main-grad', log_level=1).start(
            barrier=args.barrier_with_L1_time)
Lawrence McAfee's avatar
Lawrence McAfee committed
434
435
        grad_norm = None
        if self.clip_grad > 0.0:
436
            grad_norm = self.clip_grad_norm(self.clip_grad)
Lawrence McAfee's avatar
Lawrence McAfee committed
437
438
        timers('optimizer-clip-main-grad').stop()

Lawrence McAfee's avatar
Lawrence McAfee committed
439
        # Count the zeros in the grads.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
440
441
        timers('optimizer-count-zeros', log_level=1).start(
            barrier=args.barrier_with_L1_time)
Lawrence McAfee's avatar
Lawrence McAfee committed
442
443
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
444
        timers('optimizer-count-zeros').stop()
Lawrence McAfee's avatar
Lawrence McAfee committed
445

446
        # Step the optimizer.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
447
448
        timers('optimizer-inner-step', log_level=1).start(
            barrier=args.barrier_with_L1_time)
449
        self.optimizer.step()
450
        timers('optimizer-inner-step').stop()
451

Lawrence McAfee's avatar
Lawrence McAfee committed
452
        # Update params from main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
453
454
        timers('optimizer-copy-main-to-model-params', log_level=1).start(
            barrier=args.barrier_with_L1_time)
455
        self._copy_main_params_to_model_params()
Lawrence McAfee's avatar
Lawrence McAfee committed
456
457
458
459
460
461
        timers('optimizer-copy-main-to-model-params').stop()

        # Successful update.
        return True, grad_norm, num_zeros_in_grad


462
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
Lawrence McAfee's avatar
Lawrence McAfee committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
    """Float16 optimizer for fp16 and bf16 data types.

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
Lawrence McAfee's avatar
Lawrence McAfee committed
479
480
481
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
Lawrence McAfee's avatar
Lawrence McAfee committed
482
483
484
485
486
487
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
Lawrence McAfee's avatar
Lawrence McAfee committed
488
489
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
Lawrence McAfee's avatar
Lawrence McAfee committed
490
491
492
493
    """

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
494
                 fp16, bf16, grad_scaler, models):
Lawrence McAfee's avatar
Lawrence McAfee committed
495
496
497
498

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
499
            fp16, bf16, grad_scaler, models)
Lawrence McAfee's avatar
Lawrence McAfee committed
500

mohammad's avatar
mohammad committed
501
        # ======================
502
        # main parameter stuff
mohammad's avatar
mohammad committed
503
504
505
        # ======================

        # Three groups of parameters:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
506
507
        #   float16_groups: original float16 parameters
        #   fp32_from_float16_groups: fp32 copy of float16 parameters
mohammad's avatar
mohammad committed
508
        #   fp32_from_fp32_groups: original fp32 parameters
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
509
510
        self.float16_groups = []
        self.fp32_from_float16_groups = []
mohammad's avatar
mohammad committed
511
512
513
514
        self.fp32_from_fp32_groups = []

        # For all the groups in the original optimizer:
        for param_group in self.optimizer.param_groups:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
515
            float16_params_this_group = []
mohammad's avatar
mohammad committed
516
            fp32_params_this_group = []
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
517
            fp32_from_float16_params_this_group = []
mohammad's avatar
mohammad committed
518
519
520
521
            # For all the parameters in this group:
            for i, param in enumerate(param_group['params']):
                if param.requires_grad:

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
522
523
524
525
                    # float16 params:
                    if param.type() in ['torch.cuda.HalfTensor',
                                        'torch.cuda.BFloat16Tensor']:
                        float16_params_this_group.append(param)
mohammad's avatar
mohammad committed
526
                        # Create a copy
527
                        main_param = param.detach().clone().float()
mohammad's avatar
mohammad committed
528
                        # Copy tensor model parallel attributes.
529
530
                        tensor_parallel.copy_tensor_model_parallel_attributes(main_param,
                                                                              param)
531
                        if hasattr(param, 'shared'):
532
                            main_param.shared = param.shared
mohammad's avatar
mohammad committed
533
                        # Replace the optimizer params with the new fp32 copy.
534
                        param_group['params'][i] = main_param
535

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
536
                        fp32_from_float16_params_this_group.append(main_param)
537
                        # Reset existing state dict key to the new main param.
mohammad's avatar
mohammad committed
538
                        if param in self.optimizer.state:
539
                            self.optimizer.state[main_param] \
mohammad's avatar
mohammad committed
540
541
542
543
544
545
546
                                = self.optimizer.state.pop(param)
                    # fp32 params.
                    elif param.type() == 'torch.cuda.FloatTensor':
                        fp32_params_this_group.append(param)
                        param_group['params'][i] = param

                    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
547
548
549
550
551
552
553
554
555
                        raise TypeError('Wrapped parameters must be one of '
                                        'torch.cuda.FloatTensor,  '
                                        'torch.cuda.HalfTensor, or '
                                        'torch.cuda.BFloat16Tensor. '
                                        'Received {}'.format(param.type()))

            self.float16_groups.append(float16_params_this_group)
            self.fp32_from_float16_groups.append(
                fp32_from_float16_params_this_group)
mohammad's avatar
mohammad committed
556
557
558
            self.fp32_from_fp32_groups.append(fp32_params_this_group)


Lawrence McAfee's avatar
Lawrence McAfee committed
559
560
561
562
563
564
565
566
567
568
569
570
    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
        float16_groups & fp32_from_fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
        for group in self.float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_fp32_groups:
            _zero_grad_group_helper(group, set_to_none)
mohammad's avatar
mohammad committed
571
572


573
    def _collect_main_grad_data_for_unscaling(self):
574

575
        main_grads = []
576

577
578
579
580
581
        # fp32 params from float16 ones.
        for main_group in self.fp32_from_float16_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
582

583
584
585
586
587
588
589
        # Append fp32 parameters.
        for main_group in self.fp32_from_fp32_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
        
        return main_grads
590
591


592
593
594
595
596
597
598
599
600
    def _get_model_and_main_params_data_float16(self):
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
601

Lawrence McAfee's avatar
Lawrence McAfee committed
602

603
    def _copy_model_grads_to_main_grads(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
604
605
606
        # This only needs to be done for the float16 group.
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
607
            for model_param, main_param in zip(model_group, main_group):
608
                if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
609
610
611
612
                    main_param.grad = model_param.main_grad.float()
                else:
                    if model_param.grad is not None:
                        main_param.grad = model_param.grad.float()
613
614
615
616
617

                # Safe to deallocate model's grad/main_grad after copying.
                # (If using contiguous buffers, main_grad's memory should
                # persist and therefore should not be deallocated.)
                model_param.grad = None
618
                if self.params_have_main_grad and \
619
                   not self.use_contiguous_buffers_in_local_ddp:
620
621
                    model_param.main_grad = None

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
622
623
624
625
626
        # For fp32 grads, we need to reset the grads to main grad.
        if self.params_have_main_grad:
            for model_group in self.fp32_from_fp32_groups:
                for model_param in model_group:
                    model_param.grad = model_param.main_grad
mohammad's avatar
mohammad committed
627

628
629
630
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
631
                    if not self.use_contiguous_buffers_in_local_ddp:
632
                        model_param.main_grad = None
mohammad's avatar
mohammad committed
633

634

635
    def _copy_main_params_to_model_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
636
637
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
638
639
640
641
642
        _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
                                        overflow_buf=self._dummy_overflow_buf)


    def _copy_model_params_to_main_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
643
644
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
645
646
        _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
                                        overflow_buf=self._dummy_overflow_buf)
647
648


mohammad's avatar
mohammad committed
649
650
651
    def state_dict(self):
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
652
653
654
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
mohammad's avatar
mohammad committed
655
656
657
658
        return state_dict


    def load_state_dict(self, state_dict):
mohammad's avatar
mohammad committed
659
660
661
662
663
664
665
666
667
        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
668
669
670
671
        if 'grad_scaler' not in state_dict:
            if self.fp16:
                print_rank_0('***WARNING*** found an old checkpoint, will not '
                             'load grad scaler ...')
mohammad's avatar
mohammad committed
672
        else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
673
674
675
676
677
678
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')
mohammad's avatar
mohammad committed
679

680
        # Copy data for the main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
681
682
683
        fp32_from_float16_params_key = 'fp32_from_fp16_params'
        if fp32_from_float16_params_key not in state_dict:
            fp32_from_float16_params_key = 'fp32_from_fp16'
mohammad's avatar
mohammad committed
684
        for current_group, saved_group in zip(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
685
686
                self.fp32_from_float16_groups,
                state_dict[fp32_from_float16_params_key]):
mohammad's avatar
mohammad committed
687
688
689
690
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)


mohammad's avatar
mohammad committed
691
692
class FP32Optimizer(MegatronOptimizer):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
693
694
    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
695
                 params_have_main_grad,
696
697
                 use_contiguous_buffers_in_local_ddp,
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
698
699
700

        super(FP32Optimizer, self).__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
701
702
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
            models)
mohammad's avatar
mohammad committed
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718

        self._scale = torch.cuda.FloatTensor([1.0])


    def zero_grad(self, set_to_none=True):
        """Copied from torch.optim.optimizer"""
        for group in self.optimizer.param_groups:
            _zero_grad_group_helper(group['params'], set_to_none)


    def get_loss_scale(self):
        """FP32 optimizer does not do any scaling."""
        return self._scale


    @torch.no_grad()
719
    def step(self, args, timers):
mohammad's avatar
mohammad committed
720
        """Clip gradients (if needed) and step the base optimizer.
mohammad's avatar
mohammad committed
721
        Always return successful since there is no overflow."""
mohammad's avatar
mohammad committed
722

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
723
        # Copy main_grads to grads.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
724
725
        timers('optimizer-copy-to-main-grad', log_level=1).start(
            barrier=args.barrier_with_L1_time)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
726
727
728
729
730
        if self.params_have_main_grad:
            for param_group in self.optimizer.param_groups:
                for param in param_group['params']:
                    param.grad = param.main_grad

731
732
733
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
734
                    if not self.use_contiguous_buffers_in_local_ddp:
735
                        param.main_grad = None
736
        timers('optimizer-copy-to-main-grad').stop()
737

mohammad's avatar
mohammad committed
738
        # Clip gradients.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
739
740
        timers('optimizer-clip-main-grad', log_level=1).start(
            barrier=args.barrier_with_L1_time)
741
        grad_norm = None
mohammad's avatar
mohammad committed
742
        if self.clip_grad > 0.0:
743
            grad_norm = self.clip_grad_norm(self.clip_grad)
744
        timers('optimizer-clip-main-grad').stop()
mohammad's avatar
mohammad committed
745

Rewon Child's avatar
Rewon Child committed
746
        # count the zeros in the grads
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
747
748
        timers('optimizer-count-zeros', log_level=1).start(
            barrier=args.barrier_with_L1_time)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
749
750
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
751
        timers('optimizer-count-zeros').stop()
Rewon Child's avatar
Rewon Child committed
752

mohammad's avatar
mohammad committed
753
        # Update parameters.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
754
755
        timers('optimizer-inner-step', log_level=1).start(
            barrier=args.barrier_with_L1_time)
mohammad's avatar
mohammad committed
756
        self.optimizer.step()
757
        timers('optimizer-inner-step').stop()
mohammad's avatar
mohammad committed
758
759

        # No overflow for FP32 optimizer.
760
        return True, grad_norm, num_zeros_in_grad
mohammad's avatar
mohammad committed
761
762


763
764
765
766
    def reload_model_params(self):
        pass


mohammad's avatar
mohammad committed
767
768
769
770
771
772
    def state_dict(self):
        return self.optimizer.state_dict()


    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)