optimizer.py 54.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron optimizer."""
mohammad's avatar
mohammad committed
17
18
19
20
21
22
23
24
25

from abc import ABC
from abc import abstractmethod

import torch

from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

mohammad's avatar
mohammad committed
26
27
from megatron import get_timers
from megatron import mpu
mohammad's avatar
mohammad committed
28
29
from megatron import print_rank_0

Rewon Child's avatar
Rewon Child committed
30
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
mohammad's avatar
mohammad committed
31

32
33
# >>>
from lutil import pax, tp
34

35
DEBUG_ITERATION = 1 # 10
36
# <<<
mohammad's avatar
mohammad committed
37

Lawrence McAfee's avatar
Lawrence McAfee committed
38

mohammad's avatar
mohammad committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def _zero_grad_group_helper(group, set_to_none):
    """Zero out the gradient for a group of parameters.
    Note: copied from torch.optim.optimizer."""
    for param in group:
        if param.grad is not None:
            if set_to_none:
                param.grad = None
            else:
                if param.grad.grad_fn is not None:
                    param.grad.detach_()
                else:
                    param.grad.requires_grad_(False)
                param.grad.zero_()


54
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
55
56
57
58
    """Use multi-tensor-applier to copy values from one list to another.
    We don't have a blfoat16 implementation so for now if the overflow_buf
    is not provided, we default back to simple loop copy to be compatible
    with bfloat16."""
59
60
    if overflow_buf:
        overflow_buf.fill_(0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
61
62
63
64
65
        # Scaling with factor `1.0` is equivalent to copy.
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             overflow_buf,
                             [this, that],
                             1.0)
66
    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
67
68
69
        for this_, that_ in zip(this, that):
            that_.copy_(this_)

70

mohammad's avatar
mohammad committed
71
72
73

class MegatronOptimizer(ABC):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
74
75
76

    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
77
                 params_have_main_grad,
78
                 use_contiguous_buffers_in_local_ddp):
79

mohammad's avatar
mohammad committed
80
81
82
        """Input optimizer is the base optimizer for example Adam."""
        self.optimizer = optimizer
        assert self.optimizer, 'no optimizer is provided.'
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
83
84
85
86
        # Set gradient clipping and logging params.
        self.clip_grad = clip_grad
        self.log_num_zeros_in_grad = log_num_zeros_in_grad
        self.params_have_main_grad = params_have_main_grad
87
        self.use_contiguous_buffers_in_local_ddp = use_contiguous_buffers_in_local_ddp
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
88

89
        if self.use_contiguous_buffers_in_local_ddp:
90
91
            assert self.params_have_main_grad, \
                "use of contiguous buffer requires that params have main grad"
mohammad's avatar
mohammad committed
92

Rewon Child's avatar
Rewon Child committed
93
    def get_parameters(self):
94
95
96
97
        params = []
        for param_group in self.optimizer.param_groups:
            for param in param_group['params']:
                params.append(param)
Rewon Child's avatar
Rewon Child committed
98
99
        return params

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
100

Rewon Child's avatar
Rewon Child committed
101
102
    def clip_grad_norm(self, clip_grad):
        params = self.get_parameters()
Lawrence McAfee's avatar
Lawrence McAfee committed
103
104
105
        # >>>
        # pax(0, {
        #     "clip_grad" : clip_grad,
106
        #     # "params": [ (p.tensor_model_parallel, tp(p)) for p in params ],
Lawrence McAfee's avatar
Lawrence McAfee committed
107
108
109
        #     "grads" : [ p.grad for p in params ],
        # })
        # <<<
110
        return clip_grad_norm_fp32(params, clip_grad)
111

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
112

Rewon Child's avatar
Rewon Child committed
113
114
115
116
    def count_zeros(self):
        params = self.get_parameters()
        return count_zeros_fp32(params)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
117

mohammad's avatar
mohammad committed
118
119
120
121
    @abstractmethod
    def zero_grad(self, set_to_none=True):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
122

mohammad's avatar
mohammad committed
123
124
    @abstractmethod
    def get_loss_scale(self):
125
        """The output should be a cuda tensor of size 1."""
mohammad's avatar
mohammad committed
126
127
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
128

mohammad's avatar
mohammad committed
129
130
131
132
    def scale_loss(self, loss):
        """Simple scaling."""
        return self.get_loss_scale() * loss

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
133

Lawrence McAfee's avatar
Lawrence McAfee committed
134
    @abstractmethod
135
    def reduce_grads(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
136
137
138
        pass


mohammad's avatar
mohammad committed
139
140
141
142
    @abstractmethod
    def step(self):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
143

Lawrence McAfee's avatar
Lawrence McAfee committed
144
145
146
147
148
    @abstractmethod
    def gather_params(self):
        pass


149
150
    @abstractmethod
    def reload_model_params(self):
151
152
153
154
155
        """Refreshes any internal state from the current model parameters.
        Call whenever the parameters are changed outside of the optimizer.
        For example, when we load a model from a checkpoint  without loading
        the optimizer, the model parameters are updated but for fp16 optimizer
        with main parameters, the main parameters need to also be updated."""
156
157
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
158

mohammad's avatar
mohammad committed
159
160
161
162
    @abstractmethod
    def state_dict(self):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
163

mohammad's avatar
mohammad committed
164
165
166
167
    @abstractmethod
    def load_state_dict(self, state_dict):
        pass

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
168

mohammad's avatar
mohammad committed
169
170
171
172
173
174
175
176
177
178
    # Promote state so it can be retrieved or set via
    # "optimizer_instance.state"
    def _get_state(self):
        return self.optimizer.state

    def _set_state(self, value):
        self.optimizer.state = value

    state = property(_get_state, _set_state)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
179

mohammad's avatar
mohammad committed
180
181
182
183
184
185
186
187
188
189
190
191
    # Promote param_groups so it can be retrieved or set via
    # "optimizer_instance.param_groups"
    # (for example, to adjust the learning rate)
    def _get_param_groups(self):
        return self.optimizer.param_groups

    def _set_param_groups(self, value):
        self.optimizer.param_groups = value

    param_groups = property(_get_param_groups, _set_param_groups)


Lawrence McAfee's avatar
Lawrence McAfee committed
192
class BaseFloat16Optimizer(MegatronOptimizer):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
193
194

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
195
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
196
197
                 bf16, grad_scaler,
                 models):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
198

Lawrence McAfee's avatar
Lawrence McAfee committed
199
        super().__init__(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
200
            optimizer, clip_grad, log_num_zeros_in_grad,
201
            params_have_main_grad, use_contiguous_buffers_in_local_ddp)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
202

203
204
205
        # >>>
        self.models = models
        # <<<
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
206
        self.bf16 = bf16
mohammad's avatar
mohammad committed
207
        self.grad_scaler = grad_scaler
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
208
209
210
        # None grad scaler is only supported for bf16.
        if self.grad_scaler is None:
            assert self.bf16, 'fp16 expects a grad scaler.'
mohammad's avatar
mohammad committed
211
212
213

        # Tensor used to determine if a nan/if has happend.
        # Any non-zero value indicates inf/nan.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
214
215
216
217
        # Note that we keep this for the cases that grad scaler is none.
        # We still record nan/inf if we have a bfloat16 with a grad scaler.
        if self.grad_scaler:
            self.found_inf = torch.cuda.FloatTensor([0.0])
mohammad's avatar
mohammad committed
218
219

        # Dummy tensor needed for apex multi-apply tensor.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
220
221
222
223
224
225
226
227
228
229
        # For bfloat, we don't have multi-tensor apply and for now
        # we set it to none so the multi-tensor apply gets ignored.
        if bf16:
            self._dummy_overflow_buf = None
        else:
            self._dummy_overflow_buf = torch.cuda.IntTensor([0])

        # In case grad scaler is not passed, define the unity scale.
        if self.grad_scaler is None:
            self._scale_one = torch.cuda.FloatTensor([1.0])
mohammad's avatar
mohammad committed
230

Lawrence McAfee's avatar
Lawrence McAfee committed
231
232
233
234
235
236
237

    def get_loss_scale(self):
        if self.grad_scaler is None:
            return self._scale_one
        return self.grad_scaler.scale


Lawrence McAfee's avatar
Lawrence McAfee committed
238
239
240
241
    def reload_model_params(self):
        self._copy_model_params_to_main_params()


Lawrence McAfee's avatar
Lawrence McAfee committed
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    def _unscale_main_grads_and_check_for_nan(self):

        # Collect main grads.
        main_grads = self._collect_main_grad_data_for_unscaling()
        # pax(1, {"main_grads": main_grads})

        # Reset found inf.
        self.found_inf.fill_(0.0)

        # Unscale and set found inf/nan
        torch._amp_foreach_non_finite_check_and_unscale_(
            main_grads, self.found_inf, self.grad_scaler.inv_scale)

        # Update across all model parallel instances.
256
257
258
259
260
        # >>>
        # torch.distributed.all_reduce(self.found_inf,
        #                              op=torch.distributed.ReduceOp.MAX,
        #                              group=mpu.get_model_parallel_group())
        # +++
Lawrence McAfee's avatar
Lawrence McAfee committed
261
        torch.distributed.all_reduce(self.found_inf,
262
263
                                     op=torch.distributed.ReduceOp.MAX)
        # <<<
Lawrence McAfee's avatar
Lawrence McAfee committed
264
265
266
267
268
269
270
271

        # Check for nan.
        found_inf_flag = (self.found_inf.item() > 0)

        return found_inf_flag


    @torch.no_grad()
272
    def step(self, ITERATION):
Lawrence McAfee's avatar
Lawrence McAfee committed
273
274
275
276
277

        timers = get_timers()

        # Copy gradients from model params to main params.
        timers('optimizer-copy-to-main-grad').start()
278
        self._copy_model_grads_to_main_grads(ITERATION)
Lawrence McAfee's avatar
Lawrence McAfee committed
279
280
        timers('optimizer-copy-to-main-grad').stop()

281
282
283
284
285
286
287
288
        # >>>
        # pax(0, {
        #     "[LOC]" : "[** BEFORE UNSCALE **]",
        #     "param_group / params" : [ p for g in self.optimizer.param_groups for p in g["params"] ],
        #     "param_group / grads" : [ p.grad for g in self.optimizer.param_groups for p in g["params"] ],
        # })
        # <<<

Lawrence McAfee's avatar
Lawrence McAfee committed
289
290
291
292
293
        # pax(0, {
        #     "params" : self.get_parameters(), # self.main_param_shards,
        #     "grads" : [ p.grad for p in self.get_parameters() ], # self.main_param_shards ],
        # })

Lawrence McAfee's avatar
Lawrence McAfee committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        # Do unscale, check for inf, and update grad scaler only for
        # the case that grad scaler is provided.
        if self.grad_scaler:

            # Unscale and check for inf/nan.
            timers('optimizer-unscale-and-check-inf').start()
            found_inf_flag = self._unscale_main_grads_and_check_for_nan()
            timers('optimizer-unscale-and-check-inf').stop()

            # We are done with scaling gradients
            # so we can update the loss scale.
            self.grad_scaler.update(found_inf_flag)

            # If we found inf/nan, skip the update.
            if found_inf_flag:
309
310
311
312
313
                pax(0, {
                    "main params" : self.get_main_params(),
                    "main grads" : self.get_main_grads(),
                    "found_inf_flag" : found_inf_flag,
                })
Lawrence McAfee's avatar
Lawrence McAfee committed
314
315
                return False, None, None

316
        # >>>
317
318
319
320
321
322
        # pax(0, {
        #     "[LOC]" : "[** BEFORE CLIP **]",
        #     "clip_grad" : self.clip_grad,
        #     # "param_group / params" : [ p for g in self.optimizer.param_groups for p in g["params"] ],
        #     "param_group / grads" : [ p.grad for g in self.optimizer.param_groups for p in g["params"] ],
        # })
323
324
        # <<<

Lawrence McAfee's avatar
Lawrence McAfee committed
325
326
327
328
329
330
331
        # Clip the main gradients.
        timers('optimizer-clip-main-grad').start()
        grad_norm = None
        if self.clip_grad > 0.0:
            grad_norm = self.clip_grad_norm(self.clip_grad)
        timers('optimizer-clip-main-grad').stop()

332
333
334
335
336
337
338
339
        # >>>
        pax(1, {
            "[LOC]" : "[** BEFORE NONZERO **]",
            # "param_group / params" : [ p for g in self.optimizer.param_groups for p in g["params"] ],
            "param_group / grads" : [ p.grad for g in self.optimizer.param_groups for p in g["params"] ],
        })
        # <<<

Lawrence McAfee's avatar
Lawrence McAfee committed
340
341
342
343
344
        # count the zeros in the grads
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None

        # >>>
345
346
347
348
349
350
        pax(0, {
            # "main params" : self.get_main_params(),
            # "main grads" : self.get_main_grads(),
            **{"param_groups / %d" % i : g for i, g in enumerate(self.optimizer.param_groups)},
            "param_group / grads" : [ p.grad for g in self.optimizer.param_groups for p in g["params"] ],
        })
Lawrence McAfee's avatar
Lawrence McAfee committed
351
352
        # <<<

353
354
355
        # Step the optimizer.
        self.optimizer.step()

Lawrence McAfee's avatar
Lawrence McAfee committed
356
357
        # Update params from main params.
        timers('optimizer-copy-main-to-model-params').start()
358
        self._copy_main_params_to_model_params(ITERATION)
Lawrence McAfee's avatar
Lawrence McAfee committed
359
360
        timers('optimizer-copy-main-to-model-params').stop()

361
362
363
364
365
366
367
        # >>>
        # pax(1, {
        #     "ITERATION" : ITERATION,
        #     "model_params" : [ p for m in self.models for p in m.parameters() ],
        # })
        # <<<

Lawrence McAfee's avatar
Lawrence McAfee committed
368
369
370
371
        # Successful update.
        return True, grad_norm, num_zeros_in_grad


Lawrence McAfee's avatar
Lawrence McAfee committed
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
# class Float16OptimizerWithFloat16Params(MegatronOptimizer):
class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
    """Float16 optimizer for fp16 and bf16 data types.

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
    """

    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
Lawrence McAfee's avatar
Lawrence McAfee committed
400
                 bf16, grad_scaler, models):
Lawrence McAfee's avatar
Lawrence McAfee committed
401
402
403
404

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
Lawrence McAfee's avatar
Lawrence McAfee committed
405
            bf16, grad_scaler, models)
Lawrence McAfee's avatar
Lawrence McAfee committed
406

mohammad's avatar
mohammad committed
407
        # ======================
408
        # main parameter stuff
mohammad's avatar
mohammad committed
409
410
411
        # ======================

        # Three groups of parameters:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
412
413
        #   float16_groups: original float16 parameters
        #   fp32_from_float16_groups: fp32 copy of float16 parameters
mohammad's avatar
mohammad committed
414
        #   fp32_from_fp32_groups: original fp32 parameters
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
415
416
        self.float16_groups = []
        self.fp32_from_float16_groups = []
mohammad's avatar
mohammad committed
417
418
419
420
        self.fp32_from_fp32_groups = []

        # For all the groups in the original optimizer:
        for param_group in self.optimizer.param_groups:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
421
            float16_params_this_group = []
mohammad's avatar
mohammad committed
422
            fp32_params_this_group = []
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
423
            fp32_from_float16_params_this_group = []
mohammad's avatar
mohammad committed
424
425
426
427
            # For all the parameters in this group:
            for i, param in enumerate(param_group['params']):
                if param.requires_grad:

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
428
429
430
431
                    # float16 params:
                    if param.type() in ['torch.cuda.HalfTensor',
                                        'torch.cuda.BFloat16Tensor']:
                        float16_params_this_group.append(param)
mohammad's avatar
mohammad committed
432
                        # Create a copy
433
                        main_param = param.detach().clone().float()
mohammad's avatar
mohammad committed
434
                        # Copy tensor model parallel attributes.
435
                        mpu.copy_tensor_model_parallel_attributes(main_param,
mohammad's avatar
mohammad committed
436
                                                                  param)
437
                        if hasattr(param, 'shared'):
438
                            main_param.shared = param.shared
mohammad's avatar
mohammad committed
439
                        # Replace the optimizer params with the new fp32 copy.
440
                        param_group['params'][i] = main_param
441

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
442
                        fp32_from_float16_params_this_group.append(main_param)
443
                        # Reset existing state dict key to the new main param.
mohammad's avatar
mohammad committed
444
                        if param in self.optimizer.state:
445
446
447
                            # >>>
                            raise Exception("hi.")
                            # <<<
448
                            self.optimizer.state[main_param] \
mohammad's avatar
mohammad committed
449
450
451
452
                                = self.optimizer.state.pop(param)

                    # fp32 params.
                    elif param.type() == 'torch.cuda.FloatTensor':
Lawrence McAfee's avatar
Lawrence McAfee committed
453
454
455
                        # >>>
                        pax(0, {"param": param})
                        # <<<
mohammad's avatar
mohammad committed
456
457
458
459
                        fp32_params_this_group.append(param)
                        param_group['params'][i] = param

                    else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
460
461
462
463
464
465
466
467
468
                        raise TypeError('Wrapped parameters must be one of '
                                        'torch.cuda.FloatTensor,  '
                                        'torch.cuda.HalfTensor, or '
                                        'torch.cuda.BFloat16Tensor. '
                                        'Received {}'.format(param.type()))

            self.float16_groups.append(float16_params_this_group)
            self.fp32_from_float16_groups.append(
                fp32_from_float16_params_this_group)
mohammad's avatar
mohammad committed
469
470
471
472
473
474
            self.fp32_from_fp32_groups.append(fp32_params_this_group)

        # Leverage state_dict() and load_state_dict() to
        # recast preexisting per-param state tensors
        self.optimizer.load_state_dict(self.optimizer.state_dict())

Lawrence McAfee's avatar
Lawrence McAfee committed
475
476
477
478
479
480
481
482
483
484
        # >>>
        # from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
        # params = self.get_parameters()
        # pax(0, {
        #     # "params / 0" : params[0],
        #     "params" : [ (p.tensor_model_parallel, tp(p)) for p in params ],
        #     "grads" : [ (param_is_not_tensor_parallel_duplicate(p.grad), tp(p.grad)) for p in params ],
        # })
        # <<<

mohammad's avatar
mohammad committed
485
486
487

    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
488
489
490
491
        float16_groups & fp32_from_fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
492
        for group in self.float16_groups:
mohammad's avatar
mohammad committed
493
            _zero_grad_group_helper(group, set_to_none)
494
495
        for group in self.fp32_from_float16_groups:
            _zero_grad_group_helper(group, set_to_none)
mohammad's avatar
mohammad committed
496
497
498
499
        for group in self.fp32_from_fp32_groups:
            _zero_grad_group_helper(group, set_to_none)


500
    # >>>
501
    def reduce_grads(self, model):
502
503
504
505
506
507
508
509
510
511
512
513
514

        # >>>
        from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

        from megatron import get_args
        from megatron import get_timers
        from megatron.model import DistributedDataParallel as LocalDDP
        from megatron.model import Float16Module
        from megatron.utils import unwrap_model

        args = get_args()
        timers = get_timers()
        # <<<
515

516
517
518
519
520
521
        # >>>
        # pax(0, {
        #     "grads" : [ p.main_grad for m in model for p in m.parameters() ],
        # })
        # <<<

522
523
524
525
526
527
528
        # All-reduce if needed.
        if args.DDP_impl == 'local':
            timers('backward-params-all-reduce').start()
            for model_module in model:
                model_module.allreduce_gradients()
            timers('backward-params-all-reduce').stop()

529
530
531
532
533
534
        # >>>
        # pax(0, {
        #     "grads" : [ p.main_grad for m in model for p in m.parameters() ],
        # })
        # <<<

535
536
537
538
539
540
541
        # All-reduce word_embeddings' grad across first and last stages to ensure
        # that word_embeddings parameters stay in sync.
        # This should only run for models that support pipelined model parallelism
        # (BERT and GPT-2).
        timers('backward-embedding-all-reduce').start()
        if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
                mpu.get_pipeline_model_parallel_world_size() > 1:
542
543
544
            # >>>
            raise Exception("hi.")
            # <<<
545
546
547
548
549
550
551
552
553
554
555
556
            if mpu.is_pipeline_first_stage(ignore_virtual=True):
                unwrapped_model = model[0]
            elif mpu.is_pipeline_last_stage(ignore_virtual=True):
                unwrapped_model = model[-1]
            else:  # We do not support the interleaved schedule for T5 yet.
                unwrapped_model = model[0]
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))

            if unwrapped_model.share_word_embeddings:
                word_embeddings_weight = unwrapped_model.word_embeddings_weight()
                # >>>
557
558
559
560
561
                if args.DDP_impl == 'local':
                    grad = word_embeddings_weight.main_grad
                else:
                    grad = word_embeddings_weight.grad
                torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
562
                # +++
563
564
565
                # grad_shard = optimizer.get_grad_shard(word_embeddings)
                # torch.distributed.all_reduce(grad_shard,
                #                              group=mpu.get_embedding_group())
566
567
568
569
570
571
572
573
574
575
576
577
578
579
                # <<<

        # All-reduce position_embeddings grad across first (encoder) and split (decoder) 
        # stages to ensure that position embeddings parameters stay in sync.
        # This should only run for T5 models with pipeline parallelism
        if mpu.is_rank_in_position_embedding_group() and \
                mpu.get_pipeline_model_parallel_world_size() > 1 and \
                args.pipeline_model_parallel_split_rank is not None:
            unwrapped_model = model[0]
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))
            assert args.DDP_impl == 'local', \
                'T5 model is only supported with local DDP mode'
            # >>>
580
581
            grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
            torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
582
            # +++
583
584
585
586
            # grad_shard = optimizer.get_grad_shard(
            #     unwrapped_model.language_model.embedding.position_embeddings.weight)
            # torch.distributed.all_reduce(grad_shard,
            #                              group=mpu.get_position_embedding_group())
587
588
589
            # <<<
        timers('backward-embedding-all-reduce').stop()

Lawrence McAfee's avatar
Lawrence McAfee committed
590
    def gather_params(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
591
        pass
Lawrence McAfee's avatar
Lawrence McAfee committed
592

593
    def _copy_model_grads_to_main_grads(self, ITERATION):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
594
595
596
        # This only needs to be done for the float16 group.
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
597
            for model_param, main_param in zip(model_group, main_group):
598
                if self.params_have_main_grad and hasattr(model_param, 'main_grad'):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
599
600
601
602
                    main_param.grad = model_param.main_grad.float()
                else:
                    if model_param.grad is not None:
                        main_param.grad = model_param.grad.float()
603
604
605
606
607

                # Safe to deallocate model's grad/main_grad after copying.
                # (If using contiguous buffers, main_grad's memory should
                # persist and therefore should not be deallocated.)
                model_param.grad = None
608
                if self.params_have_main_grad and \
609
                   not self.use_contiguous_buffers_in_local_ddp:
610
611
                    model_param.main_grad = None

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
612
613
614
615
616
        # For fp32 grads, we need to reset the grads to main grad.
        if self.params_have_main_grad:
            for model_group in self.fp32_from_fp32_groups:
                for model_param in model_group:
                    model_param.grad = model_param.main_grad
mohammad's avatar
mohammad committed
617

618
619
620
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
621
                    if not self.use_contiguous_buffers_in_local_ddp:
622
                        model_param.main_grad = None
mohammad's avatar
mohammad committed
623

624
625
626
627
628
629
630
631
632
633
        # >>>
        # if ITERATION == DEBUG_ITERATION:
        #     pax(0, {
        #         "** branch **" : "** main. **",
        #         "ITERATION" : ITERATION,
        #         "model grads" :
        #         [ p.main_grad for m in self.models for p in m.parameters() ],
        #     })
        # <<<

Lawrence McAfee's avatar
Lawrence McAfee committed
634
635
    def _collect_main_grad_data_for_unscaling(self):

636
        main_grads = []
Lawrence McAfee's avatar
Lawrence McAfee committed
637
638

        # fp32 params from float16 ones.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
639
        for main_group in self.fp32_from_float16_groups:
640
641
642
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
Lawrence McAfee's avatar
Lawrence McAfee committed
643
644
645

        # pax(1, {"main_grads": main_grads})

mohammad's avatar
mohammad committed
646
        # Append fp32 parameters.
647
648
649
650
        for main_group in self.fp32_from_fp32_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
Lawrence McAfee's avatar
Lawrence McAfee committed
651
652
653
654
655
        
        # >>>
        # from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
        # pax(1, {"main_grads": [ (param_is_not_tensor_parallel_duplicate(t), tp(t)) for t in main_grads ]})
        # <<<
mohammad's avatar
mohammad committed
656

Lawrence McAfee's avatar
Lawrence McAfee committed
657
        return main_grads
mohammad's avatar
mohammad committed
658
659


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
660
    def _get_model_and_main_params_data_float16(self):
mohammad's avatar
mohammad committed
661
        model_data = []
662
        main_data = []
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
663
664
        for model_group, main_group in zip(self.float16_groups,
                                           self.fp32_from_float16_groups):
665
            for model_param, main_param in zip(model_group, main_group):
mohammad's avatar
mohammad committed
666
                model_data.append(model_param.data)
667
668
                main_data.append(main_param.data)
        return model_data, main_data
669
670


671
    def _copy_main_params_to_model_params(self, ITERATION):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
672
673
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
674
675
        _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
                                        overflow_buf=self._dummy_overflow_buf)
676
677
678
679
680
        # >>>
        if ITERATION == DEBUG_ITERATION:
            pax(0, {
                "** branch **" : "** main. **",
                "ITERATION" : ITERATION,
681
                "model params" : [p for m in self.models for p in m.parameters()],
682
683
            })
        # <<<
684
685
686


    def _copy_model_params_to_main_params(self):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
687
688
        # Only needed for the float16 params.
        model_data, main_data = self._get_model_and_main_params_data_float16()
689
690
        _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
                                        overflow_buf=self._dummy_overflow_buf)
691
692


mohammad's avatar
mohammad committed
693
694
695
    def state_dict(self):
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
696
697
698
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
mohammad's avatar
mohammad committed
699
700
701
702
        return state_dict


    def load_state_dict(self, state_dict):
mohammad's avatar
mohammad committed
703
704
705
706
707
708
709
710
711
712
713
714
715
        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
            print_rank_0('***WARNING*** found an old checkpoint, will not '
                         'load grad scaler ...')
        else:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
716
717
718
719
720
721
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')
mohammad's avatar
mohammad committed
722

723
        # Copy data for the main params.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
724
725
726
        fp32_from_float16_params_key = 'fp32_from_fp16_params'
        if fp32_from_float16_params_key not in state_dict:
            fp32_from_float16_params_key = 'fp32_from_fp16'
mohammad's avatar
mohammad committed
727
        for current_group, saved_group in zip(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
728
729
                self.fp32_from_float16_groups,
                state_dict[fp32_from_float16_params_key]):
mohammad's avatar
mohammad committed
730
731
732
733
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)


734
# >>>
735
import math
736

737
from megatron import get_args
738

Lawrence McAfee's avatar
Lawrence McAfee committed
739
740
741
742
743
744
745
746
747
748
749
# class ShardIndex:
class Shard:
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.size = end - start
    def normalize(self, start = 0):
        return Shard(start, start + self.size)
    def __str__(self):
        return "%d,%d [%d]" % (self.start, self.end, self.size)

750
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
Lawrence McAfee's avatar
Lawrence McAfee committed
751
752
# class Float16DistributedOptimizer(MegatronOptimizer):
class Float16DistributedOptimizer(BaseFloat16Optimizer):
753

754
    # >>>
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
    # @classmethod
    # def test_reduce_scatter(cls):

    #     torch.manual_seed(mpu.get_data_parallel_rank())
    #     size = (20,)
    #     dtype = torch.float
    #     device = torch.cuda.current_device()
    #     data_parallel_world_size = mpu.get_data_parallel_world_size()
    #     data_parallel_group = mpu.get_data_parallel_group()

    #     input_list = [
    #         # torch.randn(size, dtype = dtype, device = device)
    #         5 * torch.randint(low = 1, high = 3, size = size, dtype = dtype, device = device)
    #         for _ in range(data_parallel_world_size)
    #     ]
    #     output = torch.empty(size, dtype = dtype, device = device)

    #     torch.distributed.reduce_scatter(
    #         output,
    #         input_list,
    #         group = data_parallel_group,
    #     )

    #     if torch.distributed.get_rank() == 0:
    #         print(output)
    #     pax(0, {
    #         "data_parallel_world_size" : data_parallel_world_size,
    #         "data_parallel_group" : data_parallel_group,
    #         "input_list" : input_list,
    #         "output" : tp(output),
    #     })
786
787
    # <<<

788
    @classmethod
Lawrence McAfee's avatar
Lawrence McAfee committed
789
    def get_model_gbuf_param_shard_map(cls, model, dtype, gbuf_world_shard):
790

Lawrence McAfee's avatar
Lawrence McAfee committed
791
792
        # Param shard map.
        param_world_index_map = model._grad_buffer_param_index_map[dtype]
793
        param_shard_map = {}
Lawrence McAfee's avatar
Lawrence McAfee committed
794
        for param, param_world_indexes in param_world_index_map.items():
795

Lawrence McAfee's avatar
Lawrence McAfee committed
796
797
798
            # Shard range.
            param_world_start, param_world_end = param_world_indexes
            param_local_start = max(
799
                0,
Lawrence McAfee's avatar
Lawrence McAfee committed
800
801
802
803
804
805
806
807
                param_world_start - gbuf_world_shard.start)
            param_local_end = min(
                gbuf_world_shard.size,
                param_world_end - gbuf_world_shard.start)

            # Add shard, if within range.
            if param_local_end > param_local_start:
                param_local_shard = Shard(param_local_start, param_local_end)
Lawrence McAfee's avatar
Lawrence McAfee committed
808
809
810
                # param_world_shard = param_local_shard.normalize(param_world_start)
                param_world_shard = param_local_shard.normalize(
                    param_local_start + gbuf_world_shard.start)
811
812
                sub_param_start = max(0, gbuf_world_shard.start-param_world_start)
                sub_param_shard = param_local_shard.normalize(sub_param_start)
Lawrence McAfee's avatar
Lawrence McAfee committed
813
                param_shard_map[param] = {
814
815
816
                    "gbuf_world" : param_world_shard,
                    "gbuf_local" : param_local_shard,
                    "param" : sub_param_shard,
817
818
                }

Lawrence McAfee's avatar
Lawrence McAfee committed
819
        # pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
820
821
822
823

        return param_shard_map

    @classmethod
Lawrence McAfee's avatar
Lawrence McAfee committed
824
    def get_model_gbuf_shard(cls, model, dtype):
825

Lawrence McAfee's avatar
Lawrence McAfee committed
826
827
        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()
828
829

        # Grad buffer shard.
Lawrence McAfee's avatar
Lawrence McAfee committed
830
831
832
833
        grad_buffer = model._grad_buffers[dtype]
        gbuf_size = grad_buffer.numel
        max_gbuf_shard_size = int(math.ceil(gbuf_size / data_parallel_world_size))

834
835
836
837
838
839
840
        gbuf_world_all_shards = []
        for r in range(data_parallel_world_size):
            gbuf_world_start = r * max_gbuf_shard_size
            gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_shard_size)
            gbuf_world_shard = Shard(gbuf_world_start, gbuf_world_end)
            gbuf_world_all_shards.append(gbuf_world_shard)
        gbuf_world_shard = gbuf_world_all_shards[data_parallel_rank]
Lawrence McAfee's avatar
Lawrence McAfee committed
841
842
843
844
845
846
847
848
849
850
851
        gbuf_local_shard = gbuf_world_shard.normalize()

        # Param shards.
        param_shard_map = cls.get_model_gbuf_param_shard_map(model,
                                                             dtype,
                                                             gbuf_world_shard)

        # Altogether.
        data = {
            "local" : gbuf_local_shard,
            "world" : gbuf_world_shard,
852
            "world_all" : gbuf_world_all_shards,
Lawrence McAfee's avatar
Lawrence McAfee committed
853
            "param_map" : param_shard_map,
854
855
        }

Lawrence McAfee's avatar
Lawrence McAfee committed
856
        # pax(1, {"data": data})
857

Lawrence McAfee's avatar
Lawrence McAfee committed
858
        return data
859
860

    @classmethod
Lawrence McAfee's avatar
Lawrence McAfee committed
861
    def get_model_gbuf_shard_map(cls, model):
862
        return {
Lawrence McAfee's avatar
Lawrence McAfee committed
863
            dtype : cls.get_model_gbuf_shard(model, dtype)
864
865
866
            for dtype in model._grad_buffers
        }

Lawrence McAfee's avatar
Lawrence McAfee committed
867
868
    @classmethod
    def get_param_gbuf_map(cls, model_gbuf_shards):
869

Lawrence McAfee's avatar
Lawrence McAfee committed
870
871
872
873
874
875
876
877
878
879
880
881
882
        param_gbuf_map = {}
        for model_index, model_gbuf_shard_map in enumerate(model_gbuf_shards):
            for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
                for param, param_shard_map in gbuf_shard_map["param_map"].items():
                    # assert param not in param_size_map
                    # param_size_map[param] = param_shard_map["local"].size
                    param_gbuf_map[param] = (model_index, dtype)
                    # pax(0, {
                    #     "dtype" : dtype,
                    #     "gbuf_shard_map" : gbuf_shard_map,
                    #     "param" : tp(param),
                    #     "param_shard_map" : param_shard_map,
                    # })
883

Lawrence McAfee's avatar
Lawrence McAfee committed
884
885
886
887
888
889
        # pax(0, {
        #     "model_gbuf_shards" : model_gbuf_shards,
        #     # "param_size_map" :
        #     # [ (str(p.shape), s) for p, s in param_size_map.items() ],
        #     "param_gbuf_map" : param_gbuf_map,
        # })
890

Lawrence McAfee's avatar
Lawrence McAfee committed
891
        return param_gbuf_map
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912

    @classmethod
    def get_optimizer_group_shards(cls, param_groups, model_gbuf_shards):

        num_groups = len(param_groups)

        # Param group map.
        param_group_map = {}
        for group_index, group in enumerate(param_groups):
            for param in group["params"]:
                assert param.requires_grad
                param_group_map[param] = group_index

        # Optimizer group shards.
        group_shards = [ {"size": 0, "param_map": {}} for _ in param_groups ]
        for model_gbuf_shard_map in model_gbuf_shards:
            for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
                for param in gbuf_shard_map["param_map"]:
                    
                    group_index = param_group_map[param]
                    group_shard = group_shards[group_index]
913
                    param_size = gbuf_shard_map["param_map"][param]["param"].size
914
915
916
917
918
919
920
921

                    param_group_start = group_shard["size"]
                    param_group_end = param_group_start + param_size
                    param_group_shard = Shard(param_group_start, param_group_end)

                    group_shard["size"] += param_size
                    group_shard["param_map"][param] = param_group_shard

922
923
924
925
926
927
928
929
930
931
                    # >>>
                    # if torch.distributed.get_rank() == 1:
                    #     print(">>> [%d] ... group %d, size %d, param %s. <<<" % (
                    #         torch.distributed.get_rank(),
                    #         group_index,
                    #         param_size,
                    #         str(tuple(param.shape)),
                    #     ))
                    # <<<

932
933
934
935
936
937
        # Squeeze zero-size group shards.
        for group_index, group_shard in enumerate(group_shards):
            group_shard["orig_group"] = param_groups[group_index]
        group_shards = [ g for g in group_shards if g["size"] > 0 ]

        # pax(0, {
938
939
940
941
942
943
        #     "param_group_map": [
        #         (g, str(p.shape))
        #         for p, g in param_group_map.items()
        #     ],
        #     "group_shards" : group_shards,
        # })
944
945
946

        return group_shards

947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
    @classmethod
    def allocate_main_param_shards(cls, opt_group_shards):

        # Allocate main param/grad shard.
        # ** torch.nn.Parameter ??
        # ** MemoryBuffer ??
        allocate_shard = lambda shard_size, dtype : torch.empty(
            (shard_size,),
            dtype = dtype,
            device = torch.cuda.current_device(),
            requires_grad = True)
        
        # main_param_shards = []
        for group_index, group_shard in enumerate(opt_group_shards):

            group_size = group_shard["size"]
            assert group_size != 0, "temporary check ... remove me."

            # ** todo: for dtype in model_main_dtypes ........ **

            # Allocate shard.
            # if group_size == 0:
            #     main_param = None
            # else:
            main_param = allocate_shard(group_size, torch.float)
            main_param.grad = allocate_shard(group_size, torch.float)
            mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)

            # main_param_shards.append(main_param)
            group_shard["orig_group"]["params"] = [ main_param ]

            # # Update optimizer group.
            # self.optimizer.param_groups[group_index]["params"] = [ main_param ]

        # pax(1, {
        #     "opt_group_shards" : opt_group_shards,
        #     "main_param_shards" : main_param_shards,
        # })

        # return main_param_shards

988
989
    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
990
                 bf16, grad_scaler, models):
991
992
993

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
Lawrence McAfee's avatar
Lawrence McAfee committed
994
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
995
            bf16, grad_scaler, models)
996

997
998
        # >>>
        args = get_args()
999
        assert args.use_contiguous_buffers_in_local_ddp # already checked in args
1000
        # <<<
1001

Lawrence McAfee's avatar
Lawrence McAfee committed
1002
1003
1004
1005
        # # Data parallel info.
        # self.data_parallel_group = mpu.get_data_parallel_group()
        # self.data_parallel_rank = mpu.get_data_parallel_rank()
        # self.data_parallel_world_size = mpu.get_data_parallel_world_size()
1006

1007
1008
1009
1010
        # Model grad buffer shards.
        self.model_gbuf_shards = []
        for model_index, model in enumerate(self.models):
            self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model))
Lawrence McAfee's avatar
Lawrence McAfee committed
1011
        self.param_gbuf_map = self.get_param_gbuf_map(self.model_gbuf_shards)
1012

1013
1014
        # pax(0, {"param_gbuf_map": [ (str(tuple(p.shape)), d) for p, d in self.param_gbuf_map.items() ]})

1015
1016
1017
1018
1019
        # Optimizer shards.
        self.opt_group_shards = self.get_optimizer_group_shards(
            self.optimizer.param_groups,
            self.model_gbuf_shards)

1020
        # pax(0, {**{"opt_group_shards / %d" % i : g for i, g in enumerate(self.opt_group_shards)}})
Lawrence McAfee's avatar
Lawrence McAfee committed
1021

1022
1023
1024
1025
        # Allocate main param shards.
        # self.main_param_shards = \
        #     self.allocate_main_param_shards(self.opt_group_shards)
        self.allocate_main_param_shards(self.opt_group_shards)
1026

1027
        # >>>
1028
1029
1030
1031
1032
        # pax(0, {
        #     "model_gbuf_shards" : self.model_gbuf_shards,
        #     "opt_group_shards" : self.opt_group_shards,
        #     "main_param_shards" : self.main_param_shards,
        # })
1033
1034
        # <<<

1035
1036
1037
1038
1039
        # Update optimizer groups.
        # - Also, leverage state_dict() and load_state_dict() to
        #   recast preexisting per-param state tensors.
        self.optimizer.param_groups = \
            [ g["orig_group"] for g in self.opt_group_shards ]
Lawrence McAfee's avatar
Lawrence McAfee committed
1040
1041
        self.optimizer.load_state_dict(self.optimizer.state_dict())

1042
1043
1044
1045
1046
1047
        # pax(0, {
        #     # "opt_group_shards" : self.opt_group_shards,
        #     # "param_groups" : self.optimizer.param_groups,
        #     "optimizer" : self.optimizer,
        #     "optimizer / state" : self.optimizer.state,
        # })
1048
        # pax(1, {
1049
1050
1051
1052
1053
        #     "optimizer" : self.optimizer,
        #     **{"optimizer / param_groups / %d" % i : g
        #        for i, g in enumerate(self.optimizer.param_groups)},
        #     "optimizer / state" : self.optimizer.state,
        #     "optimizer / state_dict" : self.optimizer.state_dict(),
1054
1055
1056
1057
        # })

        # Initialize main params.
        self._copy_model_params_to_main_params()
1058

1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
    @staticmethod
    def has_nan_debug(tensors):
        if isinstance(tensors, torch.Tensor):
            tensors = [ tensors ]
        assert isinstance(tensors, list)
        has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
        has_nan = any(has_nans)
        return has_nan
    def get_local_model_param_views(self):
        '''** FOR DEBUGGING. **'''
        model_param_views = []
        for group_index, opt_group_shard in enumerate(self.opt_group_shards):
            for param, opt_shard in opt_group_shard["param_map"].items():
                model_index, dtype = self.param_gbuf_map[param]
                gbuf_shard_map = \
                    self.model_gbuf_shards[model_index][dtype]["param_map"][param]
                model_param_shard = gbuf_shard_map["param"]
                model_param_views.append(
                    param.view(-1)[model_param_shard.start:model_param_shard.end])
        return model_param_views
    def get_local_model_grad_views(self):
        '''** FOR DEBUGGING. **'''
        model_grad_views = []
        for group_index, opt_group_shard in enumerate(self.opt_group_shards):
            for param, opt_shard in opt_group_shard["param_map"].items():
                model_index, dtype = self.param_gbuf_map[param]
                gbuf = self.models[model_index]._grad_buffers[dtype].data
                gbuf_shard_map = \
                    self.model_gbuf_shards[model_index][dtype]["param_map"][param]
                gbuf_world_shard = gbuf_shard_map["gbuf_world"]
                model_grad_views.append(
                    gbuf[gbuf_world_shard.start:gbuf_world_shard.end])
        return model_grad_views
    def get_world_model_params(self):
        '''** FOR DEBUGGING. **'''
        return [ p for m in self.models for p in m.parameters() ]
1095
1096
1097
    def get_world_model_grads(self):
        '''** FOR DEBUGGING. **'''
        return [ p.main_grad for p in self.get_world_model_params() ]
1098
1099
1100
1101
1102

    def get_main_params(self):
        return [ g["params"][0] for g in self.optimizer.param_groups ]
    def get_main_grads(self):
        return [ p.grad for p in self.get_main_params() ]
1103
    def get_main_param(self, group_index):
1104
1105
        # return self.optimizer.param_groups[group_index]["params"][0]
        return self.get_main_params()[group_index]
1106
1107
1108
    def get_main_grad(self, group_index):
        return self.get_main_param(group_index).grad

1109
1110
1111
1112
1113
1114
    def load_state_dict(self):
        raise Exception("hi.")
    def reload_model_params(self):
        raise Exception("hi.")
    def state_dict(self):
        raise Exception("hi.")
Lawrence McAfee's avatar
Lawrence McAfee committed
1115
1116
1117

    def zero_grad(self, set_to_none=True):

Lawrence McAfee's avatar
Lawrence McAfee committed
1118
        model_params = []
Lawrence McAfee's avatar
Lawrence McAfee committed
1119
1120
        for model in self.models:
            for dtype, param_map in model._grad_buffer_param_index_map.items():
Lawrence McAfee's avatar
Lawrence McAfee committed
1121
1122
1123
1124
                model_params.extend(param_map.keys())
        # main_params = []
        # for main_group in self.optimizer.param_groups:
        #     main_params.extend(main_group["params"])
Lawrence McAfee's avatar
Lawrence McAfee committed
1125

1126
1127
        # ** using contiguous buffer; don't set_to_none **
        _zero_grad_group_helper(model_params, set_to_none = False) # set_to_none)
Lawrence McAfee's avatar
Lawrence McAfee committed
1128
        # _zero_grad_group_helper(params, set_to_none = False)
Lawrence McAfee's avatar
Lawrence McAfee committed
1129

1130
        # pax(0, {"model_params": model_params})
1131

1132
1133
    def get_model_grad_buffer_dp_views(self):

Lawrence McAfee's avatar
Lawrence McAfee committed
1134
        # >>>
1135
1136
1137
        # ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
        args = get_args()
        assert args.use_contiguous_buffers_in_local_ddp
Lawrence McAfee's avatar
Lawrence McAfee committed
1138
        # <<<
1139
1140
1141
1142
1143
1144

        # Grad buffer views.
        gbuf_view_items = []
        for model_index, model in enumerate(self.models):
            for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
                world_shards = gbuf_shard["world_all"]
1145
1146
                gbuf = model._grad_buffers[dtype].data
                gbuf_views = [ gbuf[s.start:s.end] for s in world_shards ]
1147
1148
                gbuf_view_items.append((model_index, dtype, gbuf_views))

1149
1150
1151
1152
1153
                # pax(0, {
                #     "world_shards" : world_shards,
                #     "gbuf_views" : gbuf_views,
                # })

1154
1155
1156
1157
        # pax(0, {"gbuf_view_items": gbuf_view_items})

        return gbuf_view_items

1158
    def reduce_grads(self, model):
1159

1160
1161
1162
1163
        # >>>
        timers = get_timers()
        # <<<

Lawrence McAfee's avatar
Lawrence McAfee committed
1164
1165
1166
1167
1168
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Sync word embedding params.

        # ... todo ...

1169
1170
1171
1172
1173
1174
1175
        # All-reduce word_embeddings' grad across first and last stages to ensure
        # that word_embeddings parameters stay in sync.
        # This should only run for models that support pipelined model parallelism
        # (BERT and GPT-2).
        timers('backward-embedding-all-reduce').start()
        if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
                mpu.get_pipeline_model_parallel_world_size() > 1:
1176
1177
1178
            # >>>
            raise Exception("hi.")
            # <<<
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
            if mpu.is_pipeline_first_stage(ignore_virtual=True):
                unwrapped_model = model[0]
            elif mpu.is_pipeline_last_stage(ignore_virtual=True):
                unwrapped_model = model[-1]
            else:  # We do not support the interleaved schedule for T5 yet.
                unwrapped_model = model[0]
            unwrapped_model = unwrap_model(
                unwrapped_model, (torchDDP, LocalDDP, Float16Module))

            if unwrapped_model.share_word_embeddings:
                word_embeddings_weight = unwrapped_model.word_embeddings_weight()
                # >>>
                if args.DDP_impl == 'local':
                    grad = word_embeddings_weight.main_grad
                else:
1194
                    raise Exception("only 'main_grad' supported for distrib-opt.")
1195
1196
1197
1198
1199
1200
1201
                    grad = word_embeddings_weight.grad
                torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
                # +++
                # grad_shard = optimizer.get_grad_shard(word_embeddings)
                # torch.distributed.all_reduce(grad_shard,
                #                              group=mpu.get_embedding_group())
                # <<<
1202
        timers('backward-embedding-all-reduce').stop()
1203

Lawrence McAfee's avatar
Lawrence McAfee committed
1204
1205
1206
1207
1208
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Sync T5 position embedding params.

        # ... todo ...

1209
1210
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Reduce-scatter.
Lawrence McAfee's avatar
Lawrence McAfee committed
1211
        data_parallel_rank = mpu.get_data_parallel_rank()
1212
        data_parallel_world_size = mpu.get_data_parallel_world_size()
Lawrence McAfee's avatar
Lawrence McAfee committed
1213
        data_parallel_group = mpu.get_data_parallel_group()
1214

1215
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
Lawrence McAfee's avatar
Lawrence McAfee committed
1216

1217
1218
        # pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})

1219
        for model_index, dtype, gbuf_views in gbuf_view_items:
1220
1221
1222
1223
1224
1225
1226
1227
1228
            # coalesced /= mpu.get_data_parallel_world_size()
            gbuf = self.models[model_index]._grad_buffers[dtype].data
            torch.mul(gbuf.data, 1. / data_parallel_world_size, out = gbuf.data)
            # gbuf_views = [ t / data_parallel_world_size for t in gbuf_views ]
            # gbuf_d
            # pax(0, {
            #     "data_parallel_world_size" : data_parallel_world_size,
            #     "gbuf" : tp(gbuf),
            # })
1229
1230
1231
1232
1233
1234
            torch.distributed.reduce_scatter(
                gbuf_views[data_parallel_rank],
                gbuf_views,
                group = data_parallel_group,
            )
            
1235
        # pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
Lawrence McAfee's avatar
Lawrence McAfee committed
1236

1237
    def gather_params(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
1238

1239
1240
        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_group = mpu.get_data_parallel_group()
1241

1242
1243
        gbuf_view_items = self.get_model_grad_buffer_dp_views()

Lawrence McAfee's avatar
Lawrence McAfee committed
1244
        # All-gather updated main params.
1245
1246
1247
1248
1249
1250
1251
        for model_index, dtype, gbuf_views in gbuf_view_items:
            torch.distributed.all_gather(
                gbuf_views,
                gbuf_views[data_parallel_rank],
                group = data_parallel_group,
            )

1252
        # Each model param now contains its updated values in its
Lawrence McAfee's avatar
Lawrence McAfee committed
1253
        # '.main_grad' field.
1254
1255
        for param in self.param_gbuf_map:
            param.detach().copy_(param.main_grad)
1256

1257
        # pax(0, {"gbuf_view_items": gbuf_view_items})
1258

Lawrence McAfee's avatar
Lawrence McAfee committed
1259
    def _collect_main_grad_data_for_unscaling(self):
1260
        return [ g.data for g in self.get_main_grads() ]
Lawrence McAfee's avatar
Lawrence McAfee committed
1261

1262
1263
1264
    def _copy_model_params_to_main_params(self):

        for group_index, group_shard in enumerate(self.opt_group_shards):
1265
            main_param = self.get_main_param(group_index)
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
            for model_param, main_shard in group_shard["param_map"].items():

                # Model shard.
                model_index, dtype = self.param_gbuf_map[model_param]
                model_shard = self.model_gbuf_shards \
                    [model_index][dtype]["param_map"][model_param]["param"]

                assert main_shard.size == model_shard.size

                # Copy shard data.
                main_view = main_param[main_shard.start:main_shard.end]
1277
                model_view = model_param.view(-1)[model_shard.start:model_shard.end]
1278

1279
                main_view.detach().copy_(model_view)
1280

1281

1282
1283
    def _copy_model_grads_to_main_grads(self, ITERATION):

Lawrence McAfee's avatar
Lawrence McAfee committed
1284
        for group_index, group_shard in enumerate(self.opt_group_shards):
1285
            for model_param, main_shard in group_shard["param_map"].items():
Lawrence McAfee's avatar
Lawrence McAfee committed
1286

1287
                # Model shard.
1288
                model_index, dtype = self.param_gbuf_map[model_param]
Lawrence McAfee's avatar
Lawrence McAfee committed
1289
                model_shard = self.model_gbuf_shards \
1290
                    [model_index][dtype]["param_map"][model_param]["gbuf_world"]
Lawrence McAfee's avatar
Lawrence McAfee committed
1291
1292
1293

                assert main_shard.size == model_shard.size

1294
1295
1296
1297
1298
1299
1300
                # pax(0, {
                #     "model_param" : tp(model_param),
                #     "main_shard" : str(main_shard),
                #     "param shard" : self.model_gbuf_shards \
                #     [model_index][dtype]["param_map"][model_param],
                # })

Lawrence McAfee's avatar
Lawrence McAfee committed
1301
                # Copy from DDP's contiguous buffer to main shard's grad.
1302
                model_grad = self.models[model_index]._grad_buffers[dtype].data
1303
                main_grad = self.get_main_grad(group_index)
Lawrence McAfee's avatar
Lawrence McAfee committed
1304

Lawrence McAfee's avatar
Lawrence McAfee committed
1305
                # Copy sub-range within tensor.
1306
1307
                model_view = model_grad[model_shard.start:model_shard.end]
                main_view = main_grad[main_shard.start:main_shard.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
1308

1309
                main_view.detach().copy_(model_view)
Lawrence McAfee's avatar
Lawrence McAfee committed
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324

                # pax(0, {
                #     "group_index" : group_index,
                #     "group_shard" : group_shard,
                #     "param" : tp(param),
                #     "model_index" : model_index,
                #     "gbuf_dtype" : str(gbuf_dtype),
                #     "model_grad_tensor" : tp(model_grad_tensor),
                #     "main_grad_tensor" : tp(main_grad_tensor),
                #     "model_grad_view" : tp(model_grad_view),
                #     "main_grad_view" : tp(main_grad_view),
                #     "model_shard" : str(model_shard),
                #     "main_shard" : str(main_shard),
                # })

Lawrence McAfee's avatar
Lawrence McAfee committed
1325
        # >>>
1326
1327
1328
1329
1330
1331
1332
        # if ITERATION == DEBUG_ITERATION:
        #     pax(0, {
        #         "** branch **" : "** fix. **",
        #         "ITERATION" : ITERATION,
        #         # "model grads" : self.get_world_model_grads(),
        #         "main_grads" : self.get_main_grads(),
        #     })
Lawrence McAfee's avatar
Lawrence McAfee committed
1333
        # <<<
Lawrence McAfee's avatar
Lawrence McAfee committed
1334

1335

1336
    def _copy_main_params_to_model_params(self, ITERATION):
1337
1338

        for group_index, group_shard in enumerate(self.opt_group_shards):
1339
            for model_param, main_shard in group_shard["param_map"].items():
1340

1341
                model_index, dtype = self.param_gbuf_map[model_param]
1342
                model_shard = self.model_gbuf_shards \
1343
                    [model_index][dtype]["param_map"][model_param]["gbuf_world"]
1344
1345
1346
1347

                assert main_shard.size == model_shard.size

                # Use DDP's contiguous buffer to temporarily hold params.
1348
                model_param = self.models[model_index]._grad_buffers[dtype].data
1349
                main_param = self.get_main_param(group_index)
1350
1351

                # Copy sub-range within tensor.
1352
1353
                model_view = model_param[model_shard.start:model_shard.end]
                main_view = main_param[main_shard.start:main_shard.end]
1354
1355
1356
1357

                model_view.detach().copy_(main_view)

                # Debug.
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
                # pax(1, {
                #     "group_index" : group_index,
                #     "group_shard" : group_shard,
                #     "model_param" : tp(model_param),
                #     "model_index" : model_index,
                #     "dtype" : str(dtype),
                #     "model_param" : tp(model_param),
                #     "main_param" : tp(main_param),
                #     "model_view" : tp(model_view),
                #     "main_view" : tp(main_view),
                #     "model_shard" : str(model_shard),
                #     "main_shard" : str(main_shard),
                # })
1371

Lawrence McAfee's avatar
Lawrence McAfee committed
1372
        # >>>
1373
1374
1375
1376
        if ITERATION == DEBUG_ITERATION:
            pax(0, {
                "** branch **" : "** fix. **",
                "ITERATION" : ITERATION,
1377
                "model params" : self.get_world_model_params(),
1378
            })
Lawrence McAfee's avatar
Lawrence McAfee committed
1379
        # <<<
1380

1381
1382
# <<<

mohammad's avatar
mohammad committed
1383

mohammad's avatar
mohammad committed
1384
1385
class FP32Optimizer(MegatronOptimizer):

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
1386
1387
    def __init__(self, optimizer, clip_grad,
                 log_num_zeros_in_grad,
1388
                 params_have_main_grad,
1389
                 use_contiguous_buffers_in_local_ddp):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
1390
1391
1392

        super(FP32Optimizer, self).__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
1393
            params_have_main_grad, use_contiguous_buffers_in_local_ddp)
mohammad's avatar
mohammad committed
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411

        self._scale = torch.cuda.FloatTensor([1.0])


    def zero_grad(self, set_to_none=True):
        """Copied from torch.optim.optimizer"""
        for group in self.optimizer.param_groups:
            _zero_grad_group_helper(group['params'], set_to_none)


    def get_loss_scale(self):
        """FP32 optimizer does not do any scaling."""
        return self._scale


    @torch.no_grad()
    def step(self):
        """Clip gradients (if needed) and step the base optimizer.
mohammad's avatar
mohammad committed
1412
        Always return successful since there is no overflow."""
mohammad's avatar
mohammad committed
1413

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
1414
1415
1416
1417
1418
1419
        # Copy main_grads to grads.
        if self.params_have_main_grad:
            for param_group in self.optimizer.param_groups:
                for param in param_group['params']:
                    param.grad = param.main_grad

1420
1421
1422
                    # Safe to de-reference model's main_grad after copying.
                    # (If using contiguous buffers, main_grad's memory should
                    # persist and therefore should not be deallocated.)
1423
                    if not self.use_contiguous_buffers_in_local_ddp:
1424
1425
                        param.main_grad = None

mohammad's avatar
mohammad committed
1426
        # Clip gradients.
1427
        grad_norm = None
mohammad's avatar
mohammad committed
1428
        if self.clip_grad > 0.0:
1429
            grad_norm = self.clip_grad_norm(self.clip_grad)
mohammad's avatar
mohammad committed
1430

Rewon Child's avatar
Rewon Child committed
1431
        # count the zeros in the grads
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
1432
1433
        num_zeros_in_grad = self.count_zeros() if \
                            self.log_num_zeros_in_grad else None
Rewon Child's avatar
Rewon Child committed
1434

mohammad's avatar
mohammad committed
1435
1436
1437
1438
        # Update parameters.
        self.optimizer.step()

        # No overflow for FP32 optimizer.
1439
        return True, grad_norm, num_zeros_in_grad
mohammad's avatar
mohammad committed
1440
1441


1442
1443
1444
1445
    def reload_model_params(self):
        pass


mohammad's avatar
mohammad committed
1446
1447
1448
1449
1450
1451
    def state_dict(self):
        return self.optimizer.state_dict()


    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)