distrib_optimizer.py 28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron distributed optimizer."""


import math
20
import torch
21
22

from megatron import get_args
23
24
from megatron import get_timers
from megatron import mpu
25
26
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
27
28

from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
29

Lawrence McAfee's avatar
Lawrence McAfee committed
30

31
class Range:
32
33
    """
    A range represents a start and end points for indexing a shard
Lawrence McAfee's avatar
Lawrence McAfee committed
34
    from a full tensor.
35
    """
36
37
38
39
40
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.size = end - start
    def normalize(self, start = 0):
41
        return Range(start, start + self.size)
42
43
44
45
    def __str__(self):
        return "%d,%d [%d]" % (self.start, self.end, self.size)


46
class DistributedOptimizer(MixedPrecisionOptimizer):
47
    """Distributed optimizer, for all data types (fp16, bf16, and fp32).
Lawrence McAfee's avatar
Lawrence McAfee committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
74
    """
75
76

    @classmethod
77
    def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        """
        Build mapping from param reference to grad buffer shard ranges.

        This method builds a mapping from parameter references to grad
        buffer shard ranges, specific to each data-parallel (DP) rank's
        set of 'owned' parameters. Each grad buffer (padded to be an even
        multiple of DP-world-size) is conceptually divided into DP-world-size
        contiguous regions, where each DP rank 'owns' a contiguous regions.
        Ownership in this sense means DP rank is responsible for reducing
        the relevant subset of grads, and updating the relevant subset of
        params.

        This conceptual partitioning of the grad buffer does NOT respect
        parameter boundaries, and as such it is assumed that each created
        range references a shard (or subset) of the full parameter. It is
        easiest to think of each DP rank as operating (i.e., reducing,
        gathering) purely on views into the grad buffer, for all model-to-
        main & main-to-model operations.

        This method creates three ranges:
        - The param's range within the entire grad buffer (i.e., world index).
        - The param's range within the DP rank's local view of the grad buffer.
        - The param's range within itself (i.e., its shard).
        """
102

103
        # Param range map.
104
        param_world_index_map = model._grad_buffer_param_index_map[dtype]
105
        param_range_map = {}
106
107
        for param, param_world_indexes in param_world_index_map.items():

108
            # Param range.
109
110
111
            param_world_start, param_world_end = param_world_indexes
            param_local_start = max(
                0,
112
                param_world_start - gbuf_world_range.start)
113
            param_local_end = min(
114
115
                gbuf_world_range.size,
                param_world_end - gbuf_world_range.start)
116

117
            # Add param, if within local gbuf range.
118
            if param_local_end > param_local_start:
119
120
121
122
123
124
125
126
127
                param_local_range = Range(param_local_start, param_local_end)
                param_world_range = param_local_range.normalize(
                    param_local_start + gbuf_world_range.start)
                sub_param_start = max(0, gbuf_world_range.start-param_world_start)
                sub_param_range = param_local_range.normalize(sub_param_start)
                param_range_map[param] = {
                    "gbuf_world" : param_world_range,
                    "gbuf_local" : param_local_range,
                    "param" : sub_param_range,
128
129
                }

130
        return param_range_map
131

Lawrence McAfee's avatar
Lawrence McAfee committed
132

133
    @classmethod
134
    def build_model_gbuf_range(cls, model, dtype):
135
136
137
138
139
140
141
142
143
        """
        Build mapping between params and their grad buffers.

        This method does the initial setup for the method above. This setup
        includes determining the shard ranges into the DDP's grad buffer for
        each data-parallel (DP) rank. Each DP rank keeps range info for
        all other DP ranks, for the purpose of creating args for
        reduce-scatter and all-gather.
        """
144
145
146
147

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()

148
        # Grad buffer range.
149
150
        grad_buffer = model._grad_buffers[dtype]
        gbuf_size = grad_buffer.numel
151
        max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size))
152

153
154
        # All world ranges. (i.e., across all data parallel ranks)
        gbuf_world_all_ranges = []
155
        for r in range(data_parallel_world_size):
156
157
158
159
            gbuf_world_start = r * max_gbuf_range_size
            gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size)
            gbuf_world_range = Range(gbuf_world_start, gbuf_world_end)
            gbuf_world_all_ranges.append(gbuf_world_range)
160

161
162
163
        # Local DP's ranges.
        gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
        gbuf_local_range = gbuf_world_range.normalize()
164

165
166
167
168
        # Get each param's ranges.
        param_range_map = cls.build_model_gbuf_param_range_map(model,
                                                               dtype,
                                                               gbuf_world_range)
169

170
        # Group into dict.
171
        data = {
172
173
174
175
176
            "local" : gbuf_local_range,
            "world" : gbuf_world_range,
            "world_all" : gbuf_world_all_ranges,
            "param_map" : param_range_map,
            "max_range_size" : max_gbuf_range_size,
177
178
179
180
        }

        return data

Lawrence McAfee's avatar
Lawrence McAfee committed
181

182
    @classmethod
183
    def build_model_gbuf_range_map(cls, model):
184
185
186
187
        """
        Create param-to-grad-buffer mappings, for grad buffer data types
        within a specific virtual model.
        """
188
        return {
189
            dtype : cls.build_model_gbuf_range(model, dtype)
190
191
192
            for dtype in model._grad_buffers
        }

Lawrence McAfee's avatar
Lawrence McAfee committed
193

194
    @classmethod
195
    def build_model_param_gbuf_map(cls, model_gbuf_ranges):
196
197
198
199
        """
        Create a reverse of the model_gbuf_ranges, for referencing in
        opposite direction.
        """
200
        param_gbuf_map = {}
201
202
203
        for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param, param_range_map in gbuf_range_map["param_map"].items():
204
205
206
                    param_gbuf_map[param] = (model_index, dtype)
        return param_gbuf_map

Lawrence McAfee's avatar
Lawrence McAfee committed
207

208
    @classmethod
209
    def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
210
211
212
213
214
215
216
217
        """
        Create optimizer groups.

        Given the set of parameter shard ranges that are owned by the current
        data-parallel (DP) rank, gather the set of parameters that will be
        used (in the method below) to create the current DP's optimizer
        groups.
        """
218
219
220
221
222
223
224
225
226
227

        num_groups = len(param_groups)

        # Param group map.
        param_group_map = {}
        for group_index, group in enumerate(param_groups):
            for param in group["params"]:
                assert param.requires_grad
                param_group_map[param] = group_index

228
229
230
231
232
        # Optimizer group ranges.
        group_ranges = [ {"params": []} for _ in param_groups ]
        for model_gbuf_range_map in model_gbuf_ranges:
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param in gbuf_range_map["param_map"]:
233
                    group_index = param_group_map[param]
234
235
                    group_range = group_ranges[group_index]
                    group_range["params"].append(param)
236

237
238
239
240
        # Squeeze zero-size group ranges.
        for group_index, group_range in enumerate(group_ranges):
            group_range["orig_group"] = param_groups[group_index]
        group_ranges = [ g for g in group_ranges if len(g["params"]) > 0 ]
241

242
        return group_ranges
243

244

245
    @classmethod
246
    def build_model_and_main_param_groups(cls,
247
248
249
250
251
252
253
254
255
256
257
258
259
                                          model_gbuf_ranges,
                                          param_gbuf_map,
                                          opt_group_ranges):
        """
        Create main parameter groups needed for the optimizer step.

        These groups encompass both: 1) groups used by this class, for
        reducing/gather, and 2) groups used by the inner optimizer for the
        parameter update. Given that the conceptual grad buffer partitioning
        (created in earlier method) doesn't respect parameter boundaries,
        the optimizer operates on shards of the model parameters, rather than
        the full parameters.
        """
260

Lawrence McAfee's avatar
Lawrence McAfee committed
261
262
263
264
265
266
267
268
        # Parameter groups:
        #   model_float16_groups: original float16 parameters
        #   model_fp32_groups: original fp32 parameters
        #   shard_float16_groups: shards of original float16 parameters
        #   shard_fp32_groups: shards of original fp32 parameters
        #   shard_fp32_from_float16_groups: fp32 copy of float16 parameters
        model_float16_groups = []
        model_fp32_groups = []
269
270
271
272
        shard_float16_groups = []
        shard_fp32_groups = []
        shard_fp32_from_float16_groups = []

Lawrence McAfee's avatar
Lawrence McAfee committed
273
        # Allocate (or slice) each group's param shard.
274
275
276
        for group_index, group_range in enumerate(opt_group_ranges):

            # Params of this group.
Lawrence McAfee's avatar
Lawrence McAfee committed
277
278
            model_float16_params_this_group = []
            model_fp32_params_this_group = []
279
280
281
            shard_float16_params_this_group = []
            shard_fp32_params_this_group = []
            shard_fp32_from_float16_params_this_group = []
Lawrence McAfee's avatar
Lawrence McAfee committed
282
283
            model_float16_groups.append(model_float16_params_this_group)
            model_fp32_groups.append(model_fp32_params_this_group)
284
285
286
287
288
289
290
            shard_float16_groups.append(shard_float16_params_this_group)
            shard_fp32_groups.append(shard_fp32_params_this_group)
            shard_fp32_from_float16_groups.append(
                shard_fp32_from_float16_params_this_group)

            for model_param in group_range["params"]:

291
292
                assert model_param.requires_grad

293
294
295
                model_index, dtype = param_gbuf_map[model_param]
                gbuf_range = model_gbuf_ranges[model_index][dtype]
                param_range = gbuf_range["param_map"][model_param]["param"]
296
297

                # fp16, bf16 params.
298
299
300
301
                if model_param.type() in ['torch.cuda.HalfTensor',
                                          'torch.cuda.BFloat16Tensor']:

                    # Clone model -> main.
Lawrence McAfee's avatar
Lawrence McAfee committed
302
303
                    shard_model_param = model_param.detach().view(-1) \
                        [param_range.start:param_range.end]
304
305
306
307
308
309
310
311
312
313
                    shard_main_param = shard_model_param.clone().float()
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_main_param, model_param)
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
                        shard_main_param.shared = model_param.shared

                    # Add to group.
Lawrence McAfee's avatar
Lawrence McAfee committed
314
                    model_float16_params_this_group.append(model_param)
315
316
                    shard_float16_params_this_group.append(shard_model_param)
                    shard_fp32_from_float16_params_this_group.append(shard_main_param)
317
318

                # fp32 params.
319
                elif model_param.type() == 'torch.cuda.FloatTensor':
Lawrence McAfee's avatar
Lawrence McAfee committed
320
321
                    shard_model_param = model_param.view(-1) \
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
322
                    model_fp32_params_this_group.append(model_param)
323
                    shard_fp32_params_this_group.append(shard_model_param)
324
325
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
326
327
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
328
329
330
331
332
333
334
335

                else:
                    raise TypeError('Wrapped parameters must be one of '
                                    'torch.cuda.FloatTensor,  '
                                    'torch.cuda.HalfTensor, or '
                                    'torch.cuda.BFloat16Tensor. '
                                    'Received {}'.format(param.type()))

Lawrence McAfee's avatar
Lawrence McAfee committed
336
            # Update optimizer's params.
337
338
339
340
341
342
            group_range["orig_group"]["params"] = [
                *shard_fp32_params_this_group,
                *shard_fp32_from_float16_params_this_group,
            ]

        return (
Lawrence McAfee's avatar
Lawrence McAfee committed
343
344
            model_float16_groups,
            model_fp32_groups,
345
346
347
348
            shard_float16_groups,
            shard_fp32_groups,
            shard_fp32_from_float16_groups,
        )
349

Lawrence McAfee's avatar
Lawrence McAfee committed
350

351
352
    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
353
                 fp16, bf16, grad_scaler, models):
354
355
        """
        See top of class definition for argument descriptions.
356
357
358
359
360
361

        The steps in this method create the core mapping between DDP grad
        buffers, parameters, and parameter shard ranges, that is needed for
        converting between model param indexes and main parameter shard
        indexes. This method also updates the optimizer parameter groups
        with the newly created shards.
362
        """
363
364
365
366

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
367
            fp16, bf16, grad_scaler, models)
368

369
370
        # Verify that contiguous buffers are being used.
        # - Note: this should already be checked in arguments.py.
371
        assert use_contiguous_buffers_in_local_ddp
372

373
374
        # Model grad buffer ranges.
        self.model_gbuf_ranges = []
375
        for model_index, model in enumerate(self.models):
376
377
378
            self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model))
        self.model_param_gbuf_map = \
            self.build_model_param_gbuf_map(self.model_gbuf_ranges)
379

380
381
        # Optimizer ranges.
        self.opt_group_ranges = self.build_optimizer_group_ranges(
382
            self.optimizer.param_groups,
383
            self.model_gbuf_ranges)
384
385

        # Allocate main param shards.
386
        (
Lawrence McAfee's avatar
Lawrence McAfee committed
387
388
            self.model_float16_groups,
            self.model_fp32_groups,
389
390
391
392
393
394
395
            self.shard_float16_groups,
            self.shard_fp32_groups,
            self.shard_fp32_from_float16_groups,
        ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges,
                                                   self.model_param_gbuf_map,
                                                   self.opt_group_ranges)

396
397
398
399
        # Update optimizer groups.
        # - Also, leverage state_dict() and load_state_dict() to
        #   recast preexisting per-param state tensors.
        self.optimizer.param_groups = \
400
            [ g["orig_group"] for g in self.opt_group_ranges ]
401
402
        self.optimizer.load_state_dict(self.optimizer.state_dict())

403

404
    def get_model_param_range_map(self, param):
405
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
406
407
        Given a model param, get the index sub-range of the param that this
        data-parallel rank owns.
408
        """
409
410
411
412
413
414
        model_index, dtype = self.model_param_gbuf_map[param]
        gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
        param_range_map = gbuf_range_map["param_map"][param]
        return param_range_map


415
    def get_model_parallel_group(self):
416
417
418
419
        """
        With the distributed optimizer, the model parallel group is the
        entire world.
        """
420
421
        return None

422
423

    def state_dict(self):
424
425
426
        """
        The state dict must contain the fp32-from-float16 shards.
        """
427
428
429
430
431
432
433
434
435
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['shard_fp32_from_float16_groups'] = \
            self.shard_fp32_from_float16_groups
        return state_dict


436
    def load_state_dict(self, state_dict):
437
438
439
        """
        Load the state dict.
        """
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466

        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
            print_rank_0('***WARNING*** found an old checkpoint, will not '
                         'load grad scaler ...')
        else:
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')

        # Copy data for the main params.
        for current_group, saved_group in zip(
                self.shard_fp32_from_float16_groups,
                state_dict["shard_fp32_from_float16_groups"]):
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)
Lawrence McAfee's avatar
Lawrence McAfee committed
467

468

469
    def zero_grad(self, set_to_none=True):
470
471
472
473
474
475
        """
        Zero grads.

        We only need to zero the model related parameters, i.e.,
        model_float16_groups & model_fp32_groups. We additionally zero
        the remaining groups as a memory optimization to reduce
476
        fragmentation; in the case of set_to_none==True, the space
477
478
        used by this field can be safely deallocated at this point.
        """
479
        for groups in (
Lawrence McAfee's avatar
Lawrence McAfee committed
480
481
                self.model_float16_groups,
                self.model_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
482
                self.shard_float16_groups, # grad empty/unused here?
483
                self.shard_fp32_groups, # throws grad-access warning
484
485
486
                self.shard_fp32_from_float16_groups):
            for group in groups:
                _zero_grad_group_helper(group, set_to_none)
487

488

489
    def get_model_grad_buffer_dp_views(self):
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        """
        Get shard views of each of the DDP's grad buffers.

        In this nested list, the top level is grouped by the virtual model
        index and the grad buffer's data type. The sub-level is a list of
        shards of that grad buffer, where each shard in the list represents
        a contiguous view of the grad buffer, that is owned by a data-parallel
        rank. The shard boundary does not respect parameter boundaries, and
        so the elements of some parameters are split across data parallel
        ranks.

        Additionally, return references to the entire grad buffers, for use
        in _reduce_scatter_base and _all_gather_base.
        """
504
505
506
507
508
509
510
511
512
513
514
515

        data_parallel_world_size = mpu.get_data_parallel_world_size()

        # Grad buffer views.
        gbuf_view_items = []
        for model_index, model in enumerate(self.models):
            for dtype, gbuf in model._grad_buffers.items():

                assert gbuf.numel_padded % data_parallel_world_size == 0
                shard_size = int(gbuf.numel_padded / data_parallel_world_size)
                gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
                              for r in range(data_parallel_world_size)]
516
                gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views))
517
518

        return gbuf_view_items
519

Lawrence McAfee's avatar
Lawrence McAfee committed
520

521
    def reduce_model_grads(self, args, timers):
522
        """
523
524
525
526
527
        Reduce-scatter model grads.

        The DDP's grad buffer is used for the reduce-scatter, and thus no
        tensors are dynamically allocated.

528
529
530
531
        Note: this is a different order of reduction, versus the non-
        distributed optimizer, which reduces: 1) all grads, 2) embedding
        grads.
        """
532

533
534
535
536
537
        # All-reduce embedding grads.
        timers('backward-embedding-all-reduce').start()
        self.allreduce_embedding_grads(args)
        timers('backward-embedding-all-reduce').stop()

538
        # Reduce-scatter setup.
539
540
541
542
543
        timers('backward-params-all-reduce').start()
        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()
        data_parallel_group = mpu.get_data_parallel_group()

544
545
546
547
548
        # Scale grad buffers by '1 / data_parallel_world_size'.
        for model in self.models:
            for dtype, gbuf in model._grad_buffers.items():
                gbuf.data /= data_parallel_world_size

549
        # Reduce-scatter all grads.
550
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
551
552
553
        for index, (model_index, dtype, gbuf, gbuf_views) \
            in enumerate(gbuf_view_items):

554
            torch.distributed._reduce_scatter_base(
555
                gbuf_views[data_parallel_rank],
556
                gbuf,
557
558
                group = data_parallel_group,
            )
559

560
        timers('backward-params-all-reduce').stop()
561

Lawrence McAfee's avatar
Lawrence McAfee committed
562

563
    def gather_model_params(self, args, timers):
564
565
566
567
568
569
570
        """
        All-gather updated model params.

        The DDP's grad buffer is used for the all-gather, and thus no
        tensors are dynamically allocated. After the all-gather, the params
        can be copied from param.main_grad to param.
        """
571
572
573
574
575
576
577
578
579
580
581

        timers('backward-params-all-gather').start()

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_group = mpu.get_data_parallel_group()

        # All-gather updated main params.
        # - All grad buffer views are guaranteed to have the same num elements
        #   across all data parallel ranks, with grad buffer padding that is done
        #   in distributed.py. Thus, all sub-views will have consistent start/end
        #   indexes across data parallel ranks.
582
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
583
584
585
        for index, (model_index, dtype, gbuf, gbuf_views) \
            in enumerate(gbuf_view_items):

586
            torch.distributed._all_gather_base(
587
                gbuf,
588
589
                gbuf_views[data_parallel_rank],
                group = data_parallel_group,
590
            )
591
592
593
594
595
596
597
598
599

        # Each model param now contains its updated values in its
        # '.main_grad' field.
        for model in self.models:
            for dtype, param_map in model._grad_buffer_param_index_map.items():
                for param in param_map:
                    param.detach().copy_(param.main_grad)

        timers('backward-params-all-gather').stop()
600

Lawrence McAfee's avatar
Lawrence McAfee committed
601

602
    def _collect_main_grad_data_for_unscaling(self):
603
604
605
606
        """
        Note: this should be equivalent to the float-16 optimizer's method,
        but writtent differently, so the two should be combined.
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
607
        return [
Lawrence McAfee's avatar
Lawrence McAfee committed
608
609
610
611
            param.grad.data
            for group in self.optimizer.param_groups
            for param in group["params"]
        ]
612
613


Lawrence McAfee's avatar
Lawrence McAfee committed
614
    def _get_model_and_main_params_data_float16(self):
615
616
617
        """
        Get aligned list of model and main params.
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
618
619
620
621
622
623
624
625
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.shard_float16_groups,
                                           self.shard_fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
626
627


628
    def _copy_model_grads_to_main_grads(self):
629
630
631
632
633
634
635
        """
        Copy model grads to main grads.

        Since this step follows a reduce-scatter through the DDP's grad
        buffer, this method is responsible for copying the updated grads
        from the grad buffer to the main shard's grad field.
        """
636

637
        # Utility method for copying group grads.
Lawrence McAfee's avatar
Lawrence McAfee committed
638
639
640
641
642
        def copy_group_grads(model_groups, shard_main_groups):
            for model_group, shard_main_group in zip(model_groups,
                                                     shard_main_groups):
                for model_param, shard_main_param in zip(model_group,
                                                         shard_main_group):
Lawrence McAfee's avatar
Lawrence McAfee committed
643

Lawrence McAfee's avatar
Lawrence McAfee committed
644
                    param_range_map = self.get_model_param_range_map(model_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
645
                    param_range = param_range_map["param"]
Lawrence McAfee's avatar
Lawrence McAfee committed
646
647
                    assert param_range.size == shard_main_param.nelement()

Lawrence McAfee's avatar
Lawrence McAfee committed
648
649
                    model_grad = model_param.main_grad
                    shard_model_grad = model_grad.view(-1) \
Lawrence McAfee's avatar
Lawrence McAfee committed
650
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
651
652
                    shard_main_param.grad = shard_model_grad.float()

653
        # Copy model groups to shard groups.
Lawrence McAfee's avatar
Lawrence McAfee committed
654
        copy_group_grads(self.model_float16_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
655
                         self.shard_fp32_from_float16_groups)
Lawrence McAfee's avatar
Lawrence McAfee committed
656
        copy_group_grads(self.model_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
657
                         self.shard_fp32_groups)
658

659
660

    def _copy_main_params_to_model_params(self):
661
662
663
664
665
666
667
        """
        Copy main params to model params.

        Since this step is followed by an all-gather through the DDP's grad
        buffer, this method is responsible for copying the updated params
        from the main shards into the correct position in the grad buffer.
        """
668

669
        # Utility method for copying group params.
Lawrence McAfee's avatar
Lawrence McAfee committed
670
671
672
673
674
        def copy_group_params(shard_main_groups, model_groups):
            for shard_main_group, model_group in zip(shard_main_groups,
                                                     model_groups):
                for shard_main_param, model_param in zip(shard_main_group,
                                                         model_group):
Lawrence McAfee's avatar
Lawrence McAfee committed
675

Lawrence McAfee's avatar
Lawrence McAfee committed
676
                    param_range_map = self.get_model_param_range_map(model_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
677
678
679
                    param_range = param_range_map["param"]
                    assert param_range.size == shard_main_param.nelement()

Lawrence McAfee's avatar
Lawrence McAfee committed
680
681
                    model_grad = model_param.main_grad
                    shard_model_grad = model_grad.view(-1) \
Lawrence McAfee's avatar
Lawrence McAfee committed
682
                        [param_range.start:param_range.end]
683

684
                    shard_model_grad.data.copy_(shard_main_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
685

686
        # Copy shard groups to model groups.
Lawrence McAfee's avatar
Lawrence McAfee committed
687
        copy_group_params(self.shard_fp32_from_float16_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
688
                          self.model_float16_groups)
Lawrence McAfee's avatar
Lawrence McAfee committed
689
        copy_group_params(self.shard_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
690
                          self.model_fp32_groups)