test_numerical.py 12.4 KB
Newer Older
1
import sys
2
from collections import OrderedDict
3
from typing import List, Type, Union
Sengxian's avatar
Sengxian committed
4

5
import pytest
Rick Ho's avatar
Rick Ho committed
6
import torch
7
import torch.nn as nn
8
import numpy as np
9

10
from copy import deepcopy
11
12
13
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
14
from fmoe.distributed import DistributedGroupedDataParallel as LocalDDP
15
from fmoe.megatron import _megatron_init_method
16
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
Rick Ho's avatar
Rick Ho committed
17
18


19
20
21
def _perform_forward(
    moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group
):
22
23
    moe.zero_grad()
    moe_raw.zero_grad()
24
25
26
27
28
29
30
31
32
33
34
35
36
    if not mp_group:
        inp = torch.rand(batch_size, d_model).cuda()
    else:
        group_sender = rank // mp_group.size() * mp_group.size()
        inp = torch.rand(batch_size, d_model).cuda()
        torch.distributed.broadcast(inp, group_sender, group=mp_group)
        torch.distributed.broadcast(
            moe.gate.gate.weight.data, group_sender, group=mp_group
        )
        torch.distributed.broadcast(
            moe.gate.gate.bias.data, group_sender, group=mp_group
        )

37
38
    inp_raw = inp.clone()
    inp.requires_grad = True
39

40
41
42
43
44
    inp_raw.requires_grad = True
    gate_idx, gate_score = moe.gate(inp_raw)
    inp_repeated = inp_raw.repeat_interleave(repeats=top_k, dim=0)
    moe_out = moe(inp)
    raw_out = moe_raw(inp_repeated, gate_idx, gate_score)
45

46
47
48
49
    raw_out.mean().backward()
    moe_out.mean().backward()

    return moe_out, raw_out, inp.grad, inp_raw.grad
50
51


52
def _assert_numercial(names, moe_out_list, raw_out_list, rank):
53
54
55
56
    for name, mo, ro in zip(names, moe_out_list, raw_out_list):
        err = (mo - ro).abs().sum()
        print("Rank {} {} abs err {}".format(rank, name, err))
        if err > 1e-3:
Sengxian's avatar
Sengxian committed
57
            sys.stderr.write(f"=========== {name} moe out ==============\n")
58
            sys.stderr.write("{}\n".format(mo))
Sengxian's avatar
Sengxian committed
59
            sys.stderr.write(f"=========== {name} raw out ==============\n")
60
61
62
63
            sys.stderr.write("{}\n".format(ro))
            assert False


64
class MyMoE(FMoE):
65
66
67
    def __init__(
        self, num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
    ):
68
69
70
        super().__init__(
            num_expert=num_expert,
            d_model=d_model,
71

72
73
74
            gate=NaiveGate,
            world_size=world_size,
            mp_group=mp_group,
75
            top_k=top_k,
76
77
78
        )
        self.experts = _Expert(num_expert, d_model, d_hidden, activation)

79
80
81
82
        rng = np.random.default_rng(1234)
        _megatron_init_method(self.experts.htoh4, rng, 1.)
        _megatron_init_method(self.experts.h4toh, rng, 1.)

83

84
@pytest.mark.parametrize("num_expert", [4, 8])
Sengxian's avatar
Sengxian committed
85
@pytest.mark.parametrize("top_k", [2, 3])
86
87
88
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
89
90
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
91
@pytest.mark.parametrize("mp_group", [None])
92
93
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
94
95
96
97
98
99
def test_fmoe_linear(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
100
101
    rank,
    world_size,
102
    mp_group,
103
104
    dp_group,
    world_group,
105
106
    activation=torch.nn.functional.gelu,
):
Rick Ho's avatar
Rick Ho committed
107
108
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
Sengxian's avatar
Sengxian committed
109

110
111
112
    moe = MyMoE(
        num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
    ).cuda()
Rick Ho's avatar
Rick Ho committed
113

Sengxian's avatar
Sengxian committed
114
115
116
117
    moe_raw = BruteForceMoELinear(
        activation=activation,
        num_expert=num_expert,
        d_model=d_model,
118
        d_hidden=d_hidden,
Sengxian's avatar
Sengxian committed
119
        world_size=world_size,
Sengxian's avatar
Sengxian committed
120
        top_k=top_k,
Sengxian's avatar
Sengxian committed
121
    ).cuda()
Rick Ho's avatar
Rick Ho committed
122
123

    if world_size == 1:
124
125
126
127
        moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
        moe_raw.bias_htoh4.data = moe.experts.htoh4.bias.data.clone()
        moe_raw.weight_h4toh.data = moe.experts.h4toh.weight.data.clone()
        moe_raw.bias_h4toh.data = moe.experts.h4toh.bias.data.clone()
Rick Ho's avatar
Rick Ho committed
128
    else:
Sengxian's avatar
Sengxian committed
129
        weight_htoh4_array = [
130
            torch.empty_like(moe.experts.htoh4.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
131
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
132
        bias_htoh4_array = [
133
            torch.empty_like(moe.experts.htoh4.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
134
        ]
135
136
        torch.distributed.all_gather(weight_htoh4_array, moe.experts.htoh4.weight.data)
        torch.distributed.all_gather(bias_htoh4_array, moe.experts.htoh4.bias.data)
Sengxian's avatar
Sengxian committed
137
        moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
138
        moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0)
Sengxian's avatar
Sengxian committed
139
140

        weight_h4toh_array = [
141
            torch.empty_like(moe.experts.h4toh.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
142
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
143
        bias_h4toh_array = [
144
            torch.empty_like(moe.experts.h4toh.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
145
        ]
146
147
        torch.distributed.all_gather(weight_h4toh_array, moe.experts.h4toh.weight.data)
        torch.distributed.all_gather(bias_h4toh_array, moe.experts.h4toh.bias.data)
Sengxian's avatar
Sengxian committed
148
        moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
149
        moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
Sengxian's avatar
Sengxian committed
150

151
    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
152
153
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
    )
Sengxian's avatar
Sengxian committed
154

155
156
    moe_out_list = moe_out, moe_grad_in, moe.experts.htoh4.weight.grad, moe.experts.h4toh.weight.grad, moe.experts.htoh4.bias.grad, moe.experts.h4toh.bias.grad
    raw_out_list = raw_out, raw_grad_in, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad, moe_raw.bias_htoh4.grad, moe_raw.bias_h4toh.grad
Sengxian's avatar
Sengxian committed
157

Rick Ho's avatar
Rick Ho committed
158
    if world_size > 1:
Sengxian's avatar
Sengxian committed
159
        _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
Jiezhong Qiu's avatar
Jiezhong Qiu committed
160
161
162
163
        torch.distributed.all_reduce(htoh4_w_grad)
        torch.distributed.all_reduce(h4toh_w_grad)
        torch.distributed.all_reduce(htoh4_b_grad)
        torch.distributed.all_reduce(h4toh_b_grad)
164
        mp_size = mp_group.size() if mp_group else 1
165
166
167
168
169
170
171
172
173
174
175
176
        htoh4_w_grad = (
            htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        h4toh_w_grad = (
            h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        htoh4_b_grad = (
            htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        h4toh_b_grad = (
            h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
Sengxian's avatar
Sengxian committed
177
        raw_out_list = _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
178

179
    names = ["output", "input grad", "htoh4 weight grad", "h4toh weight grad", "htoh4 bias grad", "h4toh bias grad"]
Sengxian's avatar
Sengxian committed
180

181
    _assert_numercial(names, moe_out_list, raw_out_list, rank)
182

Sengxian's avatar
Sengxian committed
183

184
185
186
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16])
Sengxian's avatar
Sengxian committed
187
@pytest.mark.parametrize("top_k", [2, 3])
188
@pytest.mark.parametrize("expert", [NaiveExpert, LinearExpert])
189
190
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
191
@pytest.mark.parametrize("mp_group", [None])
192
193
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
194
def test_fmoe(
195
196
197
198
199
200
201
    batch_size,
    num_expert,
    d_model,
    top_k,
    expert: Union[Type[nn.Module], str],
    rank,
    world_size,
202
203
204
    mp_group,
    dp_group,
    world_group,
205
206
207
208
209
210
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]
Sengxian's avatar
Sengxian committed
211

212
213
214
215
216
    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
217
        mp_group=mp_group,
218
219
220
221
222
        expert=expert,
        top_k=top_k,
    ).cuda()

    moe_raw = BruteForceMoE(
Sengxian's avatar
Sengxian committed
223
224
225
226
227
        expert=expert,
        num_expert=num_expert,
        d_model=d_model,
        world_size=world_size,
        top_k=top_k,
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
    ).cuda()

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
Sengxian's avatar
Sengxian committed
244
245
246
247
248
249
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]
250

251
    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
252
253
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
    )
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
275
276
        mp_size = mp_group.size() if mp_group else 1
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
277

278
279
280
    moe_out_list = [moe_out, moe_grad, moe_grad_in]
    raw_out_list = [raw_out, raw_grad, raw_grad_in]
    names = ["forward", "backward", "grad_in"]
281

282
    _assert_numercial(names, moe_out_list, raw_out_list, rank)
Sengxian's avatar
Sengxian committed
283
284


285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
class MyModule(nn.Module):
    def __init__(self, dim=8):
        super(MyModule, self).__init__()
        self.model = nn.Sequential(
            OrderedDict(
                [
                    ("linear1", nn.Linear(dim, dim)),
                    ("relu1", nn.ReLU()),
                    ("linear2", nn.Linear(dim, dim)),
                    ("relu2", nn.ReLU()),
                    ("linear3", nn.Linear(dim, dim)),
                ]
            )
        )

    def set_comm(self):
        for p in self.model._modules["linear1"].parameters():
            setattr(p, "dp_comm", "mp")
        for p in self.model._modules["linear2"].parameters():
            setattr(p, "dp_comm", "dp")
        for p in self.model._modules["linear3"].parameters():
            setattr(p, "dp_comm", "world")

    def forward(self, inp):
        return self.model(inp)


def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
    batch_size, dim = 4, 8

    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    model = MyModule().cuda()
    model_ddp = LocalDDP(deepcopy(model), mp_group, dp_group, world_group)
    model.set_comm()
    model_ddp.module.set_comm()

    inp = torch.randn(batch_size, dim).cuda()

    raw_out = model(inp).mean()
    ddp_out = model_ddp(inp).mean()

    raw_out.backward()
    ddp_out.backward()

    torch.distributed.all_reduce(
        model.model._modules["linear1"].weight.grad.data, group=mp_group
    )
    model.model._modules["linear1"].weight.grad /= mp_group.size()
    torch.distributed.all_reduce(
        model.model._modules["linear2"].weight.grad.data, group=dp_group
    )
    model.model._modules["linear2"].weight.grad /= dp_group.size()
    torch.distributed.all_reduce(
        model.model._modules["linear3"].weight.grad.data, group=world_group
    )
    model.model._modules["linear3"].weight.grad /= world_group.size()
    model_ddp.allreduce_params(reduce_after=False, fp32_allreduce=False)

    raw_out_list = [
        model.model._modules["linear1"].weight.grad,
        model.model._modules["linear2"].weight.grad,
        model.model._modules["linear3"].weight.grad,
    ]
    ddp_out_list = [
        model_ddp.module.model._modules["linear1"].weight.grad,
        model_ddp.module.model._modules["linear2"].weight.grad,
        model_ddp.module.model._modules["linear3"].weight.grad,
    ]

    names = ["mp grad", "dp grad", "wp grad"]

    _assert_numercial(names, ddp_out_list, raw_out_list, rank)


361
362
if __name__ == "__main__":
    test_fmoe_linear(
363
364
365
        batch_size=2,
        num_expert=2,
        d_model=2,
366
367
368
369
        top_k=2,
        d_hidden=16,
        rank=0,
        world_size=1,
370
        mp_group=None,
371
372
        dp_group=None,
        world_group=None,
373
    )