distrib_optimizer.py 22.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron distributed optimizer."""


import math
20
import torch
21
22

from megatron import get_args
23
24
from megatron import get_timers
from megatron import mpu
25
26
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
27
28

from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
29

Lawrence McAfee's avatar
Lawrence McAfee committed
30

31
class Range:
Lawrence McAfee's avatar
Lawrence McAfee committed
32
33
34
    '''A range represents a start and end points for indexing a shard
    from a full tensor.
    '''
35
36
37
38
39
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.size = end - start
    def normalize(self, start = 0):
40
        return Range(start, start + self.size)
41
42
43
44
    def __str__(self):
        return "%d,%d [%d]" % (self.start, self.end, self.size)


45
class DistributedOptimizer(MixedPrecisionOptimizer):
Lawrence McAfee's avatar
Lawrence McAfee committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    '''Distributed optimizer, for all data types (fp16, bf16, and fp32).

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
    '''
74
75

    @classmethod
76
    def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
77

78
        # Param range map.
79
        param_world_index_map = model._grad_buffer_param_index_map[dtype]
80
        param_range_map = {}
81
82
        for param, param_world_indexes in param_world_index_map.items():

83
            # Param range.
84
85
86
            param_world_start, param_world_end = param_world_indexes
            param_local_start = max(
                0,
87
                param_world_start - gbuf_world_range.start)
88
            param_local_end = min(
89
90
                gbuf_world_range.size,
                param_world_end - gbuf_world_range.start)
91

92
            # Add param, if within local gbuf range.
93
            if param_local_end > param_local_start:
94
95
96
97
98
99
100
101
102
                param_local_range = Range(param_local_start, param_local_end)
                param_world_range = param_local_range.normalize(
                    param_local_start + gbuf_world_range.start)
                sub_param_start = max(0, gbuf_world_range.start-param_world_start)
                sub_param_range = param_local_range.normalize(sub_param_start)
                param_range_map[param] = {
                    "gbuf_world" : param_world_range,
                    "gbuf_local" : param_local_range,
                    "param" : sub_param_range,
103
104
                }

105
        return param_range_map
106

Lawrence McAfee's avatar
Lawrence McAfee committed
107

108
    @classmethod
109
    def build_model_gbuf_range(cls, model, dtype):
110
111
112
113

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()

114
        # Grad buffer range.
115
116
        grad_buffer = model._grad_buffers[dtype]
        gbuf_size = grad_buffer.numel
117
        max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size))
118

119
120
        # All world ranges. (i.e., across all data parallel ranks)
        gbuf_world_all_ranges = []
121
        for r in range(data_parallel_world_size):
122
123
124
125
            gbuf_world_start = r * max_gbuf_range_size
            gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size)
            gbuf_world_range = Range(gbuf_world_start, gbuf_world_end)
            gbuf_world_all_ranges.append(gbuf_world_range)
126

127
128
129
        # Local DP's ranges.
        gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
        gbuf_local_range = gbuf_world_range.normalize()
130

131
132
133
134
        # Get each param's ranges.
        param_range_map = cls.build_model_gbuf_param_range_map(model,
                                                               dtype,
                                                               gbuf_world_range)
135
136
137

        # Altogether.
        data = {
138
139
140
141
142
            "local" : gbuf_local_range,
            "world" : gbuf_world_range,
            "world_all" : gbuf_world_all_ranges,
            "param_map" : param_range_map,
            "max_range_size" : max_gbuf_range_size,
143
144
145
146
        }

        return data

Lawrence McAfee's avatar
Lawrence McAfee committed
147

148
    @classmethod
149
    def build_model_gbuf_range_map(cls, model):
150
        return {
151
            dtype : cls.build_model_gbuf_range(model, dtype)
152
153
154
            for dtype in model._grad_buffers
        }

Lawrence McAfee's avatar
Lawrence McAfee committed
155

156
    @classmethod
157
158
    def build_model_param_gbuf_map(cls, model_gbuf_ranges):
        '''Create a reverse of the model_gbuf_ranges, for referencing in
159
        opposite direction.'''
160
        param_gbuf_map = {}
161
162
163
        for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param, param_range_map in gbuf_range_map["param_map"].items():
164
165
166
                    param_gbuf_map[param] = (model_index, dtype)
        return param_gbuf_map

Lawrence McAfee's avatar
Lawrence McAfee committed
167

168
    @classmethod
169
    def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
170
171
172
173
174
175
176
177
178
179

        num_groups = len(param_groups)

        # Param group map.
        param_group_map = {}
        for group_index, group in enumerate(param_groups):
            for param in group["params"]:
                assert param.requires_grad
                param_group_map[param] = group_index

180
181
182
183
184
        # Optimizer group ranges.
        group_ranges = [ {"params": []} for _ in param_groups ]
        for model_gbuf_range_map in model_gbuf_ranges:
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param in gbuf_range_map["param_map"]:
185
                    group_index = param_group_map[param]
186
187
                    group_range = group_ranges[group_index]
                    group_range["params"].append(param)
188

189
190
191
192
        # Squeeze zero-size group ranges.
        for group_index, group_range in enumerate(group_ranges):
            group_range["orig_group"] = param_groups[group_index]
        group_ranges = [ g for g in group_ranges if len(g["params"]) > 0 ]
193

194
        return group_ranges
195

196

197
    @classmethod
198
199
200
201
202
    def build_model_and_main_param_groups(cls,
                                        model_gbuf_ranges,
                                        param_gbuf_map,
                                        opt_group_ranges):

Lawrence McAfee's avatar
Lawrence McAfee committed
203
204
205
206
207
208
209
210
        # Parameter groups:
        #   model_float16_groups: original float16 parameters
        #   model_fp32_groups: original fp32 parameters
        #   shard_float16_groups: shards of original float16 parameters
        #   shard_fp32_groups: shards of original fp32 parameters
        #   shard_fp32_from_float16_groups: fp32 copy of float16 parameters
        model_float16_groups = []
        model_fp32_groups = []
211
212
213
214
        shard_float16_groups = []
        shard_fp32_groups = []
        shard_fp32_from_float16_groups = []

Lawrence McAfee's avatar
Lawrence McAfee committed
215
        # Allocate (or slice) each group's param shard.
216
217
218
        for group_index, group_range in enumerate(opt_group_ranges):

            # Params of this group.
Lawrence McAfee's avatar
Lawrence McAfee committed
219
220
            model_float16_params_this_group = []
            model_fp32_params_this_group = []
221
222
223
            shard_float16_params_this_group = []
            shard_fp32_params_this_group = []
            shard_fp32_from_float16_params_this_group = []
Lawrence McAfee's avatar
Lawrence McAfee committed
224
225
            model_float16_groups.append(model_float16_params_this_group)
            model_fp32_groups.append(model_fp32_params_this_group)
226
227
228
229
230
231
232
            shard_float16_groups.append(shard_float16_params_this_group)
            shard_fp32_groups.append(shard_fp32_params_this_group)
            shard_fp32_from_float16_groups.append(
                shard_fp32_from_float16_params_this_group)

            for model_param in group_range["params"]:

233
234
                assert model_param.requires_grad

235
236
237
                model_index, dtype = param_gbuf_map[model_param]
                gbuf_range = model_gbuf_ranges[model_index][dtype]
                param_range = gbuf_range["param_map"][model_param]["param"]
238
239

                # fp16, bf16 params.
240
241
242
243
                if model_param.type() in ['torch.cuda.HalfTensor',
                                          'torch.cuda.BFloat16Tensor']:

                    # Clone model -> main.
Lawrence McAfee's avatar
Lawrence McAfee committed
244
245
                    shard_model_param = model_param.detach().view(-1) \
                        [param_range.start:param_range.end]
246
247
248
249
250
251
252
253
254
255
                    shard_main_param = shard_model_param.clone().float()
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_main_param, model_param)
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
                        shard_main_param.shared = model_param.shared

                    # Add to group.
Lawrence McAfee's avatar
Lawrence McAfee committed
256
                    model_float16_params_this_group.append(model_param)
257
258
                    shard_float16_params_this_group.append(shard_model_param)
                    shard_fp32_from_float16_params_this_group.append(shard_main_param)
259
260

                # fp32 params.
261
                elif model_param.type() == 'torch.cuda.FloatTensor':
Lawrence McAfee's avatar
Lawrence McAfee committed
262
263
                    shard_model_param = model_param.view(-1) \
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
264
                    model_fp32_params_this_group.append(model_param)
265
                    shard_fp32_params_this_group.append(shard_model_param)
266
267
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
268
269
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
270
271
272
273
274
275
276
277

                else:
                    raise TypeError('Wrapped parameters must be one of '
                                    'torch.cuda.FloatTensor,  '
                                    'torch.cuda.HalfTensor, or '
                                    'torch.cuda.BFloat16Tensor. '
                                    'Received {}'.format(param.type()))

Lawrence McAfee's avatar
Lawrence McAfee committed
278
            # Update optimizer's params.
279
280
281
282
283
284
            group_range["orig_group"]["params"] = [
                *shard_fp32_params_this_group,
                *shard_fp32_from_float16_params_this_group,
            ]

        return (
Lawrence McAfee's avatar
Lawrence McAfee committed
285
286
            model_float16_groups,
            model_fp32_groups,
287
288
289
290
            shard_float16_groups,
            shard_fp32_groups,
            shard_fp32_from_float16_groups,
        )
291

Lawrence McAfee's avatar
Lawrence McAfee committed
292

293
294
    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
295
                 fp16, bf16, grad_scaler, models):
296
297
298
299

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
300
            fp16, bf16, grad_scaler, models)
301

302
303
        # Verify that contiguous buffers are being used
        # - Note: this should already be checked in arguments.py
304
        assert use_contiguous_buffers_in_local_ddp
305

306
307
        # Model grad buffer ranges.
        self.model_gbuf_ranges = []
308
        for model_index, model in enumerate(self.models):
309
310
311
            self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model))
        self.model_param_gbuf_map = \
            self.build_model_param_gbuf_map(self.model_gbuf_ranges)
312

313
314
        # Optimizer ranges.
        self.opt_group_ranges = self.build_optimizer_group_ranges(
315
            self.optimizer.param_groups,
316
            self.model_gbuf_ranges)
317
318

        # Allocate main param shards.
319
        (
Lawrence McAfee's avatar
Lawrence McAfee committed
320
321
            self.model_float16_groups,
            self.model_fp32_groups,
322
323
324
325
326
327
328
            self.shard_float16_groups,
            self.shard_fp32_groups,
            self.shard_fp32_from_float16_groups,
        ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges,
                                                   self.model_param_gbuf_map,
                                                   self.opt_group_ranges)

329
330
331
332
        # Update optimizer groups.
        # - Also, leverage state_dict() and load_state_dict() to
        #   recast preexisting per-param state tensors.
        self.optimizer.param_groups = \
333
            [ g["orig_group"] for g in self.opt_group_ranges ]
334
335
        self.optimizer.load_state_dict(self.optimizer.state_dict())

336

337
    def get_model_param_range_map(self, param):
Lawrence McAfee's avatar
Lawrence McAfee committed
338
339
340
341
        '''
        Given a model param, get the index sub-range of the param that this
        data-parallel rank owns.
        '''
342
343
344
345
346
347
        model_index, dtype = self.model_param_gbuf_map[param]
        gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
        param_range_map = gbuf_range_map["param_map"][param]
        return param_range_map


348
349
350
    def get_model_parallel_group(self):
        return None

351
352

    def state_dict(self):
353
354
355
356
357
358
359
360
361
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['shard_fp32_from_float16_groups'] = \
            self.shard_fp32_from_float16_groups
        return state_dict


362
    def load_state_dict(self, state_dict):
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389

        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
            print_rank_0('***WARNING*** found an old checkpoint, will not '
                         'load grad scaler ...')
        else:
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')

        # Copy data for the main params.
        for current_group, saved_group in zip(
                self.shard_fp32_from_float16_groups,
                state_dict["shard_fp32_from_float16_groups"]):
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)
Lawrence McAfee's avatar
Lawrence McAfee committed
390

391

392
393
394
395
396
397
398
    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
        float16_groups & fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
        for groups in (
Lawrence McAfee's avatar
Lawrence McAfee committed
399
400
                self.model_float16_groups,
                self.model_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
401
                self.shard_float16_groups, # grad empty/unused here?
402
                self.shard_fp32_groups, # throws grad-access warning
403
404
405
                self.shard_fp32_from_float16_groups):
            for group in groups:
                _zero_grad_group_helper(group, set_to_none)
406

407

408
409
410
411
412
413
414
415
416
417
418
419
420
    def get_model_grad_buffer_dp_views(self):

        data_parallel_world_size = mpu.get_data_parallel_world_size()

        # Grad buffer views.
        gbuf_view_items = []
        for model_index, model in enumerate(self.models):
            for dtype, gbuf in model._grad_buffers.items():

                assert gbuf.numel_padded % data_parallel_world_size == 0
                shard_size = int(gbuf.numel_padded / data_parallel_world_size)
                gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
                              for r in range(data_parallel_world_size)]
421
                gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views))
422
423

        return gbuf_view_items
424

Lawrence McAfee's avatar
Lawrence McAfee committed
425

426
    def reduce_model_grads(self, args, timers):
427
428
429
430
        '''Note: this is a different order of reduction, versus the non-
           distributed optimizer, which reduces: 1) all grads, 2) embedding
           grads.
        '''
431

432
433
434
435
436
        # All-reduce embedding grads.
        timers('backward-embedding-all-reduce').start()
        self.allreduce_embedding_grads(args)
        timers('backward-embedding-all-reduce').stop()

437
        # Reduce-scatter setup.
438
439
440
441
442
        timers('backward-params-all-reduce').start()
        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()
        data_parallel_group = mpu.get_data_parallel_group()

443
444
445
446
447
        # Scale grad buffers by '1 / data_parallel_world_size'.
        for model in self.models:
            for dtype, gbuf in model._grad_buffers.items():
                gbuf.data /= data_parallel_world_size

448
        # Reduce-scatter all grads.
449
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
450
        for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
451
            torch.distributed._reduce_scatter_base(
452
                gbuf_views[data_parallel_rank],
453
                gbuf,
454
455
                group = data_parallel_group,
            )
456

457
        timers('backward-params-all-reduce').stop()
458

Lawrence McAfee's avatar
Lawrence McAfee committed
459

460
    def gather_model_params(self, args, timers):
461
462
463
464
465
466
467
468
469
470
471

        timers('backward-params-all-gather').start()

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_group = mpu.get_data_parallel_group()

        # All-gather updated main params.
        # - All grad buffer views are guaranteed to have the same num elements
        #   across all data parallel ranks, with grad buffer padding that is done
        #   in distributed.py. Thus, all sub-views will have consistent start/end
        #   indexes across data parallel ranks.
472
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
473
        for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
474
            torch.distributed._all_gather_base(
475
                gbuf,
476
477
                gbuf_views[data_parallel_rank],
                group = data_parallel_group,
478
            )
479
480
481
482
483
484
485
486
487

        # Each model param now contains its updated values in its
        # '.main_grad' field.
        for model in self.models:
            for dtype, param_map in model._grad_buffer_param_index_map.items():
                for param in param_map:
                    param.detach().copy_(param.main_grad)

        timers('backward-params-all-gather').stop()
488

Lawrence McAfee's avatar
Lawrence McAfee committed
489

490
    def _collect_main_grad_data_for_unscaling(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
491
        return [
Lawrence McAfee's avatar
Lawrence McAfee committed
492
493
494
495
            param.grad.data
            for group in self.optimizer.param_groups
            for param in group["params"]
        ]
496
497


Lawrence McAfee's avatar
Lawrence McAfee committed
498
499
500
501
502
503
504
505
506
    def _get_model_and_main_params_data_float16(self):
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.shard_float16_groups,
                                           self.shard_fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
507
508


509
    def _copy_model_grads_to_main_grads(self):
510

Lawrence McAfee's avatar
Lawrence McAfee committed
511
512
513
514
515
        def copy_group_grads(model_groups, shard_main_groups):
            for model_group, shard_main_group in zip(model_groups,
                                                     shard_main_groups):
                for model_param, shard_main_param in zip(model_group,
                                                         shard_main_group):
Lawrence McAfee's avatar
Lawrence McAfee committed
516

Lawrence McAfee's avatar
Lawrence McAfee committed
517
                    param_range_map = self.get_model_param_range_map(model_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
518
                    param_range = param_range_map["param"]
Lawrence McAfee's avatar
Lawrence McAfee committed
519
520
                    assert param_range.size == shard_main_param.nelement()

Lawrence McAfee's avatar
Lawrence McAfee committed
521
522
                    model_grad = model_param.main_grad
                    shard_model_grad = model_grad.view(-1) \
Lawrence McAfee's avatar
Lawrence McAfee committed
523
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
524
525
                    shard_main_param.grad = shard_model_grad.float()

Lawrence McAfee's avatar
Lawrence McAfee committed
526
        copy_group_grads(self.model_float16_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
527
                         self.shard_fp32_from_float16_groups)
Lawrence McAfee's avatar
Lawrence McAfee committed
528
        copy_group_grads(self.model_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
529
                         self.shard_fp32_groups)
530

531
532
533

    def _copy_main_params_to_model_params(self):

Lawrence McAfee's avatar
Lawrence McAfee committed
534
535
536
537
538
        def copy_group_params(shard_main_groups, model_groups):
            for shard_main_group, model_group in zip(shard_main_groups,
                                                     model_groups):
                for shard_main_param, model_param in zip(shard_main_group,
                                                         model_group):
Lawrence McAfee's avatar
Lawrence McAfee committed
539

Lawrence McAfee's avatar
Lawrence McAfee committed
540
                    param_range_map = self.get_model_param_range_map(model_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
541
542
543
                    param_range = param_range_map["param"]
                    assert param_range.size == shard_main_param.nelement()

Lawrence McAfee's avatar
Lawrence McAfee committed
544
545
                    model_grad = model_param.main_grad
                    shard_model_grad = model_grad.view(-1) \
Lawrence McAfee's avatar
Lawrence McAfee committed
546
                        [param_range.start:param_range.end]
547

548
                    shard_model_grad.data.copy_(shard_main_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
549
550

        copy_group_params(self.shard_fp32_from_float16_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
551
                          self.model_float16_groups)
Lawrence McAfee's avatar
Lawrence McAfee committed
552
        copy_group_params(self.shard_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
553
                          self.model_fp32_groups)