distrib_optimizer.py 27.9 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6

"""Megatron distributed optimizer."""


import math
7
import torch
8
9

from megatron import get_args
10
from megatron import get_timers
Lawrence McAfee's avatar
Lawrence McAfee committed
11
from megatron import print_rank_0
12
from megatron.core import mpu, tensor_parallel
13
from megatron.model.module import param_is_not_shared
14
15

from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
16

Lawrence McAfee's avatar
Lawrence McAfee committed
17

18
class Range:
19
20
    """
    A range represents a start and end points for indexing a shard
Lawrence McAfee's avatar
Lawrence McAfee committed
21
    from a full tensor.
22
    """
23
24
25
26
27
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.size = end - start
    def normalize(self, start = 0):
28
        return Range(start, start + self.size)
29
30
31
32
    def __str__(self):
        return "%d,%d [%d]" % (self.start, self.end, self.size)


33
class DistributedOptimizer(MixedPrecisionOptimizer):
34
    """Distributed optimizer, for all data types (fp16, bf16, and fp32).
Lawrence McAfee's avatar
Lawrence McAfee committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
61
    """
62
63

    @classmethod
64
    def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
        """
        Build mapping from param reference to grad buffer shard ranges.

        This method builds a mapping from parameter references to grad
        buffer shard ranges, specific to each data-parallel (DP) rank's
        set of 'owned' parameters. Each grad buffer (padded to be an even
        multiple of DP-world-size) is conceptually divided into DP-world-size
        contiguous regions, where each DP rank 'owns' a contiguous regions.
        Ownership in this sense means DP rank is responsible for reducing
        the relevant subset of grads, and updating the relevant subset of
        params.

        This conceptual partitioning of the grad buffer does NOT respect
        parameter boundaries, and as such it is assumed that each created
        range references a shard (or subset) of the full parameter. It is
        easiest to think of each DP rank as operating (i.e., reducing,
        gathering) purely on views into the grad buffer, for all model-to-
        main & main-to-model operations.

        This method creates three ranges:
        - The param's range within the entire grad buffer (i.e., world index).
        - The param's range within the DP rank's local view of the grad buffer.
        - The param's range within itself (i.e., its shard).
        """
89

90
        # Param range map.
91
        param_world_index_map = model._grad_buffer_param_index_map[dtype]
92
        param_range_map = {}
93
94
        for param, param_world_indexes in param_world_index_map.items():

95
            # Param range.
96
97
98
            param_world_start, param_world_end = param_world_indexes
            param_local_start = max(
                0,
99
                param_world_start - gbuf_world_range.start)
100
            param_local_end = min(
101
102
                gbuf_world_range.size,
                param_world_end - gbuf_world_range.start)
103

104
            # Add param, if within local gbuf range.
105
            if param_local_end > param_local_start:
106
107
108
109
110
111
112
113
114
                param_local_range = Range(param_local_start, param_local_end)
                param_world_range = param_local_range.normalize(
                    param_local_start + gbuf_world_range.start)
                sub_param_start = max(0, gbuf_world_range.start-param_world_start)
                sub_param_range = param_local_range.normalize(sub_param_start)
                param_range_map[param] = {
                    "gbuf_world" : param_world_range,
                    "gbuf_local" : param_local_range,
                    "param" : sub_param_range,
115
116
                }

117
        return param_range_map
118

Lawrence McAfee's avatar
Lawrence McAfee committed
119

120
    @classmethod
121
    def build_model_gbuf_range(cls, model, dtype):
122
123
124
125
126
127
128
129
130
        """
        Build mapping between params and their grad buffers.

        This method does the initial setup for the method above. This setup
        includes determining the shard ranges into the DDP's grad buffer for
        each data-parallel (DP) rank. Each DP rank keeps range info for
        all other DP ranks, for the purpose of creating args for
        reduce-scatter and all-gather.
        """
131
132
133
134

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()

135
        # Grad buffer range.
136
137
        grad_buffer = model._grad_buffers[dtype]
        gbuf_size = grad_buffer.numel
138
        max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size))
139

140
141
        # All world ranges. (i.e., across all data parallel ranks)
        gbuf_world_all_ranges = []
142
        for r in range(data_parallel_world_size):
143
144
145
146
            gbuf_world_start = r * max_gbuf_range_size
            gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size)
            gbuf_world_range = Range(gbuf_world_start, gbuf_world_end)
            gbuf_world_all_ranges.append(gbuf_world_range)
147

148
149
150
        # Local DP's ranges.
        gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
        gbuf_local_range = gbuf_world_range.normalize()
151

152
153
154
155
        # Get each param's ranges.
        param_range_map = cls.build_model_gbuf_param_range_map(model,
                                                               dtype,
                                                               gbuf_world_range)
156

157
        # Group into dict.
158
        data = {
159
160
161
162
163
            "local" : gbuf_local_range,
            "world" : gbuf_world_range,
            "world_all" : gbuf_world_all_ranges,
            "param_map" : param_range_map,
            "max_range_size" : max_gbuf_range_size,
164
165
166
167
        }

        return data

Lawrence McAfee's avatar
Lawrence McAfee committed
168

169
    @classmethod
170
    def build_model_gbuf_range_map(cls, model):
171
172
173
174
        """
        Create param-to-grad-buffer mappings, for grad buffer data types
        within a specific virtual model.
        """
175
        return {
176
            dtype : cls.build_model_gbuf_range(model, dtype)
177
178
179
            for dtype in model._grad_buffers
        }

Lawrence McAfee's avatar
Lawrence McAfee committed
180

181
    @classmethod
182
    def build_model_param_gbuf_map(cls, model_gbuf_ranges):
183
184
185
186
        """
        Create a reverse of the model_gbuf_ranges, for referencing in
        opposite direction.
        """
187
        param_gbuf_map = {}
188
189
190
        for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param, param_range_map in gbuf_range_map["param_map"].items():
191
192
193
                    param_gbuf_map[param] = (model_index, dtype)
        return param_gbuf_map

Lawrence McAfee's avatar
Lawrence McAfee committed
194

195
    @classmethod
196
    def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
197
198
199
200
201
202
203
204
        """
        Create optimizer groups.

        Given the set of parameter shard ranges that are owned by the current
        data-parallel (DP) rank, gather the set of parameters that will be
        used (in the method below) to create the current DP's optimizer
        groups.
        """
205
206
207
208
209
210
211
212
213
214

        num_groups = len(param_groups)

        # Param group map.
        param_group_map = {}
        for group_index, group in enumerate(param_groups):
            for param in group["params"]:
                assert param.requires_grad
                param_group_map[param] = group_index

215
216
217
218
219
        # Optimizer group ranges.
        group_ranges = [ {"params": []} for _ in param_groups ]
        for model_gbuf_range_map in model_gbuf_ranges:
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param in gbuf_range_map["param_map"]:
220
                    group_index = param_group_map[param]
221
222
                    group_range = group_ranges[group_index]
                    group_range["params"].append(param)
223

224
225
226
227
        # Squeeze zero-size group ranges.
        for group_index, group_range in enumerate(group_ranges):
            group_range["orig_group"] = param_groups[group_index]
        group_ranges = [ g for g in group_ranges if len(g["params"]) > 0 ]
228

229
        return group_ranges
230

231

232
    @classmethod
233
    def build_model_and_main_param_groups(cls,
234
235
236
237
238
239
240
241
242
243
244
245
246
                                          model_gbuf_ranges,
                                          param_gbuf_map,
                                          opt_group_ranges):
        """
        Create main parameter groups needed for the optimizer step.

        These groups encompass both: 1) groups used by this class, for
        reducing/gather, and 2) groups used by the inner optimizer for the
        parameter update. Given that the conceptual grad buffer partitioning
        (created in earlier method) doesn't respect parameter boundaries,
        the optimizer operates on shards of the model parameters, rather than
        the full parameters.
        """
247

Lawrence McAfee's avatar
Lawrence McAfee committed
248
249
250
251
252
253
254
255
        # Parameter groups:
        #   model_float16_groups: original float16 parameters
        #   model_fp32_groups: original fp32 parameters
        #   shard_float16_groups: shards of original float16 parameters
        #   shard_fp32_groups: shards of original fp32 parameters
        #   shard_fp32_from_float16_groups: fp32 copy of float16 parameters
        model_float16_groups = []
        model_fp32_groups = []
256
257
258
259
        shard_float16_groups = []
        shard_fp32_groups = []
        shard_fp32_from_float16_groups = []

Lawrence McAfee's avatar
Lawrence McAfee committed
260
        # Allocate (or slice) each group's param shard.
261
262
263
        for group_index, group_range in enumerate(opt_group_ranges):

            # Params of this group.
Lawrence McAfee's avatar
Lawrence McAfee committed
264
265
            model_float16_params_this_group = []
            model_fp32_params_this_group = []
266
267
268
            shard_float16_params_this_group = []
            shard_fp32_params_this_group = []
            shard_fp32_from_float16_params_this_group = []
Lawrence McAfee's avatar
Lawrence McAfee committed
269
270
            model_float16_groups.append(model_float16_params_this_group)
            model_fp32_groups.append(model_fp32_params_this_group)
271
272
273
274
275
276
277
            shard_float16_groups.append(shard_float16_params_this_group)
            shard_fp32_groups.append(shard_fp32_params_this_group)
            shard_fp32_from_float16_groups.append(
                shard_fp32_from_float16_params_this_group)

            for model_param in group_range["params"]:

278
279
                assert model_param.requires_grad

280
281
282
                model_index, dtype = param_gbuf_map[model_param]
                gbuf_range = model_gbuf_ranges[model_index][dtype]
                param_range = gbuf_range["param_map"][model_param]["param"]
283
284

                # fp16, bf16 params.
285
286
287
288
                if model_param.type() in ['torch.cuda.HalfTensor',
                                          'torch.cuda.BFloat16Tensor']:

                    # Clone model -> main.
Lawrence McAfee's avatar
Lawrence McAfee committed
289
290
                    shard_model_param = model_param.detach().view(-1) \
                        [param_range.start:param_range.end]
291
                    shard_main_param = shard_model_param.clone().float()
292
                    tensor_parallel.copy_tensor_model_parallel_attributes(
293
                        shard_model_param, model_param)
294
                    tensor_parallel.copy_tensor_model_parallel_attributes(
295
296
297
298
299
300
                        shard_main_param, model_param)
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
                        shard_main_param.shared = model_param.shared

                    # Add to group.
Lawrence McAfee's avatar
Lawrence McAfee committed
301
                    model_float16_params_this_group.append(model_param)
302
303
                    shard_float16_params_this_group.append(shard_model_param)
                    shard_fp32_from_float16_params_this_group.append(shard_main_param)
304
305

                # fp32 params.
306
                elif model_param.type() == 'torch.cuda.FloatTensor':
Lawrence McAfee's avatar
Lawrence McAfee committed
307
308
                    shard_model_param = model_param.view(-1) \
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
309
                    model_fp32_params_this_group.append(model_param)
310
                    shard_fp32_params_this_group.append(shard_model_param)
311
                    tensor_parallel.copy_tensor_model_parallel_attributes(
312
                        shard_model_param, model_param)
313
314
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
315
316
317
318
319
320
321
322

                else:
                    raise TypeError('Wrapped parameters must be one of '
                                    'torch.cuda.FloatTensor,  '
                                    'torch.cuda.HalfTensor, or '
                                    'torch.cuda.BFloat16Tensor. '
                                    'Received {}'.format(param.type()))

Lawrence McAfee's avatar
Lawrence McAfee committed
323
            # Update optimizer's params.
324
325
326
327
328
329
            group_range["orig_group"]["params"] = [
                *shard_fp32_params_this_group,
                *shard_fp32_from_float16_params_this_group,
            ]

        return (
Lawrence McAfee's avatar
Lawrence McAfee committed
330
331
            model_float16_groups,
            model_fp32_groups,
332
333
334
335
            shard_float16_groups,
            shard_fp32_groups,
            shard_fp32_from_float16_groups,
        )
336

Lawrence McAfee's avatar
Lawrence McAfee committed
337

338
339
    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
340
                 fp16, bf16, grad_scaler, models):
341
342
        """
        See top of class definition for argument descriptions.
343
344
345
346
347
348

        The steps in this method create the core mapping between DDP grad
        buffers, parameters, and parameter shard ranges, that is needed for
        converting between model param indexes and main parameter shard
        indexes. This method also updates the optimizer parameter groups
        with the newly created shards.
349
        """
350
351
352
353

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
354
            fp16, bf16, grad_scaler, models)
355

356
357
        # Verify that contiguous buffers are being used.
        # - Note: this should already be checked in arguments.py.
358
        assert use_contiguous_buffers_in_local_ddp
359

360
361
        # Model grad buffer ranges.
        self.model_gbuf_ranges = []
362
        for model_index, model in enumerate(self.models):
363
364
365
            self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model))
        self.model_param_gbuf_map = \
            self.build_model_param_gbuf_map(self.model_gbuf_ranges)
366

367
368
        # Optimizer ranges.
        self.opt_group_ranges = self.build_optimizer_group_ranges(
369
            self.optimizer.param_groups,
370
            self.model_gbuf_ranges)
371
372

        # Allocate main param shards.
373
        (
Lawrence McAfee's avatar
Lawrence McAfee committed
374
375
            self.model_float16_groups,
            self.model_fp32_groups,
376
377
378
379
380
381
382
            self.shard_float16_groups,
            self.shard_fp32_groups,
            self.shard_fp32_from_float16_groups,
        ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges,
                                                   self.model_param_gbuf_map,
                                                   self.opt_group_ranges)

383
384
385
386
        # Update optimizer groups.
        # - Also, leverage state_dict() and load_state_dict() to
        #   recast preexisting per-param state tensors.
        self.optimizer.param_groups = \
387
            [ g["orig_group"] for g in self.opt_group_ranges ]
388
389
        self.optimizer.load_state_dict(self.optimizer.state_dict())

390

391
    def get_model_param_range_map(self, param):
392
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
393
394
        Given a model param, get the index sub-range of the param that this
        data-parallel rank owns.
395
        """
396
397
398
399
400
401
        model_index, dtype = self.model_param_gbuf_map[param]
        gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
        param_range_map = gbuf_range_map["param_map"][param]
        return param_range_map


402
    def get_model_parallel_group(self):
403
404
405
406
        """
        With the distributed optimizer, the model parallel group is the
        entire world.
        """
407
408
        return None

409
410

    def state_dict(self):
411
412
413
        """
        The state dict must contain the fp32-from-float16 shards.
        """
414
415
416
417
418
419
420
421
422
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['shard_fp32_from_float16_groups'] = \
            self.shard_fp32_from_float16_groups
        return state_dict


423
    def load_state_dict(self, state_dict):
424
425
426
        """
        Load the state dict.
        """
427
428
429
430
431
432
433
434
435
436
437

        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
438
439
440
            if self.fp16:
                print_rank_0('***WARNING*** found an old checkpoint, will not '
                             'load grad scaler ...')
441
442
443
444
445
446
447
448
449
450
451
452
453
454
        else:
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')

        # Copy data for the main params.
        for current_group, saved_group in zip(
                self.shard_fp32_from_float16_groups,
                state_dict["shard_fp32_from_float16_groups"]):
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)
Lawrence McAfee's avatar
Lawrence McAfee committed
455

456

457
    def zero_grad(self, set_to_none=True):
458
459
460
461
462
463
        """
        Zero grads.

        We only need to zero the model related parameters, i.e.,
        model_float16_groups & model_fp32_groups. We additionally zero
        the remaining groups as a memory optimization to reduce
464
        fragmentation; in the case of set_to_none==True, the space
465
466
        used by this field can be safely deallocated at this point.
        """
467
        for groups in (
Lawrence McAfee's avatar
Lawrence McAfee committed
468
469
                self.model_float16_groups,
                self.model_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
470
                self.shard_float16_groups, # grad empty/unused here?
471
                self.shard_fp32_groups, # throws grad-access warning
472
473
474
                self.shard_fp32_from_float16_groups):
            for group in groups:
                _zero_grad_group_helper(group, set_to_none)
475

476

477
    def get_model_grad_buffer_dp_views(self):
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        """
        Get shard views of each of the DDP's grad buffers.

        In this nested list, the top level is grouped by the virtual model
        index and the grad buffer's data type. The sub-level is a list of
        shards of that grad buffer, where each shard in the list represents
        a contiguous view of the grad buffer, that is owned by a data-parallel
        rank. The shard boundary does not respect parameter boundaries, and
        so the elements of some parameters are split across data parallel
        ranks.

        Additionally, return references to the entire grad buffers, for use
        in _reduce_scatter_base and _all_gather_base.
        """
492
493
494
495
496
497
498
499
500
501
502
503

        data_parallel_world_size = mpu.get_data_parallel_world_size()

        # Grad buffer views.
        gbuf_view_items = []
        for model_index, model in enumerate(self.models):
            for dtype, gbuf in model._grad_buffers.items():

                assert gbuf.numel_padded % data_parallel_world_size == 0
                shard_size = int(gbuf.numel_padded / data_parallel_world_size)
                gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
                              for r in range(data_parallel_world_size)]
504
                gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views))
505
506

        return gbuf_view_items
507

Lawrence McAfee's avatar
Lawrence McAfee committed
508

509
    def reduce_model_grads(self, args, timers):
510
        """
511
512
513
514
515
        Reduce-scatter model grads.

        The DDP's grad buffer is used for the reduce-scatter, and thus no
        tensors are dynamically allocated.

516
        Note: this is a different order of reduction, versus the non-
517
518
        distributed optimizer, which reduces: 1) layernorm grads, 2) all
        grads, 3) embedding grads.
519
        """
520

521
        # All-reduce layer-norm grads (for sequence parallelism).
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
522
523
        timers('layernorm-grads-all-reduce', log_level=1).start(
            barrier=args.barrier_with_L1_time)
524
        self.allreduce_layernorm_grads(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
525
        timers('layernorm-grads-all-reduce').stop()
526

527
        # All-reduce embedding grads.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
528
529
        timers('embedding-grads-all-reduce', log_level=1).start(
            barrier=args.barrier_with_L1_time)
530
        self.allreduce_embedding_grads(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
531
        timers('embedding-grads-all-reduce').stop()
532

533
        # Reduce-scatter setup.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
534
535
        timers('grads-reduce-scatter', log_level=1).start(
            barrier=args.barrier_with_L1_time)
536
537
538
539
        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()
        data_parallel_group = mpu.get_data_parallel_group()

540
541
542
543
544
        # Scale grad buffers by '1 / data_parallel_world_size'.
        for model in self.models:
            for dtype, gbuf in model._grad_buffers.items():
                gbuf.data /= data_parallel_world_size

545
        # Reduce-scatter all grads.
546
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
547
548
549
        for index, (model_index, dtype, gbuf, gbuf_views) \
            in enumerate(gbuf_view_items):

550
            torch.distributed._reduce_scatter_base(
551
                gbuf_views[data_parallel_rank],
552
                gbuf,
553
554
                group = data_parallel_group,
            )
555

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
556
        timers('grads-reduce-scatter').stop()
557

Lawrence McAfee's avatar
Lawrence McAfee committed
558

559
    def gather_model_params(self, args, timers):
560
561
562
563
564
565
566
        """
        All-gather updated model params.

        The DDP's grad buffer is used for the all-gather, and thus no
        tensors are dynamically allocated. After the all-gather, the params
        can be copied from param.main_grad to param.
        """
567

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
568
569
        timers('params-all-gather', log_level=1).start(
            barrier=args.barrier_with_L1_time)
570
571
572
573
574
575
576
577
578

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_group = mpu.get_data_parallel_group()

        # All-gather updated main params.
        # - All grad buffer views are guaranteed to have the same num elements
        #   across all data parallel ranks, with grad buffer padding that is done
        #   in distributed.py. Thus, all sub-views will have consistent start/end
        #   indexes across data parallel ranks.
579
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
580
581
582
        for index, (model_index, dtype, gbuf, gbuf_views) \
            in enumerate(gbuf_view_items):

583
            torch.distributed._all_gather_base(
584
                gbuf,
585
586
                gbuf_views[data_parallel_rank],
                group = data_parallel_group,
587
            )
588
589
590
591
592
593
594
595

        # Each model param now contains its updated values in its
        # '.main_grad' field.
        for model in self.models:
            for dtype, param_map in model._grad_buffer_param_index_map.items():
                for param in param_map:
                    param.detach().copy_(param.main_grad)

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
596
        timers('params-all-gather').stop()
597

Lawrence McAfee's avatar
Lawrence McAfee committed
598

599
    def _collect_main_grad_data_for_unscaling(self):
600
601
602
603
        """
        Note: this should be equivalent to the float-16 optimizer's method,
        but writtent differently, so the two should be combined.
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
604
        return [
Lawrence McAfee's avatar
Lawrence McAfee committed
605
606
607
608
            param.grad.data
            for group in self.optimizer.param_groups
            for param in group["params"]
        ]
609
610


Lawrence McAfee's avatar
Lawrence McAfee committed
611
    def _get_model_and_main_params_data_float16(self):
612
613
614
        """
        Get aligned list of model and main params.
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
615
616
617
618
619
620
621
622
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.shard_float16_groups,
                                           self.shard_fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
623
624


625
    def _copy_model_grads_to_main_grads(self):
626
627
628
629
630
631
632
        """
        Copy model grads to main grads.

        Since this step follows a reduce-scatter through the DDP's grad
        buffer, this method is responsible for copying the updated grads
        from the grad buffer to the main shard's grad field.
        """
633

634
        # Utility method for copying group grads.
Lawrence McAfee's avatar
Lawrence McAfee committed
635
636
637
638
639
        def copy_group_grads(model_groups, shard_main_groups):
            for model_group, shard_main_group in zip(model_groups,
                                                     shard_main_groups):
                for model_param, shard_main_param in zip(model_group,
                                                         shard_main_group):
Lawrence McAfee's avatar
Lawrence McAfee committed
640

Lawrence McAfee's avatar
Lawrence McAfee committed
641
                    param_range_map = self.get_model_param_range_map(model_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
642
                    param_range = param_range_map["param"]
Lawrence McAfee's avatar
Lawrence McAfee committed
643
644
                    assert param_range.size == shard_main_param.nelement()

Lawrence McAfee's avatar
Lawrence McAfee committed
645
646
                    model_grad = model_param.main_grad
                    shard_model_grad = model_grad.view(-1) \
Lawrence McAfee's avatar
Lawrence McAfee committed
647
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
648
649
                    shard_main_param.grad = shard_model_grad.float()

650
        # Copy model groups to shard groups.
Lawrence McAfee's avatar
Lawrence McAfee committed
651
        copy_group_grads(self.model_float16_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
652
                         self.shard_fp32_from_float16_groups)
Lawrence McAfee's avatar
Lawrence McAfee committed
653
        copy_group_grads(self.model_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
654
                         self.shard_fp32_groups)
655

656
657

    def _copy_main_params_to_model_params(self):
658
659
660
661
662
663
664
        """
        Copy main params to model params.

        Since this step is followed by an all-gather through the DDP's grad
        buffer, this method is responsible for copying the updated params
        from the main shards into the correct position in the grad buffer.
        """
665

666
        # Utility method for copying group params.
Lawrence McAfee's avatar
Lawrence McAfee committed
667
668
669
670
671
        def copy_group_params(shard_main_groups, model_groups):
            for shard_main_group, model_group in zip(shard_main_groups,
                                                     model_groups):
                for shard_main_param, model_param in zip(shard_main_group,
                                                         model_group):
Lawrence McAfee's avatar
Lawrence McAfee committed
672

Lawrence McAfee's avatar
Lawrence McAfee committed
673
                    param_range_map = self.get_model_param_range_map(model_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
674
675
676
                    param_range = param_range_map["param"]
                    assert param_range.size == shard_main_param.nelement()

Lawrence McAfee's avatar
Lawrence McAfee committed
677
678
                    model_grad = model_param.main_grad
                    shard_model_grad = model_grad.view(-1) \
Lawrence McAfee's avatar
Lawrence McAfee committed
679
                        [param_range.start:param_range.end]
680

681
                    shard_model_grad.data.copy_(shard_main_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
682

683
        # Copy shard groups to model groups.
Lawrence McAfee's avatar
Lawrence McAfee committed
684
        copy_group_params(self.shard_fp32_from_float16_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
685
                          self.model_float16_groups)
Lawrence McAfee's avatar
Lawrence McAfee committed
686
        copy_group_params(self.shard_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
687
                          self.model_fp32_groups)