optimizer.py 30.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron optimizer."""
mohammad's avatar
mohammad committed
17
18
19
20
21

from abc import ABC
from abc import abstractmethod
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
22
23
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
Lawrence McAfee's avatar
Lawrence McAfee committed
24
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
mohammad's avatar
mohammad committed
25

mohammad's avatar
mohammad committed
26
27
from megatron import get_timers
from megatron import mpu
mohammad's avatar
mohammad committed
28
from megatron import print_rank_0
29
30
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
31
32
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
33
from megatron.utils import unwrap_model
34

35
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
36

Lawrence McAfee's avatar
Lawrence McAfee committed
37

mohammad's avatar
mohammad committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def _zero_grad_group_helper(group, set_to_none):
    """Zero out the gradient for a group of parameters.
    Note: copied from torch.optim.optimizer."""
    for param in group:
        if param.grad is not None:
            if set_to_none:
                param.grad = None
            else:
                if param.grad.grad_fn is not None:
                    param.grad.detach_()
                else:
                    param.grad.requires_grad_(False)
                param.grad.zero_()


53
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
54
55
56
57
    """Use multi-tensor-applier to copy values from one list to another.
    We don't have a blfoat16 implementation so for now if the overflow_buf
    is not provided, we default back to simple loop copy to be compatible
    with bfloat16."""
58
59
    if overflow_buf:
        overflow_buf.fill_(0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
60
61
62
63
64
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             overflow_buf,
                             [this, that],
                             1.0)
65
    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
66
67
68
        for this_, that_ in zip(this, that):
            that_.copy_(this_)

69

mohammad's avatar
mohammad committed
70
71
72

class MegatronOptimizer(ABC):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
73
74
75

    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
76
                 params_have_main_grad,
77
78
                 use_contiguous_buffers_in_local_ddp,
                 models):
79

mohammad's avatar
mohammad committed
80
81
82
        """Input optimizer is the base optimizer for example Adam."""
        self.optimizer = optimizer
        assert self.optimizer, 'no optimizer is provided.'
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
83
84
85
86
        # Set gradient clipping and logging params.
        self.clip_grad = clip_grad
        self.log_num_zeros_in_grad = log_num_zeros_in_grad
        self.params_have_main_grad = params_have_main_grad
87
        self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
88

89
90
91
92
        # 'models' are retained for access to the contiguous grad buffers.
        # (see distributed optimizer)
        self.models = models

93
        if self.use_contiguous_buffers_in_local_ddp:
94
95
            assert self.params_have_main_grad, \
                "use of contiguous buffer requires that params have main grad"
mohammad's avatar
mohammad committed
96

97

Rewon Child's avatar
Rewon Child committed
98
    def get_parameters(self):
99
100
101
102
        params = []
        for param_group in self.optimizer.param_groups:
            for param in param_group['params']:
                params.append(param)
Rewon Child's avatar
Rewon Child committed
103
104
        return params

105

106
    def get_main_grads_for_grad_norm(self):
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

        # Filter parameters based on:
        #   - grad should not be none
        #   - parameter should not be shared
        #   - should not be a replica due to tensor model parallelism
        params = self.get_parameters()
        grads_for_norm = []
        for param in params:
            grad = param.grad
            grad_not_none = grad is not None
            is_not_shared = param_is_not_shared(param)
            is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
            if grad_not_none and is_not_shared and is_not_tp_duplicate:
                grads_for_norm.append(grad)

        return grads_for_norm

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
124

125
    def get_model_parallel_group(self):
126
        """Default returned here, but the distributed optimizer overrides this."""
127
128
129
        return mpu.get_model_parallel_group()


130
    def clip_grad_norm(self, clip_grad):
Lawrence McAfee's avatar
Lawrence McAfee committed
131
        params = self.get_parameters()
132
        grads_for_norm = self.get_main_grads_for_grad_norm()
133
        return clip_grad_norm_fp32(
134
            params, grads_for_norm, clip_grad,
135
            model_parallel_group=self.get_model_parallel_group())
136

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
137

Rewon Child's avatar
Rewon Child committed
138
139
    def count_zeros(self):
        params = self.get_parameters()
140
141
        return count_zeros_fp32(params,
                                model_parallel_group=self.get_model_parallel_group())
Rewon Child's avatar
Rewon Child committed
142

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
143

mohammad's avatar
mohammad committed
144
145
146
147
    @abstractmethod
    def zero_grad(self, set_to_none=True):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
148

mohammad's avatar
mohammad committed
149
150
    @abstractmethod
    def get_loss_scale(self):
151
        """The output should be a cuda tensor of size 1."""
mohammad's avatar
mohammad committed
152
153
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
154

mohammad's avatar
mohammad committed
155
156
157
158
    def scale_loss(self, loss):
        """Simple scaling."""
        return self.get_loss_scale() * loss

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
159

160
161
    @abstractmethod
    def reload_model_params(self):
162
163
164
165
166
        """Refreshes any internal state from the current model parameters.
        Call whenever the parameters are changed outside of the optimizer.
        For example, when we load a model from a checkpoint  without loading
        the optimizer, the model parameters are updated but for fp16 optimizer
        with main parameters, the main parameters need to also be updated."""
167
168
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
169

mohammad's avatar
mohammad committed
170
171
172
173
    @abstractmethod
    def state_dict(self):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
174

mohammad's avatar
mohammad committed
175
176
177
178
    @abstractmethod
    def load_state_dict(self, state_dict):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
179

mohammad's avatar
mohammad committed
180
181
182
183
184
185
186
187
188
189
    # Promote state so it can be retrieved or set via
    # "optimizer_instance.state"
    def _get_state(self):
        return self.optimizer.state

    def _set_state(self, value):
        self.optimizer.state = value

    state = property(_get_state, _set_state)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
190

mohammad's avatar
mohammad committed
191
192
193
194
195
196
197
198
199
200
201
202
    # Promote param_groups so it can be retrieved or set via
    # "optimizer_instance.param_groups"
    # (for example, to adjust the learning rate)
    def _get_param_groups(self):
        return self.optimizer.param_groups

    def _set_param_groups(self, value):
        self.optimizer.param_groups = value

    param_groups = property(_get_param_groups, _set_param_groups)


203
    @abstractmethod
204
    def step(self, args, timers):
205
206
        pass

Lawrence McAfee's avatar
Lawrence McAfee committed
207

208
    def gather_model_params(self, args, timers):
209
210
211
212
        """
        For the case of a non-distributed-optimizer, there is nothing to
        do here.
        """
213
214
        pass

Lawrence McAfee's avatar
Lawrence McAfee committed
215

216
    def allreduce_word_embedding_grads(self, args):
217
        """
218
        All-reduce word embedding grads.
219

220
221
222
        Reduce grads across first and last stages to ensure that word_embeddings
        parameters stay in sync. This should only run for models that support
        pipelined model parallelism (BERT and GPT-2).
223
        """
224
225
226
227

        if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
                mpu.get_pipeline_model_parallel_world_size() > 1:
            if mpu.is_pipeline_first_stage(ignore_virtual=True):
228
                unwrapped_model = self.models[0]
229
            elif mpu.is_pipeline_last_stage(ignore_virtual=True):
230
                unwrapped_model = self.models[-1]
231
            else:  # We do not support the interleaved schedule for T5 yet.
232
                unwrapped_model = self.models[0]
233
234
235
236
237
238
239
240
241
242
243
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))

            if unwrapped_model.share_word_embeddings:
                word_embeddings_weight = unwrapped_model.word_embeddings_weight()
                if args.DDP_impl == 'local':
                    grad = word_embeddings_weight.main_grad
                else:
                    grad = word_embeddings_weight.grad
                torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())

Lawrence McAfee's avatar
Lawrence McAfee committed
244

245
    def allreduce_position_embedding_grads(self, args):
246
        """
247
248
249
250
        All-reduce position_embeddings grad across first (encoder) and
        split (decoder) stages to ensure that position embeddings parameters
        stay in sync. This should only run for T5 models with pipeline
        parallelism.
251
        """
252
253
254
        if mpu.is_rank_in_position_embedding_group() and \
                mpu.get_pipeline_model_parallel_world_size() > 1 and \
                args.pipeline_model_parallel_split_rank is not None:
255
            unwrapped_model = self.models[0]
256
257
258
259
260
261
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))
            assert args.DDP_impl == 'local', \
                'T5 model is only supported with local DDP mode'
            grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
            torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
262

Lawrence McAfee's avatar
Lawrence McAfee committed
263

264
    def allreduce_embedding_grads(self, args):
265
        """All-reduce both word and position embeddings."""
266
267
        self.allreduce_word_embedding_grads(args)
        self.allreduce_position_embedding_grads(args)
268

Lawrence McAfee's avatar
Lawrence McAfee committed
269

270
271
272
273
274
275
276
277
    def allreduce_layernorm_grads(self, args):
        """All-reduce layernorm grads (for sequence parallelism)."""

        # All-reduce layernorm parameters across model parallel nodes
        # when sequence parallelism is used
        if mpu.get_tensor_model_parallel_world_size() > 1 and \
                args.sequence_parallel:
            grads = []
Lawrence McAfee's avatar
Lawrence McAfee committed
278
            for model_module in self.models:
279
280
281
282
283
284
285
286
287
288
289
290
291
292
                unwrapped_model = unwrap_model( 
                    model_module, (torchDDP, LocalDDP, Float16Module))
                for param in unwrapped_model.parameters():
                    if getattr(param, 'sequence_parallel', False):
                        grad = param.main_grad if args.DDP_impl == 'local' else param.grad
                        grads.append(grad.data)
            coalesced = _flatten_dense_tensors(grads)
            torch.distributed.all_reduce(
                coalesced, group=mpu.get_tensor_model_parallel_group())
            for buf, synced in zip(grads, _unflatten_dense_tensors(
                    coalesced, grads)):
                buf.copy_(synced)


293
    def reduce_model_grads(self, args, timers):
294
        """All-reduce all grads, and all-reduce embeddings."""
295

296
        # All-reduce layer-norm grads (for sequence parallelism).
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
297
298
        timers('layernorm-grads-all-reduce', log_level=1).start(
            barrier=args.barrier_with_L1_time)
299
        self.allreduce_layernorm_grads(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
300
        timers('layernorm-grads-all-reduce').stop()
301

302
303
        # All-reduce if needed.
        if args.DDP_impl == 'local':
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
304
305
            timers('grads-all-reduce', log_level=1).start(
                barrier=args.barrier_with_L1_time)
306
307
            for model in self.models:
                model.allreduce_gradients()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
308
            timers('grads-all-reduce').stop()
309
310

        # All-reduce embedding grads.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
311
312
        timers('embedding-grads-all-reduce', log_level=1).start(
            barrier=args.barrier_with_L1_time)
313
        self.allreduce_embedding_grads(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
314
        timers('embedding-grads-all-reduce').stop()
315

316

317
class MixedPrecisionOptimizer(MegatronOptimizer):
Lawrence McAfee's avatar
Lawrence McAfee committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    """Base class for both the float-16 and the distributed optimizer.

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
    """
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
346
347

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
348
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
349
                 fp16, bf16, grad_scaler,
350
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
351

Lawrence McAfee's avatar
Lawrence McAfee committed
352
        super().__init__(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
353
            optimizer, clip_grad, log_num_zeros_in_grad,
354
355
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
            models)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
356

357
        self.fp16 = fp16
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
358
        self.bf16 = bf16
mohammad's avatar
mohammad committed
359
        self.grad_scaler = grad_scaler
360

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
361
362
        # None grad scaler is only supported for bf16.
        if self.grad_scaler is None:
363
            assert not self.fp16, 'fp16 expects a grad scaler.'
mohammad's avatar
mohammad committed
364
365
366

        # Tensor used to determine if a nan/if has happend.
        # Any non-zero value indicates inf/nan.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
367
368
369
370
        # Note that we keep this for the cases that grad scaler is none.
        # We still record nan/inf if we have a bfloat16 with a grad scaler.
        if self.grad_scaler:
            self.found_inf = torch.cuda.FloatTensor([0.0])
mohammad's avatar
mohammad committed
371
372

        # Dummy tensor needed for apex multi-apply tensor.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
373
374
375
376
377
378
379
380
381
382
        # For bfloat, we don't have multi-tensor apply and for now
        # we set it to none so the multi-tensor apply gets ignored.
        if bf16:
            self._dummy_overflow_buf = None
        else:
            self._dummy_overflow_buf = torch.cuda.IntTensor([0])

        # In case grad scaler is not passed, define the unity scale.
        if self.grad_scaler is None:
            self._scale_one = torch.cuda.FloatTensor([1.0])
mohammad's avatar
mohammad committed
383

Lawrence McAfee's avatar
Lawrence McAfee committed
384
385
386
387
388
389
390

    def get_loss_scale(self):
        if self.grad_scaler is None:
            return self._scale_one
        return self.grad_scaler.scale


Lawrence McAfee's avatar
Lawrence McAfee committed
391
392
393
394
    def reload_model_params(self):
        self._copy_model_params_to_main_params()


395
    def _unscale_main_grads_and_check_for_nan(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
396
397
398
399
400
401
402
403
404
405
406
407
408

        # Collect main grads.
        main_grads = self._collect_main_grad_data_for_unscaling()

        # Reset found inf.
        self.found_inf.fill_(0.0)

        # Unscale and set found inf/nan
        torch._amp_foreach_non_finite_check_and_unscale_(
            main_grads, self.found_inf, self.grad_scaler.inv_scale)

        # Update across all model parallel instances.
        torch.distributed.all_reduce(self.found_inf,
409
410
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=self.get_model_parallel_group())
Lawrence McAfee's avatar
Lawrence McAfee committed
411
412
413
414
415
416
417
418

        # Check for nan.
        found_inf_flag = (self.found_inf.item() > 0)

        return found_inf_flag


    @torch.no_grad()
419
    def step(self, args, timers):
420

Lawrence McAfee's avatar
Lawrence McAfee committed
421
        # Copy gradients from model params to main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
422
423
        timers('optimizer-copy-to-main-grad', log_level=1).start(
            barrier=args.barrier_with_L1_time)
424
        self._copy_model_grads_to_main_grads()
Lawrence McAfee's avatar
Lawrence McAfee committed
425
426
427
428
429
430
431
        timers('optimizer-copy-to-main-grad').stop()

        # Do unscale, check for inf, and update grad scaler only for
        # the case that grad scaler is provided.
        if self.grad_scaler:

            # Unscale and check for inf/nan.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
432
433
            timers('optimizer-unscale-and-check-inf', log_level=1).start(
                barrier=args.barrier_with_L1_time)
Lawrence McAfee's avatar
Lawrence McAfee committed
434
435
436
437
438
439
440
441
442
443
444
445
            found_inf_flag = self._unscale_main_grads_and_check_for_nan()
            timers('optimizer-unscale-and-check-inf').stop()

            # We are done with scaling gradients
            # so we can update the loss scale.
            self.grad_scaler.update(found_inf_flag)

            # If we found inf/nan, skip the update.
            if found_inf_flag:
                return False, None, None

        # Clip the main gradients.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
446
447
        timers('optimizer-clip-main-grad', log_level=1).start(
            barrier=args.barrier_with_L1_time)
Lawrence McAfee's avatar
Lawrence McAfee committed
448
449
        grad_norm = None
        if self.clip_grad > 0.0:
450
            grad_norm = self.clip_grad_norm(self.clip_grad)
Lawrence McAfee's avatar
Lawrence McAfee committed
451
452
        timers('optimizer-clip-main-grad').stop()

Lawrence McAfee's avatar
Lawrence McAfee committed
453
        # Count the zeros in the grads.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
454
455
        timers('optimizer-count-zeros', log_level=1).start(
            barrier=args.barrier_with_L1_time)
Lawrence McAfee's avatar
Lawrence McAfee committed
456
457
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
458
        timers('optimizer-count-zeros').stop()
Lawrence McAfee's avatar
Lawrence McAfee committed
459

460
        # Step the optimizer.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
461
462
        timers('optimizer-inner-step', log_level=1).start(
            barrier=args.barrier_with_L1_time)
463
        self.optimizer.step()
464
        timers('optimizer-inner-step').stop()
465

Lawrence McAfee's avatar
Lawrence McAfee committed
466
        # Update params from main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
467
468
        timers('optimizer-copy-main-to-model-params', log_level=1).start(
            barrier=args.barrier_with_L1_time)
469
        self._copy_main_params_to_model_params()
Lawrence McAfee's avatar
Lawrence McAfee committed
470
471
472
473
474
475
        timers('optimizer-copy-main-to-model-params').stop()

        # Successful update.
        return True, grad_norm, num_zeros_in_grad


476
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
Lawrence McAfee's avatar
Lawrence McAfee committed
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
    """Float16 optimizer for fp16 and bf16 data types.

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
Lawrence McAfee's avatar
Lawrence McAfee committed
493
494
495
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
Lawrence McAfee's avatar
Lawrence McAfee committed
496
497
498
499
500
501
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
Lawrence McAfee's avatar
Lawrence McAfee committed
502
503
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
Lawrence McAfee's avatar
Lawrence McAfee committed
504
505
506
507
    """

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
508
                 fp16, bf16, grad_scaler, models):
Lawrence McAfee's avatar
Lawrence McAfee committed
509
510
511
512

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
513
            fp16, bf16, grad_scaler, models)
Lawrence McAfee's avatar
Lawrence McAfee committed
514

mohammad's avatar
mohammad committed
515
        # ======================
516
        # main parameter stuff
mohammad's avatar
mohammad committed
517
518
519
        # ======================

        # Three groups of parameters:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
520
521
        #   float16_groups: original float16 parameters
        #   fp32_from_float16_groups: fp32 copy of float16 parameters
mohammad's avatar
mohammad committed
522
        #   fp32_from_fp32_groups: original fp32 parameters
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
523
524
        self.float16_groups = []
        self.fp32_from_float16_groups = []
mohammad's avatar
mohammad committed
525
526
527
528
        self.fp32_from_fp32_groups = []

        # For all the groups in the original optimizer:
        for param_group in self.optimizer.param_groups:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
529
            float16_params_this_group = []
mohammad's avatar
mohammad committed
530
            fp32_params_this_group = []
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
531
            fp32_from_float16_params_this_group = []
mohammad's avatar
mohammad committed
532
533
534
535
            # For all the parameters in this group:
            for i, param in enumerate(param_group['params']):
                if param.requires_grad:

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
536
537
538
539
                    # float16 params:
                    if param.type() in ['torch.cuda.HalfTensor',
                                        'torch.cuda.BFloat16Tensor']:
                        float16_params_this_group.append(param)
mohammad's avatar
mohammad committed
540
                        # Create a copy
541
                        main_param = param.detach().clone().float()
mohammad's avatar
mohammad committed
542
                        # Copy tensor model parallel attributes.
543
                        mpu.copy_tensor_model_parallel_attributes(main_param,
mohammad's avatar
mohammad committed
544
                                                                  param)
545
                        if hasattr(param, 'shared'):
546
                            main_param.shared = param.shared
mohammad's avatar
mohammad committed
547
                        # Replace the optimizer params with the new fp32 copy.
548
                        param_group['params'][i] = main_param
549

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
550
                        fp32_from_float16_params_this_group.append(main_param)
551
                        # Reset existing state dict key to the new main param.
mohammad's avatar
mohammad committed
552
                        if param in self.optimizer.state:
553
                            self.optimizer.state[main_param] \
mohammad's avatar
mohammad committed
554
555
556
557
558
559
560
                                = self.optimizer.state.pop(param)
                    # fp32 params.
                    elif param.type() == 'torch.cuda.FloatTensor':
                        fp32_params_this_group.append(param)
                        param_group['params'][i] = param

                    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
561
562
563
564
565
566
567
568
569
                        raise TypeError('Wrapped parameters must be one of '
                                        'torch.cuda.FloatTensor,  '
                                        'torch.cuda.HalfTensor, or '
                                        'torch.cuda.BFloat16Tensor. '
                                        'Received {}'.format(param.type()))

            self.float16_groups.append(float16_params_this_group)
            self.fp32_from_float16_groups.append(
                fp32_from_float16_params_this_group)
mohammad's avatar
mohammad committed
570
571
572
            self.fp32_from_fp32_groups.append(fp32_params_this_group)


Lawrence McAfee's avatar
Lawrence McAfee committed
573
574
575
576
577
578
579
580
581
582
583
584
    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
        float16_groups & fp32_from_fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
        for group in self.float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_fp32_groups:
            _zero_grad_group_helper(group, set_to_none)
mohammad's avatar
mohammad committed
585
586


587
    def _collect_main_grad_data_for_unscaling(self):
588

589
        main_grads = []
590

591
592
593
594
595
        # fp32 params from float16 ones.
        for main_group in self.fp32_from_float16_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
596

597
598
599
600
601
602
603
        # Append fp32 parameters.
        for main_group in self.fp32_from_fp32_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
        
        return main_grads
604
605


606
607
608
609
610
611
612
613
614
    def _get_model_and_main_params_data_float16(self):
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
615

Lawrence McAfee's avatar
Lawrence McAfee committed
616

617
    def _copy_model_grads_to_main_grads(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
618
619
620
        # This only needs to be done for the float16 group.
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
621
            for model_param, main_param in zip(model_group, main_group):
622
                if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
623
624
625
626
                    main_param.grad = model_param.main_grad.float()
                else:
                    if model_param.grad is not None:
                        main_param.grad = model_param.grad.float()
627
628
629
630
631

                # Safe to deallocate model's grad/main_grad after copying.
                # (If using contiguous buffers, main_grad's memory should
                # persist and therefore should not be deallocated.)
                model_param.grad = None
632
                if self.params_have_main_grad and \
633
                   not self.use_contiguous_buffers_in_local_ddp:
634
635
                    model_param.main_grad = None

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
636
637
638
639
640
        # For fp32 grads, we need to reset the grads to main grad.
        if self.params_have_main_grad:
            for model_group in self.fp32_from_fp32_groups:
                for model_param in model_group:
                    model_param.grad = model_param.main_grad
mohammad's avatar
mohammad committed
641

642
643
644
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
645
                    if not self.use_contiguous_buffers_in_local_ddp:
646
                        model_param.main_grad = None
mohammad's avatar
mohammad committed
647

648

649
    def _copy_main_params_to_model_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
650
651
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
652
653
654
655
656
        _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
                                        overflow_buf=self._dummy_overflow_buf)


    def _copy_model_params_to_main_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
657
658
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
659
660
        _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
                                        overflow_buf=self._dummy_overflow_buf)
661
662


mohammad's avatar
mohammad committed
663
664
665
    def state_dict(self):
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
666
667
668
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
mohammad's avatar
mohammad committed
669
670
671
672
        return state_dict


    def load_state_dict(self, state_dict):
mohammad's avatar
mohammad committed
673
674
675
676
677
678
679
680
681
        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
682
683
684
685
        if 'grad_scaler' not in state_dict:
            if self.fp16:
                print_rank_0('***WARNING*** found an old checkpoint, will not '
                             'load grad scaler ...')
mohammad's avatar
mohammad committed
686
        else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
687
688
689
690
691
692
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')
mohammad's avatar
mohammad committed
693

694
        # Copy data for the main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
695
696
697
        fp32_from_float16_params_key = 'fp32_from_fp16_params'
        if fp32_from_float16_params_key not in state_dict:
            fp32_from_float16_params_key = 'fp32_from_fp16'
mohammad's avatar
mohammad committed
698
        for current_group, saved_group in zip(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
699
700
                self.fp32_from_float16_groups,
                state_dict[fp32_from_float16_params_key]):
mohammad's avatar
mohammad committed
701
702
703
704
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)


mohammad's avatar
mohammad committed
705
706
class FP32Optimizer(MegatronOptimizer):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
707
708
    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
709
                 params_have_main_grad,
710
711
                 use_contiguous_buffers_in_local_ddp,
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
712
713
714

        super(FP32Optimizer, self).__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
715
716
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
            models)
mohammad's avatar
mohammad committed
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732

        self._scale = torch.cuda.FloatTensor([1.0])


    def zero_grad(self, set_to_none=True):
        """Copied from torch.optim.optimizer"""
        for group in self.optimizer.param_groups:
            _zero_grad_group_helper(group['params'], set_to_none)


    def get_loss_scale(self):
        """FP32 optimizer does not do any scaling."""
        return self._scale


    @torch.no_grad()
733
    def step(self, args, timers):
mohammad's avatar
mohammad committed
734
        """Clip gradients (if needed) and step the base optimizer.
mohammad's avatar
mohammad committed
735
        Always return successful since there is no overflow."""
mohammad's avatar
mohammad committed
736

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
737
        # Copy main_grads to grads.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
738
739
        timers('optimizer-copy-to-main-grad', log_level=1).start(
            barrier=args.barrier_with_L1_time)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
740
741
742
743
744
        if self.params_have_main_grad:
            for param_group in self.optimizer.param_groups:
                for param in param_group['params']:
                    param.grad = param.main_grad

745
746
747
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
748
                    if not self.use_contiguous_buffers_in_local_ddp:
749
                        param.main_grad = None
750
        timers('optimizer-copy-to-main-grad').stop()
751

mohammad's avatar
mohammad committed
752
        # Clip gradients.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
753
754
        timers('optimizer-clip-main-grad', log_level=1).start(
            barrier=args.barrier_with_L1_time)
755
        grad_norm = None
mohammad's avatar
mohammad committed
756
        if self.clip_grad > 0.0:
757
            grad_norm = self.clip_grad_norm(self.clip_grad)
758
        timers('optimizer-clip-main-grad').stop()
mohammad's avatar
mohammad committed
759

Rewon Child's avatar
Rewon Child committed
760
        # count the zeros in the grads
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
761
762
        timers('optimizer-count-zeros', log_level=1).start(
            barrier=args.barrier_with_L1_time)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
763
764
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
765
        timers('optimizer-count-zeros').stop()
Rewon Child's avatar
Rewon Child committed
766

mohammad's avatar
mohammad committed
767
        # Update parameters.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
768
769
        timers('optimizer-inner-step', log_level=1).start(
            barrier=args.barrier_with_L1_time)
mohammad's avatar
mohammad committed
770
        self.optimizer.step()
771
        timers('optimizer-inner-step').stop()
mohammad's avatar
mohammad committed
772
773

        # No overflow for FP32 optimizer.
774
        return True, grad_norm, num_zeros_in_grad
mohammad's avatar
mohammad committed
775
776


777
778
779
780
    def reload_model_params(self):
        pass


mohammad's avatar
mohammad committed
781
782
783
784
785
786
    def state_dict(self):
        return self.optimizer.state_dict()


    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)