megatron.py 11.4 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
3
4
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
Sengxian's avatar
Sengxian committed
5
"""
Rick Ho's avatar
Rick Ho committed
6
7
8
import math
import numpy as np
import torch
Rick Ho's avatar
Rick Ho committed
9
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
10
import torch.nn.functional as F
Rick Ho's avatar
Rick Ho committed
11
12
13

from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel
14
15
from .balance import update_balance_profile, reset_balance_profile
from .utils import get_torch_default_comm
Rick Ho's avatar
Rick Ho committed
16
17


18
class _FakeMegatronMLP(nn.Module):
Sengxian's avatar
Sengxian committed
19
    r"""
20
    A fake mlp without model parallelism for correctness testing
Sengxian's avatar
Sengxian committed
21
22
    """

Rick Ho's avatar
Rick Ho committed
23
    def __init__(self, args, _):
Rick Ho's avatar
Rick Ho committed
24
25
26
        super().__init__()
        self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
        self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
Sengxian's avatar
Sengxian committed
27

Rick Ho's avatar
Rick Ho committed
28
    def forward(self, x):
Sengxian's avatar
Sengxian committed
29
        r"""
Rick Ho's avatar
Rick Ho committed
30
        Directly use GeLU
Sengxian's avatar
Sengxian committed
31
        """
Rick Ho's avatar
Rick Ho committed
32
33
34
35
36
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x, torch.zeros_like(x)

Sengxian's avatar
Sengxian committed
37

38
def _megatron_init_method(self, rng, sigma):
Sengxian's avatar
Sengxian committed
39
    r"""
40
41
    Init method based on N(0, sigma).
    Copied from Megatron-LM
Sengxian's avatar
Sengxian committed
42
    """
43
44
45
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
Rick Ho's avatar
Rick Ho committed
46
    self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)
47
48
49
50
51

    if self.bias is not None:
        # Always initialize bias to zero.
        with torch.no_grad():
            self.bias.zero_()
Rick Ho's avatar
Rick Ho committed
52

Sengxian's avatar
Sengxian committed
53

Rick Ho's avatar
Rick Ho committed
54
def _random_init_weight(self, rng):
Sengxian's avatar
Sengxian committed
55
    r"""
Rick Ho's avatar
Rick Ho committed
56
    Copied from torch.nn.init.kaiming_uniform_
Sengxian's avatar
Sengxian committed
57
58
59
    """
    fan = nn.init._calculate_correct_fan(self.weight[0], "fan_in")
    gain = nn.init.calculate_gain("leaky_relu", math.sqrt(5))
Rick Ho's avatar
Rick Ho committed
60
61
62
63
64
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
Rick Ho's avatar
Rick Ho committed
65
    self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)
Rick Ho's avatar
Rick Ho committed
66
67
68
69
70

    if self.bias is not None:
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in)
        bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
Rick Ho's avatar
Rick Ho committed
71
        self.bias.data = torch.from_numpy(bias).to(dtype=dtype, device=device)
Rick Ho's avatar
Rick Ho committed
72
73


74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
balance_dict = {}
num_layers = 0


def reset_gate_hook():
    from megatron import get_args

    global balance_dict, num_layers
    reset_balance_profile(balance_dict, num_layers, get_args().balance_strategy)


def get_balance_profile():
    global balance_dict
    return balance_dict


def generate_megatron_gate_hook(layer_idx, num_expert_global):
    from megatron import get_args

    balance_strategy = get_args().balance_strategy

    def megatron_gate_hook(gate_top_k_idx, gate_score_top_k, gate_state_dict):
        global balance_dict
        update_balance_profile(
            balance_dict,
            gate_top_k_idx,
            gate_score_top_k,
            gate_state_dict,
            layer_idx,
            num_expert_global,
            balance_strategy,
        )

    return megatron_gate_hook


def add_fmoe_args(parser):
    group = parser.add_argument_group(title="fastmoe")

    group.add_argument("--fmoefy", action="store_true")
    group.add_argument("--num-experts", type=int, default=None)
    group.add_argument("--top-k", type=int, default=2)
    group.add_argument("--balance-loss-weight", type=float, default=1)
    group.add_argument("--balance-strategy", type=str, default=None)

    return parser


def add_balance_log(writer, iteration):
    from megatron import is_last_rank

    balance_dict_tensor = torch.vstack(
        [torch.tensor(item, device=item[0].device) for item in balance_dict.values()]
    ).detach()
    world_group = get_torch_default_comm()
    world_size = torch.distributed.get_world_size(group=world_group)
    torch.distributed.all_reduce(balance_dict_tensor, group=world_group)
    balance_dict_tensor /= world_size

    if writer and is_last_rank():
        for idx, metric_name in enumerate(balance_dict):
            for layer_id, val in enumerate(balance_dict_tensor[idx]):
                writer.add_scalar(
                    f"balance-{metric_name}/layer-{layer_id}", val.item(), iteration
                )
            writer.add_scalar(
                f"balance-{metric_name}/all",
                balance_dict_tensor[idx].mean().item(),
                iteration,
            )

    reset_gate_hook()


def patch_forward_step(forward_step_func):
    r"""
    Patch model's forward_step_func to support balance loss
    """

    from megatron.mpu import is_pipeline_last_stage
    from megatron import get_args

    if not get_args().balance_strategy:
        return forward_step_func

    def forward_step_with_balance_loss(data_iterator, model, input_tensor):
        args = get_args()
        output = forward_step_func(data_iterator, model, input_tensor)

        if is_pipeline_last_stage():
            loss_name = args.balance_strategy + "_loss"

            (loss, state_dict), bal_loss = (
                output,
                (
                    torch.tensor(
                        balance_dict[loss_name],
                        device=balance_dict[loss_name][0].device,
                    ).mean()
                    * args.balance_loss_weight
                ).float(),
            )

            # avarage across world group
            world_group = get_torch_default_comm()
            world_size = torch.distributed.get_world_size(group=world_group)
            averaged_bal_loss = bal_loss.clone().detach()
            torch.distributed.all_reduce(averaged_bal_loss, group=world_group)
            averaged_bal_loss /= world_size

            loss += bal_loss
            state_dict[loss_name] = averaged_bal_loss

            return loss, state_dict
        else:
            return output

    return forward_step_with_balance_loss


def patch_model_provider(model_provider):
    from megatron import get_args

    def fmoefied_model_provider():
        args = get_args()
        return fmoefy(
            model_provider(),
            num_experts=args.num_experts,
            hidden_hidden_size=4 * args.hidden_size // args.top_k,
            top_k=args.top_k,
        )

    return fmoefied_model_provider


Rick Ho's avatar
Rick Ho committed
209
class MegatronMLP(FMoETransformerMLP):
Sengxian's avatar
Sengxian committed
210
    r"""
Rick Ho's avatar
Rick Ho committed
211
212
    Make the FMoETransformerMLP layer that distributes experts across
    communication group `group` to replace the original MLP layer in Megatron.
Sengxian's avatar
Sengxian committed
213
214
    """

215
    def __init__(self, args, group, layer_idx):
Sengxian's avatar
Sengxian committed
216
        assert (
217
            args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
Sengxian's avatar
Sengxian committed
218
            == 0
Rick Ho's avatar
Rick Ho committed
219
220
221
222
223
        ), "Batch size x sequence length should be multiple of mp size"
        if not args.distributed_experts:
            world_size = 1
        else:
            world_size = args.world_size
224
225
226
227
228
229
230
231
232
233
234
        gate = None
        if not args.balance_strategy or args.balance_strategy == "gshard":
            from .gates import NaiveGate

            gate = NaiveGate
        elif args.balance_strategy == "noisy":
            from .gates import NoisyGate

            gate = NoisyGate
        else:
            assert False, "Undefined balance strategy {}" % (args.balance_strategy)
Sengxian's avatar
Sengxian committed
235
236
237
238
239
240
241
242
        super().__init__(
            args.num_experts,
            top_k=args.top_k,
            d_model=args.hidden_size,
            d_hidden=args.hidden_hidden_size,
            world_size=world_size,
            mp_group=group,
            expert_dp_comm="none" if args.distributed_experts else "dp",
243
244
245
246
            gate_hook=generate_megatron_gate_hook(
                layer_idx, args.num_experts * world_size
            ),
            gate=gate,
Sengxian's avatar
Sengxian committed
247
        )
Rick Ho's avatar
Rick Ho committed
248
        self.hidden_size = args.hidden_size
Rick Ho's avatar
Rick Ho committed
249
250
251
252
        if args.distributed_experts:
            self.rank = args.rank
        else:
            self.rank = 0
253
254
        self.sigma = args.init_method_std
        self.num_layers = args.num_layers
Rick Ho's avatar
Rick Ho committed
255
256
257
        self.reset_parameters()

    def reset_parameters(self):
Sengxian's avatar
Sengxian committed
258
        r"""
Rick Ho's avatar
Rick Ho committed
259
260
        Initialize the weight as linear layers.
        As megatron is using fixed random seed for some nasty stuff, an
Rick Ho's avatar
Rick Ho committed
261
        additional numpy rng is used.
Sengxian's avatar
Sengxian committed
262
        """
Rick Ho's avatar
Rick Ho committed
263
        rng = np.random.default_rng(np.random.randint(2048) + self.rank)
264
        _megatron_init_method(self.experts.htoh4, rng, self.sigma)
265
        std = self.sigma / math.sqrt(2.0 * self.num_layers)
266
        _megatron_init_method(self.experts.h4toh, rng, std)
Rick Ho's avatar
Rick Ho committed
267
268

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
269
270
271
272
        return (
            super().forward(inp),
            torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
        )
Rick Ho's avatar
Rick Ho committed
273
274


Sengxian's avatar
Sengxian committed
275
276
277
278
279
280
281
282
def fmoefy(
    model,
    num_experts=None,
    distributed_experts=True,
    hidden_hidden_size=None,
    top_k=None,
):
    r"""
Rick Ho's avatar
Rick Ho committed
283
284
285
286
287
288
289
290
291
292
293
294
    Replace MLP layers in a transformer-based model in Megatron by MoE.
    * `model` should be a standard Megatron model that has
    `model.language_model.transformer.layers` as transformer layers, which is an
    array of transformer blocks that contain an `mlp` member.
    * `distributed_expert` is set to True if different experts are located in
    different workers. Otherwise, the experts on the workers are identical, and
    they are trained in data-parallel mode. This can be useful when testing on
    small models that do not require high training throughput or large parameter
    capacity.
    Note that pipeline parallel is not supported yet. When distributed experts
    are enabled, their communicator should be Megatron's
    tensor_model_parall_comm x data_parallel_comm, which is not created.
Sengxian's avatar
Sengxian committed
295
    """
Rick Ho's avatar
Rick Ho committed
296
    from megatron import get_args
Rick Ho's avatar
Rick Ho committed
297
    from megatron import mpu
Sengxian's avatar
Sengxian committed
298

Rick Ho's avatar
Rick Ho committed
299
300
301
302
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
Sengxian's avatar
Sengxian committed
303
304
        "num_experts" in args
    ), "num_experts should be specified in arguments or fmoefy function"
Rick Ho's avatar
Rick Ho committed
305
306
307

    if hidden_hidden_size is not None:
        args.hidden_hidden_size = hidden_hidden_size
Sengxian's avatar
Sengxian committed
308
    elif not hasattr(args, "hidden_hidden_size"):
Rick Ho's avatar
Rick Ho committed
309
310
311
312
        args.hidden_hidden_size = args.hidden_size * 4

    if top_k is not None:
        args.top_k = top_k
Sengxian's avatar
Sengxian committed
313
    elif not hasattr(args, "top_k"):
Rick Ho's avatar
Rick Ho committed
314
315
316
317
318
319
        args.top_k = 2

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

320
321
322
323
324
325
326
327
    for idx, l in enumerate(model.language_model.transformer.layers):
        l.mlp = MegatronMLP(args, mpu.get_model_parallel_group(), idx)

    # initialize gate hook
    global num_layers, balance_dict
    num_layers = len(model.language_model.transformer.layers)
    reset_gate_hook()

Rick Ho's avatar
Rick Ho committed
328
329
330
331
    return model


class DistributedDataParallel(DistributedGroupedDataParallel):
Sengxian's avatar
Sengxian committed
332
    r"""
Rick Ho's avatar
Rick Ho committed
333
334
335
    A wrapper that is used to replace the DDP module provided by Megatron, which
    is adapted to enable the sophiscated parallel and reduction strategies in
    Fast MoE.
Sengxian's avatar
Sengxian committed
336
337
    """

Rick Ho's avatar
Rick Ho committed
338
339
    def __init__(self, module):
        from megatron import mpu
Sengxian's avatar
Sengxian committed
340

Rick Ho's avatar
Rick Ho committed
341
342
343
        super().__init__(
            module,
            mp_group=mpu.get_model_parallel_group(),
Sengxian's avatar
Sengxian committed
344
            dp_group=mpu.get_data_parallel_group(),
Rick Ho's avatar
Rick Ho committed
345
346
347
        )

    def state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
348
        r"""
Rick Ho's avatar
Rick Ho committed
349
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
350
        """
Rick Ho's avatar
Rick Ho committed
351
352
353
        return self.module.state_dict(*args, **kwargs)

    def state_dict_for_save_checkpoint(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
354
        r"""
Rick Ho's avatar
Rick Ho committed
355
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
356
        """
Rick Ho's avatar
Rick Ho committed
357
358
359
        return self.module.state_dict_for_save_checkpoint(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
360
        r"""
Rick Ho's avatar
Rick Ho committed
361
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
362
        """
Rick Ho's avatar
Rick Ho committed
363
        return self.module.load_state_dict(*args, **kwargs)