test_numerical.py 17.5 KB
Newer Older
1
import sys
2
from collections import OrderedDict
3
from typing import List, Type, Union
Sengxian's avatar
Sengxian committed
4

5
import pytest
Rick Ho's avatar
Rick Ho committed
6
import torch
7
import torch.nn as nn
8
import numpy as np
9

10
from copy import deepcopy
11
12
13
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
14
from fmoe.distributed import DistributedGroupedDataParallel as LocalDDP
Rick Ho's avatar
Rick Ho committed
15
from fmoe.megatron.layers import _megatron_init_method
16
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
Rick Ho's avatar
Rick Ho committed
17
18


19
def _perform_forward(
20
    moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group, data_type='torch.FloatTensor'
21
):
22
23
    moe.zero_grad()
    moe_raw.zero_grad()
24
25
26

    inp = torch.rand(batch_size, d_model).type(data_type).cuda()
        
Rick Ho's avatar
Rick Ho committed
27
    if mp_group is not None:
28
29
30
31
32
33
34
35
36
        group_sender = rank // mp_group.size() * mp_group.size()
        torch.distributed.broadcast(inp, group_sender, group=mp_group)
        torch.distributed.broadcast(
            moe.gate.gate.weight.data, group_sender, group=mp_group
        )
        torch.distributed.broadcast(
            moe.gate.gate.bias.data, group_sender, group=mp_group
        )

37
38
    inp_raw = inp.clone()
    inp.requires_grad = True
39

40
    inp_raw.requires_grad = True
Rick Ho's avatar
Rick Ho committed
41
    gate_idx, gate_score = moe.gate(inp_raw)
42
    moe_out = moe(inp)
Rick Ho's avatar
Rick Ho committed
43
    raw_out = moe_raw(inp_raw, gate_idx, gate_score)
44

45
46
47
48
    raw_out.mean().backward()
    moe_out.mean().backward()

    return moe_out, raw_out, inp.grad, inp_raw.grad
49
50


51
def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3):
52
    for name, mo, ro in zip(names, moe_out_list, raw_out_list):
Rick Ho's avatar
Rick Ho committed
53
54
55
56
        if mo is None and ro is None:
            continue
        if mo is None or ro is None:
            assert False
57
        err = (mo - ro).abs().max()
Rick Ho's avatar
Rick Ho committed
58
        if err.dtype == torch.bfloat16 or err.dtype == torch.float16:
Rick Ho's avatar
Rick Ho committed
59
            precision *= 100
60
        print("Rank {} {} abs err {}".format(rank, name, err))
61
        if err > precision:
Sengxian's avatar
Sengxian committed
62
            sys.stderr.write(f"=========== {name} moe out ==============\n")
63
            sys.stderr.write("{}\n".format(mo))
zhanggzh's avatar
zhanggzh committed
64
            sys.stderr.write(f'------------->>>>>>>>>>>> mo dtype: {mo.dtype}\n')
Sengxian's avatar
Sengxian committed
65
            sys.stderr.write(f"=========== {name} raw out ==============\n")
66
            sys.stderr.write("{}\n".format(ro))
zhanggzh's avatar
zhanggzh committed
67
            sys.stderr.write(f'------------->>>>>>>>>>>> ro dtype: {ro.dtype}\n')
68
69
            sys.stderr.write(f"=========== {name} diff ==============\n")
            sys.stderr.write("{}\n{}\n".format((mo - ro).abs(), err))
70
71
72
            assert False


73
class MyMoE(FMoE):
74
75
76
    def __init__(
        self, num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
    ):
77
78
79
80
81
        super().__init__(
            num_expert=num_expert,
            d_model=d_model,
            gate=NaiveGate,
            world_size=world_size,
82
            slice_group=mp_group,
83
            top_k=top_k,
84
85
86
        )
        self.experts = _Expert(num_expert, d_model, d_hidden, activation)

87
        rng = np.random.default_rng(1234)
Sengxian's avatar
Sengxian committed
88
89
        _megatron_init_method(self.experts.htoh4, rng, 1.0)
        _megatron_init_method(self.experts.h4toh, rng, 1.0)
90

91

92
@pytest.mark.parametrize("num_expert", [4, 8])
Sengxian's avatar
Sengxian committed
93
@pytest.mark.parametrize("top_k", [2, 3])
94
95
96
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
97
98
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
99
@pytest.mark.parametrize("mp_group", [None])
100
101
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
Rick Ho's avatar
Rick Ho committed
102
@pytest.mark.parametrize("data_type", ['torch.float32', 'torch.bfloat16', 'torch.float16'])
103
104
105
106
107
108
def test_fmoe_linear(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
109
110
    rank,
    world_size,
111
    mp_group,
112
113
    dp_group,
    world_group,
114
    data_type,
115
116
    activation=torch.nn.functional.gelu,
):
Rick Ho's avatar
Rick Ho committed
117
118
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
Sengxian's avatar
Sengxian committed
119

Rick Ho's avatar
Rick Ho committed
120
121
122
    if isinstance(data_type, str):
        data_type = eval(data_type)

123
124
    moe = MyMoE(
        num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
125
    ).type(data_type).cuda()
Rick Ho's avatar
Rick Ho committed
126

Sengxian's avatar
Sengxian committed
127
128
129
130
    moe_raw = BruteForceMoELinear(
        activation=activation,
        num_expert=num_expert,
        d_model=d_model,
131
        d_hidden=d_hidden,
Sengxian's avatar
Sengxian committed
132
        world_size=world_size,
Sengxian's avatar
Sengxian committed
133
        top_k=top_k,
134
    ).type(data_type).cuda()
Rick Ho's avatar
Rick Ho committed
135
136

    if world_size == 1:
137
138
139
140
        moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
        moe_raw.bias_htoh4.data = moe.experts.htoh4.bias.data.clone()
        moe_raw.weight_h4toh.data = moe.experts.h4toh.weight.data.clone()
        moe_raw.bias_h4toh.data = moe.experts.h4toh.bias.data.clone()
Rick Ho's avatar
Rick Ho committed
141
    else:
Sengxian's avatar
Sengxian committed
142
        weight_htoh4_array = [
143
            torch.empty_like(moe.experts.htoh4.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
144
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
145
        bias_htoh4_array = [
146
            torch.empty_like(moe.experts.htoh4.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
147
        ]
148
149
        torch.distributed.all_gather(weight_htoh4_array, moe.experts.htoh4.weight.data)
        torch.distributed.all_gather(bias_htoh4_array, moe.experts.htoh4.bias.data)
Sengxian's avatar
Sengxian committed
150
        moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
151
        moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0)
Sengxian's avatar
Sengxian committed
152
153

        weight_h4toh_array = [
154
            torch.empty_like(moe.experts.h4toh.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
155
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
156
        bias_h4toh_array = [
157
            torch.empty_like(moe.experts.h4toh.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
158
        ]
159
160
        torch.distributed.all_gather(weight_h4toh_array, moe.experts.h4toh.weight.data)
        torch.distributed.all_gather(bias_h4toh_array, moe.experts.h4toh.bias.data)
Sengxian's avatar
Sengxian committed
161
        moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
162
        moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
Sengxian's avatar
Sengxian committed
163

164
    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
165
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type=data_type
166
    )
Sengxian's avatar
Sengxian committed
167

Sengxian's avatar
Sengxian committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    moe_out_list = (
        moe_out,
        moe_grad_in,
        moe.experts.htoh4.weight.grad,
        moe.experts.h4toh.weight.grad,
        moe.experts.htoh4.bias.grad,
        moe.experts.h4toh.bias.grad,
    )
    raw_out_list = (
        raw_out,
        raw_grad_in,
        moe_raw.weight_htoh4.grad,
        moe_raw.weight_h4toh.grad,
        moe_raw.bias_htoh4.grad,
        moe_raw.bias_h4toh.grad,
    )
Sengxian's avatar
Sengxian committed
184

Rick Ho's avatar
Rick Ho committed
185
    if world_size > 1:
Sengxian's avatar
Sengxian committed
186
        _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
Jiezhong Qiu's avatar
Jiezhong Qiu committed
187
188
189
190
        torch.distributed.all_reduce(htoh4_w_grad)
        torch.distributed.all_reduce(h4toh_w_grad)
        torch.distributed.all_reduce(htoh4_b_grad)
        torch.distributed.all_reduce(h4toh_b_grad)
191
        mp_size = mp_group.size() if mp_group else 1
192
193
194
195
196
197
198
199
200
201
202
203
        htoh4_w_grad = (
            htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        h4toh_w_grad = (
            h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        htoh4_b_grad = (
            htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        h4toh_b_grad = (
            h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
Sengxian's avatar
Sengxian committed
204
        raw_out_list = _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
205

Sengxian's avatar
Sengxian committed
206
207
208
209
210
211
212
213
    names = [
        "output",
        "input grad",
        "htoh4 weight grad",
        "h4toh weight grad",
        "htoh4 bias grad",
        "h4toh bias grad",
    ]
Sengxian's avatar
Sengxian committed
214

215
216
217
218

    precision = 5e-1 if data_type == 'torch.HalfTensor' else 1e-3
        
    _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=precision)
219

Sengxian's avatar
Sengxian committed
220

221
222
223
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16])
Sengxian's avatar
Sengxian committed
224
@pytest.mark.parametrize("top_k", [2, 3])
225
@pytest.mark.parametrize("expert", [NaiveExpert, LinearExpert])
226
227
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
228
@pytest.mark.parametrize("mp_group", [None])
229
230
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
Rick Ho's avatar
Rick Ho committed
231
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16, torch.bfloat16])
232
def test_fmoe(
233
234
235
236
237
238
239
    batch_size,
    num_expert,
    d_model,
    top_k,
    expert: Union[Type[nn.Module], str],
    rank,
    world_size,
240
241
242
    mp_group,
    dp_group,
    world_group,
Rick Ho's avatar
Rick Ho committed
243
    data_type
244
245
246
247
248
249
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]
Rick Ho's avatar
Rick Ho committed
250
251
252
        assert(expert is not None)
    if isinstance(data_type, str):
        data_type = eval(data_type)
Sengxian's avatar
Sengxian committed
253

254
255
256
257
258
    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
259
        mp_group=mp_group,
260
261
        expert=expert,
        top_k=top_k,
Rick Ho's avatar
Rick Ho committed
262
    ).cuda().type(data_type)
263
264

    moe_raw = BruteForceMoE(
Sengxian's avatar
Sengxian committed
265
266
267
268
269
        expert=expert,
        num_expert=num_expert,
        d_model=d_model,
        world_size=world_size,
        top_k=top_k,
Rick Ho's avatar
Rick Ho committed
270
    ).cuda().type(data_type)
271
272
273
274
275
276
277
278
279
280

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
Rick Ho's avatar
Rick Ho committed
281
            assert(para.device.type == 'cuda')
282
283
284
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
Rick Ho's avatar
Rick Ho committed
285
            assert(para_tensor.device.type == 'cuda')
286
287
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
Sengxian's avatar
Sengxian committed
288
289
290
291
292
293
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]
294

295
    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
Rick Ho's avatar
Rick Ho committed
296
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type
297
    )
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
319
320
        mp_size = mp_group.size() if mp_group else 1
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
321

322
323
324
    moe_out_list = [moe_out, moe_grad, moe_grad_in]
    raw_out_list = [raw_out, raw_grad, raw_grad_in]
    names = ["forward", "backward", "grad_in"]
325

326
    _assert_numerical(names, moe_out_list, raw_out_list, rank)
Sengxian's avatar
Sengxian committed
327
328


329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
class MyModule(nn.Module):
    def __init__(self, dim=8):
        super(MyModule, self).__init__()
        self.model = nn.Sequential(
            OrderedDict(
                [
                    ("linear1", nn.Linear(dim, dim)),
                    ("relu1", nn.ReLU()),
                    ("linear2", nn.Linear(dim, dim)),
                    ("relu2", nn.ReLU()),
                    ("linear3", nn.Linear(dim, dim)),
                ]
            )
        )

    def set_comm(self):
        for p in self.model._modules["linear1"].parameters():
            setattr(p, "dp_comm", "mp")
        for p in self.model._modules["linear2"].parameters():
            setattr(p, "dp_comm", "dp")
        for p in self.model._modules["linear3"].parameters():
            setattr(p, "dp_comm", "world")

    def forward(self, inp):
        return self.model(inp)


def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
    batch_size, dim = 4, 8

    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    model = MyModule().cuda()
Rick Ho's avatar
Rick Ho committed
363
364
    model_ddp = LocalDDP(deepcopy(model),
            mp_group=mp_group, dp_group=dp_group, world_group=world_group)
365
    model = deepcopy(model_ddp.module)
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    model.set_comm()
    model_ddp.module.set_comm()

    inp = torch.randn(batch_size, dim).cuda()

    raw_out = model(inp).mean()
    ddp_out = model_ddp(inp).mean()

    raw_out.backward()
    ddp_out.backward()

    torch.distributed.all_reduce(
        model.model._modules["linear1"].weight.grad.data, group=mp_group
    )
    model.model._modules["linear1"].weight.grad /= mp_group.size()
    torch.distributed.all_reduce(
        model.model._modules["linear2"].weight.grad.data, group=dp_group
    )
    model.model._modules["linear2"].weight.grad /= dp_group.size()
    torch.distributed.all_reduce(
        model.model._modules["linear3"].weight.grad.data, group=world_group
    )
    model.model._modules["linear3"].weight.grad /= world_group.size()
    model_ddp.allreduce_params(reduce_after=False, fp32_allreduce=False)

    raw_out_list = [
        model.model._modules["linear1"].weight.grad,
        model.model._modules["linear2"].weight.grad,
        model.model._modules["linear3"].weight.grad,
    ]
    ddp_out_list = [
        model_ddp.module.model._modules["linear1"].weight.grad,
        model_ddp.module.model._modules["linear2"].weight.grad,
        model_ddp.module.model._modules["linear3"].weight.grad,
    ]

    names = ["mp grad", "dp grad", "wp grad"]

404
    _assert_numerical(names, ddp_out_list, raw_out_list, rank)
405
406


Colin's avatar
Colin committed
407
408
409
410
411
412
413
414
415
416
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [None])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("expert", [ [NaiveExpert for _ in range(4)], [LinearExpert, NaiveExpert, LinearExpert, NaiveExpert, LinearExpert, NaiveExpert, LinearExpert, NaiveExpert] ])
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
Rick Ho's avatar
Rick Ho committed
417
@pytest.mark.parametrize("data_type", [torch.float32])
Colin's avatar
Colin committed
418
419
420
421
422
423
424
425
426
427
428
def test_fmoe_experts(
    batch_size,
    num_expert,
    d_model,
    top_k,
    expert: Union[Type[nn.Module], str],
    rank,
    world_size,
    mp_group,
    dp_group,
    world_group,
Rick Ho's avatar
Rick Ho committed
429
    data_type
Colin's avatar
Colin committed
430
431
432
433
434
435
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]
Rick Ho's avatar
Rick Ho committed
436
437
    if isinstance(data_type, str):
        data_type = eval(data_type)
Colin's avatar
Colin committed
438
439
440
441
442
443
444
445
446

    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
        mp_group=mp_group,
        expert=expert,
        top_k=top_k,
Rick Ho's avatar
Rick Ho committed
447
    ).cuda().type(data_type)
Colin's avatar
Colin committed
448
449
450
451
452
453
454

    moe_raw = BruteForceMoE(
        expert=expert,
        num_expert=num_expert,
        d_model=d_model,
        world_size=world_size,
        top_k=top_k,
Rick Ho's avatar
Rick Ho committed
455
    ).cuda().to(data_type)
Colin's avatar
Colin committed
456
457
458
459
460
461
462
463
464
465

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
Rick Ho's avatar
Rick Ho committed
466
467
            for ep in expert.parameters():
                assert(ep.device.type == 'cuda')
Colin's avatar
Colin committed
468
469
470
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
Rick Ho's avatar
Rick Ho committed
471
            assert(para_tensor.device.type == 'cuda')
Colin's avatar
Colin committed
472
473
474
475
476
477
478
479
480
481
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]

    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
Rick Ho's avatar
Rick Ho committed
482
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type
Colin's avatar
Colin committed
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
    )

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
        mp_size = mp_group.size() if mp_group else 1
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size

    moe_out_list = [moe_out, moe_grad, moe_grad_in]
    raw_out_list = [raw_out, raw_grad, raw_grad_in]
    names = ["forward", "backward", "grad_in"]

    _assert_numerical(names, moe_out_list, raw_out_list, rank)


515
if __name__ == "__main__":
Rick Ho's avatar
Rick Ho committed
516
    test_fmoe(
517
518
519
        batch_size=2,
        num_expert=2,
        d_model=2,
520
        top_k=2,
Rick Ho's avatar
Rick Ho committed
521
        expert=[NaiveExpert for _ in range(4)],
522
523
        rank=0,
        world_size=1,
524
        mp_group=None,
525
526
        dp_group=None,
        world_group=None,
Rick Ho's avatar
Rick Ho committed
527
        data_type=torch.bfloat16
528
    )