test_dp.py 1.24 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
import os

Sengxian's avatar
Sengxian committed
3
4
import pytest
import torch
Rick Ho's avatar
Rick Ho committed
5

Sengxian's avatar
Sengxian committed
6
7
8
9
10
11
12
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert

n_devices = int(os.environ.get("N_GPUS", "2"))


13
14
15
16
17
18
19
20
class MyMoE(FMoE):
    def __init__(self, num_expert, d_model, d_hidden, top_k, activation):
        super().__init__(
            num_expert=num_expert,
            d_model=d_model,
            gate=NaiveGate,
            world_size=1,
            mp_group=None,
Sengxian's avatar
Sengxian committed
21
            top_k=top_k,
22
23
24
25
        )
        self.experts = _Expert(num_expert, d_model, d_hidden, activation)


Sengxian's avatar
Sengxian committed
26
27
28
29
30
31
32
33
34
35
36
37
38
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
def test_fmoe_dp(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
    activation=torch.nn.functional.gelu,
):
Rick Ho's avatar
Rick Ho committed
39
40
41
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

42
    moe = MyMoE(num_expert, d_model, d_hidden, top_k, activation).cuda()
Sengxian's avatar
Sengxian committed
43
    moe_dp = torch.nn.DataParallel(moe, device_ids=list(range(n_devices)))
Rick Ho's avatar
Rick Ho committed
44

Sengxian's avatar
Sengxian committed
45
46
    for i in range(5):
        output = moe_dp(torch.rand(batch_size, d_model).cuda())
47
48


Sengxian's avatar
Sengxian committed
49
if __name__ == "__main__":
50
    test_fmoe_dp(4, 2, 4, 16, 32)