optimizer.py 28.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron optimizer."""
mohammad's avatar
mohammad committed
17
18
19
20
21

from abc import ABC
from abc import abstractmethod
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
22
23
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
mohammad's avatar
mohammad committed
24

mohammad's avatar
mohammad committed
25
26
from megatron import get_timers
from megatron import mpu
mohammad's avatar
mohammad committed
27
from megatron import print_rank_0
28
29
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
30
31
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
32
from megatron.utils import unwrap_model
33

34
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
35

Lawrence McAfee's avatar
Lawrence McAfee committed
36

mohammad's avatar
mohammad committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def _zero_grad_group_helper(group, set_to_none):
    """Zero out the gradient for a group of parameters.
    Note: copied from torch.optim.optimizer."""
    for param in group:
        if param.grad is not None:
            if set_to_none:
                param.grad = None
            else:
                if param.grad.grad_fn is not None:
                    param.grad.detach_()
                else:
                    param.grad.requires_grad_(False)
                param.grad.zero_()


52
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
53
54
55
56
    """Use multi-tensor-applier to copy values from one list to another.
    We don't have a blfoat16 implementation so for now if the overflow_buf
    is not provided, we default back to simple loop copy to be compatible
    with bfloat16."""
57
58
    if overflow_buf:
        overflow_buf.fill_(0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
59
60
61
62
63
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             overflow_buf,
                             [this, that],
                             1.0)
64
    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
65
66
67
        for this_, that_ in zip(this, that):
            that_.copy_(this_)

68

mohammad's avatar
mohammad committed
69
70
71

class MegatronOptimizer(ABC):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
72
73
74

    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
75
                 params_have_main_grad,
76
77
                 use_contiguous_buffers_in_local_ddp,
                 models):
78

mohammad's avatar
mohammad committed
79
80
81
        """Input optimizer is the base optimizer for example Adam."""
        self.optimizer = optimizer
        assert self.optimizer, 'no optimizer is provided.'
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
82
83
84
85
        # Set gradient clipping and logging params.
        self.clip_grad = clip_grad
        self.log_num_zeros_in_grad = log_num_zeros_in_grad
        self.params_have_main_grad = params_have_main_grad
86
        self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
87

88
89
90
91
        # 'models' are retained for access to the contiguous grad buffers.
        # (see distributed optimizer)
        self.models = models

92
        if self.use_contiguous_buffers_in_local_ddp:
93
94
            assert self.params_have_main_grad, \
                "use of contiguous buffer requires that params have main grad"
mohammad's avatar
mohammad committed
95

96

Rewon Child's avatar
Rewon Child committed
97
    def get_parameters(self):
98
99
100
101
        params = []
        for param_group in self.optimizer.param_groups:
            for param in param_group['params']:
                params.append(param)
Rewon Child's avatar
Rewon Child committed
102
103
        return params

104

105
    def get_main_grads_for_grad_norm(self):
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

        # Filter parameters based on:
        #   - grad should not be none
        #   - parameter should not be shared
        #   - should not be a replica due to tensor model parallelism
        params = self.get_parameters()
        grads_for_norm = []
        for param in params:
            grad = param.grad
            grad_not_none = grad is not None
            is_not_shared = param_is_not_shared(param)
            is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
            if grad_not_none and is_not_shared and is_not_tp_duplicate:
                grads_for_norm.append(grad)

        return grads_for_norm

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
123

124
    def get_model_parallel_group(self):
125
        """Default returned here, but the distributed optimizer overrides this."""
126
127
128
        return mpu.get_model_parallel_group()


129
    def clip_grad_norm(self, clip_grad):
Lawrence McAfee's avatar
Lawrence McAfee committed
130
        params = self.get_parameters()
131
        grads_for_norm = self.get_main_grads_for_grad_norm()
132
        return clip_grad_norm_fp32(
133
            params, grads_for_norm, clip_grad,
134
            model_parallel_group=self.get_model_parallel_group())
135

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
136

Rewon Child's avatar
Rewon Child committed
137
138
    def count_zeros(self):
        params = self.get_parameters()
139
140
        return count_zeros_fp32(params,
                                model_parallel_group=self.get_model_parallel_group())
Rewon Child's avatar
Rewon Child committed
141

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
142

mohammad's avatar
mohammad committed
143
144
145
146
    @abstractmethod
    def zero_grad(self, set_to_none=True):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
147

mohammad's avatar
mohammad committed
148
149
    @abstractmethod
    def get_loss_scale(self):
150
        """The output should be a cuda tensor of size 1."""
mohammad's avatar
mohammad committed
151
152
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
153

mohammad's avatar
mohammad committed
154
155
156
157
    def scale_loss(self, loss):
        """Simple scaling."""
        return self.get_loss_scale() * loss

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
158

159
160
    @abstractmethod
    def reload_model_params(self):
161
162
163
164
165
        """Refreshes any internal state from the current model parameters.
        Call whenever the parameters are changed outside of the optimizer.
        For example, when we load a model from a checkpoint  without loading
        the optimizer, the model parameters are updated but for fp16 optimizer
        with main parameters, the main parameters need to also be updated."""
166
167
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
168

mohammad's avatar
mohammad committed
169
170
171
172
    @abstractmethod
    def state_dict(self):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
173

mohammad's avatar
mohammad committed
174
175
176
177
    @abstractmethod
    def load_state_dict(self, state_dict):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
178

mohammad's avatar
mohammad committed
179
180
181
182
183
184
185
186
187
188
    # Promote state so it can be retrieved or set via
    # "optimizer_instance.state"
    def _get_state(self):
        return self.optimizer.state

    def _set_state(self, value):
        self.optimizer.state = value

    state = property(_get_state, _set_state)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
189

mohammad's avatar
mohammad committed
190
191
192
193
194
195
196
197
198
199
200
201
    # Promote param_groups so it can be retrieved or set via
    # "optimizer_instance.param_groups"
    # (for example, to adjust the learning rate)
    def _get_param_groups(self):
        return self.optimizer.param_groups

    def _set_param_groups(self, value):
        self.optimizer.param_groups = value

    param_groups = property(_get_param_groups, _set_param_groups)


202
    @abstractmethod
203
    def step(self, args, timers):
204
205
        pass

Lawrence McAfee's avatar
Lawrence McAfee committed
206

207
    def gather_model_params(self, args, timers):
208
209
210
211
        """
        For the case of a non-distributed-optimizer, there is nothing to
        do here.
        """
212
213
        pass

Lawrence McAfee's avatar
Lawrence McAfee committed
214

215
    def allreduce_word_embedding_grads(self, args):
216
        """
217
        All-reduce word embedding grads.
218

219
220
221
        Reduce grads across first and last stages to ensure that word_embeddings
        parameters stay in sync. This should only run for models that support
        pipelined model parallelism (BERT and GPT-2).
222
        """
223
224
225
226

        if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
                mpu.get_pipeline_model_parallel_world_size() > 1:
            if mpu.is_pipeline_first_stage(ignore_virtual=True):
227
                unwrapped_model = self.models[0]
228
            elif mpu.is_pipeline_last_stage(ignore_virtual=True):
229
                unwrapped_model = self.models[-1]
230
            else:  # We do not support the interleaved schedule for T5 yet.
231
                unwrapped_model = self.models[0]
232
233
234
235
236
237
238
239
240
241
242
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))

            if unwrapped_model.share_word_embeddings:
                word_embeddings_weight = unwrapped_model.word_embeddings_weight()
                if args.DDP_impl == 'local':
                    grad = word_embeddings_weight.main_grad
                else:
                    grad = word_embeddings_weight.grad
                torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())

Lawrence McAfee's avatar
Lawrence McAfee committed
243

244
    def allreduce_position_embedding_grads(self, args):
245
        """
246
247
248
249
        All-reduce position_embeddings grad across first (encoder) and
        split (decoder) stages to ensure that position embeddings parameters
        stay in sync. This should only run for T5 models with pipeline
        parallelism.
250
        """
251
252
253
        if mpu.is_rank_in_position_embedding_group() and \
                mpu.get_pipeline_model_parallel_world_size() > 1 and \
                args.pipeline_model_parallel_split_rank is not None:
254
            unwrapped_model = self.models[0]
255
256
257
258
259
260
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))
            assert args.DDP_impl == 'local', \
                'T5 model is only supported with local DDP mode'
            grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
            torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
261

Lawrence McAfee's avatar
Lawrence McAfee committed
262

263
    def allreduce_embedding_grads(self, args):
264
        """All-reduce both word and position embeddings."""
265
266
        self.allreduce_word_embedding_grads(args)
        self.allreduce_position_embedding_grads(args)
267

Lawrence McAfee's avatar
Lawrence McAfee committed
268

269
    def reduce_model_grads(self, args, timers):
270
        """All-reduce all grads, and all-reduce embeddings."""
271
272
273
274

        # All-reduce if needed.
        if args.DDP_impl == 'local':
            timers('backward-params-all-reduce').start()
275
276
            for model in self.models:
                model.allreduce_gradients()
277
278
279
280
            timers('backward-params-all-reduce').stop()

        # All-reduce embedding grads.
        timers('backward-embedding-all-reduce').start()
281
        self.allreduce_embedding_grads(args)
282
283
        timers('backward-embedding-all-reduce').stop()

284

285
class MixedPrecisionOptimizer(MegatronOptimizer):
Lawrence McAfee's avatar
Lawrence McAfee committed
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    """Base class for both the float-16 and the distributed optimizer.

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
    """
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
314
315

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
316
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
317
                 fp16, bf16, grad_scaler,
318
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
319

Lawrence McAfee's avatar
Lawrence McAfee committed
320
        super().__init__(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
321
            optimizer, clip_grad, log_num_zeros_in_grad,
322
323
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
            models)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
324

325
        self.fp16 = fp16
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
326
        self.bf16 = bf16
mohammad's avatar
mohammad committed
327
        self.grad_scaler = grad_scaler
328

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
329
330
        # None grad scaler is only supported for bf16.
        if self.grad_scaler is None:
331
            assert not self.fp16, 'fp16 expects a grad scaler.'
mohammad's avatar
mohammad committed
332
333
334

        # Tensor used to determine if a nan/if has happend.
        # Any non-zero value indicates inf/nan.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
335
336
337
338
        # Note that we keep this for the cases that grad scaler is none.
        # We still record nan/inf if we have a bfloat16 with a grad scaler.
        if self.grad_scaler:
            self.found_inf = torch.cuda.FloatTensor([0.0])
mohammad's avatar
mohammad committed
339
340

        # Dummy tensor needed for apex multi-apply tensor.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
341
342
343
344
345
346
347
348
349
350
        # For bfloat, we don't have multi-tensor apply and for now
        # we set it to none so the multi-tensor apply gets ignored.
        if bf16:
            self._dummy_overflow_buf = None
        else:
            self._dummy_overflow_buf = torch.cuda.IntTensor([0])

        # In case grad scaler is not passed, define the unity scale.
        if self.grad_scaler is None:
            self._scale_one = torch.cuda.FloatTensor([1.0])
mohammad's avatar
mohammad committed
351

Lawrence McAfee's avatar
Lawrence McAfee committed
352
353
354
355
356
357
358

    def get_loss_scale(self):
        if self.grad_scaler is None:
            return self._scale_one
        return self.grad_scaler.scale


Lawrence McAfee's avatar
Lawrence McAfee committed
359
360
361
362
    def reload_model_params(self):
        self._copy_model_params_to_main_params()


363
    def _unscale_main_grads_and_check_for_nan(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
364
365
366
367
368
369
370
371
372
373
374
375
376

        # Collect main grads.
        main_grads = self._collect_main_grad_data_for_unscaling()

        # Reset found inf.
        self.found_inf.fill_(0.0)

        # Unscale and set found inf/nan
        torch._amp_foreach_non_finite_check_and_unscale_(
            main_grads, self.found_inf, self.grad_scaler.inv_scale)

        # Update across all model parallel instances.
        torch.distributed.all_reduce(self.found_inf,
377
378
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=self.get_model_parallel_group())
Lawrence McAfee's avatar
Lawrence McAfee committed
379
380
381
382
383
384
385
386

        # Check for nan.
        found_inf_flag = (self.found_inf.item() > 0)

        return found_inf_flag


    @torch.no_grad()
387
    def step(self, args, timers):
388

Lawrence McAfee's avatar
Lawrence McAfee committed
389
390
        # Copy gradients from model params to main params.
        timers('optimizer-copy-to-main-grad').start()
391
        self._copy_model_grads_to_main_grads()
Lawrence McAfee's avatar
Lawrence McAfee committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
        timers('optimizer-copy-to-main-grad').stop()

        # Do unscale, check for inf, and update grad scaler only for
        # the case that grad scaler is provided.
        if self.grad_scaler:

            # Unscale and check for inf/nan.
            timers('optimizer-unscale-and-check-inf').start()
            found_inf_flag = self._unscale_main_grads_and_check_for_nan()
            timers('optimizer-unscale-and-check-inf').stop()

            # We are done with scaling gradients
            # so we can update the loss scale.
            self.grad_scaler.update(found_inf_flag)

            # If we found inf/nan, skip the update.
            if found_inf_flag:
                return False, None, None

        # Clip the main gradients.
        timers('optimizer-clip-main-grad').start()
        grad_norm = None
        if self.clip_grad > 0.0:
415
            grad_norm = self.clip_grad_norm(self.clip_grad)
Lawrence McAfee's avatar
Lawrence McAfee committed
416
417
        timers('optimizer-clip-main-grad').stop()

Lawrence McAfee's avatar
Lawrence McAfee committed
418
        # Count the zeros in the grads.
419
        timers('optimizer-count-zeros').start()
Lawrence McAfee's avatar
Lawrence McAfee committed
420
421
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
422
        timers('optimizer-count-zeros').stop()
Lawrence McAfee's avatar
Lawrence McAfee committed
423

424
        # Step the optimizer.
425
        timers('optimizer-inner-step').start()
426
        self.optimizer.step()
427
        timers('optimizer-inner-step').stop()
428

Lawrence McAfee's avatar
Lawrence McAfee committed
429
430
        # Update params from main params.
        timers('optimizer-copy-main-to-model-params').start()
431
        self._copy_main_params_to_model_params()
Lawrence McAfee's avatar
Lawrence McAfee committed
432
433
434
435
436
437
        timers('optimizer-copy-main-to-model-params').stop()

        # Successful update.
        return True, grad_norm, num_zeros_in_grad


438
class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
Lawrence McAfee's avatar
Lawrence McAfee committed
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
    """Float16 optimizer for fp16 and bf16 data types.

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
Lawrence McAfee's avatar
Lawrence McAfee committed
455
456
457
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
Lawrence McAfee's avatar
Lawrence McAfee committed
458
459
460
461
462
463
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
Lawrence McAfee's avatar
Lawrence McAfee committed
464
465
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
Lawrence McAfee's avatar
Lawrence McAfee committed
466
467
468
469
    """

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
470
                 fp16, bf16, grad_scaler, models):
Lawrence McAfee's avatar
Lawrence McAfee committed
471
472
473
474

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
475
            fp16, bf16, grad_scaler, models)
Lawrence McAfee's avatar
Lawrence McAfee committed
476

mohammad's avatar
mohammad committed
477
        # ======================
478
        # main parameter stuff
mohammad's avatar
mohammad committed
479
480
481
        # ======================

        # Three groups of parameters:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
482
483
        #   float16_groups: original float16 parameters
        #   fp32_from_float16_groups: fp32 copy of float16 parameters
mohammad's avatar
mohammad committed
484
        #   fp32_from_fp32_groups: original fp32 parameters
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
485
486
        self.float16_groups = []
        self.fp32_from_float16_groups = []
mohammad's avatar
mohammad committed
487
488
489
490
        self.fp32_from_fp32_groups = []

        # For all the groups in the original optimizer:
        for param_group in self.optimizer.param_groups:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
491
            float16_params_this_group = []
mohammad's avatar
mohammad committed
492
            fp32_params_this_group = []
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
493
            fp32_from_float16_params_this_group = []
mohammad's avatar
mohammad committed
494
495
496
497
            # For all the parameters in this group:
            for i, param in enumerate(param_group['params']):
                if param.requires_grad:

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
498
499
500
501
                    # float16 params:
                    if param.type() in ['torch.cuda.HalfTensor',
                                        'torch.cuda.BFloat16Tensor']:
                        float16_params_this_group.append(param)
mohammad's avatar
mohammad committed
502
                        # Create a copy
503
                        main_param = param.detach().clone().float()
mohammad's avatar
mohammad committed
504
                        # Copy tensor model parallel attributes.
505
                        mpu.copy_tensor_model_parallel_attributes(main_param,
mohammad's avatar
mohammad committed
506
                                                                  param)
507
                        if hasattr(param, 'shared'):
508
                            main_param.shared = param.shared
mohammad's avatar
mohammad committed
509
                        # Replace the optimizer params with the new fp32 copy.
510
                        param_group['params'][i] = main_param
511

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
512
                        fp32_from_float16_params_this_group.append(main_param)
513
                        # Reset existing state dict key to the new main param.
mohammad's avatar
mohammad committed
514
                        if param in self.optimizer.state:
515
                            self.optimizer.state[main_param] \
mohammad's avatar
mohammad committed
516
517
518
519
520
521
522
                                = self.optimizer.state.pop(param)
                    # fp32 params.
                    elif param.type() == 'torch.cuda.FloatTensor':
                        fp32_params_this_group.append(param)
                        param_group['params'][i] = param

                    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
523
524
525
526
527
528
529
530
531
                        raise TypeError('Wrapped parameters must be one of '
                                        'torch.cuda.FloatTensor,  '
                                        'torch.cuda.HalfTensor, or '
                                        'torch.cuda.BFloat16Tensor. '
                                        'Received {}'.format(param.type()))

            self.float16_groups.append(float16_params_this_group)
            self.fp32_from_float16_groups.append(
                fp32_from_float16_params_this_group)
mohammad's avatar
mohammad committed
532
533
534
            self.fp32_from_fp32_groups.append(fp32_params_this_group)


Lawrence McAfee's avatar
Lawrence McAfee committed
535
536
537
538
539
540
541
542
543
544
545
546
    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
        float16_groups & fp32_from_fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
        for group in self.float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_float16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_fp32_groups:
            _zero_grad_group_helper(group, set_to_none)
mohammad's avatar
mohammad committed
547
548


549
    def _collect_main_grad_data_for_unscaling(self):
550

551
        main_grads = []
552

553
554
555
556
557
        # fp32 params from float16 ones.
        for main_group in self.fp32_from_float16_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
558

559
560
561
562
563
564
565
        # Append fp32 parameters.
        for main_group in self.fp32_from_fp32_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
        
        return main_grads
566
567


568
569
570
571
572
573
574
575
576
    def _get_model_and_main_params_data_float16(self):
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
577

Lawrence McAfee's avatar
Lawrence McAfee committed
578

579
    def _copy_model_grads_to_main_grads(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
580
581
582
        # This only needs to be done for the float16 group.
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
583
            for model_param, main_param in zip(model_group, main_group):
584
                if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
585
586
587
588
                    main_param.grad = model_param.main_grad.float()
                else:
                    if model_param.grad is not None:
                        main_param.grad = model_param.grad.float()
589
590
591
592
593

                # Safe to deallocate model's grad/main_grad after copying.
                # (If using contiguous buffers, main_grad's memory should
                # persist and therefore should not be deallocated.)
                model_param.grad = None
594
                if self.params_have_main_grad and \
595
                   not self.use_contiguous_buffers_in_local_ddp:
596
597
                    model_param.main_grad = None

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
598
599
600
601
602
        # For fp32 grads, we need to reset the grads to main grad.
        if self.params_have_main_grad:
            for model_group in self.fp32_from_fp32_groups:
                for model_param in model_group:
                    model_param.grad = model_param.main_grad
mohammad's avatar
mohammad committed
603

604
605
606
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
607
                    if not self.use_contiguous_buffers_in_local_ddp:
608
                        model_param.main_grad = None
mohammad's avatar
mohammad committed
609

610

611
    def _copy_main_params_to_model_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
612
613
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
614
615
616
617
618
        _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
                                        overflow_buf=self._dummy_overflow_buf)


    def _copy_model_params_to_main_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
619
620
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
621
622
        _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
                                        overflow_buf=self._dummy_overflow_buf)
623
624


mohammad's avatar
mohammad committed
625
626
627
    def state_dict(self):
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
628
629
630
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
mohammad's avatar
mohammad committed
631
632
633
634
        return state_dict


    def load_state_dict(self, state_dict):
mohammad's avatar
mohammad committed
635
636
637
638
639
640
641
642
643
644
645
646
647
        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
            print_rank_0('***WARNING*** found an old checkpoint, will not '
                         'load grad scaler ...')
        else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
648
649
650
651
652
653
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')
mohammad's avatar
mohammad committed
654

655
        # Copy data for the main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
656
657
658
        fp32_from_float16_params_key = 'fp32_from_fp16_params'
        if fp32_from_float16_params_key not in state_dict:
            fp32_from_float16_params_key = 'fp32_from_fp16'
mohammad's avatar
mohammad committed
659
        for current_group, saved_group in zip(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
660
661
                self.fp32_from_float16_groups,
                state_dict[fp32_from_float16_params_key]):
mohammad's avatar
mohammad committed
662
663
664
665
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)


mohammad's avatar
mohammad committed
666
667
class FP32Optimizer(MegatronOptimizer):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
668
669
    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
670
                 params_have_main_grad,
671
672
                 use_contiguous_buffers_in_local_ddp,
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
673
674
675

        super(FP32Optimizer, self).__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
676
677
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
            models)
mohammad's avatar
mohammad committed
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693

        self._scale = torch.cuda.FloatTensor([1.0])


    def zero_grad(self, set_to_none=True):
        """Copied from torch.optim.optimizer"""
        for group in self.optimizer.param_groups:
            _zero_grad_group_helper(group['params'], set_to_none)


    def get_loss_scale(self):
        """FP32 optimizer does not do any scaling."""
        return self._scale


    @torch.no_grad()
694
    def step(self, args, timers):
mohammad's avatar
mohammad committed
695
        """Clip gradients (if needed) and step the base optimizer.
mohammad's avatar
mohammad committed
696
        Always return successful since there is no overflow."""
mohammad's avatar
mohammad committed
697

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
698
        # Copy main_grads to grads.
699
        timers('optimizer-copy-to-main-grad').start()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
700
701
702
703
704
        if self.params_have_main_grad:
            for param_group in self.optimizer.param_groups:
                for param in param_group['params']:
                    param.grad = param.main_grad

705
706
707
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
708
                    if not self.use_contiguous_buffers_in_local_ddp:
709
                        param.main_grad = None
710
        timers('optimizer-copy-to-main-grad').stop()
711

mohammad's avatar
mohammad committed
712
        # Clip gradients.
713
        timers('optimizer-clip-main-grad').start()
714
        grad_norm = None
mohammad's avatar
mohammad committed
715
        if self.clip_grad > 0.0:
716
            grad_norm = self.clip_grad_norm(self.clip_grad)
717
        timers('optimizer-clip-main-grad').stop()
mohammad's avatar
mohammad committed
718

Rewon Child's avatar
Rewon Child committed
719
        # count the zeros in the grads
720
        timers('optimizer-count-zeros').start()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
721
722
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
723
        timers('optimizer-count-zeros').stop()
Rewon Child's avatar
Rewon Child committed
724

mohammad's avatar
mohammad committed
725
        # Update parameters.
726
        timers('optimizer-inner-step').start()
mohammad's avatar
mohammad committed
727
        self.optimizer.step()
728
        timers('optimizer-inner-step').stop()
mohammad's avatar
mohammad committed
729
730

        # No overflow for FP32 optimizer.
731
        return True, grad_norm, num_zeros_in_grad
mohammad's avatar
mohammad committed
732
733


734
735
736
737
    def reload_model_params(self):
        pass


mohammad's avatar
mohammad committed
738
739
740
741
742
743
    def state_dict(self):
        return self.optimizer.state_dict()


    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)