megatron.py 6.63 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
3
4
The adaptor to seamlessly enable FastMoE in Megatron-LM v2.0 with at most two
lines of modification.
See `examples/megatron` for usage instructions.
Sengxian's avatar
Sengxian committed
5
"""
Rick Ho's avatar
Rick Ho committed
6
7
8
import math
import numpy as np
import torch
Rick Ho's avatar
Rick Ho committed
9
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
10
import torch.nn.functional as F
Rick Ho's avatar
Rick Ho committed
11
12
13
14
15

from .transformer import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel


16
class _FakeMegatronMLP(nn.Module):
Sengxian's avatar
Sengxian committed
17
    r"""
18
    A fake mlp without model parallelism for correctness testing
Sengxian's avatar
Sengxian committed
19
20
    """

Rick Ho's avatar
Rick Ho committed
21
    def __init__(self, args, _):
Rick Ho's avatar
Rick Ho committed
22
23
24
        super().__init__()
        self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
        self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)
Sengxian's avatar
Sengxian committed
25

Rick Ho's avatar
Rick Ho committed
26
    def forward(self, x):
Sengxian's avatar
Sengxian committed
27
        r"""
Rick Ho's avatar
Rick Ho committed
28
        Directly use GeLU
Sengxian's avatar
Sengxian committed
29
        """
Rick Ho's avatar
Rick Ho committed
30
31
32
33
34
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x, torch.zeros_like(x)

Sengxian's avatar
Sengxian committed
35

36
def _megatron_init_method(self, rng, sigma):
Sengxian's avatar
Sengxian committed
37
    r"""
38
39
    Init method based on N(0, sigma).
    Copied from Megatron-LM
Sengxian's avatar
Sengxian committed
40
    """
41
42
43
44
45
46
47
48
49
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
    self.weight.data = torch.tensor(weight, dtype=dtype, device=device)

    if self.bias is not None:
        # Always initialize bias to zero.
        with torch.no_grad():
            self.bias.zero_()
Rick Ho's avatar
Rick Ho committed
50

Sengxian's avatar
Sengxian committed
51

Rick Ho's avatar
Rick Ho committed
52
def _random_init_weight(self, rng):
Sengxian's avatar
Sengxian committed
53
    r"""
Rick Ho's avatar
Rick Ho committed
54
    Copied from torch.nn.init.kaiming_uniform_
Sengxian's avatar
Sengxian committed
55
56
57
    """
    fan = nn.init._calculate_correct_fan(self.weight[0], "fan_in")
    gain = nn.init.calculate_gain("leaky_relu", math.sqrt(5))
Rick Ho's avatar
Rick Ho committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
    self.weight.data = torch.tensor(weight, dtype=dtype, device=device)

    if self.bias is not None:
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in)
        bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
        self.bias.data = torch.tensor(bias, dtype=dtype, device=device)


Rick Ho's avatar
Rick Ho committed
72
class MegatronMLP(FMoETransformerMLP):
Sengxian's avatar
Sengxian committed
73
    r"""
Rick Ho's avatar
Rick Ho committed
74
75
    Make the FMoETransformerMLP layer that distributes experts across
    communication group `group` to replace the original MLP layer in Megatron.
Sengxian's avatar
Sengxian committed
76
77
    """

Rick Ho's avatar
Rick Ho committed
78
    def __init__(self, args, group):
Sengxian's avatar
Sengxian committed
79
80
81
        assert (
            args.seq_length * args.micro_batch_size % args.tensor_model_parallel_size
            == 0
Rick Ho's avatar
Rick Ho committed
82
83
84
85
86
        ), "Batch size x sequence length should be multiple of mp size"
        if not args.distributed_experts:
            world_size = 1
        else:
            world_size = args.world_size
Sengxian's avatar
Sengxian committed
87
88
89
90
91
92
93
94
95
        super().__init__(
            args.num_experts,
            top_k=args.top_k,
            d_model=args.hidden_size,
            d_hidden=args.hidden_hidden_size,
            world_size=world_size,
            mp_group=group,
            expert_dp_comm="none" if args.distributed_experts else "dp",
        )
Rick Ho's avatar
Rick Ho committed
96
        self.hidden_size = args.hidden_size
Rick Ho's avatar
Rick Ho committed
97
98
99
100
        if args.distributed_experts:
            self.rank = args.rank
        else:
            self.rank = 0
101
102
        self.sigma = args.init_method_std
        self.num_layers = args.num_layers
Rick Ho's avatar
Rick Ho committed
103
104
105
        self.reset_parameters()

    def reset_parameters(self):
Sengxian's avatar
Sengxian committed
106
        r"""
Rick Ho's avatar
Rick Ho committed
107
108
        Initialize the weight as linear layers.
        As megatron is using fixed random seed for some nasty stuff, an
Rick Ho's avatar
Rick Ho committed
109
        additional numpy rng is used.
Sengxian's avatar
Sengxian committed
110
        """
Rick Ho's avatar
Rick Ho committed
111
        rng = np.random.default_rng(np.random.randint(2048) + self.rank)
112
        _megatron_init_method(self.experts.htoh4, rng, self.sigma)
113
        std = self.sigma / math.sqrt(2.0 * self.num_layers)
114
        _megatron_init_method(self.experts.h4toh, rng, std)
Rick Ho's avatar
Rick Ho committed
115
116

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
117
118
119
120
        return (
            super().forward(inp),
            torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
        )
Rick Ho's avatar
Rick Ho committed
121
122


Sengxian's avatar
Sengxian committed
123
124
125
126
127
128
129
130
def fmoefy(
    model,
    num_experts=None,
    distributed_experts=True,
    hidden_hidden_size=None,
    top_k=None,
):
    r"""
Rick Ho's avatar
Rick Ho committed
131
132
133
134
135
136
137
138
139
140
141
142
    Replace MLP layers in a transformer-based model in Megatron by MoE.
    * `model` should be a standard Megatron model that has
    `model.language_model.transformer.layers` as transformer layers, which is an
    array of transformer blocks that contain an `mlp` member.
    * `distributed_expert` is set to True if different experts are located in
    different workers. Otherwise, the experts on the workers are identical, and
    they are trained in data-parallel mode. This can be useful when testing on
    small models that do not require high training throughput or large parameter
    capacity.
    Note that pipeline parallel is not supported yet. When distributed experts
    are enabled, their communicator should be Megatron's
    tensor_model_parall_comm x data_parallel_comm, which is not created.
Sengxian's avatar
Sengxian committed
143
    """
Rick Ho's avatar
Rick Ho committed
144
    from megatron import get_args
Rick Ho's avatar
Rick Ho committed
145
    from megatron import mpu
Sengxian's avatar
Sengxian committed
146

Rick Ho's avatar
Rick Ho committed
147
148
149
150
    args = get_args()
    if num_experts is not None:
        args.num_experts = num_experts
    assert (
Sengxian's avatar
Sengxian committed
151
152
        "num_experts" in args
    ), "num_experts should be specified in arguments or fmoefy function"
Rick Ho's avatar
Rick Ho committed
153
154
155

    if hidden_hidden_size is not None:
        args.hidden_hidden_size = hidden_hidden_size
Sengxian's avatar
Sengxian committed
156
    elif not hasattr(args, "hidden_hidden_size"):
Rick Ho's avatar
Rick Ho committed
157
158
159
160
        args.hidden_hidden_size = args.hidden_size * 4

    if top_k is not None:
        args.top_k = top_k
Sengxian's avatar
Sengxian committed
161
    elif not hasattr(args, "top_k"):
Rick Ho's avatar
Rick Ho committed
162
163
164
165
166
167
168
        args.top_k = 2

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

    for l in model.language_model.transformer.layers:
Rick Ho's avatar
Rick Ho committed
169
        l.mlp = MegatronMLP(args, mpu.get_model_parallel_group())
Rick Ho's avatar
Rick Ho committed
170
171
172
173
    return model


class DistributedDataParallel(DistributedGroupedDataParallel):
Sengxian's avatar
Sengxian committed
174
    r"""
Rick Ho's avatar
Rick Ho committed
175
176
177
    A wrapper that is used to replace the DDP module provided by Megatron, which
    is adapted to enable the sophiscated parallel and reduction strategies in
    Fast MoE.
Sengxian's avatar
Sengxian committed
178
179
    """

Rick Ho's avatar
Rick Ho committed
180
181
    def __init__(self, module):
        from megatron import mpu
Sengxian's avatar
Sengxian committed
182

Rick Ho's avatar
Rick Ho committed
183
184
185
        super().__init__(
            module,
            mp_group=mpu.get_model_parallel_group(),
Sengxian's avatar
Sengxian committed
186
            dp_group=mpu.get_data_parallel_group(),
Rick Ho's avatar
Rick Ho committed
187
188
189
        )

    def state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
190
        r"""
Rick Ho's avatar
Rick Ho committed
191
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
192
        """
Rick Ho's avatar
Rick Ho committed
193
194
195
        return self.module.state_dict(*args, **kwargs)

    def state_dict_for_save_checkpoint(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
196
        r"""
Rick Ho's avatar
Rick Ho committed
197
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
198
        """
Rick Ho's avatar
Rick Ho committed
199
200
201
        return self.module.state_dict_for_save_checkpoint(*args, **kwargs)

    def load_state_dict(self, *args, **kwargs):
Sengxian's avatar
Sengxian committed
202
        r"""
Rick Ho's avatar
Rick Ho committed
203
        Keep consitency with Megatron
Sengxian's avatar
Sengxian committed
204
        """
Rick Ho's avatar
Rick Ho committed
205
        return self.module.load_state_dict(*args, **kwargs)