test_numerical.py 12.6 KB
Newer Older
1
import sys
2
from collections import OrderedDict
3
from typing import List, Type, Union
Sengxian's avatar
Sengxian committed
4

5
import pytest
Rick Ho's avatar
Rick Ho committed
6
import torch
7
import torch.nn as nn
8
import numpy as np
9

10
from copy import deepcopy
11
12
13
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
14
from fmoe.distributed import DistributedGroupedDataParallel as LocalDDP
Rick Ho's avatar
Rick Ho committed
15
from fmoe.megatron.layers import _megatron_init_method
16
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
Rick Ho's avatar
Rick Ho committed
17
18


19
20
21
def _perform_forward(
    moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group
):
22
23
    moe.zero_grad()
    moe_raw.zero_grad()
24
25
26
27
28
29
30
31
32
33
34
35
36
    if not mp_group:
        inp = torch.rand(batch_size, d_model).cuda()
    else:
        group_sender = rank // mp_group.size() * mp_group.size()
        inp = torch.rand(batch_size, d_model).cuda()
        torch.distributed.broadcast(inp, group_sender, group=mp_group)
        torch.distributed.broadcast(
            moe.gate.gate.weight.data, group_sender, group=mp_group
        )
        torch.distributed.broadcast(
            moe.gate.gate.bias.data, group_sender, group=mp_group
        )

37
38
    inp_raw = inp.clone()
    inp.requires_grad = True
39

40
    inp_raw.requires_grad = True
41
    gate_idx, gate_score, _ = moe.gate(inp_raw)
42
43
44
    inp_repeated = inp_raw.repeat_interleave(repeats=top_k, dim=0)
    moe_out = moe(inp)
    raw_out = moe_raw(inp_repeated, gate_idx, gate_score)
45

46
47
48
49
    raw_out.mean().backward()
    moe_out.mean().backward()

    return moe_out, raw_out, inp.grad, inp_raw.grad
50
51


52
def _assert_numercial(names, moe_out_list, raw_out_list, rank):
53
54
55
56
    for name, mo, ro in zip(names, moe_out_list, raw_out_list):
        err = (mo - ro).abs().sum()
        print("Rank {} {} abs err {}".format(rank, name, err))
        if err > 1e-3:
Sengxian's avatar
Sengxian committed
57
            sys.stderr.write(f"=========== {name} moe out ==============\n")
58
            sys.stderr.write("{}\n".format(mo))
Sengxian's avatar
Sengxian committed
59
            sys.stderr.write(f"=========== {name} raw out ==============\n")
60
61
62
63
            sys.stderr.write("{}\n".format(ro))
            assert False


64
class MyMoE(FMoE):
65
66
67
    def __init__(
        self, num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
    ):
68
69
70
71
72
73
        super().__init__(
            num_expert=num_expert,
            d_model=d_model,
            gate=NaiveGate,
            world_size=world_size,
            mp_group=mp_group,
74
            top_k=top_k,
75
76
77
        )
        self.experts = _Expert(num_expert, d_model, d_hidden, activation)

78
        rng = np.random.default_rng(1234)
Sengxian's avatar
Sengxian committed
79
80
        _megatron_init_method(self.experts.htoh4, rng, 1.0)
        _megatron_init_method(self.experts.h4toh, rng, 1.0)
81

82

83
@pytest.mark.parametrize("num_expert", [4, 8])
Sengxian's avatar
Sengxian committed
84
@pytest.mark.parametrize("top_k", [2, 3])
85
86
87
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
88
89
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
90
@pytest.mark.parametrize("mp_group", [None])
91
92
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
93
94
95
96
97
98
def test_fmoe_linear(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
99
100
    rank,
    world_size,
101
    mp_group,
102
103
    dp_group,
    world_group,
104
105
    activation=torch.nn.functional.gelu,
):
Rick Ho's avatar
Rick Ho committed
106
107
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
Sengxian's avatar
Sengxian committed
108

109
110
111
    moe = MyMoE(
        num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
    ).cuda()
Rick Ho's avatar
Rick Ho committed
112

Sengxian's avatar
Sengxian committed
113
114
115
116
    moe_raw = BruteForceMoELinear(
        activation=activation,
        num_expert=num_expert,
        d_model=d_model,
117
        d_hidden=d_hidden,
Sengxian's avatar
Sengxian committed
118
        world_size=world_size,
Sengxian's avatar
Sengxian committed
119
        top_k=top_k,
Sengxian's avatar
Sengxian committed
120
    ).cuda()
Rick Ho's avatar
Rick Ho committed
121
122

    if world_size == 1:
123
124
125
126
        moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
        moe_raw.bias_htoh4.data = moe.experts.htoh4.bias.data.clone()
        moe_raw.weight_h4toh.data = moe.experts.h4toh.weight.data.clone()
        moe_raw.bias_h4toh.data = moe.experts.h4toh.bias.data.clone()
Rick Ho's avatar
Rick Ho committed
127
    else:
Sengxian's avatar
Sengxian committed
128
        weight_htoh4_array = [
129
            torch.empty_like(moe.experts.htoh4.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
130
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
131
        bias_htoh4_array = [
132
            torch.empty_like(moe.experts.htoh4.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
133
        ]
134
135
        torch.distributed.all_gather(weight_htoh4_array, moe.experts.htoh4.weight.data)
        torch.distributed.all_gather(bias_htoh4_array, moe.experts.htoh4.bias.data)
Sengxian's avatar
Sengxian committed
136
        moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
137
        moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0)
Sengxian's avatar
Sengxian committed
138
139

        weight_h4toh_array = [
140
            torch.empty_like(moe.experts.h4toh.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
141
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
142
        bias_h4toh_array = [
143
            torch.empty_like(moe.experts.h4toh.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
144
        ]
145
146
        torch.distributed.all_gather(weight_h4toh_array, moe.experts.h4toh.weight.data)
        torch.distributed.all_gather(bias_h4toh_array, moe.experts.h4toh.bias.data)
Sengxian's avatar
Sengxian committed
147
        moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
148
        moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
Sengxian's avatar
Sengxian committed
149

150
    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
151
152
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
    )
Sengxian's avatar
Sengxian committed
153

Sengxian's avatar
Sengxian committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    moe_out_list = (
        moe_out,
        moe_grad_in,
        moe.experts.htoh4.weight.grad,
        moe.experts.h4toh.weight.grad,
        moe.experts.htoh4.bias.grad,
        moe.experts.h4toh.bias.grad,
    )
    raw_out_list = (
        raw_out,
        raw_grad_in,
        moe_raw.weight_htoh4.grad,
        moe_raw.weight_h4toh.grad,
        moe_raw.bias_htoh4.grad,
        moe_raw.bias_h4toh.grad,
    )
Sengxian's avatar
Sengxian committed
170

Rick Ho's avatar
Rick Ho committed
171
    if world_size > 1:
Sengxian's avatar
Sengxian committed
172
        _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
Jiezhong Qiu's avatar
Jiezhong Qiu committed
173
174
175
176
        torch.distributed.all_reduce(htoh4_w_grad)
        torch.distributed.all_reduce(h4toh_w_grad)
        torch.distributed.all_reduce(htoh4_b_grad)
        torch.distributed.all_reduce(h4toh_b_grad)
177
        mp_size = mp_group.size() if mp_group else 1
178
179
180
181
182
183
184
185
186
187
188
189
        htoh4_w_grad = (
            htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        h4toh_w_grad = (
            h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        htoh4_b_grad = (
            htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        h4toh_b_grad = (
            h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
Sengxian's avatar
Sengxian committed
190
        raw_out_list = _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
191

Sengxian's avatar
Sengxian committed
192
193
194
195
196
197
198
199
    names = [
        "output",
        "input grad",
        "htoh4 weight grad",
        "h4toh weight grad",
        "htoh4 bias grad",
        "h4toh bias grad",
    ]
Sengxian's avatar
Sengxian committed
200

201
    _assert_numercial(names, moe_out_list, raw_out_list, rank)
202

Sengxian's avatar
Sengxian committed
203

204
205
206
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16])
Sengxian's avatar
Sengxian committed
207
@pytest.mark.parametrize("top_k", [2, 3])
208
@pytest.mark.parametrize("expert", [NaiveExpert, LinearExpert])
209
210
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
211
@pytest.mark.parametrize("mp_group", [None])
212
213
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
214
def test_fmoe(
215
216
217
218
219
220
221
    batch_size,
    num_expert,
    d_model,
    top_k,
    expert: Union[Type[nn.Module], str],
    rank,
    world_size,
222
223
224
    mp_group,
    dp_group,
    world_group,
225
226
227
228
229
230
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]
Sengxian's avatar
Sengxian committed
231

232
233
234
235
236
    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
237
        mp_group=mp_group,
238
239
240
241
242
        expert=expert,
        top_k=top_k,
    ).cuda()

    moe_raw = BruteForceMoE(
Sengxian's avatar
Sengxian committed
243
244
245
246
247
        expert=expert,
        num_expert=num_expert,
        d_model=d_model,
        world_size=world_size,
        top_k=top_k,
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    ).cuda()

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
Sengxian's avatar
Sengxian committed
264
265
266
267
268
269
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]
270

271
    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
272
273
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
    )
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
295
296
        mp_size = mp_group.size() if mp_group else 1
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
297

298
299
300
    moe_out_list = [moe_out, moe_grad, moe_grad_in]
    raw_out_list = [raw_out, raw_grad, raw_grad_in]
    names = ["forward", "backward", "grad_in"]
301

302
    _assert_numercial(names, moe_out_list, raw_out_list, rank)
Sengxian's avatar
Sengxian committed
303
304


305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
class MyModule(nn.Module):
    def __init__(self, dim=8):
        super(MyModule, self).__init__()
        self.model = nn.Sequential(
            OrderedDict(
                [
                    ("linear1", nn.Linear(dim, dim)),
                    ("relu1", nn.ReLU()),
                    ("linear2", nn.Linear(dim, dim)),
                    ("relu2", nn.ReLU()),
                    ("linear3", nn.Linear(dim, dim)),
                ]
            )
        )

    def set_comm(self):
        for p in self.model._modules["linear1"].parameters():
            setattr(p, "dp_comm", "mp")
        for p in self.model._modules["linear2"].parameters():
            setattr(p, "dp_comm", "dp")
        for p in self.model._modules["linear3"].parameters():
            setattr(p, "dp_comm", "world")

    def forward(self, inp):
        return self.model(inp)


def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
    batch_size, dim = 4, 8

    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    model = MyModule().cuda()
    model_ddp = LocalDDP(deepcopy(model), mp_group, dp_group, world_group)
    model.set_comm()
    model_ddp.module.set_comm()

    inp = torch.randn(batch_size, dim).cuda()

    raw_out = model(inp).mean()
    ddp_out = model_ddp(inp).mean()

    raw_out.backward()
    ddp_out.backward()

    torch.distributed.all_reduce(
        model.model._modules["linear1"].weight.grad.data, group=mp_group
    )
    model.model._modules["linear1"].weight.grad /= mp_group.size()
    torch.distributed.all_reduce(
        model.model._modules["linear2"].weight.grad.data, group=dp_group
    )
    model.model._modules["linear2"].weight.grad /= dp_group.size()
    torch.distributed.all_reduce(
        model.model._modules["linear3"].weight.grad.data, group=world_group
    )
    model.model._modules["linear3"].weight.grad /= world_group.size()
    model_ddp.allreduce_params(reduce_after=False, fp32_allreduce=False)

    raw_out_list = [
        model.model._modules["linear1"].weight.grad,
        model.model._modules["linear2"].weight.grad,
        model.model._modules["linear3"].weight.grad,
    ]
    ddp_out_list = [
        model_ddp.module.model._modules["linear1"].weight.grad,
        model_ddp.module.model._modules["linear2"].weight.grad,
        model_ddp.module.model._modules["linear3"].weight.grad,
    ]

    names = ["mp grad", "dp grad", "wp grad"]

    _assert_numercial(names, ddp_out_list, raw_out_list, rank)


381
382
if __name__ == "__main__":
    test_fmoe_linear(
383
384
385
        batch_size=2,
        num_expert=2,
        d_model=2,
386
387
388
389
        top_k=2,
        d_hidden=16,
        rank=0,
        world_size=1,
390
        mp_group=None,
391
392
        dp_group=None,
        world_group=None,
393
    )