megatron.py 19.7 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
3
4
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
Sengxian's avatar
Sengxian committed
5
"""
6
import os
7
import sys
Rick Ho's avatar
Rick Ho committed
8
import math
9
10
import random
from collections import OrderedDict
Jiezhong Qiu's avatar
Jiezhong Qiu committed
11
import numpy as np
Rick Ho's avatar
Rick Ho committed
12
import torch
Rick Ho's avatar
Rick Ho committed
13
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
14
import torch.nn.functional as F
Rick Ho's avatar
Rick Ho committed
15
16
17
18
19

from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel


20
class _FakeMegatronMLP(nn.Module):
Sengxian's avatar
Sengxian committed
21
    r"""
22
    A fake mlp without model parallelism for correctness testing
Sengxian's avatar
Sengxian committed
23
24
    """

Rick Ho's avatar
Rick Ho committed
25
    def __init__(self, args, _):
Rick Ho's avatar
Rick Ho committed
26
27
28
        super().__init__()
        self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
        self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
Sengxian's avatar
Sengxian committed
29

Rick Ho's avatar
Rick Ho committed
30
    def forward(self, x):
Sengxian's avatar
Sengxian committed
31
        r"""
Rick Ho's avatar
Rick Ho committed
32
        Directly use GeLU
Sengxian's avatar
Sengxian committed
33
        """
Rick Ho's avatar
Rick Ho committed
34
35
36
37
38
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x, torch.zeros_like(x)

Sengxian's avatar
Sengxian committed
39

40
def _megatron_init_method(self, rng, sigma):
Sengxian's avatar
Sengxian committed
41
    r"""
42
43
    Init method based on N(0, sigma).
    Copied from Megatron-LM
Sengxian's avatar
Sengxian committed
44
    """
45
46
47
48
49
50
51
52
53
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
    self.weight.data = torch.tensor(weight, dtype=dtype, device=device)

    if self.bias is not None:
        # Always initialize bias to zero.
        with torch.no_grad():
            self.bias.zero_()
Rick Ho's avatar
Rick Ho committed
54

Sengxian's avatar
Sengxian committed
55

Rick Ho's avatar
Rick Ho committed
56
def _random_init_weight(self, rng):
Sengxian's avatar
Sengxian committed
57
    r"""
Rick Ho's avatar
Rick Ho committed
58
    Copied from torch.nn.init.kaiming_uniform_
Sengxian's avatar
Sengxian committed
59
60
61
    """
    fan = nn.init._calculate_correct_fan(self.weight[0], "fan_in")
    gain = nn.init.calculate_gain("leaky_relu", math.sqrt(5))
Rick Ho's avatar
Rick Ho committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
    self.weight.data = torch.tensor(weight, dtype=dtype, device=device)

    if self.bias is not None:
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in)
        bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
        self.bias.data = torch.tensor(bias, dtype=dtype, device=device)


Rick Ho's avatar
Rick Ho committed
76
class MegatronMLP(FMoETransformerMLP):
Sengxian's avatar
Sengxian committed
77
    r"""
Rick Ho's avatar
Rick Ho committed
78
79
    Make the FMoETransformerMLP layer that distributes experts across
    communication group `group` to replace the original MLP layer in Megatron.
Sengxian's avatar
Sengxian committed
80
81
    """

Rick Ho's avatar
Rick Ho committed
82
    def __init__(self, args, group):
Sengxian's avatar
Sengxian committed
83
84
85
        assert (
            args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
            == 0
Rick Ho's avatar
Rick Ho committed
86
87
88
89
90
        ), "Batch size x sequence length should be multiple of mp size"
        if not args.distributed_experts:
            world_size = 1
        else:
            world_size = args.world_size
Sengxian's avatar
Sengxian committed
91
92
93
94
95
96
97
98
99
        super().__init__(
            args.num_experts,
            top_k=args.top_k,
            d_model=args.hidden_size,
            d_hidden=args.hidden_hidden_size,
            world_size=world_size,
            mp_group=group,
            expert_dp_comm="none" if args.distributed_experts else "dp",
        )
Rick Ho's avatar
Rick Ho committed
100
        self.hidden_size = args.hidden_size
Rick Ho's avatar
Rick Ho committed
101
102
103
104
        if args.distributed_experts:
            self.rank = args.rank
        else:
            self.rank = 0
105
106
        self.sigma = args.init_method_std
        self.num_layers = args.num_layers
Rick Ho's avatar
Rick Ho committed
107
108
109
        self.reset_parameters()

    def reset_parameters(self):
Sengxian's avatar
Sengxian committed
110
        r"""
Rick Ho's avatar
Rick Ho committed
111
112
        Initialize the weight as linear layers.
        As megatron is using fixed random seed for some nasty stuff, an
Rick Ho's avatar
Rick Ho committed
113
        additional numpy rng is used.
Sengxian's avatar
Sengxian committed
114
        """
Rick Ho's avatar
Rick Ho committed
115
        rng = np.random.default_rng(np.random.randint(2048) + self.rank)
116
        _megatron_init_method(self.experts.htoh4, rng, self.sigma)
117
        std = self.sigma / math.sqrt(2.0 * self.num_layers)
118
        _megatron_init_method(self.experts.h4toh, rng, std)
Rick Ho's avatar
Rick Ho committed
119
120

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
121
122
123
124
        return (
            super().forward(inp),
            torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
        )
Rick Ho's avatar
Rick Ho committed
125
126


Sengxian's avatar
Sengxian committed
127
128
129
130
131
132
133
134
def fmoefy(
    model,
    num_experts=None,
    distributed_experts=True,
    hidden_hidden_size=None,
    top_k=None,
):
    r"""
Rick Ho's avatar
Rick Ho committed
135
136
137
138
139
140
141
142
143
144
145
146
    Replace MLP layers in a transformer-based model in Megatron by MoE.
    * `model` should be a standard Megatron model that has
    `model.language_model.transformer.layers` as transformer layers, which is an
    array of transformer blocks that contain an `mlp` member.
    * `distributed_expert` is set to True if different experts are located in
    different workers. Otherwise, the experts on the workers are identical, and
    they are trained in data-parallel mode. This can be useful when testing on
    small models that do not require high training throughput or large parameter
    capacity.
    Note that pipeline parallel is not supported yet. When distributed experts
    are enabled, their communicator should be Megatron's
    tensor_model_parall_comm x data_parallel_comm, which is not created.
Sengxian's avatar
Sengxian committed
147
    """
Rick Ho's avatar
Rick Ho committed
148
    from megatron import get_args
Rick Ho's avatar
Rick Ho committed
149
    from megatron import mpu
Sengxian's avatar
Sengxian committed
150

Rick Ho's avatar
Rick Ho committed
151
152
153
154
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
Sengxian's avatar
Sengxian committed
155
156
        "num_experts" in args
    ), "num_experts should be specified in arguments or fmoefy function"
Rick Ho's avatar
Rick Ho committed
157
158
159

    if hidden_hidden_size is not None:
        args.hidden_hidden_size = hidden_hidden_size
Sengxian's avatar
Sengxian committed
160
    elif not hasattr(args, "hidden_hidden_size"):
Rick Ho's avatar
Rick Ho committed
161
162
163
164
        args.hidden_hidden_size = args.hidden_size * 4

    if top_k is not None:
        args.top_k = top_k
Sengxian's avatar
Sengxian committed
165
    elif not hasattr(args, "top_k"):
Rick Ho's avatar
Rick Ho committed
166
167
168
169
170
171
172
        args.top_k = 2

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

    for l in model.language_model.transformer.layers:
Rick Ho's avatar
Rick Ho committed
173
        l.mlp = MegatronMLP(args, mpu.get_model_parallel_group())
Rick Ho's avatar
Rick Ho committed
174
175
176
177
    return model


class DistributedDataParallel(DistributedGroupedDataParallel):
Sengxian's avatar
Sengxian committed
178
    r"""
Rick Ho's avatar
Rick Ho committed
179
180
181
    A wrapper that is used to replace the DDP module provided by Megatron, which
    is adapted to enable the sophiscated parallel and reduction strategies in
    Fast MoE.
Sengxian's avatar
Sengxian committed
182
183
    """

Rick Ho's avatar
Rick Ho committed
184
185
    def __init__(self, module):
        from megatron import mpu
Sengxian's avatar
Sengxian committed
186

Rick Ho's avatar
Rick Ho committed
187
188
189
        super().__init__(
            module,
            mp_group=mpu.get_model_parallel_group(),
Sengxian's avatar
Sengxian committed
190
            dp_group=mpu.get_data_parallel_group(),
Rick Ho's avatar
Rick Ho committed
191
192
193
        )

    def state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
194
        r"""
Rick Ho's avatar
Rick Ho committed
195
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
196
        """
Rick Ho's avatar
Rick Ho committed
197
198
199
        return self.module.state_dict(*args, **kwargs)

    def state_dict_for_save_checkpoint(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
200
        r"""
Rick Ho's avatar
Rick Ho committed
201
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
202
        """
Rick Ho's avatar
Rick Ho committed
203
204
205
        return self.module.state_dict_for_save_checkpoint(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
206
        r"""
Rick Ho's avatar
Rick Ho committed
207
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
208
        """
Rick Ho's avatar
Rick Ho committed
209
        return self.module.load_state_dict(*args, **kwargs)
210

211
212
213
def get_fmoe_checkpoint_name(checkpoints_path, iteration,
                        release=False, data_parallel_rank=-1):
    """A unified checkpoint name, allowing specifying a data parallel rank"""
214
    from megatron import mpu
215
216
217
218
219
    from megatron.checkpointing import get_checkpoint_name
    if data_parallel_rank == -1:
        data_parallel_rank = mpu.get_data_parallel_rank()
    if data_parallel_rank == 0:
        return get_checkpoint_name(checkpoints_path, iteration, release)
220
221
222
223
224
225
226
227
228
229

    if release:
        directory = 'release'
    else:
        directory = 'iter_{:07d}'.format(iteration)
    # Use both the tensor and pipeline MP rank.
    if mpu.get_pipeline_model_parallel_world_size() == 1:
        return os.path.join(checkpoints_path, directory,
                            'mp_rank_{:02d}_dp_rank_{:04d}'.format(
                                mpu.get_tensor_model_parallel_rank(),
230
                                data_parallel_rank
231
232
233
234
235
236
                                ),
                            'model_optim_rng.pt')
    return os.path.join(checkpoints_path, directory,
                        'mp_rank_{:02d}_{:03d}_dp_rank_{:04d}'.format(
                            mpu.get_tensor_model_parallel_rank(),
                            mpu.get_pipeline_model_parallel_rank(),
237
                            data_parallel_rank
238
239
240
                            ),
                        'model_optim_rng.pt')

241
def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='none'):
242
    """Save a model checkpoint with expert parallel """
Jiezhong Qiu's avatar
Jiezhong Qiu committed
243
    # TODO: update patch
244
245
246
    from megatron import get_args
    from megatron import mpu

247
248
249
250
251
    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follows the native load_checkpoint by megatron
        from megatron.checkpointing import save_checkpoint as save_checkpoint_native
        save_checkpoint_native(iteration, model, optimizer, lr_scheduler)
        return
252

253
    args = get_args()
254

255
256
257
258
259
260
261
262
263
264
    # Only rank zero of the data parallel writes to the disk.
    if isinstance(model, DistributedDataParallel):
        model = model.module

    if torch.distributed.get_rank() == 0:
        print('saving checkpoint at iteration {:7d} to {}'.format(
            iteration, args.save), flush=True)

    # Arguments, iteration, and model.
    state_dict = {}
Jiezhong Qiu's avatar
Jiezhong Qiu committed
265
266
    state_dict['model'] = model.state_dict_for_save_checkpoint(
        keep_vars=(mpu.get_data_parallel_rank() > 0))
267

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    def extract_expert_param(state_dict, expert_dp_comm='none'):
        state_dict_new = state_dict.__class__()
        for k, v in state_dict.items():
            # megatron uses both dict and OrderedDict in its state_dict
            if isinstance(v, (OrderedDict, dict)):
                v_new = extract_expert_param(v, expert_dp_comm)
                if len(v_new) > 0:
                    state_dict_new[k] = v_new
            elif hasattr(v, 'dp_comm') and v.dp_comm == expert_dp_comm:
                state_dict_new[k] = v.detach()
        return state_dict_new

    state_dict['model'] = extract_expert_param(
                state_dict['model'],
                expert_dp_comm)
283
284
285
286
287

    # Optimizer stuff.
    if not args.no_save_optim:
        if optimizer is not None:
            state_dict['optimizer'] = optimizer.state_dict()
288
289
290
291
292
293
294
295
            index = 0
            for param_group in optimizer.optimizer.param_groups:
                for param in param_group['params']:
                    if not (hasattr(param, 'dp_comm') and \
                        param.dp_comm == expert_dp_comm):
                        # this parameter is not an expert parameter
                        # thus there is no need to save its state in current rank
                        # since it has been saved by data parallel rank 0
296
297
298
299
300
                        if args.fp16:
                            # fp16 optimizer may have empty state due to overflow
                            state_dict['optimizer']['optimizer']['state'].pop(index, None)
                        else:
                            state_dict['optimizer']['state'].pop(index)
301
                    index += 1
302
303
304
305
            if args.fp16:
                state_dict['optimizer']['optimizer'].pop('param_groups')
            else:
                state_dict['optimizer'].pop('param_groups')
306
307

    # Save.
308
    checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration)
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
    from megatron.checkpointing import ensure_directory_exists
    from megatron.checkpointing import get_checkpoint_tracker_filename
    ensure_directory_exists(checkpoint_name)
    torch.save(state_dict, checkpoint_name)

    # Wait so everyone is done (necessary)
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('  successfully saved checkpoint at iteration {:7d} to {}'.format(
            iteration, args.save), flush=True)
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, 'w') as f:
            f.write(str(iteration))
    # Wait so everyone is done (not necessary)
    torch.distributed.barrier()
326
327


328
def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    """merge two state dicts, one from data parallel rank 0,
    another only contains expert states"""
    from megatron import print_rank_last
    def merge_model(state_dict_rank0, state_dict_local):
        for k, v in state_dict_local.items():
            # megatron uses both dict and OrderedDict in its state_dict
            if isinstance(v, (OrderedDict, dict)):
                print_rank_last("[merge model] go recursively to {}".format(k))
                merge_model(state_dict_rank0[k], v)
            else:
                before = state_dict_rank0[k].sum().item()
                state_dict_rank0[k] = v
                after = state_dict_rank0[k].sum().item()
                print_rank_last("[merge model] copy parameter {}, \
                    before.sum={:7f}, after.sum={:7f}".format(k, before, after))
    merge_model(state_dict_rank0['model'], state_dict_local['model'])

346
347
348
349
    optimizer_rank0 = state_dict_rank0['optimizer']['optimizer'] if fp16 else state_dict_rank0['optimizer']
    optimizer_local = state_dict_local['optimizer']['optimizer'] if fp16 else state_dict_local['optimizer']

    for k, v in optimizer_local['state'].items():
350
        before = {kk: vv.sum().item() \
351
352
            for kk, vv in optimizer_rank0['state'][k].items()}
        optimizer_rank0['state'][k] = v
353
        after = {kk: vv.sum().item() \
354
            for kk, vv in optimizer_rank0['state'][k].items()}
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        print_rank_last("[merge optimizer] copy {}, \
               before.sum={}, after.sum={}".format(k, str(before), str(after)))
    return state_dict_rank0

def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
    """Load a model checkpoint and return the iteration."""

    from megatron import get_args
    from megatron import mpu
    from megatron import print_rank_last
    from megatron.checkpointing import get_checkpoint_tracker_filename, set_checkpoint_version, check_checkpoint_args, update_num_microbatches
    if mpu.get_data_parallel_rank() == 0:
        # at dp rank 0, we still follow the native load_checkpoint by megatron
        from megatron.checkpointing import load_checkpoint as load_checkpoint_native
        return load_checkpoint_native(model, optimizer, lr_scheduler, load_arg)

    args = get_args()
    load_dir = getattr(args, load_arg)

    if isinstance(model, DistributedDataParallel):
        model = model.module
    # Read the tracker file and set the iteration.
    tracker_filename = get_checkpoint_tracker_filename(load_dir)

    # If no tracker file, return iretation zero.
    if not os.path.isfile(tracker_filename):
        print_rank_last('WARNING: could not find the metadata file {} '.format(
            tracker_filename))
        print_rank_last('    will not load any checkpoints and will start from '
                     'random')
        return 0

    # Otherwise, read the tracker file and either set the iteration or
    # mark it as a release checkpoint.
    iteration = 0
    release = False
    with open(tracker_filename, 'r') as f:
        metastring = f.read().strip()
        try:
            iteration = int(metastring)
        except ValueError:
            release = metastring == 'release'
            if not release:
                print_rank_last('ERROR: Invalid metadata file {}. Exiting'.format(
                    tracker_filename))
                sys.exit()

    assert iteration > 0 or release, 'error parsing metadata file {}'.format(
        tracker_filename)

    # Checkpoint.
    checkpoint_name_rank0 = get_fmoe_checkpoint_name(
        load_dir, iteration, release, 0)
    checkpoint_name_local = get_fmoe_checkpoint_name(
        load_dir, iteration, release, mpu.get_data_parallel_rank())
    print_rank_last(' loading checkpoint at rank 0 from {} and rank {} from {} at iteration {}, will merge them later'.format(
        checkpoint_name_rank0, mpu.get_data_parallel_rank(),
        checkpoint_name_local, iteration))

    # Load the checkpoint.
    def load_state_dict(checkpoint_name):
        try:
            state_dict = torch.load(checkpoint_name, map_location='cpu')
        except ModuleNotFoundError:
            from megatron.fp16_deprecated import loss_scaler
            # For backward compatibility.
            print_rank_last(' > deserializing using the old code structure ...')
            sys.modules['fp16.loss_scaler'] = sys.modules[
                'megatron.fp16_deprecated.loss_scaler']
            sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
                'megatron.fp16_deprecated.loss_scaler']
            state_dict = torch.load(checkpoint_name, map_location='cpu')
            sys.modules.pop('fp16.loss_scaler', None)
            sys.modules.pop('megatron.fp16.loss_scaler', None)
        except BaseException:
            print_rank_last('could not load the checkpoint')
            sys.exit()
        return state_dict
    state_dict_rank0 = load_state_dict(checkpoint_name_rank0)
    state_dict_local = load_state_dict(checkpoint_name_local)

436
    state_dict = merge_state_dict(state_dict_rank0, state_dict_local, args.fp16)
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507

    # set checkpoint version
    set_checkpoint_version(state_dict.get('checkpoint_version', 0))

    # Set iteration.
    if args.finetune or release:
        iteration = 0
    else:
        try:
            iteration = state_dict['iteration']
        except KeyError:
            try:  # Backward compatible with older checkpoints
                iteration = state_dict['total_iters']
            except KeyError:
                print_rank_last('A metadata file exists but unable to load '
                             'iteration from checkpoint {}, exiting'.format(
                                 checkpoint_name_local))
                sys.exit()

    # Check arguments.
    assert args.consumed_train_samples == 0
    assert args.consumed_valid_samples == 0
    if 'args' in state_dict:
        checkpoint_args = state_dict['args']
        check_checkpoint_args(checkpoint_args)
        args.consumed_train_samples = getattr(checkpoint_args,
                                              'consumed_train_samples', 0)
        update_num_microbatches(consumed_samples=args.consumed_train_samples)
        args.consumed_valid_samples = getattr(checkpoint_args,
                                              'consumed_valid_samples', 0)
    else:
        print_rank_last('could not find arguments in the checkpoint ...')

    # Model.
    model.load_state_dict(state_dict['model'])

    # Optimizer.
    if not release and not args.finetune and not args.no_load_optim:
        try:
            if optimizer is not None:
                optimizer.load_state_dict(state_dict['optimizer'])
            if lr_scheduler is not None:
                lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
        except KeyError:
            print_rank_last('Unable to load optimizer from checkpoint {}. '
                         'Specify --no-load-optim or --finetune to prevent '
                         'attempting to load the optimizer state, '
                         'exiting ...'.format(checkpoint_name_local))
            sys.exit()

    # rng states.
    if not release and not args.finetune and not args.no_load_rng:
        try:
            random.setstate(state_dict['random_rng_state'])
            np.random.set_state(state_dict['np_rng_state'])
            torch.set_rng_state(state_dict['torch_rng_state'])
            torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
            mpu.get_cuda_rng_tracker().set_states(
                state_dict['rng_tracker_states'])
        except KeyError:
            print_rank_last('Unable to load optimizer from checkpoint {}. '
                         'Specify --no-load-rng or --finetune to prevent '
                         'attempting to load the optimizer state, '
                         'exiting ...'.format(checkpoint_name_local))
            sys.exit()

    torch.distributed.barrier()
    print_rank_last('  successfully loaded checkpoint (with expert parametes updated) from {} at iteration {}'.format(
        args.load, iteration))

    return iteration