optimizer.py 30.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron optimizer."""
mohammad's avatar
mohammad committed
17
18
19
20
21

from abc import ABC
from abc import abstractmethod
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
22
23
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
Lawrence McAfee's avatar
Lawrence McAfee committed
24
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
mohammad's avatar
mohammad committed
25

mohammad's avatar
mohammad committed
26
27
from megatron import get_timers
from megatron import mpu
mohammad's avatar
mohammad committed
28
from megatron import print_rank_0
29
30
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
31
32
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
33
from megatron.utils import unwrap_model
34

35
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
36

Lawrence McAfee's avatar
Lawrence McAfee committed
37

mohammad's avatar
mohammad committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def _zero_grad_group_helper(group, set_to_none):
    """Zero out the gradient for a group of parameters.
    Note: copied from torch.optim.optimizer."""
    for param in group:
        if param.grad is not None:
            if set_to_none:
                param.grad = None
            else:
                if param.grad.grad_fn is not None:
                    param.grad.detach_()
                else:
                    param.grad.requires_grad_(False)
                param.grad.zero_()


53
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
54
55
56
57
    """Use multi-tensor-applier to copy values from one list to another.
    We don't have a blfoat16 implementation so for now if the overflow_buf
    is not provided, we default back to simple loop copy to be compatible
    with bfloat16."""
58
59
    if overflow_buf:
        overflow_buf.fill_(0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
60
61
62
63
64
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             overflow_buf,
                             [this, that],
                             1.0)
65
    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
66
67
68
        for this_, that_ in zip(this, that):
            that_.copy_(this_)

69

mohammad's avatar
mohammad committed
70
71
72

class MegatronOptimizer(ABC):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
73
74
75

    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
76
                 params_have_main_grad,
77
78
                 use_contiguous_buffers_in_local_ddp,
                 models):
79

mohammad's avatar
mohammad committed
80
81
82
        """Input optimizer is the base optimizer for example Adam."""
        self.optimizer = optimizer
        assert self.optimizer, 'no optimizer is provided.'
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
83
84
85
86
        # Set gradient clipping and logging params.
        self.clip_grad = clip_grad
        self.log_num_zeros_in_grad = log_num_zeros_in_grad
        self.params_have_main_grad = params_have_main_grad
87
        self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
88

89
90
91
92
        # 'models' are retained for access to the contiguous grad buffers.
        # (see distributed optimizer)
        self.models = models

93
        if self.use_contiguous_buffers_in_local_ddp:
94
95
            assert self.params_have_main_grad, \
                "use of contiguous buffer requires that params have main grad"
mohammad's avatar
mohammad committed
96

97

Rewon Child's avatar
Rewon Child committed
98
    def get_parameters(self):
99
100
101
102
        params = []
        for param_group in self.optimizer.param_groups:
            for param in param_group['params']:
                params.append(param)
Rewon Child's avatar
Rewon Child committed
103
104
        return params

105

106
    def get_main_grads_for_grad_norm(self):
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

        # Filter parameters based on:
        #   - grad should not be none
        #   - parameter should not be shared
        #   - should not be a replica due to tensor model parallelism
        params = self.get_parameters()
        grads_for_norm = []
        for param in params:
            grad = param.grad
            grad_not_none = grad is not None
            is_not_shared = param_is_not_shared(param)
            is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
            if grad_not_none and is_not_shared and is_not_tp_duplicate:
                grads_for_norm.append(grad)

        return grads_for_norm

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
124

125
    def get_model_parallel_group(self):
126
        """Default returned here, but the distributed optimizer overrides this."""
127
128
129
        return mpu.get_model_parallel_group()


130
    def clip_grad_norm(self, clip_grad):
Lawrence McAfee's avatar
Lawrence McAfee committed
131
        params = self.get_parameters()
132
        grads_for_norm = self.get_main_grads_for_grad_norm()
133
        return clip_grad_norm_fp32(
134
            params, grads_for_norm, clip_grad,
135
            model_parallel_group=self.get_model_parallel_group())
136

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
137

Rewon Child's avatar
Rewon Child committed
138
139
    def count_zeros(self):
        params = self.get_parameters()
140
141
        return count_zeros_fp32(params,
                                model_parallel_group=self.get_model_parallel_group())
Rewon Child's avatar
Rewon Child committed
142

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
143

mohammad's avatar
mohammad committed
144
145
146
147
    @abstractmethod
    def zero_grad(self, set_to_none=True):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
148

mohammad's avatar
mohammad committed
149
150
    @abstractmethod
    def get_loss_scale(self):
151
        """The output should be a cuda tensor of size 1."""
mohammad's avatar
mohammad committed
152
153
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
154

mohammad's avatar
mohammad committed
155
156
157
158
    def scale_loss(self, loss):
        """Simple scaling."""
        return self.get_loss_scale() * loss

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
159

160
161
    @abstractmethod
    def reload_model_params(self):
162
163
164
165
166
        """Refreshes any internal state from the current model parameters.
        Call whenever the parameters are changed outside of the optimizer.
        For example, when we load a model from a checkpoint  without loading
        the optimizer, the model parameters are updated but for fp16 optimizer
        with main parameters, the main parameters need to also be updated."""
167
168
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
169

mohammad's avatar
mohammad committed
170
171
172
173
    @abstractmethod
    def state_dict(self):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
174

mohammad's avatar
mohammad committed
175
176
177
178
    @abstractmethod
    def load_state_dict(self, state_dict):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
179

mohammad's avatar
mohammad committed
180
181
182
183
184
185
186
187
188
189
    # Promote state so it can be retrieved or set via
    # "optimizer_instance.state"
    def _get_state(self):
        return self.optimizer.state

    def _set_state(self, value):
        self.optimizer.state = value

    state = property(_get_state, _set_state)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
190

mohammad's avatar
mohammad committed
191
192
193
194
195
196
197
198
199
200
201
202
    # Promote param_groups so it can be retrieved or set via
    # "optimizer_instance.param_groups"
    # (for example, to adjust the learning rate)
    def _get_param_groups(self):
        return self.optimizer.param_groups

    def _set_param_groups(self, value):
        self.optimizer.param_groups = value

    param_groups = property(_get_param_groups, _set_param_groups)


203
    @abstractmethod
204
    def step(self, args, timers):
205
206
        pass

Lawrence McAfee's avatar
Lawrence McAfee committed
207

208
    def gather_model_params(self, args, timers):
209
210
211
212
        """
        For the case of a non-distributed-optimizer, there is nothing to
        do here.
        """
213
214
        pass

Lawrence McAfee's avatar
Lawrence McAfee committed
215

216
    def allreduce_word_embedding_grads(self, args):
217
        """
218
        All-reduce word embedding grads.
219

220
221
222
        Reduce grads across first and last stages to ensure that word_embeddings
        parameters stay in sync. This should only run for models that support
        pipelined model parallelism (BERT and GPT-2).
223
        """
224
225
226
227

        if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
                mpu.get_pipeline_model_parallel_world_size() > 1:
            if mpu.is_pipeline_first_stage(ignore_virtual=True):
228
                unwrapped_model = self.models[0]
229
            elif mpu.is_pipeline_last_stage(ignore_virtual=True):
230
                unwrapped_model = self.models[-1]
231
            else:  # We do not support the interleaved schedule for T5 yet.
232
                unwrapped_model = self.models[0]
233
234
235
236
237
238
239
240
241
242
243
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))

            if unwrapped_model.share_word_embeddings:
                word_embeddings_weight = unwrapped_model.word_embeddings_weight()
                if args.DDP_impl == 'local':
                    grad = word_embeddings_weight.main_grad
                else:
                    grad = word_embeddings_weight.grad
                torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())

Lawrence McAfee's avatar
Lawrence McAfee committed
244

245
    def allreduce_position_embedding_grads(self, args):
246
        """
247
248
249
250
        All-reduce position_embeddings grad across first (encoder) and
        split (decoder) stages to ensure that position embeddings parameters
        stay in sync. This should only run for T5 models with pipeline
        parallelism.
251
        """
252
253
254
        if mpu.is_rank_in_position_embedding_group() and \
                mpu.get_pipeline_model_parallel_world_size() > 1 and \
                args.pipeline_model_parallel_split_rank is not None:
255
            unwrapped_model = self.models[0]
256
257
258
259
260
261
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))
            assert args.DDP_impl == 'local', \
                'T5 model is only supported with local DDP mode'
            grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
            torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
262

Lawrence McAfee's avatar
Lawrence McAfee committed
263

264
    def allreduce_embedding_grads(self, args):
265
        """All-reduce both word and position embeddings."""
266
267
        self.allreduce_word_embedding_grads(args)
        self.allreduce_position_embedding_grads(args)
268

Lawrence McAfee's avatar
Lawrence McAfee committed
269

270
271
272
273
274
275
276
277
    def allreduce_layernorm_grads(self, args):
        """All-reduce layernorm grads (for sequence parallelism)."""

        # All-reduce layernorm parameters across model parallel nodes
        # when sequence parallelism is used
        if mpu.get_tensor_model_parallel_world_size() > 1 and \
                args.sequence_parallel:
            grads = []
Lawrence McAfee's avatar
Lawrence McAfee committed
278
            for model_module in self.models:
279
280
281
282
283
284
285
286
287
288
289
290
291
292
                unwrapped_model = unwrap_model( 
                    model_module, (torchDDP, LocalDDP, Float16Module))
                for param in unwrapped_model.parameters():
                    if getattr(param, 'sequence_parallel', False):
                        grad = param.main_grad if args.DDP_impl == 'local' else param.grad
                        grads.append(grad.data)
            coalesced = _flatten_dense_tensors(grads)
            torch.distributed.all_reduce(
                coalesced, group=mpu.get_tensor_model_parallel_group())
            for buf, synced in zip(grads, _unflatten_dense_tensors(
                    coalesced, grads)):
                buf.copy_(synced)


293
    def reduce_model_grads(self, args, timers):
294
        """All-reduce all grads, and all-reduce embeddings."""
295

296
297
298
299
300
        # All-reduce layer-norm grads (for sequence parallelism).
        timers('backward-layernorm-all-reduce').start()
        self.allreduce_layernorm_grads(args)
        timers('backward-layernorm-all-reduce').stop()

301
302
303
        # All-reduce if needed.
        if args.DDP_impl == 'local':
            timers('backward-params-all-reduce').start()
304
305
            for model in self.models:
                model.allreduce_gradients()
306
307
308
309
            timers('backward-params-all-reduce').stop()

        # All-reduce embedding grads.
        timers('backward-embedding-all-reduce').start()
310
        self.allreduce_embedding_grads(args)
311
312
        timers('backward-embedding-all-reduce').stop()

313

314
class MixedPrecisionOptimizer(MegatronOptimizer):
Lawrence McAfee's avatar
Lawrence McAfee committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    """Base class for both the float-16 and the distributed optimizer.

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
    """
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
343
344

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
345
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
346
                 fp16, bf16, grad_scaler,
347
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
348

Lawrence McAfee's avatar
Lawrence McAfee committed
349
        super().__init__(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
350
            optimizer, clip_grad, log_num_zeros_in_grad,
351
352
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
            models)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
353

354
        self.fp16 = fp16
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
355
        self.bf16 = bf16
mohammad's avatar
mohammad committed
356
        self.grad_scaler = grad_scaler
357

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
358
359
        # None grad scaler is only supported for bf16.
        if self.grad_scaler is None:
360
            assert not self.fp16, 'fp16 expects a grad scaler.'
mohammad's avatar
mohammad committed
361
362
363

        # Tensor used to determine if a nan/if has happend.
        # Any non-zero value indicates inf/nan.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
364
365
366
367
        # Note that we keep this for the cases that grad scaler is none.
        # We still record nan/inf if we have a bfloat16 with a grad scaler.
        if self.grad_scaler:
            self.found_inf = torch.cuda.FloatTensor([0.0])
mohammad's avatar
mohammad committed
368
369

        # Dummy tensor needed for apex multi-apply tensor.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
370
371
372
373
374
375
376
377
378
379
        # For bfloat, we don't have multi-tensor apply and for now
        # we set it to none so the multi-tensor apply gets ignored.
        if bf16:
            self._dummy_overflow_buf = None
        else:
            self._dummy_overflow_buf = torch.cuda.IntTensor([0])

        # In case grad scaler is not passed, define the unity scale.
        if self.grad_scaler is None:
            self._scale_one = torch.cuda.FloatTensor([1.0])
mohammad's avatar
mohammad committed
380

Lawrence McAfee's avatar
Lawrence McAfee committed
381
382
383
384
385
386
387

    def get_loss_scale(self):
        if self.grad_scaler is None:
            return self._scale_one
        return self.grad_scaler.scale


Lawrence McAfee's avatar
Lawrence McAfee committed
388
389
390
391
    def reload_model_params(self):
        self._copy_model_params_to_main_params()


392
    def _unscale_main_grads_and_check_for_nan(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
393
394
395
396
397
398
399
400
401
402
403
404
405

        # Collect main grads.
        main_grads = self._collect_main_grad_data_for_unscaling()

        # Reset found inf.
        self.found_inf.fill_(0.0)

        # Unscale and set found inf/nan
        torch._amp_foreach_non_finite_check_and_unscale_(
            main_grads, self.found_inf, self.grad_scaler.inv_scale)

        # Update across all model parallel instances.
        torch.distributed.all_reduce(self.found_inf,
406
407
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=self.get_model_parallel_group())
Lawrence McAfee's avatar
Lawrence McAfee committed
408
409
410
411
412
413
414
415

        # Check for nan.
        found_inf_flag = (self.found_inf.item() > 0)

        return found_inf_flag


    @torch.no_grad()
416
    def step(self, args, timers):
417

Lawrence McAfee's avatar
Lawrence McAfee committed
418
419
        # Copy gradients from model params to main params.
        timers('optimizer-copy-to-main-grad').start()
420
        self._copy_model_grads_to_main_grads()
Lawrence McAfee's avatar
Lawrence McAfee committed
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        timers('optimizer-copy-to-main-grad').stop()

        # Do unscale, check for inf, and update grad scaler only for
        # the case that grad scaler is provided.
        if self.grad_scaler:

            # Unscale and check for inf/nan.
            timers('optimizer-unscale-and-check-inf').start()
            found_inf_flag = self._unscale_main_grads_and_check_for_nan()
            timers('optimizer-unscale-and-check-inf').stop()

            # We are done with scaling gradients
            # so we can update the loss scale.
            self.grad_scaler.update(found_inf_flag)

            # If we found inf/nan, skip the update.
            if found_inf_flag:
                return False, None, None

        # Clip the main gradients.
        timers('optimizer-clip-main-grad').start()
        grad_norm = None
        if self.clip_grad > 0.0:
444
            grad_norm = self.clip_grad_norm(self.clip_grad)
Lawrence McAfee's avatar
Lawrence McAfee committed
445
446
        timers('optimizer-clip-main-grad').stop()

Lawrence McAfee's avatar
Lawrence McAfee committed
447
        # Count the zeros in the grads.
448
        timers('optimizer-count-zeros').start()
Lawrence McAfee's avatar
Lawrence McAfee committed
449
450
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
451
        timers('optimizer-count-zeros').stop()
Lawrence McAfee's avatar
Lawrence McAfee committed
452

453
        # Step the optimizer.
454
        timers('optimizer-inner-step').start()
455
        self.optimizer.step()
456
        timers('optimizer-inner-step').stop()
457

Lawrence McAfee's avatar
Lawrence McAfee committed
458
459
        # Update params from main params.
        timers('optimizer-copy-main-to-model-params').start()
460
        self._copy_main_params_to_model_params()
Lawrence McAfee's avatar
Lawrence McAfee committed
461
462
463
464
465
466
        timers('optimizer-copy-main-to-model-params').stop()

        # Successful update.
        return True, grad_norm, num_zeros_in_grad


467
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
Lawrence McAfee's avatar
Lawrence McAfee committed
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
    """Float16 optimizer for fp16 and bf16 data types.

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
Lawrence McAfee's avatar
Lawrence McAfee committed
484
485
486
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
Lawrence McAfee's avatar
Lawrence McAfee committed
487
488
489
490
491
492
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
Lawrence McAfee's avatar
Lawrence McAfee committed
493
494
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
Lawrence McAfee's avatar
Lawrence McAfee committed
495
496
497
498
    """

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
499
                 fp16, bf16, grad_scaler, models):
Lawrence McAfee's avatar
Lawrence McAfee committed
500
501
502
503

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
504
            fp16, bf16, grad_scaler, models)
Lawrence McAfee's avatar
Lawrence McAfee committed
505

mohammad's avatar
mohammad committed
506
        # ======================
507
        # main parameter stuff
mohammad's avatar
mohammad committed
508
509
510
        # ======================

        # Three groups of parameters:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
511
512
        #   float16_groups: original float16 parameters
        #   fp32_from_float16_groups: fp32 copy of float16 parameters
mohammad's avatar
mohammad committed
513
        #   fp32_from_fp32_groups: original fp32 parameters
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
514
515
        self.float16_groups = []
        self.fp32_from_float16_groups = []
mohammad's avatar
mohammad committed
516
517
518
519
        self.fp32_from_fp32_groups = []

        # For all the groups in the original optimizer:
        for param_group in self.optimizer.param_groups:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
520
            float16_params_this_group = []
mohammad's avatar
mohammad committed
521
            fp32_params_this_group = []
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
522
            fp32_from_float16_params_this_group = []
mohammad's avatar
mohammad committed
523
524
525
526
            # For all the parameters in this group:
            for i, param in enumerate(param_group['params']):
                if param.requires_grad:

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
527
528
529
530
                    # float16 params:
                    if param.type() in ['torch.cuda.HalfTensor',
                                        'torch.cuda.BFloat16Tensor']:
                        float16_params_this_group.append(param)
mohammad's avatar
mohammad committed
531
                        # Create a copy
532
                        main_param = param.detach().clone().float()
mohammad's avatar
mohammad committed
533
                        # Copy tensor model parallel attributes.
534
                        mpu.copy_tensor_model_parallel_attributes(main_param,
mohammad's avatar
mohammad committed
535
                                                                  param)
536
                        if hasattr(param, 'shared'):
537
                            main_param.shared = param.shared
mohammad's avatar
mohammad committed
538
                        # Replace the optimizer params with the new fp32 copy.
539
                        param_group['params'][i] = main_param
540

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
541
                        fp32_from_float16_params_this_group.append(main_param)
542
                        # Reset existing state dict key to the new main param.
mohammad's avatar
mohammad committed
543
                        if param in self.optimizer.state:
544
                            self.optimizer.state[main_param] \
mohammad's avatar
mohammad committed
545
546
547
548
549
550
551
                                = self.optimizer.state.pop(param)
                    # fp32 params.
                    elif param.type() == 'torch.cuda.FloatTensor':
                        fp32_params_this_group.append(param)
                        param_group['params'][i] = param

                    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
552
553
554
555
556
557
558
559
560
                        raise TypeError('Wrapped parameters must be one of '
                                        'torch.cuda.FloatTensor,  '
                                        'torch.cuda.HalfTensor, or '
                                        'torch.cuda.BFloat16Tensor. '
                                        'Received {}'.format(param.type()))

            self.float16_groups.append(float16_params_this_group)
            self.fp32_from_float16_groups.append(
                fp32_from_float16_params_this_group)
mohammad's avatar
mohammad committed
561
562
563
            self.fp32_from_fp32_groups.append(fp32_params_this_group)


Lawrence McAfee's avatar
Lawrence McAfee committed
564
565
566
567
568
569
570
571
572
573
574
575
    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
        float16_groups & fp32_from_fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
        for group in self.float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_fp32_groups:
            _zero_grad_group_helper(group, set_to_none)
mohammad's avatar
mohammad committed
576
577


578
    def _collect_main_grad_data_for_unscaling(self):
579

580
        main_grads = []
581

582
583
584
585
586
        # fp32 params from float16 ones.
        for main_group in self.fp32_from_float16_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
587

588
589
590
591
592
593
594
        # Append fp32 parameters.
        for main_group in self.fp32_from_fp32_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
        
        return main_grads
595
596


597
598
599
600
601
602
603
604
605
    def _get_model_and_main_params_data_float16(self):
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
606

Lawrence McAfee's avatar
Lawrence McAfee committed
607

608
    def _copy_model_grads_to_main_grads(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
609
610
611
        # This only needs to be done for the float16 group.
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
612
            for model_param, main_param in zip(model_group, main_group):
613
                if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
614
615
616
617
                    main_param.grad = model_param.main_grad.float()
                else:
                    if model_param.grad is not None:
                        main_param.grad = model_param.grad.float()
618
619
620
621
622

                # Safe to deallocate model's grad/main_grad after copying.
                # (If using contiguous buffers, main_grad's memory should
                # persist and therefore should not be deallocated.)
                model_param.grad = None
623
                if self.params_have_main_grad and \
624
                   not self.use_contiguous_buffers_in_local_ddp:
625
626
                    model_param.main_grad = None

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
627
628
629
630
631
        # For fp32 grads, we need to reset the grads to main grad.
        if self.params_have_main_grad:
            for model_group in self.fp32_from_fp32_groups:
                for model_param in model_group:
                    model_param.grad = model_param.main_grad
mohammad's avatar
mohammad committed
632

633
634
635
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
636
                    if not self.use_contiguous_buffers_in_local_ddp:
637
                        model_param.main_grad = None
mohammad's avatar
mohammad committed
638

639

640
    def _copy_main_params_to_model_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
641
642
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
643
644
645
646
647
        _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
                                        overflow_buf=self._dummy_overflow_buf)


    def _copy_model_params_to_main_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
648
649
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
650
651
        _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
                                        overflow_buf=self._dummy_overflow_buf)
652
653


mohammad's avatar
mohammad committed
654
655
656
    def state_dict(self):
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
657
658
659
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
mohammad's avatar
mohammad committed
660
661
662
663
        return state_dict


    def load_state_dict(self, state_dict):
mohammad's avatar
mohammad committed
664
665
666
667
668
669
670
671
672
673
674
675
676
        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
            print_rank_0('***WARNING*** found an old checkpoint, will not '
                         'load grad scaler ...')
        else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
677
678
679
680
681
682
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')
mohammad's avatar
mohammad committed
683

684
        # Copy data for the main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
685
686
687
        fp32_from_float16_params_key = 'fp32_from_fp16_params'
        if fp32_from_float16_params_key not in state_dict:
            fp32_from_float16_params_key = 'fp32_from_fp16'
mohammad's avatar
mohammad committed
688
        for current_group, saved_group in zip(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
689
690
                self.fp32_from_float16_groups,
                state_dict[fp32_from_float16_params_key]):
mohammad's avatar
mohammad committed
691
692
693
694
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)


mohammad's avatar
mohammad committed
695
696
class FP32Optimizer(MegatronOptimizer):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
697
698
    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
699
                 params_have_main_grad,
700
701
                 use_contiguous_buffers_in_local_ddp,
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
702
703
704

        super(FP32Optimizer, self).__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
705
706
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
            models)
mohammad's avatar
mohammad committed
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722

        self._scale = torch.cuda.FloatTensor([1.0])


    def zero_grad(self, set_to_none=True):
        """Copied from torch.optim.optimizer"""
        for group in self.optimizer.param_groups:
            _zero_grad_group_helper(group['params'], set_to_none)


    def get_loss_scale(self):
        """FP32 optimizer does not do any scaling."""
        return self._scale


    @torch.no_grad()
723
    def step(self, args, timers):
mohammad's avatar
mohammad committed
724
        """Clip gradients (if needed) and step the base optimizer.
mohammad's avatar
mohammad committed
725
        Always return successful since there is no overflow."""
mohammad's avatar
mohammad committed
726

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
727
        # Copy main_grads to grads.
728
        timers('optimizer-copy-to-main-grad').start()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
729
730
731
732
733
        if self.params_have_main_grad:
            for param_group in self.optimizer.param_groups:
                for param in param_group['params']:
                    param.grad = param.main_grad

734
735
736
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
737
                    if not self.use_contiguous_buffers_in_local_ddp:
738
                        param.main_grad = None
739
        timers('optimizer-copy-to-main-grad').stop()
740

mohammad's avatar
mohammad committed
741
        # Clip gradients.
742
        timers('optimizer-clip-main-grad').start()
743
        grad_norm = None
mohammad's avatar
mohammad committed
744
        if self.clip_grad > 0.0:
745
            grad_norm = self.clip_grad_norm(self.clip_grad)
746
        timers('optimizer-clip-main-grad').stop()
mohammad's avatar
mohammad committed
747

Rewon Child's avatar
Rewon Child committed
748
        # count the zeros in the grads
749
        timers('optimizer-count-zeros').start()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
750
751
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
752
        timers('optimizer-count-zeros').stop()
Rewon Child's avatar
Rewon Child committed
753

mohammad's avatar
mohammad committed
754
        # Update parameters.
755
        timers('optimizer-inner-step').start()
mohammad's avatar
mohammad committed
756
        self.optimizer.step()
757
        timers('optimizer-inner-step').stop()
mohammad's avatar
mohammad committed
758
759

        # No overflow for FP32 optimizer.
760
        return True, grad_norm, num_zeros_in_grad
mohammad's avatar
mohammad committed
761
762


763
764
765
766
    def reload_model_params(self):
        pass


mohammad's avatar
mohammad committed
767
768
769
770
771
772
    def state_dict(self):
        return self.optimizer.state_dict()


    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)