layers.py 7.55 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
r"""
nn modules to replace Megatron's native ones
"""
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from fmoe.transformer import FMoETransformerMLP
from .balance import reset_gate_hook
from .balance import generate_megatron_gate_hook


class _FakeMegatronMLP(nn.Module):
    r"""
    A fake mlp without model parallelism for correctness testing
    """

    def __init__(self, args, _):
        super().__init__()
        self.fc1 = nn.Linear(args.hidden_size, args.hidden_hidden_size)
        self.fc2 = nn.Linear(args.hidden_hidden_size, args.hidden_size)

    def forward(self, x):
        r"""
        Directly use GeLU
        """
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x, torch.zeros_like(x)


def _megatron_init_method(self, rng, sigma):
    r"""
    Init method based on N(0, sigma).
    Copied from Megatron-LM
    """
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.normal(loc=0.0, scale=sigma, size=tuple(self.weight.size()))
    self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)

    if self.bias is not None:
        # Always initialize bias to zero.
        with torch.no_grad():
            self.bias.zero_()


def _random_init_weight(self, rng):
    r"""
    Copied from torch.nn.init.kaiming_uniform_
    """
    fan = nn.init._calculate_correct_fan(self.weight[0], "fan_in")
    gain = nn.init.calculate_gain("leaky_relu", math.sqrt(5))
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std
    device = self.weight.device
    dtype = self.weight.dtype
    weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
    self.weight.data = torch.from_numpy(weight).to(dtype=dtype, device=device)

    if self.bias is not None:
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
        bound = 1 / math.sqrt(fan_in)
        bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
        self.bias.data = torch.from_numpy(bias).to(dtype=dtype, device=device)


class MegatronMLP(FMoETransformerMLP):
    r"""
    Make the FMoETransformerMLP layer that distributes experts across
    communication group `group` to replace the original MLP layer in Megatron.
    """

Rick Ho's avatar
Rick Ho committed
77
    def __init__(self, args, layer_idx, gate=None):
Rick Ho's avatar
Rick Ho committed
78
79
        if not args.distributed_experts:
            world_size = 1
Rick Ho's avatar
Rick Ho committed
80
            moe_group = None
Rick Ho's avatar
Rick Ho committed
81
        else:
Rick Ho's avatar
Rick Ho committed
82
83
84
85
            world_size = args.data_parallel_size
            from megatron.mpu import get_data_parallel_group
            moe_group = get_data_parallel_group()

86
        if not args.balance_strategy or args.balance_strategy == "naive":
Rick Ho's avatar
Rick Ho committed
87
88
89
90
91
            from fmoe.gates import NaiveGate
            gate = NaiveGate
        elif args.balance_strategy == "noisy":
            from fmoe.gates import NoisyGate
            gate = NoisyGate
92
93
94
95
96
97
        elif args.balance_strategy == "gshard":
            from fmoe.gates import GShardGate
            gate = GShardGate
        elif args.balance_strategy == "switch":
            from fmoe.gates import SwitchGate
            gate = SwitchGate
Rick Ho's avatar
Rick Ho committed
98
99
100
        elif args.balance_strategy == "swipe":
            from fmoe.gates import SwipeGate
            gate = SwipeGate
Rick Ho's avatar
Rick Ho committed
101
        elif gate is None:
Rick Ho's avatar
Rick Ho committed
102
            assert False, "Undefined balance strategy {}" % (args.balance_strategy)
Rick Ho's avatar
Rick Ho committed
103

Rick Ho's avatar
Rick Ho committed
104
        super().__init__(
Jiezhong Qiu's avatar
Jiezhong Qiu committed
105
            args.fmoe_num_experts,
Rick Ho's avatar
Rick Ho committed
106
107
108
109
            top_k=args.top_k,
            d_model=args.hidden_size,
            d_hidden=args.hidden_hidden_size,
            world_size=world_size,
Rick Ho's avatar
Rick Ho committed
110
            moe_group=moe_group,
Rick Ho's avatar
Rick Ho committed
111
112
            expert_dp_comm="none" if args.distributed_experts else "dp",
            gate_hook=generate_megatron_gate_hook(
Jiezhong Qiu's avatar
Jiezhong Qiu committed
113
                layer_idx, args.fmoe_num_experts * world_size
Rick Ho's avatar
Rick Ho committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
            ),
            gate=gate,
        )
        self.hidden_size = args.hidden_size
        if args.distributed_experts:
            self.rank = args.rank
        else:
            self.rank = 0
        self.sigma = args.init_method_std
        self.num_layers = args.num_layers
        self.reset_parameters()

    def reset_parameters(self):
        r"""
        Initialize the weight as linear layers.
        As megatron is using fixed random seed for some nasty stuff, an
        additional numpy rng is used.
        """
        rng = np.random.default_rng(np.random.randint(2048) + self.rank)
133
134
135
136
137
138
139
        
        if type(self.experts) is nn.ModuleList:
            for expert in self.experts:
                _megatron_init_method(expert.htoh4, rng, self.sigma)
        else:
            _megatron_init_method(self.experts.htoh4, rng, self.sigma)
        
Rick Ho's avatar
Rick Ho committed
140
        std = self.sigma / math.sqrt(2.0 * self.num_layers)
141
142
143
144
145
146
        
        if type(self.experts) is nn.ModuleList:
            for expert in self.experts:
                _megatron_init_method(expert.h4toh, rng, std)
        else:
            _megatron_init_method(self.experts.h4toh, rng, std)
Rick Ho's avatar
Rick Ho committed
147
148

    def forward(self, inp):
Rick Ho's avatar
Rick Ho committed
149
150
151
        from megatron import mpu
        x = super().forward(inp)
        x = mpu.reduce_from_tensor_model_parallel_region(x)
Rick Ho's avatar
Rick Ho committed
152
        return (
Rick Ho's avatar
Rick Ho committed
153
            x,
Rick Ho's avatar
Rick Ho committed
154
155
156
157
158
159
            torch.zeros(self.hidden_size, dtype=inp.dtype, device=inp.device),
        )


def fmoefy(
    model,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
160
    fmoe_num_experts=None,
Rick Ho's avatar
Rick Ho committed
161
162
163
    distributed_experts=True,
    hidden_hidden_size=None,
    top_k=None,
Rick Ho's avatar
Rick Ho committed
164
    gate=None,
zms1999's avatar
zms1999 committed
165
    megatron_version=None
Rick Ho's avatar
Rick Ho committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
):
    r"""
    Replace MLP layers in a transformer-based model in Megatron by MoE.
    * `model` should be a standard Megatron model that has
    `model.language_model.transformer.layers` as transformer layers, which is an
    array of transformer blocks that contain an `mlp` member.
    * `distributed_expert` is set to True if different experts are located in
    different workers. Otherwise, the experts on the workers are identical, and
    they are trained in data-parallel mode. This can be useful when testing on
    small models that do not require high training throughput or large parameter
    capacity.
    """
    from megatron import get_args

    args = get_args()
Rick Ho's avatar
Rick Ho committed
181
182
183
184
185

    # Set distributed_experts to None to use default setting in args
    if distributed_experts is not None:
        args.distributed_experts = distributed_experts

Jiezhong Qiu's avatar
Jiezhong Qiu committed
186
187
    if fmoe_num_experts is not None:
        args.fmoe_num_experts = fmoe_num_experts
Rick Ho's avatar
Rick Ho committed
188
    assert (
Jiezhong Qiu's avatar
Jiezhong Qiu committed
189
190
        "fmoe_num_experts" in args
    ), "fmoe_num_experts should be specified in arguments or fmoefy function"
Rick Ho's avatar
Rick Ho committed
191
192
193
194
195
196

    if top_k is not None:
        args.top_k = top_k
    elif not hasattr(args, "top_k"):
        args.top_k = 2

Rick Ho's avatar
Rick Ho committed
197
    args.hidden_hidden_size = hidden_hidden_size
Rick Ho's avatar
Rick Ho committed
198

zms1999's avatar
zms1999 committed
199
200
201
202
203
204
205
    if megatron_version == "v2.2":

        for idx, l in enumerate(model.language_model.transformer.layers):
            l.mlp = MegatronMLP(args, idx, gate=gate)

        # initialize gate hook
        num_layers = len(model.language_model.transformer.layers)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
206
    elif megatron_version in ["v2.5", "v3.0.2"]:
zms1999's avatar
zms1999 committed
207
208
209
        
        for idx, l in enumerate(model.language_model.encoder.layers):
            l.mlp = MegatronMLP(args, idx, gate=gate)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
210
        if hasattr(model.language_model, "decoder") and model.language_model.decoder is not None:
zms1999's avatar
zms1999 committed
211
212
213
214
215
            for idx, l in enumerate(model.language_model.decoder.layers):
                l.mlp = MegatronMLP(args, idx, gate=gate)

        # initialize gate hook
        num_layers = len(model.language_model.encoder.layers)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
216
        if hasattr(model.language_model, "decoder") and model.language_model.decoder is not None:
zms1999's avatar
zms1999 committed
217
218
            num_layers += len(model.language_model.decoder.layers)
    else:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
219
        print(model.language_model)
zms1999's avatar
zms1999 committed
220
        assert False, f"megatron_version {megatron_version} not known."
Rick Ho's avatar
Rick Ho committed
221
222
223
224

    reset_gate_hook(num_layers)

    return model