optimizer.py 30.5 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3

"""Megatron optimizer."""
mohammad's avatar
mohammad committed
4
5
6
7
8

from abc import ABC
from abc import abstractmethod
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
9
10
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
Lawrence McAfee's avatar
Lawrence McAfee committed
11
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
mohammad's avatar
mohammad committed
12

mohammad's avatar
mohammad committed
13
14
from megatron import get_timers
from megatron import mpu
mohammad's avatar
mohammad committed
15
from megatron import print_rank_0
16
17
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
18
19
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
20
from megatron.utils import unwrap_model
21

22
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
23

Lawrence McAfee's avatar
Lawrence McAfee committed
24

mohammad's avatar
mohammad committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def _zero_grad_group_helper(group, set_to_none):
    """Zero out the gradient for a group of parameters.
    Note: copied from torch.optim.optimizer."""
    for param in group:
        if param.grad is not None:
            if set_to_none:
                param.grad = None
            else:
                if param.grad.grad_fn is not None:
                    param.grad.detach_()
                else:
                    param.grad.requires_grad_(False)
                param.grad.zero_()


40
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
41
42
43
44
    """Use multi-tensor-applier to copy values from one list to another.
    We don't have a blfoat16 implementation so for now if the overflow_buf
    is not provided, we default back to simple loop copy to be compatible
    with bfloat16."""
45
46
    if overflow_buf:
        overflow_buf.fill_(0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
47
48
49
50
51
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             overflow_buf,
                             [this, that],
                             1.0)
52
    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
53
54
55
        for this_, that_ in zip(this, that):
            that_.copy_(this_)

56

mohammad's avatar
mohammad committed
57
58
59

class MegatronOptimizer(ABC):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
60
61
62

    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
63
                 params_have_main_grad,
64
65
                 use_contiguous_buffers_in_local_ddp,
                 models):
66

mohammad's avatar
mohammad committed
67
68
69
        """Input optimizer is the base optimizer for example Adam."""
        self.optimizer = optimizer
        assert self.optimizer, 'no optimizer is provided.'
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
70
71
72
73
        # Set gradient clipping and logging params.
        self.clip_grad = clip_grad
        self.log_num_zeros_in_grad = log_num_zeros_in_grad
        self.params_have_main_grad = params_have_main_grad
74
        self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
75

76
77
78
79
        # 'models' are retained for access to the contiguous grad buffers.
        # (see distributed optimizer)
        self.models = models

80
        if self.use_contiguous_buffers_in_local_ddp:
81
82
            assert self.params_have_main_grad, \
                "use of contiguous buffer requires that params have main grad"
mohammad's avatar
mohammad committed
83

84

Rewon Child's avatar
Rewon Child committed
85
    def get_parameters(self):
86
87
88
89
        params = []
        for param_group in self.optimizer.param_groups:
            for param in param_group['params']:
                params.append(param)
Rewon Child's avatar
Rewon Child committed
90
91
        return params

92

93
    def get_main_grads_for_grad_norm(self):
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

        # Filter parameters based on:
        #   - grad should not be none
        #   - parameter should not be shared
        #   - should not be a replica due to tensor model parallelism
        params = self.get_parameters()
        grads_for_norm = []
        for param in params:
            grad = param.grad
            grad_not_none = grad is not None
            is_not_shared = param_is_not_shared(param)
            is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
            if grad_not_none and is_not_shared and is_not_tp_duplicate:
                grads_for_norm.append(grad)

        return grads_for_norm

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
111

112
    def get_model_parallel_group(self):
113
        """Default returned here, but the distributed optimizer overrides this."""
114
115
116
        return mpu.get_model_parallel_group()


117
    def clip_grad_norm(self, clip_grad):
Lawrence McAfee's avatar
Lawrence McAfee committed
118
        params = self.get_parameters()
119
        grads_for_norm = self.get_main_grads_for_grad_norm()
120
        return clip_grad_norm_fp32(
121
            params, grads_for_norm, clip_grad,
122
            model_parallel_group=self.get_model_parallel_group())
123

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
124

Rewon Child's avatar
Rewon Child committed
125
126
    def count_zeros(self):
        params = self.get_parameters()
127
128
        return count_zeros_fp32(params,
                                model_parallel_group=self.get_model_parallel_group())
Rewon Child's avatar
Rewon Child committed
129

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
130

mohammad's avatar
mohammad committed
131
132
133
134
    @abstractmethod
    def zero_grad(self, set_to_none=True):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
135

mohammad's avatar
mohammad committed
136
137
    @abstractmethod
    def get_loss_scale(self):
138
        """The output should be a cuda tensor of size 1."""
mohammad's avatar
mohammad committed
139
140
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
141

mohammad's avatar
mohammad committed
142
143
144
145
    def scale_loss(self, loss):
        """Simple scaling."""
        return self.get_loss_scale() * loss

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
146

147
148
    @abstractmethod
    def reload_model_params(self):
149
150
151
152
153
        """Refreshes any internal state from the current model parameters.
        Call whenever the parameters are changed outside of the optimizer.
        For example, when we load a model from a checkpoint  without loading
        the optimizer, the model parameters are updated but for fp16 optimizer
        with main parameters, the main parameters need to also be updated."""
154
155
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
156

mohammad's avatar
mohammad committed
157
158
159
160
    @abstractmethod
    def state_dict(self):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
161

mohammad's avatar
mohammad committed
162
163
164
165
    @abstractmethod
    def load_state_dict(self, state_dict):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
166

mohammad's avatar
mohammad committed
167
168
169
170
171
172
173
174
175
176
    # Promote state so it can be retrieved or set via
    # "optimizer_instance.state"
    def _get_state(self):
        return self.optimizer.state

    def _set_state(self, value):
        self.optimizer.state = value

    state = property(_get_state, _set_state)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
177

mohammad's avatar
mohammad committed
178
179
180
181
182
183
184
185
186
187
188
189
    # Promote param_groups so it can be retrieved or set via
    # "optimizer_instance.param_groups"
    # (for example, to adjust the learning rate)
    def _get_param_groups(self):
        return self.optimizer.param_groups

    def _set_param_groups(self, value):
        self.optimizer.param_groups = value

    param_groups = property(_get_param_groups, _set_param_groups)


190
    @abstractmethod
191
    def step(self, args, timers):
192
193
        pass

Lawrence McAfee's avatar
Lawrence McAfee committed
194

195
    def gather_model_params(self, args, timers):
196
197
198
199
        """
        For the case of a non-distributed-optimizer, there is nothing to
        do here.
        """
200
201
        pass

Lawrence McAfee's avatar
Lawrence McAfee committed
202

203
    def allreduce_word_embedding_grads(self, args):
204
        """
205
        All-reduce word embedding grads.
206

207
208
209
        Reduce grads across first and last stages to ensure that word_embeddings
        parameters stay in sync. This should only run for models that support
        pipelined model parallelism (BERT and GPT-2).
210
        """
211
212
213
214

        if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
                mpu.get_pipeline_model_parallel_world_size() > 1:
            if mpu.is_pipeline_first_stage(ignore_virtual=True):
215
                unwrapped_model = self.models[0]
216
            elif mpu.is_pipeline_last_stage(ignore_virtual=True):
217
                unwrapped_model = self.models[-1]
218
            else:  # We do not support the interleaved schedule for T5 yet.
219
                unwrapped_model = self.models[0]
220
221
222
223
224
225
226
227
228
229
230
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))

            if unwrapped_model.share_word_embeddings:
                word_embeddings_weight = unwrapped_model.word_embeddings_weight()
                if args.DDP_impl == 'local':
                    grad = word_embeddings_weight.main_grad
                else:
                    grad = word_embeddings_weight.grad
                torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())

Lawrence McAfee's avatar
Lawrence McAfee committed
231

232
    def allreduce_position_embedding_grads(self, args):
233
        """
234
235
236
237
        All-reduce position_embeddings grad across first (encoder) and
        split (decoder) stages to ensure that position embeddings parameters
        stay in sync. This should only run for T5 models with pipeline
        parallelism.
238
        """
239
240
241
        if mpu.is_rank_in_position_embedding_group() and \
                mpu.get_pipeline_model_parallel_world_size() > 1 and \
                args.pipeline_model_parallel_split_rank is not None:
242
            unwrapped_model = self.models[0]
243
244
245
246
247
248
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))
            assert args.DDP_impl == 'local', \
                'T5 model is only supported with local DDP mode'
            grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
            torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
249

Lawrence McAfee's avatar
Lawrence McAfee committed
250

251
    def allreduce_embedding_grads(self, args):
252
        """All-reduce both word and position embeddings."""
253
254
        self.allreduce_word_embedding_grads(args)
        self.allreduce_position_embedding_grads(args)
255

Lawrence McAfee's avatar
Lawrence McAfee committed
256

257
258
259
260
261
262
263
264
    def allreduce_layernorm_grads(self, args):
        """All-reduce layernorm grads (for sequence parallelism)."""

        # All-reduce layernorm parameters across model parallel nodes
        # when sequence parallelism is used
        if mpu.get_tensor_model_parallel_world_size() > 1 and \
                args.sequence_parallel:
            grads = []
Lawrence McAfee's avatar
Lawrence McAfee committed
265
            for model_module in self.models:
266
267
268
269
270
271
272
273
274
275
276
277
278
279
                unwrapped_model = unwrap_model( 
                    model_module, (torchDDP, LocalDDP, Float16Module))
                for param in unwrapped_model.parameters():
                    if getattr(param, 'sequence_parallel', False):
                        grad = param.main_grad if args.DDP_impl == 'local' else param.grad
                        grads.append(grad.data)
            coalesced = _flatten_dense_tensors(grads)
            torch.distributed.all_reduce(
                coalesced, group=mpu.get_tensor_model_parallel_group())
            for buf, synced in zip(grads, _unflatten_dense_tensors(
                    coalesced, grads)):
                buf.copy_(synced)


280
    def reduce_model_grads(self, args, timers):
281
        """All-reduce all grads, and all-reduce embeddings."""
282

283
        # All-reduce layer-norm grads (for sequence parallelism).
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
284
285
        timers('layernorm-grads-all-reduce', log_level=1).start(
            barrier=args.barrier_with_L1_time)
286
        self.allreduce_layernorm_grads(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
287
        timers('layernorm-grads-all-reduce').stop()
288

289
290
        # All-reduce if needed.
        if args.DDP_impl == 'local':
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
291
292
            timers('grads-all-reduce', log_level=1).start(
                barrier=args.barrier_with_L1_time)
293
294
            for model in self.models:
                model.allreduce_gradients()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
295
            timers('grads-all-reduce').stop()
296
297

        # All-reduce embedding grads.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
298
299
        timers('embedding-grads-all-reduce', log_level=1).start(
            barrier=args.barrier_with_L1_time)
300
        self.allreduce_embedding_grads(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
301
        timers('embedding-grads-all-reduce').stop()
302

303

304
class MixedPrecisionOptimizer(MegatronOptimizer):
Lawrence McAfee's avatar
Lawrence McAfee committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
    """Base class for both the float-16 and the distributed optimizer.

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
        bf16: if true, the model is running in bfloat16.
325
        params_dtype: used by distributed optimizer.
Lawrence McAfee's avatar
Lawrence McAfee committed
326
327
328
329
330
331
332
333
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
    """
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
334
335

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
336
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
337
                 fp16, bf16, params_dtype, grad_scaler,
338
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
339

Lawrence McAfee's avatar
Lawrence McAfee committed
340
        super().__init__(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
341
            optimizer, clip_grad, log_num_zeros_in_grad,
342
343
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
            models)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
344

345
        self.fp16 = fp16
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
346
        self.bf16 = bf16
347
        self.params_dtype = params_dtype
mohammad's avatar
mohammad committed
348
        self.grad_scaler = grad_scaler
349

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
350
351
        # None grad scaler is only supported for bf16.
        if self.grad_scaler is None:
352
            assert not self.fp16, 'fp16 expects a grad scaler.'
mohammad's avatar
mohammad committed
353
354
355

        # Tensor used to determine if a nan/if has happend.
        # Any non-zero value indicates inf/nan.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
356
357
358
359
        # Note that we keep this for the cases that grad scaler is none.
        # We still record nan/inf if we have a bfloat16 with a grad scaler.
        if self.grad_scaler:
            self.found_inf = torch.cuda.FloatTensor([0.0])
mohammad's avatar
mohammad committed
360
361

        # Dummy tensor needed for apex multi-apply tensor.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
362
363
364
365
366
367
368
369
370
371
        # For bfloat, we don't have multi-tensor apply and for now
        # we set it to none so the multi-tensor apply gets ignored.
        if bf16:
            self._dummy_overflow_buf = None
        else:
            self._dummy_overflow_buf = torch.cuda.IntTensor([0])

        # In case grad scaler is not passed, define the unity scale.
        if self.grad_scaler is None:
            self._scale_one = torch.cuda.FloatTensor([1.0])
mohammad's avatar
mohammad committed
372

Lawrence McAfee's avatar
Lawrence McAfee committed
373
374
375
376
377
378
379

    def get_loss_scale(self):
        if self.grad_scaler is None:
            return self._scale_one
        return self.grad_scaler.scale


Lawrence McAfee's avatar
Lawrence McAfee committed
380
381
382
383
    def reload_model_params(self):
        self._copy_model_params_to_main_params()


384
    def _unscale_main_grads_and_check_for_nan(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
385
386
387
388
389
390
391
392
393
394
395
396
397

        # Collect main grads.
        main_grads = self._collect_main_grad_data_for_unscaling()

        # Reset found inf.
        self.found_inf.fill_(0.0)

        # Unscale and set found inf/nan
        torch._amp_foreach_non_finite_check_and_unscale_(
            main_grads, self.found_inf, self.grad_scaler.inv_scale)

        # Update across all model parallel instances.
        torch.distributed.all_reduce(self.found_inf,
398
399
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=self.get_model_parallel_group())
Lawrence McAfee's avatar
Lawrence McAfee committed
400
401
402
403
404
405
406
407

        # Check for nan.
        found_inf_flag = (self.found_inf.item() > 0)

        return found_inf_flag


    @torch.no_grad()
408
    def step(self, args, timers):
409

Lawrence McAfee's avatar
Lawrence McAfee committed
410
        # Copy gradients from model params to main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
411
412
        timers('optimizer-copy-to-main-grad', log_level=1).start(
            barrier=args.barrier_with_L1_time)
413
        self._copy_model_grads_to_main_grads()
Lawrence McAfee's avatar
Lawrence McAfee committed
414
415
416
417
418
419
420
        timers('optimizer-copy-to-main-grad').stop()

        # Do unscale, check for inf, and update grad scaler only for
        # the case that grad scaler is provided.
        if self.grad_scaler:

            # Unscale and check for inf/nan.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
421
422
            timers('optimizer-unscale-and-check-inf', log_level=1).start(
                barrier=args.barrier_with_L1_time)
Lawrence McAfee's avatar
Lawrence McAfee committed
423
424
425
426
427
428
429
430
431
432
433
434
            found_inf_flag = self._unscale_main_grads_and_check_for_nan()
            timers('optimizer-unscale-and-check-inf').stop()

            # We are done with scaling gradients
            # so we can update the loss scale.
            self.grad_scaler.update(found_inf_flag)

            # If we found inf/nan, skip the update.
            if found_inf_flag:
                return False, None, None

        # Clip the main gradients.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
435
436
        timers('optimizer-clip-main-grad', log_level=1).start(
            barrier=args.barrier_with_L1_time)
Lawrence McAfee's avatar
Lawrence McAfee committed
437
438
        grad_norm = None
        if self.clip_grad > 0.0:
439
            grad_norm = self.clip_grad_norm(self.clip_grad)
Lawrence McAfee's avatar
Lawrence McAfee committed
440
441
        timers('optimizer-clip-main-grad').stop()

Lawrence McAfee's avatar
Lawrence McAfee committed
442
        # Count the zeros in the grads.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
443
444
        timers('optimizer-count-zeros', log_level=1).start(
            barrier=args.barrier_with_L1_time)
Lawrence McAfee's avatar
Lawrence McAfee committed
445
446
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
447
        timers('optimizer-count-zeros').stop()
Lawrence McAfee's avatar
Lawrence McAfee committed
448

449
        # Step the optimizer.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
450
451
        timers('optimizer-inner-step', log_level=1).start(
            barrier=args.barrier_with_L1_time)
452
        self.optimizer.step()
453
        timers('optimizer-inner-step').stop()
454

Lawrence McAfee's avatar
Lawrence McAfee committed
455
        # Update params from main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
456
457
        timers('optimizer-copy-main-to-model-params', log_level=1).start(
            barrier=args.barrier_with_L1_time)
458
        self._copy_main_params_to_model_params()
Lawrence McAfee's avatar
Lawrence McAfee committed
459
460
461
462
463
464
        timers('optimizer-copy-main-to-model-params').stop()

        # Successful update.
        return True, grad_norm, num_zeros_in_grad


465
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
Lawrence McAfee's avatar
Lawrence McAfee committed
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    """Float16 optimizer for fp16 and bf16 data types.

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
Lawrence McAfee's avatar
Lawrence McAfee committed
482
483
484
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
Lawrence McAfee's avatar
Lawrence McAfee committed
485
486
487
488
489
490
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
Lawrence McAfee's avatar
Lawrence McAfee committed
491
492
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
Lawrence McAfee's avatar
Lawrence McAfee committed
493
494
495
496
    """

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
497
                 fp16, bf16, params_dtype, grad_scaler, models):
Lawrence McAfee's avatar
Lawrence McAfee committed
498
499
500
501

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
502
            fp16, bf16, params_dtype, grad_scaler, models)
Lawrence McAfee's avatar
Lawrence McAfee committed
503

mohammad's avatar
mohammad committed
504
        # ======================
505
        # main parameter stuff
mohammad's avatar
mohammad committed
506
507
508
        # ======================

        # Three groups of parameters:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
509
510
        #   float16_groups: original float16 parameters
        #   fp32_from_float16_groups: fp32 copy of float16 parameters
mohammad's avatar
mohammad committed
511
        #   fp32_from_fp32_groups: original fp32 parameters
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
512
513
        self.float16_groups = []
        self.fp32_from_float16_groups = []
mohammad's avatar
mohammad committed
514
515
516
517
        self.fp32_from_fp32_groups = []

        # For all the groups in the original optimizer:
        for param_group in self.optimizer.param_groups:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
518
            float16_params_this_group = []
mohammad's avatar
mohammad committed
519
            fp32_params_this_group = []
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
520
            fp32_from_float16_params_this_group = []
mohammad's avatar
mohammad committed
521
522
523
524
            # For all the parameters in this group:
            for i, param in enumerate(param_group['params']):
                if param.requires_grad:

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
525
526
527
528
                    # float16 params:
                    if param.type() in ['torch.cuda.HalfTensor',
                                        'torch.cuda.BFloat16Tensor']:
                        float16_params_this_group.append(param)
mohammad's avatar
mohammad committed
529
                        # Create a copy
530
                        main_param = param.detach().clone().float()
mohammad's avatar
mohammad committed
531
                        # Copy tensor model parallel attributes.
532
                        mpu.copy_tensor_model_parallel_attributes(main_param,
mohammad's avatar
mohammad committed
533
                                                                  param)
534
                        if hasattr(param, 'shared'):
535
                            main_param.shared = param.shared
mohammad's avatar
mohammad committed
536
                        # Replace the optimizer params with the new fp32 copy.
537
                        param_group['params'][i] = main_param
538

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
539
                        fp32_from_float16_params_this_group.append(main_param)
540
                        # Reset existing state dict key to the new main param.
mohammad's avatar
mohammad committed
541
                        if param in self.optimizer.state:
542
                            self.optimizer.state[main_param] \
mohammad's avatar
mohammad committed
543
544
545
546
547
548
549
                                = self.optimizer.state.pop(param)
                    # fp32 params.
                    elif param.type() == 'torch.cuda.FloatTensor':
                        fp32_params_this_group.append(param)
                        param_group['params'][i] = param

                    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
550
551
552
553
554
555
556
557
558
                        raise TypeError('Wrapped parameters must be one of '
                                        'torch.cuda.FloatTensor,  '
                                        'torch.cuda.HalfTensor, or '
                                        'torch.cuda.BFloat16Tensor. '
                                        'Received {}'.format(param.type()))

            self.float16_groups.append(float16_params_this_group)
            self.fp32_from_float16_groups.append(
                fp32_from_float16_params_this_group)
mohammad's avatar
mohammad committed
559
560
561
            self.fp32_from_fp32_groups.append(fp32_params_this_group)


Lawrence McAfee's avatar
Lawrence McAfee committed
562
563
564
565
566
567
568
569
570
571
572
573
    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
        float16_groups & fp32_from_fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
        for group in self.float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_fp32_groups:
            _zero_grad_group_helper(group, set_to_none)
mohammad's avatar
mohammad committed
574
575


576
    def _collect_main_grad_data_for_unscaling(self):
577

578
        main_grads = []
579

580
581
582
583
584
        # fp32 params from float16 ones.
        for main_group in self.fp32_from_float16_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
585

586
587
588
589
590
591
592
        # Append fp32 parameters.
        for main_group in self.fp32_from_fp32_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
        
        return main_grads
593
594


595
596
597
598
599
600
601
602
603
    def _get_model_and_main_params_data_float16(self):
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
604

Lawrence McAfee's avatar
Lawrence McAfee committed
605

606
    def _copy_model_grads_to_main_grads(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
607
608
609
        # This only needs to be done for the float16 group.
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
610
            for model_param, main_param in zip(model_group, main_group):
611
                if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
612
613
614
615
                    main_param.grad = model_param.main_grad.float()
                else:
                    if model_param.grad is not None:
                        main_param.grad = model_param.grad.float()
616
617
618
619
620

                # Safe to deallocate model's grad/main_grad after copying.
                # (If using contiguous buffers, main_grad's memory should
                # persist and therefore should not be deallocated.)
                model_param.grad = None
621
                if self.params_have_main_grad and \
622
                   not self.use_contiguous_buffers_in_local_ddp:
623
624
                    model_param.main_grad = None

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
625
626
627
628
629
        # For fp32 grads, we need to reset the grads to main grad.
        if self.params_have_main_grad:
            for model_group in self.fp32_from_fp32_groups:
                for model_param in model_group:
                    model_param.grad = model_param.main_grad
mohammad's avatar
mohammad committed
630

631
632
633
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
634
                    if not self.use_contiguous_buffers_in_local_ddp:
635
                        model_param.main_grad = None
mohammad's avatar
mohammad committed
636

637

638
    def _copy_main_params_to_model_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
639
640
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
641
642
643
644
645
        _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
                                        overflow_buf=self._dummy_overflow_buf)


    def _copy_model_params_to_main_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
646
647
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
648
649
        _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
                                        overflow_buf=self._dummy_overflow_buf)
650
651


mohammad's avatar
mohammad committed
652
653
654
    def state_dict(self):
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
655
656
657
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
mohammad's avatar
mohammad committed
658
659
660
661
        return state_dict


    def load_state_dict(self, state_dict):
mohammad's avatar
mohammad committed
662
663
664
665
666
667
668
669
670
        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
671
672
673
674
        if 'grad_scaler' not in state_dict:
            if self.fp16:
                print_rank_0('***WARNING*** found an old checkpoint, will not '
                             'load grad scaler ...')
mohammad's avatar
mohammad committed
675
        else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
676
677
678
679
680
681
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')
mohammad's avatar
mohammad committed
682

683
        # Copy data for the main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
684
685
686
        fp32_from_float16_params_key = 'fp32_from_fp16_params'
        if fp32_from_float16_params_key not in state_dict:
            fp32_from_float16_params_key = 'fp32_from_fp16'
mohammad's avatar
mohammad committed
687
        for current_group, saved_group in zip(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
688
689
                self.fp32_from_float16_groups,
                state_dict[fp32_from_float16_params_key]):
mohammad's avatar
mohammad committed
690
691
692
693
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)


mohammad's avatar
mohammad committed
694
695
class FP32Optimizer(MegatronOptimizer):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
696
697
    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
698
                 params_have_main_grad,
699
700
                 use_contiguous_buffers_in_local_ddp,
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
701
702
703

        super(FP32Optimizer, self).__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
704
705
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
            models)
mohammad's avatar
mohammad committed
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721

        self._scale = torch.cuda.FloatTensor([1.0])


    def zero_grad(self, set_to_none=True):
        """Copied from torch.optim.optimizer"""
        for group in self.optimizer.param_groups:
            _zero_grad_group_helper(group['params'], set_to_none)


    def get_loss_scale(self):
        """FP32 optimizer does not do any scaling."""
        return self._scale


    @torch.no_grad()
722
    def step(self, args, timers):
mohammad's avatar
mohammad committed
723
        """Clip gradients (if needed) and step the base optimizer.
mohammad's avatar
mohammad committed
724
        Always return successful since there is no overflow."""
mohammad's avatar
mohammad committed
725

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
726
        # Copy main_grads to grads.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
727
728
        timers('optimizer-copy-to-main-grad', log_level=1).start(
            barrier=args.barrier_with_L1_time)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
729
730
731
732
733
        if self.params_have_main_grad:
            for param_group in self.optimizer.param_groups:
                for param in param_group['params']:
                    param.grad = param.main_grad

734
735
736
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
737
                    if not self.use_contiguous_buffers_in_local_ddp:
738
                        param.main_grad = None
739
        timers('optimizer-copy-to-main-grad').stop()
740

mohammad's avatar
mohammad committed
741
        # Clip gradients.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
742
743
        timers('optimizer-clip-main-grad', log_level=1).start(
            barrier=args.barrier_with_L1_time)
744
        grad_norm = None
mohammad's avatar
mohammad committed
745
        if self.clip_grad > 0.0:
746
            grad_norm = self.clip_grad_norm(self.clip_grad)
747
        timers('optimizer-clip-main-grad').stop()
mohammad's avatar
mohammad committed
748

Rewon Child's avatar
Rewon Child committed
749
        # count the zeros in the grads
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
750
751
        timers('optimizer-count-zeros', log_level=1).start(
            barrier=args.barrier_with_L1_time)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
752
753
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
754
        timers('optimizer-count-zeros').stop()
Rewon Child's avatar
Rewon Child committed
755

mohammad's avatar
mohammad committed
756
        # Update parameters.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
757
758
        timers('optimizer-inner-step', log_level=1).start(
            barrier=args.barrier_with_L1_time)
mohammad's avatar
mohammad committed
759
        self.optimizer.step()
760
        timers('optimizer-inner-step').stop()
mohammad's avatar
mohammad committed
761
762

        # No overflow for FP32 optimizer.
763
        return True, grad_norm, num_zeros_in_grad
mohammad's avatar
mohammad committed
764
765


766
767
768
769
    def reload_model_params(self):
        pass


mohammad's avatar
mohammad committed
770
771
772
773
774
775
    def state_dict(self):
        return self.optimizer.state_dict()


    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)