distrib_optimizer.py 22.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron distributed optimizer."""


import math
20
import torch
21
22

from megatron import get_args
23
24
from megatron import get_timers
from megatron import mpu
25
26
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
27
28

from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
29

Lawrence McAfee's avatar
Lawrence McAfee committed
30

31
class Range:
32
33
    """
    A range represents a start and end points for indexing a shard
Lawrence McAfee's avatar
Lawrence McAfee committed
34
    from a full tensor.
35
    """
36
37
38
39
40
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.size = end - start
    def normalize(self, start = 0):
41
        return Range(start, start + self.size)
42
43
44
45
    def __str__(self):
        return "%d,%d [%d]" % (self.start, self.end, self.size)


46
class DistributedOptimizer(MixedPrecisionOptimizer):
47
    """Distributed optimizer, for all data types (fp16, bf16, and fp32).
Lawrence McAfee's avatar
Lawrence McAfee committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
74
    """
75
76

    @classmethod
77
    def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
78

79
        # Param range map.
80
        param_world_index_map = model._grad_buffer_param_index_map[dtype]
81
        param_range_map = {}
82
83
        for param, param_world_indexes in param_world_index_map.items():

84
            # Param range.
85
86
87
            param_world_start, param_world_end = param_world_indexes
            param_local_start = max(
                0,
88
                param_world_start - gbuf_world_range.start)
89
            param_local_end = min(
90
91
                gbuf_world_range.size,
                param_world_end - gbuf_world_range.start)
92

93
            # Add param, if within local gbuf range.
94
            if param_local_end > param_local_start:
95
96
97
98
99
100
101
102
103
                param_local_range = Range(param_local_start, param_local_end)
                param_world_range = param_local_range.normalize(
                    param_local_start + gbuf_world_range.start)
                sub_param_start = max(0, gbuf_world_range.start-param_world_start)
                sub_param_range = param_local_range.normalize(sub_param_start)
                param_range_map[param] = {
                    "gbuf_world" : param_world_range,
                    "gbuf_local" : param_local_range,
                    "param" : sub_param_range,
104
105
                }

106
        return param_range_map
107

Lawrence McAfee's avatar
Lawrence McAfee committed
108

109
    @classmethod
110
    def build_model_gbuf_range(cls, model, dtype):
111
112
113
114

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()

115
        # Grad buffer range.
116
117
        grad_buffer = model._grad_buffers[dtype]
        gbuf_size = grad_buffer.numel
118
        max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size))
119

120
121
        # All world ranges. (i.e., across all data parallel ranks)
        gbuf_world_all_ranges = []
122
        for r in range(data_parallel_world_size):
123
124
125
126
            gbuf_world_start = r * max_gbuf_range_size
            gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size)
            gbuf_world_range = Range(gbuf_world_start, gbuf_world_end)
            gbuf_world_all_ranges.append(gbuf_world_range)
127

128
129
130
        # Local DP's ranges.
        gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
        gbuf_local_range = gbuf_world_range.normalize()
131

132
133
134
135
        # Get each param's ranges.
        param_range_map = cls.build_model_gbuf_param_range_map(model,
                                                               dtype,
                                                               gbuf_world_range)
136
137
138

        # Altogether.
        data = {
139
140
141
142
143
            "local" : gbuf_local_range,
            "world" : gbuf_world_range,
            "world_all" : gbuf_world_all_ranges,
            "param_map" : param_range_map,
            "max_range_size" : max_gbuf_range_size,
144
145
146
147
        }

        return data

Lawrence McAfee's avatar
Lawrence McAfee committed
148

149
    @classmethod
150
    def build_model_gbuf_range_map(cls, model):
151
        return {
152
            dtype : cls.build_model_gbuf_range(model, dtype)
153
154
155
            for dtype in model._grad_buffers
        }

Lawrence McAfee's avatar
Lawrence McAfee committed
156

157
    @classmethod
158
    def build_model_param_gbuf_map(cls, model_gbuf_ranges):
159
160
161
162
        """
        Create a reverse of the model_gbuf_ranges, for referencing in
        opposite direction.
        """
163
        param_gbuf_map = {}
164
165
166
        for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param, param_range_map in gbuf_range_map["param_map"].items():
167
168
169
                    param_gbuf_map[param] = (model_index, dtype)
        return param_gbuf_map

Lawrence McAfee's avatar
Lawrence McAfee committed
170

171
    @classmethod
172
    def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
173
174
175
176
177
178
179
180
181
182

        num_groups = len(param_groups)

        # Param group map.
        param_group_map = {}
        for group_index, group in enumerate(param_groups):
            for param in group["params"]:
                assert param.requires_grad
                param_group_map[param] = group_index

183
184
185
186
187
        # Optimizer group ranges.
        group_ranges = [ {"params": []} for _ in param_groups ]
        for model_gbuf_range_map in model_gbuf_ranges:
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param in gbuf_range_map["param_map"]:
188
                    group_index = param_group_map[param]
189
190
                    group_range = group_ranges[group_index]
                    group_range["params"].append(param)
191

192
193
194
195
        # Squeeze zero-size group ranges.
        for group_index, group_range in enumerate(group_ranges):
            group_range["orig_group"] = param_groups[group_index]
        group_ranges = [ g for g in group_ranges if len(g["params"]) > 0 ]
196

197
        return group_ranges
198

199

200
    @classmethod
201
202
203
204
205
    def build_model_and_main_param_groups(cls,
                                        model_gbuf_ranges,
                                        param_gbuf_map,
                                        opt_group_ranges):

Lawrence McAfee's avatar
Lawrence McAfee committed
206
207
208
209
210
211
212
213
        # Parameter groups:
        #   model_float16_groups: original float16 parameters
        #   model_fp32_groups: original fp32 parameters
        #   shard_float16_groups: shards of original float16 parameters
        #   shard_fp32_groups: shards of original fp32 parameters
        #   shard_fp32_from_float16_groups: fp32 copy of float16 parameters
        model_float16_groups = []
        model_fp32_groups = []
214
215
216
217
        shard_float16_groups = []
        shard_fp32_groups = []
        shard_fp32_from_float16_groups = []

Lawrence McAfee's avatar
Lawrence McAfee committed
218
        # Allocate (or slice) each group's param shard.
219
220
221
        for group_index, group_range in enumerate(opt_group_ranges):

            # Params of this group.
Lawrence McAfee's avatar
Lawrence McAfee committed
222
223
            model_float16_params_this_group = []
            model_fp32_params_this_group = []
224
225
226
            shard_float16_params_this_group = []
            shard_fp32_params_this_group = []
            shard_fp32_from_float16_params_this_group = []
Lawrence McAfee's avatar
Lawrence McAfee committed
227
228
            model_float16_groups.append(model_float16_params_this_group)
            model_fp32_groups.append(model_fp32_params_this_group)
229
230
231
232
233
234
235
            shard_float16_groups.append(shard_float16_params_this_group)
            shard_fp32_groups.append(shard_fp32_params_this_group)
            shard_fp32_from_float16_groups.append(
                shard_fp32_from_float16_params_this_group)

            for model_param in group_range["params"]:

236
237
                assert model_param.requires_grad

238
239
240
                model_index, dtype = param_gbuf_map[model_param]
                gbuf_range = model_gbuf_ranges[model_index][dtype]
                param_range = gbuf_range["param_map"][model_param]["param"]
241
242

                # fp16, bf16 params.
243
244
245
246
                if model_param.type() in ['torch.cuda.HalfTensor',
                                          'torch.cuda.BFloat16Tensor']:

                    # Clone model -> main.
Lawrence McAfee's avatar
Lawrence McAfee committed
247
248
                    shard_model_param = model_param.detach().view(-1) \
                        [param_range.start:param_range.end]
249
250
251
252
253
254
255
256
257
258
                    shard_main_param = shard_model_param.clone().float()
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_main_param, model_param)
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
                        shard_main_param.shared = model_param.shared

                    # Add to group.
Lawrence McAfee's avatar
Lawrence McAfee committed
259
                    model_float16_params_this_group.append(model_param)
260
261
                    shard_float16_params_this_group.append(shard_model_param)
                    shard_fp32_from_float16_params_this_group.append(shard_main_param)
262
263

                # fp32 params.
264
                elif model_param.type() == 'torch.cuda.FloatTensor':
Lawrence McAfee's avatar
Lawrence McAfee committed
265
266
                    shard_model_param = model_param.view(-1) \
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
267
                    model_fp32_params_this_group.append(model_param)
268
                    shard_fp32_params_this_group.append(shard_model_param)
269
270
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
271
272
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
273
274
275
276
277
278
279
280

                else:
                    raise TypeError('Wrapped parameters must be one of '
                                    'torch.cuda.FloatTensor,  '
                                    'torch.cuda.HalfTensor, or '
                                    'torch.cuda.BFloat16Tensor. '
                                    'Received {}'.format(param.type()))

Lawrence McAfee's avatar
Lawrence McAfee committed
281
            # Update optimizer's params.
282
283
284
285
286
287
            group_range["orig_group"]["params"] = [
                *shard_fp32_params_this_group,
                *shard_fp32_from_float16_params_this_group,
            ]

        return (
Lawrence McAfee's avatar
Lawrence McAfee committed
288
289
            model_float16_groups,
            model_fp32_groups,
290
291
292
293
            shard_float16_groups,
            shard_fp32_groups,
            shard_fp32_from_float16_groups,
        )
294

Lawrence McAfee's avatar
Lawrence McAfee committed
295

296
297
    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
298
                 fp16, bf16, grad_scaler, models):
299
300
301
302

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
303
            fp16, bf16, grad_scaler, models)
304

305
306
        # Verify that contiguous buffers are being used
        # - Note: this should already be checked in arguments.py
307
        assert use_contiguous_buffers_in_local_ddp
308

309
310
        # Model grad buffer ranges.
        self.model_gbuf_ranges = []
311
        for model_index, model in enumerate(self.models):
312
313
314
            self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model))
        self.model_param_gbuf_map = \
            self.build_model_param_gbuf_map(self.model_gbuf_ranges)
315

316
317
        # Optimizer ranges.
        self.opt_group_ranges = self.build_optimizer_group_ranges(
318
            self.optimizer.param_groups,
319
            self.model_gbuf_ranges)
320
321

        # Allocate main param shards.
322
        (
Lawrence McAfee's avatar
Lawrence McAfee committed
323
324
            self.model_float16_groups,
            self.model_fp32_groups,
325
326
327
328
329
330
331
            self.shard_float16_groups,
            self.shard_fp32_groups,
            self.shard_fp32_from_float16_groups,
        ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges,
                                                   self.model_param_gbuf_map,
                                                   self.opt_group_ranges)

332
333
334
335
        # Update optimizer groups.
        # - Also, leverage state_dict() and load_state_dict() to
        #   recast preexisting per-param state tensors.
        self.optimizer.param_groups = \
336
            [ g["orig_group"] for g in self.opt_group_ranges ]
337
338
        self.optimizer.load_state_dict(self.optimizer.state_dict())

339

340
    def get_model_param_range_map(self, param):
341
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
342
343
        Given a model param, get the index sub-range of the param that this
        data-parallel rank owns.
344
        """
345
346
347
348
349
350
        model_index, dtype = self.model_param_gbuf_map[param]
        gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
        param_range_map = gbuf_range_map["param_map"][param]
        return param_range_map


351
    def get_model_parallel_group(self):
352
353
354
355
        """
        With the distributed optimizer, the model parallel group is the
        entire world.
        """
356
357
        return None

358
359

    def state_dict(self):
360
361
362
        """
        The state dict must contain the fp32-from-float16 shards.
        """
363
364
365
366
367
368
369
370
371
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['shard_fp32_from_float16_groups'] = \
            self.shard_fp32_from_float16_groups
        return state_dict


372
    def load_state_dict(self, state_dict):
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399

        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
            print_rank_0('***WARNING*** found an old checkpoint, will not '
                         'load grad scaler ...')
        else:
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')

        # Copy data for the main params.
        for current_group, saved_group in zip(
                self.shard_fp32_from_float16_groups,
                state_dict["shard_fp32_from_float16_groups"]):
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)
Lawrence McAfee's avatar
Lawrence McAfee committed
400

401

402
403
404
405
406
407
408
    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
        float16_groups & fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
        for groups in (
Lawrence McAfee's avatar
Lawrence McAfee committed
409
410
                self.model_float16_groups,
                self.model_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
411
                self.shard_float16_groups, # grad empty/unused here?
412
                self.shard_fp32_groups, # throws grad-access warning
413
414
415
                self.shard_fp32_from_float16_groups):
            for group in groups:
                _zero_grad_group_helper(group, set_to_none)
416

417

418
419
420
421
422
423
424
425
426
427
428
429
430
    def get_model_grad_buffer_dp_views(self):

        data_parallel_world_size = mpu.get_data_parallel_world_size()

        # Grad buffer views.
        gbuf_view_items = []
        for model_index, model in enumerate(self.models):
            for dtype, gbuf in model._grad_buffers.items():

                assert gbuf.numel_padded % data_parallel_world_size == 0
                shard_size = int(gbuf.numel_padded / data_parallel_world_size)
                gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
                              for r in range(data_parallel_world_size)]
431
                gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views))
432
433

        return gbuf_view_items
434

Lawrence McAfee's avatar
Lawrence McAfee committed
435

436
    def reduce_model_grads(self, args, timers):
437
438
439
440
441
        """
        Note: this is a different order of reduction, versus the non-
        distributed optimizer, which reduces: 1) all grads, 2) embedding
        grads.
        """
442

443
444
445
446
447
        # All-reduce embedding grads.
        timers('backward-embedding-all-reduce').start()
        self.allreduce_embedding_grads(args)
        timers('backward-embedding-all-reduce').stop()

448
        # Reduce-scatter setup.
449
450
451
452
453
        timers('backward-params-all-reduce').start()
        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()
        data_parallel_group = mpu.get_data_parallel_group()

454
455
456
457
458
        # Scale grad buffers by '1 / data_parallel_world_size'.
        for model in self.models:
            for dtype, gbuf in model._grad_buffers.items():
                gbuf.data /= data_parallel_world_size

459
        # Reduce-scatter all grads.
460
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
461
        for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
462
            torch.distributed._reduce_scatter_base(
463
                gbuf_views[data_parallel_rank],
464
                gbuf,
465
466
                group = data_parallel_group,
            )
467

468
        timers('backward-params-all-reduce').stop()
469

Lawrence McAfee's avatar
Lawrence McAfee committed
470

471
    def gather_model_params(self, args, timers):
472
473
474
475
476
477
478
479
480
481
482

        timers('backward-params-all-gather').start()

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_group = mpu.get_data_parallel_group()

        # All-gather updated main params.
        # - All grad buffer views are guaranteed to have the same num elements
        #   across all data parallel ranks, with grad buffer padding that is done
        #   in distributed.py. Thus, all sub-views will have consistent start/end
        #   indexes across data parallel ranks.
483
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
484
        for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
485
            torch.distributed._all_gather_base(
486
                gbuf,
487
488
                gbuf_views[data_parallel_rank],
                group = data_parallel_group,
489
            )
490
491
492
493
494
495
496
497
498

        # Each model param now contains its updated values in its
        # '.main_grad' field.
        for model in self.models:
            for dtype, param_map in model._grad_buffer_param_index_map.items():
                for param in param_map:
                    param.detach().copy_(param.main_grad)

        timers('backward-params-all-gather').stop()
499

Lawrence McAfee's avatar
Lawrence McAfee committed
500

501
    def _collect_main_grad_data_for_unscaling(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
502
        return [
Lawrence McAfee's avatar
Lawrence McAfee committed
503
504
505
506
            param.grad.data
            for group in self.optimizer.param_groups
            for param in group["params"]
        ]
507
508


Lawrence McAfee's avatar
Lawrence McAfee committed
509
510
511
512
513
514
515
516
517
    def _get_model_and_main_params_data_float16(self):
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.shard_float16_groups,
                                           self.shard_fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
518
519


520
    def _copy_model_grads_to_main_grads(self):
521

Lawrence McAfee's avatar
Lawrence McAfee committed
522
523
524
525
526
        def copy_group_grads(model_groups, shard_main_groups):
            for model_group, shard_main_group in zip(model_groups,
                                                     shard_main_groups):
                for model_param, shard_main_param in zip(model_group,
                                                         shard_main_group):
Lawrence McAfee's avatar
Lawrence McAfee committed
527

Lawrence McAfee's avatar
Lawrence McAfee committed
528
                    param_range_map = self.get_model_param_range_map(model_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
529
                    param_range = param_range_map["param"]
Lawrence McAfee's avatar
Lawrence McAfee committed
530
531
                    assert param_range.size == shard_main_param.nelement()

Lawrence McAfee's avatar
Lawrence McAfee committed
532
533
                    model_grad = model_param.main_grad
                    shard_model_grad = model_grad.view(-1) \
Lawrence McAfee's avatar
Lawrence McAfee committed
534
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
535
536
                    shard_main_param.grad = shard_model_grad.float()

Lawrence McAfee's avatar
Lawrence McAfee committed
537
        copy_group_grads(self.model_float16_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
538
                         self.shard_fp32_from_float16_groups)
Lawrence McAfee's avatar
Lawrence McAfee committed
539
        copy_group_grads(self.model_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
540
                         self.shard_fp32_groups)
541

542
543
544

    def _copy_main_params_to_model_params(self):

Lawrence McAfee's avatar
Lawrence McAfee committed
545
546
547
548
549
        def copy_group_params(shard_main_groups, model_groups):
            for shard_main_group, model_group in zip(shard_main_groups,
                                                     model_groups):
                for shard_main_param, model_param in zip(shard_main_group,
                                                         model_group):
Lawrence McAfee's avatar
Lawrence McAfee committed
550

Lawrence McAfee's avatar
Lawrence McAfee committed
551
                    param_range_map = self.get_model_param_range_map(model_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
552
553
554
                    param_range = param_range_map["param"]
                    assert param_range.size == shard_main_param.nelement()

Lawrence McAfee's avatar
Lawrence McAfee committed
555
556
                    model_grad = model_param.main_grad
                    shard_model_grad = model_grad.view(-1) \
Lawrence McAfee's avatar
Lawrence McAfee committed
557
                        [param_range.start:param_range.end]
558

559
                    shard_model_grad.data.copy_(shard_main_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
560
561

        copy_group_params(self.shard_fp32_from_float16_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
562
                          self.model_float16_groups)
Lawrence McAfee's avatar
Lawrence McAfee committed
563
        copy_group_params(self.shard_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
564
                          self.model_fp32_groups)