distrib_optimizer.py 29.4 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6

"""Megatron distributed optimizer."""


import math
7
import torch
8
9

from megatron import get_args
10
11
from megatron import get_timers
from megatron import mpu
Lawrence McAfee's avatar
Lawrence McAfee committed
12
from megatron import print_rank_0
13
14
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
15
16

from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
17

Lawrence McAfee's avatar
Lawrence McAfee committed
18

19
class Range:
20
21
    """
    A range represents a start and end points for indexing a shard
Lawrence McAfee's avatar
Lawrence McAfee committed
22
    from a full tensor.
23
    """
24
25
26
27
28
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.size = end - start
    def normalize(self, start = 0):
29
        return Range(start, start + self.size)
30
31
32
33
    def __str__(self):
        return "%d,%d [%d]" % (self.start, self.end, self.size)


34
class DistributedOptimizer(MixedPrecisionOptimizer):
35
    """Distributed optimizer, for all data types (fp16, bf16, and fp32).
Lawrence McAfee's avatar
Lawrence McAfee committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
62
    """
63
64

    @classmethod
65
    def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        """
        Build mapping from param reference to grad buffer shard ranges.

        This method builds a mapping from parameter references to grad
        buffer shard ranges, specific to each data-parallel (DP) rank's
        set of 'owned' parameters. Each grad buffer (padded to be an even
        multiple of DP-world-size) is conceptually divided into DP-world-size
        contiguous regions, where each DP rank 'owns' a contiguous regions.
        Ownership in this sense means DP rank is responsible for reducing
        the relevant subset of grads, and updating the relevant subset of
        params.

        This conceptual partitioning of the grad buffer does NOT respect
        parameter boundaries, and as such it is assumed that each created
        range references a shard (or subset) of the full parameter. It is
        easiest to think of each DP rank as operating (i.e., reducing,
        gathering) purely on views into the grad buffer, for all model-to-
        main & main-to-model operations.

        This method creates three ranges:
        - The param's range within the entire grad buffer (i.e., world index).
        - The param's range within the DP rank's local view of the grad buffer.
        - The param's range within itself (i.e., its shard).
        """
90

91
        # Param range map.
92
        param_world_index_map = model._grad_buffer_param_index_map[dtype]
93
        param_range_map = {}
94
95
        for param, param_world_indexes in param_world_index_map.items():

96
            # Param range.
97
98
99
            param_world_start, param_world_end = param_world_indexes
            param_local_start = max(
                0,
100
                param_world_start - gbuf_world_range.start)
101
            param_local_end = min(
102
103
                gbuf_world_range.size,
                param_world_end - gbuf_world_range.start)
104

105
            # Add param, if within local gbuf range.
106
            if param_local_end > param_local_start:
107
108
109
110
111
112
113
114
115
                param_local_range = Range(param_local_start, param_local_end)
                param_world_range = param_local_range.normalize(
                    param_local_start + gbuf_world_range.start)
                sub_param_start = max(0, gbuf_world_range.start-param_world_start)
                sub_param_range = param_local_range.normalize(sub_param_start)
                param_range_map[param] = {
                    "gbuf_world" : param_world_range,
                    "gbuf_local" : param_local_range,
                    "param" : sub_param_range,
116
117
                }

118
        return param_range_map
119

Lawrence McAfee's avatar
Lawrence McAfee committed
120

121
    @classmethod
122
    def build_model_gbuf_range(cls, model, dtype):
123
124
125
126
127
128
129
130
131
        """
        Build mapping between params and their grad buffers.

        This method does the initial setup for the method above. This setup
        includes determining the shard ranges into the DDP's grad buffer for
        each data-parallel (DP) rank. Each DP rank keeps range info for
        all other DP ranks, for the purpose of creating args for
        reduce-scatter and all-gather.
        """
132
133
134
135

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()

136
        # Grad buffer range.
137
138
        grad_buffer = model._grad_buffers[dtype]
        gbuf_size = grad_buffer.numel
139
        max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size))
140

141
142
        # All world ranges. (i.e., across all data parallel ranks)
        gbuf_world_all_ranges = []
143
        for r in range(data_parallel_world_size):
144
145
146
147
            gbuf_world_start = r * max_gbuf_range_size
            gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size)
            gbuf_world_range = Range(gbuf_world_start, gbuf_world_end)
            gbuf_world_all_ranges.append(gbuf_world_range)
148

149
150
151
        # Local DP's ranges.
        gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
        gbuf_local_range = gbuf_world_range.normalize()
152

153
154
155
156
        # Get each param's ranges.
        param_range_map = cls.build_model_gbuf_param_range_map(model,
                                                               dtype,
                                                               gbuf_world_range)
157

158
        # Group into dict.
159
        data = {
160
161
162
163
164
            "local" : gbuf_local_range,
            "world" : gbuf_world_range,
            "world_all" : gbuf_world_all_ranges,
            "param_map" : param_range_map,
            "max_range_size" : max_gbuf_range_size,
165
166
167
168
        }

        return data

Lawrence McAfee's avatar
Lawrence McAfee committed
169

170
    @classmethod
171
    def build_model_gbuf_range_map(cls, model):
172
173
174
175
        """
        Create param-to-grad-buffer mappings, for grad buffer data types
        within a specific virtual model.
        """
176
        return {
177
            dtype : cls.build_model_gbuf_range(model, dtype)
178
179
180
            for dtype in model._grad_buffers
        }

Lawrence McAfee's avatar
Lawrence McAfee committed
181

182
    @classmethod
183
    def build_model_param_gbuf_map(cls, model_gbuf_ranges):
184
185
186
187
        """
        Create a reverse of the model_gbuf_ranges, for referencing in
        opposite direction.
        """
188
        param_gbuf_map = {}
189
190
191
        for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param, param_range_map in gbuf_range_map["param_map"].items():
192
193
194
                    param_gbuf_map[param] = (model_index, dtype)
        return param_gbuf_map

Lawrence McAfee's avatar
Lawrence McAfee committed
195

196
    @classmethod
197
    def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
198
199
200
201
202
203
204
205
        """
        Create optimizer groups.

        Given the set of parameter shard ranges that are owned by the current
        data-parallel (DP) rank, gather the set of parameters that will be
        used (in the method below) to create the current DP's optimizer
        groups.
        """
206
207
208
209
210
211
212
213
214
215

        num_groups = len(param_groups)

        # Param group map.
        param_group_map = {}
        for group_index, group in enumerate(param_groups):
            for param in group["params"]:
                assert param.requires_grad
                param_group_map[param] = group_index

216
217
218
219
220
        # Optimizer group ranges.
        group_ranges = [ {"params": []} for _ in param_groups ]
        for model_gbuf_range_map in model_gbuf_ranges:
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param in gbuf_range_map["param_map"]:
221
                    group_index = param_group_map[param]
222
223
                    group_range = group_ranges[group_index]
                    group_range["params"].append(param)
224

225
226
227
228
        # Squeeze zero-size group ranges.
        for group_index, group_range in enumerate(group_ranges):
            group_range["orig_group"] = param_groups[group_index]
        group_ranges = [ g for g in group_ranges if len(g["params"]) > 0 ]
229

230
        return group_ranges
231

232

233
    @classmethod
234
    def build_model_and_main_param_groups(cls,
235
236
237
238
239
240
241
242
243
244
245
246
247
                                          model_gbuf_ranges,
                                          param_gbuf_map,
                                          opt_group_ranges):
        """
        Create main parameter groups needed for the optimizer step.

        These groups encompass both: 1) groups used by this class, for
        reducing/gather, and 2) groups used by the inner optimizer for the
        parameter update. Given that the conceptual grad buffer partitioning
        (created in earlier method) doesn't respect parameter boundaries,
        the optimizer operates on shards of the model parameters, rather than
        the full parameters.
        """
248

Lawrence McAfee's avatar
Lawrence McAfee committed
249
250
251
252
253
254
255
256
        # Parameter groups:
        #   model_float16_groups: original float16 parameters
        #   model_fp32_groups: original fp32 parameters
        #   shard_float16_groups: shards of original float16 parameters
        #   shard_fp32_groups: shards of original fp32 parameters
        #   shard_fp32_from_float16_groups: fp32 copy of float16 parameters
        model_float16_groups = []
        model_fp32_groups = []
257
258
259
260
        shard_float16_groups = []
        shard_fp32_groups = []
        shard_fp32_from_float16_groups = []

Lawrence McAfee's avatar
Lawrence McAfee committed
261
        # Allocate (or slice) each group's param shard.
262
263
264
        for group_index, group_range in enumerate(opt_group_ranges):

            # Params of this group.
Lawrence McAfee's avatar
Lawrence McAfee committed
265
266
            model_float16_params_this_group = []
            model_fp32_params_this_group = []
267
268
269
            shard_float16_params_this_group = []
            shard_fp32_params_this_group = []
            shard_fp32_from_float16_params_this_group = []
Lawrence McAfee's avatar
Lawrence McAfee committed
270
271
            model_float16_groups.append(model_float16_params_this_group)
            model_fp32_groups.append(model_fp32_params_this_group)
272
273
274
275
276
277
278
            shard_float16_groups.append(shard_float16_params_this_group)
            shard_fp32_groups.append(shard_fp32_params_this_group)
            shard_fp32_from_float16_groups.append(
                shard_fp32_from_float16_params_this_group)

            for model_param in group_range["params"]:

279
280
                assert model_param.requires_grad

281
282
283
                model_index, dtype = param_gbuf_map[model_param]
                gbuf_range = model_gbuf_ranges[model_index][dtype]
                param_range = gbuf_range["param_map"][model_param]["param"]
284
285

                # fp16, bf16 params.
286
287
288
289
                if model_param.type() in ['torch.cuda.HalfTensor',
                                          'torch.cuda.BFloat16Tensor']:

                    # Clone model -> main.
Lawrence McAfee's avatar
Lawrence McAfee committed
290
291
                    shard_model_param = model_param.detach().view(-1) \
                        [param_range.start:param_range.end]
292
293
294
295
296
297
298
299
300
301
                    shard_main_param = shard_model_param.clone().float()
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_main_param, model_param)
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
                        shard_main_param.shared = model_param.shared

                    # Add to group.
Lawrence McAfee's avatar
Lawrence McAfee committed
302
                    model_float16_params_this_group.append(model_param)
303
304
                    shard_float16_params_this_group.append(shard_model_param)
                    shard_fp32_from_float16_params_this_group.append(shard_main_param)
305
306

                # fp32 params.
307
                elif model_param.type() == 'torch.cuda.FloatTensor':
Lawrence McAfee's avatar
Lawrence McAfee committed
308
309
                    shard_model_param = model_param.view(-1) \
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
310
                    model_fp32_params_this_group.append(model_param)
311
                    shard_fp32_params_this_group.append(shard_model_param)
312
313
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
314
315
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
316
317
318
319
320
321
322
323

                else:
                    raise TypeError('Wrapped parameters must be one of '
                                    'torch.cuda.FloatTensor,  '
                                    'torch.cuda.HalfTensor, or '
                                    'torch.cuda.BFloat16Tensor. '
                                    'Received {}'.format(param.type()))

Lawrence McAfee's avatar
Lawrence McAfee committed
324
            # Update optimizer's params.
325
326
327
328
329
330
            group_range["orig_group"]["params"] = [
                *shard_fp32_params_this_group,
                *shard_fp32_from_float16_params_this_group,
            ]

        return (
Lawrence McAfee's avatar
Lawrence McAfee committed
331
332
            model_float16_groups,
            model_fp32_groups,
333
334
335
336
            shard_float16_groups,
            shard_fp32_groups,
            shard_fp32_from_float16_groups,
        )
337

Lawrence McAfee's avatar
Lawrence McAfee committed
338

339
340
    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
341
                 fp16, bf16, params_dtype, grad_scaler, models):
342
343
        """
        See top of class definition for argument descriptions.
344
345
346
347
348
349

        The steps in this method create the core mapping between DDP grad
        buffers, parameters, and parameter shard ranges, that is needed for
        converting between model param indexes and main parameter shard
        indexes. This method also updates the optimizer parameter groups
        with the newly created shards.
350
        """
351
352
353
354

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
355
            fp16, bf16, params_dtype, grad_scaler, models)
356

357
358
        # Verify that contiguous buffers are being used.
        # - Note: this should already be checked in arguments.py.
359
        assert use_contiguous_buffers_in_local_ddp
360

361
362
        # Model grad buffer ranges.
        self.model_gbuf_ranges = []
363
        for model_index, model in enumerate(self.models):
364
365
366
            self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model))
        self.model_param_gbuf_map = \
            self.build_model_param_gbuf_map(self.model_gbuf_ranges)
367

368
369
        # Optimizer ranges.
        self.opt_group_ranges = self.build_optimizer_group_ranges(
370
            self.optimizer.param_groups,
371
            self.model_gbuf_ranges)
372
373

        # Allocate main param shards.
374
        (
Lawrence McAfee's avatar
Lawrence McAfee committed
375
376
            self.model_float16_groups,
            self.model_fp32_groups,
377
378
379
380
381
382
383
            self.shard_float16_groups,
            self.shard_fp32_groups,
            self.shard_fp32_from_float16_groups,
        ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges,
                                                   self.model_param_gbuf_map,
                                                   self.opt_group_ranges)

384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
        # Initialize param buffers.
        # - These are views on the DDP model's grad buffers, that share
        #   storage & have their own dtype. This is safe because the param
        #   dtype size is always <= grad dtype size.
        self.param_buffers = []
        for model_index, model in enumerate(self.models):
            current_param_buffers = {}
            for dtype, grad_buffer in model._grad_buffers.items():
                param_buffer = torch.tensor(grad_buffer.data.storage()._untyped(),
                                            dtype = params_dtype,
                                            device = grad_buffer.data.device)
                param_buffer = param_buffer[:grad_buffer.numel_padded]
                current_param_buffers[dtype] = param_buffer
            self.param_buffers.append(current_param_buffers)

399
400
401
402
        # Update optimizer groups.
        # - Also, leverage state_dict() and load_state_dict() to
        #   recast preexisting per-param state tensors.
        self.optimizer.param_groups = \
403
            [ g["orig_group"] for g in self.opt_group_ranges ]
404
405
        self.optimizer.load_state_dict(self.optimizer.state_dict())

406

407
    def get_model_param_range_map(self, param):
408
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
409
410
        Given a model param, get the index sub-range of the param that this
        data-parallel rank owns.
411
        """
412
413
414
415
416
417
        model_index, dtype = self.model_param_gbuf_map[param]
        gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
        param_range_map = gbuf_range_map["param_map"][param]
        return param_range_map


418
    def get_model_parallel_group(self):
419
420
421
422
        """
        With the distributed optimizer, the model parallel group is the
        entire world.
        """
423
424
        return None

425
426

    def state_dict(self):
427
428
429
        """
        The state dict must contain the fp32-from-float16 shards.
        """
430
431
432
433
434
435
436
437
438
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['shard_fp32_from_float16_groups'] = \
            self.shard_fp32_from_float16_groups
        return state_dict


439
    def load_state_dict(self, state_dict):
440
441
442
        """
        Load the state dict.
        """
443
444
445
446
447
448
449
450
451
452
453

        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
454
455
456
            if self.fp16:
                print_rank_0('***WARNING*** found an old checkpoint, will not '
                             'load grad scaler ...')
457
458
459
460
461
462
463
464
465
466
467
468
469
470
        else:
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')

        # Copy data for the main params.
        for current_group, saved_group in zip(
                self.shard_fp32_from_float16_groups,
                state_dict["shard_fp32_from_float16_groups"]):
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)
Lawrence McAfee's avatar
Lawrence McAfee committed
471

472

473
    def zero_grad(self, set_to_none=True):
474
475
476
477
478
479
        """
        Zero grads.

        We only need to zero the model related parameters, i.e.,
        model_float16_groups & model_fp32_groups. We additionally zero
        the remaining groups as a memory optimization to reduce
480
        fragmentation; in the case of set_to_none==True, the space
481
482
        used by this field can be safely deallocated at this point.
        """
483
        for groups in (
Lawrence McAfee's avatar
Lawrence McAfee committed
484
485
                self.model_float16_groups,
                self.model_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
486
                self.shard_float16_groups, # grad empty/unused here?
487
                self.shard_fp32_groups, # throws grad-access warning
488
489
490
                self.shard_fp32_from_float16_groups):
            for group in groups:
                _zero_grad_group_helper(group, set_to_none)
491

492

493
494
    @staticmethod
    def get_model_buffer_dp_views(model_buffers):
495
        """
496
        Get shard views of each of the DDP's param/grad buffers.
497
498

        In this nested list, the top level is grouped by the virtual model
499
500
501
        index and the buffer's data type. The sub-level is a list of
        shards of that buffer, where each shard in the list represents
        a contiguous view of the buffer, that is owned by a data-parallel
502
503
504
505
        rank. The shard boundary does not respect parameter boundaries, and
        so the elements of some parameters are split across data parallel
        ranks.

506
        Additionally, return references to the entire buffers, for use
507
508
        in _reduce_scatter_base and _all_gather_base.
        """
509
510
511

        data_parallel_world_size = mpu.get_data_parallel_world_size()

512
513
514
515
516
517
518
519
520
521
        # Buffer views.
        view_items = []
        for model_index, buffers in enumerate(model_buffers):
            for dtype, buf in buffers.items():

                assert buf.numel() % data_parallel_world_size == 0
                shard_size = int(buf.numel() / data_parallel_world_size)
                buf_views = [buf[(r*shard_size):((r+1)*shard_size)]
                             for r in range(data_parallel_world_size)]
                view_items.append((model_index, dtype, buf, buf_views))
522

523
        return view_items
524

525
526
527
528
529
530
531
532
533
534

    def get_model_grad_buffer_dp_views(self):
        return self.get_model_buffer_dp_views([
            {dtype : mem_buffer.data}
            for model in self.models
            for dtype, mem_buffer in model._grad_buffers.items()])


    def get_model_param_buffer_dp_views(self):
        return self.get_model_buffer_dp_views(self.param_buffers)
535

Lawrence McAfee's avatar
Lawrence McAfee committed
536

537
    def reduce_model_grads(self, args, timers):
538
        """
539
540
541
542
543
        Reduce-scatter model grads.

        The DDP's grad buffer is used for the reduce-scatter, and thus no
        tensors are dynamically allocated.

544
        Note: this is a different order of reduction, versus the non-
545
546
        distributed optimizer, which reduces: 1) layernorm grads, 2) all
        grads, 3) embedding grads.
547
        """
548

549
        # All-reduce layer-norm grads (for sequence parallelism).
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
550
551
        timers('layernorm-grads-all-reduce', log_level=1).start(
            barrier=args.barrier_with_L1_time)
552
        self.allreduce_layernorm_grads(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
553
        timers('layernorm-grads-all-reduce').stop()
554

555
        # All-reduce embedding grads.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
556
557
        timers('embedding-grads-all-reduce', log_level=1).start(
            barrier=args.barrier_with_L1_time)
558
        self.allreduce_embedding_grads(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
559
        timers('embedding-grads-all-reduce').stop()
560

561
        # Reduce-scatter setup.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
562
563
        timers('grads-reduce-scatter', log_level=1).start(
            barrier=args.barrier_with_L1_time)
564
565
566
567
        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()
        data_parallel_group = mpu.get_data_parallel_group()

568
569
570
571
572
        # Scale grad buffers by '1 / data_parallel_world_size'.
        for model in self.models:
            for dtype, gbuf in model._grad_buffers.items():
                gbuf.data /= data_parallel_world_size

573
        # Reduce-scatter all grads.
574
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
575
576
577
        for index, (model_index, dtype, gbuf, gbuf_views) \
            in enumerate(gbuf_view_items):

578
            torch.distributed._reduce_scatter_base(
579
                gbuf_views[data_parallel_rank],
580
                gbuf,
581
582
                group = data_parallel_group,
            )
583

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
584
        timers('grads-reduce-scatter').stop()
585

Lawrence McAfee's avatar
Lawrence McAfee committed
586

587
    def gather_model_params(self, args, timers):
588
589
590
        """
        All-gather updated model params.

591
        The DDP's param buffer is used for the all-gather, and thus no
592
        tensors are dynamically allocated. After the all-gather, the params
593
        can be copied from the param buffer to the param.
594
        """
595

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
596
597
        timers('params-all-gather', log_level=1).start(
            barrier=args.barrier_with_L1_time)
598
599
600
601
602

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_group = mpu.get_data_parallel_group()

        # All-gather updated main params.
603
604
605
606
607
608
609
610
        # - All param buffer views are guaranteed to have the same num elements
        #   across all data parallel ranks, due to grad buffer padding that is
        #   done in distributed.py, and extended to the param buffers. Thus,
        #   all sub-views will have consistent start/end indexes across data
        #   parallel ranks.
        pbuf_view_items = self.get_model_param_buffer_dp_views()
        for index, (model_index, dtype, pbuf, pbuf_views) \
            in enumerate(pbuf_view_items):
611

612
            torch.distributed._all_gather_base(
613
614
                pbuf,
                pbuf_views[data_parallel_rank],
615
                group = data_parallel_group,
616
            )
617

618
619
        # Copy from param buffer to each param.
        for model_id, model in enumerate(self.models):
620
            for dtype, param_map in model._grad_buffer_param_index_map.items():
621
622
623
624
                for param, buf_range in param_map.items():
                    param_buf = self.param_buffers[model_id][dtype]
                    param_buf_shard = param_buf[buf_range[0]:buf_range[1]]
                    param.view(-1).detach().copy_(param_buf_shard)
625

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
626
        timers('params-all-gather').stop()
627

Lawrence McAfee's avatar
Lawrence McAfee committed
628

629
    def _collect_main_grad_data_for_unscaling(self):
630
631
632
633
        """
        Note: this should be equivalent to the float-16 optimizer's method,
        but writtent differently, so the two should be combined.
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
634
        return [
Lawrence McAfee's avatar
Lawrence McAfee committed
635
636
637
638
            param.grad.data
            for group in self.optimizer.param_groups
            for param in group["params"]
        ]
639
640


Lawrence McAfee's avatar
Lawrence McAfee committed
641
    def _get_model_and_main_params_data_float16(self):
642
643
644
        """
        Get aligned list of model and main params.
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
645
646
647
648
649
650
651
652
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.shard_float16_groups,
                                           self.shard_fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
653
654


655
    def _copy_model_grads_to_main_grads(self):
656
657
658
659
660
661
662
        """
        Copy model grads to main grads.

        Since this step follows a reduce-scatter through the DDP's grad
        buffer, this method is responsible for copying the updated grads
        from the grad buffer to the main shard's grad field.
        """
663

664
        # Utility method for copying group grads.
Lawrence McAfee's avatar
Lawrence McAfee committed
665
666
667
668
669
        def copy_group_grads(model_groups, shard_main_groups):
            for model_group, shard_main_group in zip(model_groups,
                                                     shard_main_groups):
                for model_param, shard_main_param in zip(model_group,
                                                         shard_main_group):
Lawrence McAfee's avatar
Lawrence McAfee committed
670

Lawrence McAfee's avatar
Lawrence McAfee committed
671
                    param_range_map = self.get_model_param_range_map(model_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
672
                    param_range = param_range_map["param"]
Lawrence McAfee's avatar
Lawrence McAfee committed
673
674
                    assert param_range.size == shard_main_param.nelement()

Lawrence McAfee's avatar
Lawrence McAfee committed
675
676
                    model_grad = model_param.main_grad
                    shard_model_grad = model_grad.view(-1) \
Lawrence McAfee's avatar
Lawrence McAfee committed
677
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
678
679
                    shard_main_param.grad = shard_model_grad.float()

680
        # Copy model groups to shard groups.
Lawrence McAfee's avatar
Lawrence McAfee committed
681
        copy_group_grads(self.model_float16_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
682
                         self.shard_fp32_from_float16_groups)
Lawrence McAfee's avatar
Lawrence McAfee committed
683
        copy_group_grads(self.model_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
684
                         self.shard_fp32_groups)
685

686
687

    def _copy_main_params_to_model_params(self):
688
689
690
691
692
693
694
        """
        Copy main params to model params.

        Since this step is followed by an all-gather through the DDP's grad
        buffer, this method is responsible for copying the updated params
        from the main shards into the correct position in the grad buffer.
        """
695

696
        # Utility method for copying group params.
Lawrence McAfee's avatar
Lawrence McAfee committed
697
698
699
700
701
        def copy_group_params(shard_main_groups, model_groups):
            for shard_main_group, model_group in zip(shard_main_groups,
                                                     model_groups):
                for shard_main_param, model_param in zip(shard_main_group,
                                                         model_group):
Lawrence McAfee's avatar
Lawrence McAfee committed
702

Lawrence McAfee's avatar
Lawrence McAfee committed
703
                    param_range_map = self.get_model_param_range_map(model_param)
704
                    world_range = param_range_map["gbuf_world"]
Lawrence McAfee's avatar
Lawrence McAfee committed
705

706
707
708
709
710
711
712
                    assert world_range.size == shard_main_param.nelement()

                    model_id, dtype = self.model_param_gbuf_map[model_param]
                    model_param_buffer = self.param_buffers[model_id][dtype]

                    shard_model_param = model_param_buffer.view(-1) \
                        [world_range.start:world_range.end]
713

714
                    shard_model_param.data.copy_(shard_main_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
715

716
        # Copy shard groups to model groups.
Lawrence McAfee's avatar
Lawrence McAfee committed
717
        copy_group_params(self.shard_fp32_from_float16_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
718
                          self.model_float16_groups)
Lawrence McAfee's avatar
Lawrence McAfee committed
719
        copy_group_params(self.shard_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
720
                          self.model_fp32_groups)