megatron.py 15.6 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
3
4
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
Sengxian's avatar
Sengxian committed
5
"""
6
import os
Rick Ho's avatar
Rick Ho committed
7
import math
8
9
import random
from collections import OrderedDict
Jiezhong Qiu's avatar
Jiezhong Qiu committed
10
import numpy as np
Rick Ho's avatar
Rick Ho committed
11
import torch
Rick Ho's avatar
Rick Ho committed
12
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
13
import torch.nn.functional as F
Rick Ho's avatar
Rick Ho committed
14
15
16

from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel
17
18
from .balance import update_balance_profile, reset_balance_profile
from .utils import get_torch_default_comm
Rick Ho's avatar
Rick Ho committed
19
20


21
class _FakeMegatronMLP(nn.Module):
Sengxian's avatar
Sengxian committed
22
    r"""
23
    A fake mlp without model parallelism for correctness testing
Sengxian's avatar
Sengxian committed
24
25
    """

Rick Ho's avatar
Rick Ho committed
26
    def __init__(self, args, _):
Rick Ho's avatar
Rick Ho committed
27
28
29
        super().__init__()
        self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
        self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
Sengxian's avatar
Sengxian committed
30

Rick Ho's avatar
Rick Ho committed
31
    def forward(self, x):
Sengxian's avatar
Sengxian committed
32
        r"""
Rick Ho's avatar
Rick Ho committed
33
        Directly use GeLU
Sengxian's avatar
Sengxian committed
34
        """
Rick Ho's avatar
Rick Ho committed
35
36
37
38
39
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x, torch.zeros_like(x)

Sengxian's avatar
Sengxian committed
40

41
def _megatron_init_method(self, rng, sigma):
Sengxian's avatar
Sengxian committed
42
    r"""
43
44
    Init method based on N(0, sigma).
    Copied from Megatron-LM
Sengxian's avatar
Sengxian committed
45
    """
46
47
48
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
Rick Ho's avatar
Rick Ho committed
49
    self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)
50
51
52
53
54

    if self.bias is not None:
        # Always initialize bias to zero.
        with torch.no_grad():
            self.bias.zero_()
Rick Ho's avatar
Rick Ho committed
55

Sengxian's avatar
Sengxian committed
56

Rick Ho's avatar
Rick Ho committed
57
def _random_init_weight(self, rng):
Sengxian's avatar
Sengxian committed
58
    r"""
Rick Ho's avatar
Rick Ho committed
59
    Copied from torch.nn.init.kaiming_uniform_
Sengxian's avatar
Sengxian committed
60
61
62
    """
    fan = nn.init._calculate_correct_fan(self.weight[0], "fan_in")
    gain = nn.init.calculate_gain("leaky_relu", math.sqrt(5))
Rick Ho's avatar
Rick Ho committed
63
64
65
66
67
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
Rick Ho's avatar
Rick Ho committed
68
    self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)
Rick Ho's avatar
Rick Ho committed
69
70
71
72
73

    if self.bias is not None:
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in)
        bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
Rick Ho's avatar
Rick Ho committed
74
        self.bias.data = torch.from_numpy(bias).to(dtype=dtype, device=device)
Rick Ho's avatar
Rick Ho committed
75
76


77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
balance_dict = {}
num_layers = 0


def reset_gate_hook():
    from megatron import get_args

    global balance_dict, num_layers
    reset_balance_profile(balance_dict, num_layers, get_args().balance_strategy)


def get_balance_profile():
    global balance_dict
    return balance_dict


def generate_megatron_gate_hook(layer_idx, num_expert_global):
    from megatron import get_args

    balance_strategy = get_args().balance_strategy

    def megatron_gate_hook(gate_top_k_idx, gate_score_top_k, gate_state_dict):
        global balance_dict
        update_balance_profile(
            balance_dict,
            gate_top_k_idx,
            gate_score_top_k,
            gate_state_dict,
            layer_idx,
            num_expert_global,
            balance_strategy,
        )

    return megatron_gate_hook


def add_fmoe_args(parser):
    group = parser.add_argument_group(title="fastmoe")

    group.add_argument("--fmoefy", action="store_true")
    group.add_argument("--num-experts", type=int, default=None)
    group.add_argument("--top-k", type=int, default=2)
    group.add_argument("--balance-loss-weight", type=float, default=1)
    group.add_argument("--balance-strategy", type=str, default=None)

    return parser


def add_balance_log(writer, iteration):
    from megatron import is_last_rank

    balance_dict_tensor = torch.vstack(
        [torch.tensor(item, device=item[0].device) for item in balance_dict.values()]
    ).detach()
    world_group = get_torch_default_comm()
    world_size = torch.distributed.get_world_size(group=world_group)
    torch.distributed.all_reduce(balance_dict_tensor, group=world_group)
    balance_dict_tensor /= world_size

    if writer and is_last_rank():
        for idx, metric_name in enumerate(balance_dict):
            for layer_id, val in enumerate(balance_dict_tensor[idx]):
                writer.add_scalar(
                    f"balance-{metric_name}/layer-{layer_id}", val.item(), iteration
                )
            writer.add_scalar(
                f"balance-{metric_name}/all",
                balance_dict_tensor[idx].mean().item(),
                iteration,
            )

    reset_gate_hook()


def patch_forward_step(forward_step_func):
    r"""
    Patch model's forward_step_func to support balance loss
    """

    from megatron.mpu import is_pipeline_last_stage
    from megatron import get_args

    if not get_args().balance_strategy:
        return forward_step_func

    def forward_step_with_balance_loss(data_iterator, model, input_tensor):
        args = get_args()
        output = forward_step_func(data_iterator, model, input_tensor)

        if is_pipeline_last_stage():
            loss_name = args.balance_strategy + "_loss"

            (loss, state_dict), bal_loss = (
                output,
                (
                    torch.tensor(
                        balance_dict[loss_name],
                        device=balance_dict[loss_name][0].device,
                    ).mean()
                    * args.balance_loss_weight
                ).float(),
            )

            # avarage across world group
            world_group = get_torch_default_comm()
            world_size = torch.distributed.get_world_size(group=world_group)
            averaged_bal_loss = bal_loss.clone().detach()
            torch.distributed.all_reduce(averaged_bal_loss, group=world_group)
            averaged_bal_loss /= world_size

            loss += bal_loss
            state_dict[loss_name] = averaged_bal_loss

            return loss, state_dict
        else:
            return output

    return forward_step_with_balance_loss


def patch_model_provider(model_provider):
    from megatron import get_args

    def fmoefied_model_provider():
        args = get_args()
        return fmoefy(
            model_provider(),
            num_experts=args.num_experts,
            hidden_hidden_size=4 * args.hidden_size // args.top_k,
            top_k=args.top_k,
        )

    return fmoefied_model_provider


Rick Ho's avatar
Rick Ho committed
212
class MegatronMLP(FMoETransformerMLP):
Sengxian's avatar
Sengxian committed
213
    r"""
Rick Ho's avatar
Rick Ho committed
214
215
    Make the FMoETransformerMLP layer that distributes experts across
    communication group `group` to replace the original MLP layer in Megatron.
Sengxian's avatar
Sengxian committed
216
217
    """

218
    def __init__(self, args, group, layer_idx):
Sengxian's avatar
Sengxian committed
219
        assert (
220
            args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
Sengxian's avatar
Sengxian committed
221
            == 0
Rick Ho's avatar
Rick Ho committed
222
223
224
225
226
        ), "Batch size x sequence length should be multiple of mp size"
        if not args.distributed_experts:
            world_size = 1
        else:
            world_size = args.world_size
227
228
229
230
231
232
233
234
235
236
237
        gate = None
        if not args.balance_strategy or args.balance_strategy == "gshard":
            from .gates import NaiveGate

            gate = NaiveGate
        elif args.balance_strategy == "noisy":
            from .gates import NoisyGate

            gate = NoisyGate
        else:
            assert False, "Undefined balance strategy {}" % (args.balance_strategy)
Sengxian's avatar
Sengxian committed
238
239
240
241
242
243
244
245
        super().__init__(
            args.num_experts,
            top_k=args.top_k,
            d_model=args.hidden_size,
            d_hidden=args.hidden_hidden_size,
            world_size=world_size,
            mp_group=group,
            expert_dp_comm="none" if args.distributed_experts else "dp",
246
247
248
249
            gate_hook=generate_megatron_gate_hook(
                layer_idx, args.num_experts * world_size
            ),
            gate=gate,
Sengxian's avatar
Sengxian committed
250
        )
Rick Ho's avatar
Rick Ho committed
251
        self.hidden_size = args.hidden_size
Rick Ho's avatar
Rick Ho committed
252
253
254
255
        if args.distributed_experts:
            self.rank = args.rank
        else:
            self.rank = 0
256
257
        self.sigma = args.init_method_std
        self.num_layers = args.num_layers
Rick Ho's avatar
Rick Ho committed
258
259
260
        self.reset_parameters()

    def reset_parameters(self):
Sengxian's avatar
Sengxian committed
261
        r"""
Rick Ho's avatar
Rick Ho committed
262
263
        Initialize the weight as linear layers.
        As megatron is using fixed random seed for some nasty stuff, an
Rick Ho's avatar
Rick Ho committed
264
        additional numpy rng is used.
Sengxian's avatar
Sengxian committed
265
        """
Rick Ho's avatar
Rick Ho committed
266
        rng = np.random.default_rng(np.random.randint(2048) + self.rank)
267
        _megatron_init_method(self.experts.htoh4, rng, self.sigma)
268
        std = self.sigma / math.sqrt(2.0 * self.num_layers)
269
        _megatron_init_method(self.experts.h4toh, rng, std)
Rick Ho's avatar
Rick Ho committed
270
271

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
272
273
274
275
        return (
            super().forward(inp),
            torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
        )
Rick Ho's avatar
Rick Ho committed
276
277


Sengxian's avatar
Sengxian committed
278
279
280
281
282
283
284
285
def fmoefy(
    model,
    num_experts=None,
    distributed_experts=True,
    hidden_hidden_size=None,
    top_k=None,
):
    r"""
Rick Ho's avatar
Rick Ho committed
286
287
288
289
290
291
292
293
294
295
296
297
    Replace MLP layers in a transformer-based model in Megatron by MoE.
    * `model` should be a standard Megatron model that has
    `model.language_model.transformer.layers` as transformer layers, which is an
    array of transformer blocks that contain an `mlp` member.
    * `distributed_expert` is set to True if different experts are located in
    different workers. Otherwise, the experts on the workers are identical, and
    they are trained in data-parallel mode. This can be useful when testing on
    small models that do not require high training throughput or large parameter
    capacity.
    Note that pipeline parallel is not supported yet. When distributed experts
    are enabled, their communicator should be Megatron's
    tensor_model_parall_comm x data_parallel_comm, which is not created.
Sengxian's avatar
Sengxian committed
298
    """
Rick Ho's avatar
Rick Ho committed
299
    from megatron import get_args
Rick Ho's avatar
Rick Ho committed
300
    from megatron import mpu
Sengxian's avatar
Sengxian committed
301

Rick Ho's avatar
Rick Ho committed
302
303
304
305
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
Sengxian's avatar
Sengxian committed
306
307
        "num_experts" in args
    ), "num_experts should be specified in arguments or fmoefy function"
Rick Ho's avatar
Rick Ho committed
308
309
310

    if hidden_hidden_size is not None:
        args.hidden_hidden_size = hidden_hidden_size
Sengxian's avatar
Sengxian committed
311
    elif not hasattr(args, "hidden_hidden_size"):
Rick Ho's avatar
Rick Ho committed
312
313
314
315
        args.hidden_hidden_size = args.hidden_size * 4

    if top_k is not None:
        args.top_k = top_k
Sengxian's avatar
Sengxian committed
316
    elif not hasattr(args, "top_k"):
Rick Ho's avatar
Rick Ho committed
317
318
319
320
321
322
        args.top_k = 2

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

323
324
325
326
327
328
329
330
    for idx, l in enumerate(model.language_model.transformer.layers):
        l.mlp = MegatronMLP(args, mpu.get_model_parallel_group(), idx)

    # initialize gate hook
    global num_layers, balance_dict
    num_layers = len(model.language_model.transformer.layers)
    reset_gate_hook()

Rick Ho's avatar
Rick Ho committed
331
332
333
334
    return model


class DistributedDataParallel(DistributedGroupedDataParallel):
Sengxian's avatar
Sengxian committed
335
    r"""
Rick Ho's avatar
Rick Ho committed
336
337
338
    A wrapper that is used to replace the DDP module provided by Megatron, which
    is adapted to enable the sophiscated parallel and reduction strategies in
    Fast MoE.
Sengxian's avatar
Sengxian committed
339
340
    """

Rick Ho's avatar
Rick Ho committed
341
342
    def __init__(self, module):
        from megatron import mpu
Sengxian's avatar
Sengxian committed
343

Rick Ho's avatar
Rick Ho committed
344
345
346
        super().__init__(
            module,
            mp_group=mpu.get_model_parallel_group(),
Sengxian's avatar
Sengxian committed
347
            dp_group=mpu.get_data_parallel_group(),
Rick Ho's avatar
Rick Ho committed
348
349
350
        )

    def state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
351
        r"""
Rick Ho's avatar
Rick Ho committed
352
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
353
        """
Rick Ho's avatar
Rick Ho committed
354
355
356
        return self.module.state_dict(*args, **kwargs)

    def state_dict_for_save_checkpoint(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
357
        r"""
Rick Ho's avatar
Rick Ho committed
358
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
359
        """
Rick Ho's avatar
Rick Ho committed
360
361
362
        return self.module.state_dict_for_save_checkpoint(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
363
        r"""
Rick Ho's avatar
Rick Ho committed
364
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
365
        """
Rick Ho's avatar
Rick Ho committed
366
        return self.module.load_state_dict(*args, **kwargs)
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394

def get_checkpoint_name(checkpoints_path, iteration,
                        release=False):
    """A unified checkpoint name."""
    from megatron import mpu

    if release:
        directory = 'release'
    else:
        directory = 'iter_{:07d}'.format(iteration)
    # Use both the tensor and pipeline MP rank.
    if mpu.get_pipeline_model_parallel_world_size() == 1:
        return os.path.join(checkpoints_path, directory,
                            'mp_rank_{:02d}_dp_rank_{:04d}'.format(
                                mpu.get_tensor_model_parallel_rank(),
                                mpu.get_data_parallel_rank()
                                ),
                            'model_optim_rng.pt')
    return os.path.join(checkpoints_path, directory,
                        'mp_rank_{:02d}_{:03d}_dp_rank_{:04d}'.format(
                            mpu.get_tensor_model_parallel_rank(),
                            mpu.get_pipeline_model_parallel_rank(),
                            mpu.get_data_parallel_rank()
                            ),
                        'model_optim_rng.pt')

def save_checkpoint(iteration, model, optimizer, lr_scheduler):
    """Save a model checkpoint with expert parallel """
Jiezhong Qiu's avatar
Jiezhong Qiu committed
395
    # TODO: update patch
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    from megatron import get_args
    from megatron import mpu

    args = get_args()

    # Only rank zero of the data parallel writes to the disk.
    if isinstance(model, DistributedDataParallel):
        model = model.module

    if torch.distributed.get_rank() == 0:
        print('saving checkpoint at iteration {:7d} to {}'.format(
            iteration, args.save), flush=True)

    # Arguments, iteration, and model.
    state_dict = {}
    state_dict['args'] = args
    state_dict['checkpoint_version'] = 3.0
    state_dict['iteration'] = iteration
Jiezhong Qiu's avatar
Jiezhong Qiu committed
414
415
    state_dict['model'] = model.state_dict_for_save_checkpoint(
        keep_vars=(mpu.get_data_parallel_rank() > 0))
416
417
418
419
420
421
422

    if mpu.get_data_parallel_rank() != 0:

        def extract_expert_param(state_dict, expert_dp_comm='none'):
            state_dict_new = state_dict.__class__()
            for k, v in state_dict.items():
                # megatron uses both dict and OrderedDict in its state_dict
Jiezhong Qiu's avatar
Jiezhong Qiu committed
423
                if isinstance(v, (OrderedDict, dict)):
424
                    v_new = extract_expert_param(v, expert_dp_comm)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
425
                    if len(v_new) > 0:
426
427
428
429
430
                        state_dict_new[k] = v_new
                elif hasattr(v, 'dp_comm') and v.dp_comm == expert_dp_comm:
                    state_dict_new[k] = v.detach()
            return state_dict_new

Jiezhong Qiu's avatar
Jiezhong Qiu committed
431
432
433
        state_dict['model'] = extract_expert_param(
                    state_dict['model'],
                    expert_dp_comm='none')
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469

    # Optimizer stuff.
    if not args.no_save_optim:
        if optimizer is not None:
            state_dict['optimizer'] = optimizer.state_dict()
        if lr_scheduler is not None:
            state_dict['lr_scheduler'] = lr_scheduler.state_dict()

    # RNG states.
    if not args.no_save_rng:
        state_dict['random_rng_state'] = random.getstate()
        state_dict['np_rng_state'] = np.random.get_state()
        state_dict['torch_rng_state'] = torch.get_rng_state()
        state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
        state_dict['rng_tracker_states'] \
            = mpu.get_cuda_rng_tracker().get_states()

    # Save.
    checkpoint_name = get_checkpoint_name(args.save, iteration)
    from megatron.checkpointing import ensure_directory_exists
    from megatron.checkpointing import get_checkpoint_tracker_filename
    ensure_directory_exists(checkpoint_name)
    torch.save(state_dict, checkpoint_name)

    # Wait so everyone is done (necessary)
    torch.distributed.barrier()
    if torch.distributed.get_rank() == 0:
        print('  successfully saved checkpoint at iteration {:7d} to {}'.format(
            iteration, args.save), flush=True)
    # And update the latest iteration
    if torch.distributed.get_rank() == 0:
        tracker_filename = get_checkpoint_tracker_filename(args.save)
        with open(tracker_filename, 'w') as f:
            f.write(str(iteration))
    # Wait so everyone is done (not necessary)
    torch.distributed.barrier()