distrib_optimizer.py 29.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron distributed optimizer."""


import math
20
import torch
21
22

from megatron import get_args
23
24
from megatron import get_timers
from megatron import mpu
Lawrence McAfee's avatar
Lawrence McAfee committed
25
from megatron import print_rank_0
26
27
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
28
29

from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
30

Lawrence McAfee's avatar
Lawrence McAfee committed
31

32
class Range:
33
34
    """
    A range represents a start and end points for indexing a shard
Lawrence McAfee's avatar
Lawrence McAfee committed
35
    from a full tensor.
36
    """
37
38
39
40
41
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.size = end - start
    def normalize(self, start = 0):
42
        return Range(start, start + self.size)
43
44
45
46
    def __str__(self):
        return "%d,%d [%d]" % (self.start, self.end, self.size)


47
class DistributedOptimizer(MixedPrecisionOptimizer):
48
    """Distributed optimizer, for all data types (fp16, bf16, and fp32).
Lawrence McAfee's avatar
Lawrence McAfee committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
75
    """
76
77

    @classmethod
78
    def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        """
        Build mapping from param reference to grad buffer shard ranges.

        This method builds a mapping from parameter references to grad
        buffer shard ranges, specific to each data-parallel (DP) rank's
        set of 'owned' parameters. Each grad buffer (padded to be an even
        multiple of DP-world-size) is conceptually divided into DP-world-size
        contiguous regions, where each DP rank 'owns' a contiguous regions.
        Ownership in this sense means DP rank is responsible for reducing
        the relevant subset of grads, and updating the relevant subset of
        params.

        This conceptual partitioning of the grad buffer does NOT respect
        parameter boundaries, and as such it is assumed that each created
        range references a shard (or subset) of the full parameter. It is
        easiest to think of each DP rank as operating (i.e., reducing,
        gathering) purely on views into the grad buffer, for all model-to-
        main & main-to-model operations.

        This method creates three ranges:
        - The param's range within the entire grad buffer (i.e., world index).
        - The param's range within the DP rank's local view of the grad buffer.
        - The param's range within itself (i.e., its shard).
        """
103

104
        # Param range map.
105
        param_world_index_map = model._grad_buffer_param_index_map[dtype]
106
        param_range_map = {}
107
108
        for param, param_world_indexes in param_world_index_map.items():

109
            # Param range.
110
111
112
            param_world_start, param_world_end = param_world_indexes
            param_local_start = max(
                0,
113
                param_world_start - gbuf_world_range.start)
114
            param_local_end = min(
115
116
                gbuf_world_range.size,
                param_world_end - gbuf_world_range.start)
117

118
            # Add param, if within local gbuf range.
119
            if param_local_end > param_local_start:
120
121
122
123
124
125
126
127
128
                param_local_range = Range(param_local_start, param_local_end)
                param_world_range = param_local_range.normalize(
                    param_local_start + gbuf_world_range.start)
                sub_param_start = max(0, gbuf_world_range.start-param_world_start)
                sub_param_range = param_local_range.normalize(sub_param_start)
                param_range_map[param] = {
                    "gbuf_world" : param_world_range,
                    "gbuf_local" : param_local_range,
                    "param" : sub_param_range,
129
130
                }

131
        return param_range_map
132

Lawrence McAfee's avatar
Lawrence McAfee committed
133

134
    @classmethod
135
    def build_model_gbuf_range(cls, model, dtype):
136
137
138
139
140
141
142
143
144
        """
        Build mapping between params and their grad buffers.

        This method does the initial setup for the method above. This setup
        includes determining the shard ranges into the DDP's grad buffer for
        each data-parallel (DP) rank. Each DP rank keeps range info for
        all other DP ranks, for the purpose of creating args for
        reduce-scatter and all-gather.
        """
145
146
147
148

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()

149
        # Grad buffer range.
150
151
        grad_buffer = model._grad_buffers[dtype]
        gbuf_size = grad_buffer.numel
152
        max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size))
153

154
155
        # All world ranges. (i.e., across all data parallel ranks)
        gbuf_world_all_ranges = []
156
        for r in range(data_parallel_world_size):
157
158
159
160
            gbuf_world_start = r * max_gbuf_range_size
            gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size)
            gbuf_world_range = Range(gbuf_world_start, gbuf_world_end)
            gbuf_world_all_ranges.append(gbuf_world_range)
161

162
163
164
        # Local DP's ranges.
        gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
        gbuf_local_range = gbuf_world_range.normalize()
165

166
167
168
169
        # Get each param's ranges.
        param_range_map = cls.build_model_gbuf_param_range_map(model,
                                                               dtype,
                                                               gbuf_world_range)
170

171
        # Group into dict.
172
        data = {
173
174
175
176
177
            "local" : gbuf_local_range,
            "world" : gbuf_world_range,
            "world_all" : gbuf_world_all_ranges,
            "param_map" : param_range_map,
            "max_range_size" : max_gbuf_range_size,
178
179
180
181
        }

        return data

Lawrence McAfee's avatar
Lawrence McAfee committed
182

183
    @classmethod
184
    def build_model_gbuf_range_map(cls, model):
185
186
187
188
        """
        Create param-to-grad-buffer mappings, for grad buffer data types
        within a specific virtual model.
        """
189
        return {
190
            dtype : cls.build_model_gbuf_range(model, dtype)
191
192
193
            for dtype in model._grad_buffers
        }

Lawrence McAfee's avatar
Lawrence McAfee committed
194

195
    @classmethod
196
    def build_model_param_gbuf_map(cls, model_gbuf_ranges):
197
198
199
200
        """
        Create a reverse of the model_gbuf_ranges, for referencing in
        opposite direction.
        """
201
        param_gbuf_map = {}
202
203
204
        for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param, param_range_map in gbuf_range_map["param_map"].items():
205
206
207
                    param_gbuf_map[param] = (model_index, dtype)
        return param_gbuf_map

Lawrence McAfee's avatar
Lawrence McAfee committed
208

209
    @classmethod
210
    def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
211
212
213
214
215
216
217
218
        """
        Create optimizer groups.

        Given the set of parameter shard ranges that are owned by the current
        data-parallel (DP) rank, gather the set of parameters that will be
        used (in the method below) to create the current DP's optimizer
        groups.
        """
219
220
221
222
223
224
225
226
227
228

        num_groups = len(param_groups)

        # Param group map.
        param_group_map = {}
        for group_index, group in enumerate(param_groups):
            for param in group["params"]:
                assert param.requires_grad
                param_group_map[param] = group_index

229
230
231
232
233
        # Optimizer group ranges.
        group_ranges = [ {"params": []} for _ in param_groups ]
        for model_gbuf_range_map in model_gbuf_ranges:
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param in gbuf_range_map["param_map"]:
234
                    group_index = param_group_map[param]
235
236
                    group_range = group_ranges[group_index]
                    group_range["params"].append(param)
237

238
239
240
241
        # Squeeze zero-size group ranges.
        for group_index, group_range in enumerate(group_ranges):
            group_range["orig_group"] = param_groups[group_index]
        group_ranges = [ g for g in group_ranges if len(g["params"]) > 0 ]
242

243
        return group_ranges
244

245

246
    @classmethod
247
    def build_model_and_main_param_groups(cls,
248
249
250
251
252
253
254
255
256
257
258
259
260
                                          model_gbuf_ranges,
                                          param_gbuf_map,
                                          opt_group_ranges):
        """
        Create main parameter groups needed for the optimizer step.

        These groups encompass both: 1) groups used by this class, for
        reducing/gather, and 2) groups used by the inner optimizer for the
        parameter update. Given that the conceptual grad buffer partitioning
        (created in earlier method) doesn't respect parameter boundaries,
        the optimizer operates on shards of the model parameters, rather than
        the full parameters.
        """
261

Lawrence McAfee's avatar
Lawrence McAfee committed
262
263
264
265
266
267
268
269
        # Parameter groups:
        #   model_float16_groups: original float16 parameters
        #   model_fp32_groups: original fp32 parameters
        #   shard_float16_groups: shards of original float16 parameters
        #   shard_fp32_groups: shards of original fp32 parameters
        #   shard_fp32_from_float16_groups: fp32 copy of float16 parameters
        model_float16_groups = []
        model_fp32_groups = []
270
271
272
273
        shard_float16_groups = []
        shard_fp32_groups = []
        shard_fp32_from_float16_groups = []

Lawrence McAfee's avatar
Lawrence McAfee committed
274
        # Allocate (or slice) each group's param shard.
275
276
277
        for group_index, group_range in enumerate(opt_group_ranges):

            # Params of this group.
Lawrence McAfee's avatar
Lawrence McAfee committed
278
279
            model_float16_params_this_group = []
            model_fp32_params_this_group = []
280
281
282
            shard_float16_params_this_group = []
            shard_fp32_params_this_group = []
            shard_fp32_from_float16_params_this_group = []
Lawrence McAfee's avatar
Lawrence McAfee committed
283
284
            model_float16_groups.append(model_float16_params_this_group)
            model_fp32_groups.append(model_fp32_params_this_group)
285
286
287
288
289
290
291
            shard_float16_groups.append(shard_float16_params_this_group)
            shard_fp32_groups.append(shard_fp32_params_this_group)
            shard_fp32_from_float16_groups.append(
                shard_fp32_from_float16_params_this_group)

            for model_param in group_range["params"]:

292
293
                assert model_param.requires_grad

294
295
296
                model_index, dtype = param_gbuf_map[model_param]
                gbuf_range = model_gbuf_ranges[model_index][dtype]
                param_range = gbuf_range["param_map"][model_param]["param"]
297
298

                # fp16, bf16 params.
299
300
301
302
                if model_param.type() in ['torch.cuda.HalfTensor',
                                          'torch.cuda.BFloat16Tensor']:

                    # Clone model -> main.
Lawrence McAfee's avatar
Lawrence McAfee committed
303
304
                    shard_model_param = model_param.detach().view(-1) \
                        [param_range.start:param_range.end]
305
306
307
308
309
310
311
312
313
314
                    shard_main_param = shard_model_param.clone().float()
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_main_param, model_param)
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
                        shard_main_param.shared = model_param.shared

                    # Add to group.
Lawrence McAfee's avatar
Lawrence McAfee committed
315
                    model_float16_params_this_group.append(model_param)
316
317
                    shard_float16_params_this_group.append(shard_model_param)
                    shard_fp32_from_float16_params_this_group.append(shard_main_param)
318
319

                # fp32 params.
320
                elif model_param.type() == 'torch.cuda.FloatTensor':
Lawrence McAfee's avatar
Lawrence McAfee committed
321
322
                    shard_model_param = model_param.view(-1) \
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
323
                    model_fp32_params_this_group.append(model_param)
324
                    shard_fp32_params_this_group.append(shard_model_param)
325
326
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
327
328
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
329
330
331
332
333
334
335
336

                else:
                    raise TypeError('Wrapped parameters must be one of '
                                    'torch.cuda.FloatTensor,  '
                                    'torch.cuda.HalfTensor, or '
                                    'torch.cuda.BFloat16Tensor. '
                                    'Received {}'.format(param.type()))

Lawrence McAfee's avatar
Lawrence McAfee committed
337
            # Update optimizer's params.
338
339
340
341
342
343
            group_range["orig_group"]["params"] = [
                *shard_fp32_params_this_group,
                *shard_fp32_from_float16_params_this_group,
            ]

        return (
Lawrence McAfee's avatar
Lawrence McAfee committed
344
345
            model_float16_groups,
            model_fp32_groups,
346
347
348
349
            shard_float16_groups,
            shard_fp32_groups,
            shard_fp32_from_float16_groups,
        )
350

Lawrence McAfee's avatar
Lawrence McAfee committed
351

352
353
    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
354
                 fp16, bf16, params_dtype, grad_scaler, models):
355
356
        """
        See top of class definition for argument descriptions.
357
358
359
360
361
362

        The steps in this method create the core mapping between DDP grad
        buffers, parameters, and parameter shard ranges, that is needed for
        converting between model param indexes and main parameter shard
        indexes. This method also updates the optimizer parameter groups
        with the newly created shards.
363
        """
364
365
366
367

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
368
            fp16, bf16, params_dtype, grad_scaler, models)
369

370
371
        # Verify that contiguous buffers are being used.
        # - Note: this should already be checked in arguments.py.
372
        assert use_contiguous_buffers_in_local_ddp
373

374
375
        # Model grad buffer ranges.
        self.model_gbuf_ranges = []
376
        for model_index, model in enumerate(self.models):
377
378
379
            self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model))
        self.model_param_gbuf_map = \
            self.build_model_param_gbuf_map(self.model_gbuf_ranges)
380

381
382
        # Optimizer ranges.
        self.opt_group_ranges = self.build_optimizer_group_ranges(
383
            self.optimizer.param_groups,
384
            self.model_gbuf_ranges)
385
386

        # Allocate main param shards.
387
        (
Lawrence McAfee's avatar
Lawrence McAfee committed
388
389
            self.model_float16_groups,
            self.model_fp32_groups,
390
391
392
393
394
395
396
            self.shard_float16_groups,
            self.shard_fp32_groups,
            self.shard_fp32_from_float16_groups,
        ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges,
                                                   self.model_param_gbuf_map,
                                                   self.opt_group_ranges)

397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
        # Initialize param buffers.
        # - These are views on the DDP model's grad buffers, that share
        #   storage & have their own dtype. This is safe because the param
        #   dtype size is always <= grad dtype size.
        self.param_buffers = []
        for model_index, model in enumerate(self.models):
            current_param_buffers = {}
            for dtype, grad_buffer in model._grad_buffers.items():
                param_buffer = torch.tensor(grad_buffer.data.storage()._untyped(),
                                            dtype = params_dtype,
                                            device = grad_buffer.data.device)
                param_buffer = param_buffer[:grad_buffer.numel_padded]
                current_param_buffers[dtype] = param_buffer
            self.param_buffers.append(current_param_buffers)

412
413
414
415
        # Update optimizer groups.
        # - Also, leverage state_dict() and load_state_dict() to
        #   recast preexisting per-param state tensors.
        self.optimizer.param_groups = \
416
            [ g["orig_group"] for g in self.opt_group_ranges ]
417
418
        self.optimizer.load_state_dict(self.optimizer.state_dict())

419

420
    def get_model_param_range_map(self, param):
421
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
422
423
        Given a model param, get the index sub-range of the param that this
        data-parallel rank owns.
424
        """
425
426
427
428
429
430
        model_index, dtype = self.model_param_gbuf_map[param]
        gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
        param_range_map = gbuf_range_map["param_map"][param]
        return param_range_map


431
    def get_model_parallel_group(self):
432
433
434
435
        """
        With the distributed optimizer, the model parallel group is the
        entire world.
        """
436
437
        return None

438
439

    def state_dict(self):
440
441
442
        """
        The state dict must contain the fp32-from-float16 shards.
        """
443
444
445
446
447
448
449
450
451
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['shard_fp32_from_float16_groups'] = \
            self.shard_fp32_from_float16_groups
        return state_dict


452
    def load_state_dict(self, state_dict):
453
454
455
        """
        Load the state dict.
        """
456
457
458
459
460
461
462
463
464
465
466

        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
467
468
469
            if self.fp16:
                print_rank_0('***WARNING*** found an old checkpoint, will not '
                             'load grad scaler ...')
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        else:
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')

        # Copy data for the main params.
        for current_group, saved_group in zip(
                self.shard_fp32_from_float16_groups,
                state_dict["shard_fp32_from_float16_groups"]):
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)
Lawrence McAfee's avatar
Lawrence McAfee committed
484

485

486
    def zero_grad(self, set_to_none=True):
487
488
489
490
491
492
        """
        Zero grads.

        We only need to zero the model related parameters, i.e.,
        model_float16_groups & model_fp32_groups. We additionally zero
        the remaining groups as a memory optimization to reduce
493
        fragmentation; in the case of set_to_none==True, the space
494
495
        used by this field can be safely deallocated at this point.
        """
496
        for groups in (
Lawrence McAfee's avatar
Lawrence McAfee committed
497
498
                self.model_float16_groups,
                self.model_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
499
                self.shard_float16_groups, # grad empty/unused here?
500
                self.shard_fp32_groups, # throws grad-access warning
501
502
503
                self.shard_fp32_from_float16_groups):
            for group in groups:
                _zero_grad_group_helper(group, set_to_none)
504

505

506
507
    @staticmethod
    def get_model_buffer_dp_views(model_buffers):
508
        """
509
        Get shard views of each of the DDP's param/grad buffers.
510
511

        In this nested list, the top level is grouped by the virtual model
512
513
514
        index and the buffer's data type. The sub-level is a list of
        shards of that buffer, where each shard in the list represents
        a contiguous view of the buffer, that is owned by a data-parallel
515
516
517
518
        rank. The shard boundary does not respect parameter boundaries, and
        so the elements of some parameters are split across data parallel
        ranks.

519
        Additionally, return references to the entire buffers, for use
520
521
        in _reduce_scatter_base and _all_gather_base.
        """
522
523
524

        data_parallel_world_size = mpu.get_data_parallel_world_size()

525
526
527
528
529
530
531
532
533
534
        # Buffer views.
        view_items = []
        for model_index, buffers in enumerate(model_buffers):
            for dtype, buf in buffers.items():

                assert buf.numel() % data_parallel_world_size == 0
                shard_size = int(buf.numel() / data_parallel_world_size)
                buf_views = [buf[(r*shard_size):((r+1)*shard_size)]
                             for r in range(data_parallel_world_size)]
                view_items.append((model_index, dtype, buf, buf_views))
535

536
        return view_items
537

538
539
540
541
542
543
544
545
546
547

    def get_model_grad_buffer_dp_views(self):
        return self.get_model_buffer_dp_views([
            {dtype : mem_buffer.data}
            for model in self.models
            for dtype, mem_buffer in model._grad_buffers.items()])


    def get_model_param_buffer_dp_views(self):
        return self.get_model_buffer_dp_views(self.param_buffers)
548

Lawrence McAfee's avatar
Lawrence McAfee committed
549

550
    def reduce_model_grads(self, args, timers):
551
        """
552
553
554
555
556
        Reduce-scatter model grads.

        The DDP's grad buffer is used for the reduce-scatter, and thus no
        tensors are dynamically allocated.

557
        Note: this is a different order of reduction, versus the non-
558
559
        distributed optimizer, which reduces: 1) layernorm grads, 2) all
        grads, 3) embedding grads.
560
        """
561

562
        # All-reduce layer-norm grads (for sequence parallelism).
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
563
564
        timers('layernorm-grads-all-reduce', log_level=1).start(
            barrier=args.barrier_with_L1_time)
565
        self.allreduce_layernorm_grads(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
566
        timers('layernorm-grads-all-reduce').stop()
567

568
        # All-reduce embedding grads.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
569
570
        timers('embedding-grads-all-reduce', log_level=1).start(
            barrier=args.barrier_with_L1_time)
571
        self.allreduce_embedding_grads(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
572
        timers('embedding-grads-all-reduce').stop()
573

574
        # Reduce-scatter setup.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
575
576
        timers('grads-reduce-scatter', log_level=1).start(
            barrier=args.barrier_with_L1_time)
577
578
579
580
        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()
        data_parallel_group = mpu.get_data_parallel_group()

581
582
583
584
585
        # Scale grad buffers by '1 / data_parallel_world_size'.
        for model in self.models:
            for dtype, gbuf in model._grad_buffers.items():
                gbuf.data /= data_parallel_world_size

586
        # Reduce-scatter all grads.
587
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
588
589
590
        for index, (model_index, dtype, gbuf, gbuf_views) \
            in enumerate(gbuf_view_items):

591
            torch.distributed._reduce_scatter_base(
592
                gbuf_views[data_parallel_rank],
593
                gbuf,
594
595
                group = data_parallel_group,
            )
596

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
597
        timers('grads-reduce-scatter').stop()
598

Lawrence McAfee's avatar
Lawrence McAfee committed
599

600
    def gather_model_params(self, args, timers):
601
602
603
        """
        All-gather updated model params.

604
        The DDP's param buffer is used for the all-gather, and thus no
605
        tensors are dynamically allocated. After the all-gather, the params
606
        can be copied from the param buffer to the param.
607
        """
608

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
609
610
        timers('params-all-gather', log_level=1).start(
            barrier=args.barrier_with_L1_time)
611
612
613
614
615

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_group = mpu.get_data_parallel_group()

        # All-gather updated main params.
616
617
618
619
620
621
622
623
        # - All param buffer views are guaranteed to have the same num elements
        #   across all data parallel ranks, due to grad buffer padding that is
        #   done in distributed.py, and extended to the param buffers. Thus,
        #   all sub-views will have consistent start/end indexes across data
        #   parallel ranks.
        pbuf_view_items = self.get_model_param_buffer_dp_views()
        for index, (model_index, dtype, pbuf, pbuf_views) \
            in enumerate(pbuf_view_items):
624

625
            torch.distributed._all_gather_base(
626
627
                pbuf,
                pbuf_views[data_parallel_rank],
628
                group = data_parallel_group,
629
            )
630

631
632
        # Copy from param buffer to each param.
        for model_id, model in enumerate(self.models):
633
            for dtype, param_map in model._grad_buffer_param_index_map.items():
634
635
636
637
                for param, buf_range in param_map.items():
                    param_buf = self.param_buffers[model_id][dtype]
                    param_buf_shard = param_buf[buf_range[0]:buf_range[1]]
                    param.view(-1).detach().copy_(param_buf_shard)
638

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
639
        timers('params-all-gather').stop()
640

Lawrence McAfee's avatar
Lawrence McAfee committed
641

642
    def _collect_main_grad_data_for_unscaling(self):
643
644
645
646
        """
        Note: this should be equivalent to the float-16 optimizer's method,
        but writtent differently, so the two should be combined.
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
647
        return [
Lawrence McAfee's avatar
Lawrence McAfee committed
648
649
650
651
            param.grad.data
            for group in self.optimizer.param_groups
            for param in group["params"]
        ]
652
653


Lawrence McAfee's avatar
Lawrence McAfee committed
654
    def _get_model_and_main_params_data_float16(self):
655
656
657
        """
        Get aligned list of model and main params.
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
658
659
660
661
662
663
664
665
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.shard_float16_groups,
                                           self.shard_fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
666
667


668
    def _copy_model_grads_to_main_grads(self):
669
670
671
672
673
674
675
        """
        Copy model grads to main grads.

        Since this step follows a reduce-scatter through the DDP's grad
        buffer, this method is responsible for copying the updated grads
        from the grad buffer to the main shard's grad field.
        """
676

677
        # Utility method for copying group grads.
Lawrence McAfee's avatar
Lawrence McAfee committed
678
679
680
681
682
        def copy_group_grads(model_groups, shard_main_groups):
            for model_group, shard_main_group in zip(model_groups,
                                                     shard_main_groups):
                for model_param, shard_main_param in zip(model_group,
                                                         shard_main_group):
Lawrence McAfee's avatar
Lawrence McAfee committed
683

Lawrence McAfee's avatar
Lawrence McAfee committed
684
                    param_range_map = self.get_model_param_range_map(model_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
685
                    param_range = param_range_map["param"]
Lawrence McAfee's avatar
Lawrence McAfee committed
686
687
                    assert param_range.size == shard_main_param.nelement()

Lawrence McAfee's avatar
Lawrence McAfee committed
688
689
                    model_grad = model_param.main_grad
                    shard_model_grad = model_grad.view(-1) \
Lawrence McAfee's avatar
Lawrence McAfee committed
690
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
691
692
                    shard_main_param.grad = shard_model_grad.float()

693
        # Copy model groups to shard groups.
Lawrence McAfee's avatar
Lawrence McAfee committed
694
        copy_group_grads(self.model_float16_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
695
                         self.shard_fp32_from_float16_groups)
Lawrence McAfee's avatar
Lawrence McAfee committed
696
        copy_group_grads(self.model_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
697
                         self.shard_fp32_groups)
698

699
700

    def _copy_main_params_to_model_params(self):
701
702
703
704
705
706
707
        """
        Copy main params to model params.

        Since this step is followed by an all-gather through the DDP's grad
        buffer, this method is responsible for copying the updated params
        from the main shards into the correct position in the grad buffer.
        """
708

709
        # Utility method for copying group params.
Lawrence McAfee's avatar
Lawrence McAfee committed
710
711
712
713
714
        def copy_group_params(shard_main_groups, model_groups):
            for shard_main_group, model_group in zip(shard_main_groups,
                                                     model_groups):
                for shard_main_param, model_param in zip(shard_main_group,
                                                         model_group):
Lawrence McAfee's avatar
Lawrence McAfee committed
715

Lawrence McAfee's avatar
Lawrence McAfee committed
716
                    param_range_map = self.get_model_param_range_map(model_param)
717
                    world_range = param_range_map["gbuf_world"]
Lawrence McAfee's avatar
Lawrence McAfee committed
718

719
720
721
722
723
724
725
                    assert world_range.size == shard_main_param.nelement()

                    model_id, dtype = self.model_param_gbuf_map[model_param]
                    model_param_buffer = self.param_buffers[model_id][dtype]

                    shard_model_param = model_param_buffer.view(-1) \
                        [world_range.start:world_range.end]
726

727
                    shard_model_param.data.copy_(shard_main_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
728

729
        # Copy shard groups to model groups.
Lawrence McAfee's avatar
Lawrence McAfee committed
730
        copy_group_params(self.shard_fp32_from_float16_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
731
                          self.model_float16_groups)
Lawrence McAfee's avatar
Lawrence McAfee committed
732
        copy_group_params(self.shard_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
733
                          self.model_fp32_groups)