test_dp.py 1.08 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
import os

Sengxian's avatar
Sengxian committed
3
4
import pytest
import torch
Rick Ho's avatar
Rick Ho committed
5

Sengxian's avatar
Sengxian committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert

n_devices = int(os.environ.get("N_GPUS", "2"))


@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
def test_fmoe_dp(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
    activation=torch.nn.functional.gelu,
):
Rick Ho's avatar
Rick Ho committed
26
27
28
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)

Sengxian's avatar
Sengxian committed
29
    experts = _Expert(num_expert, d_model, d_hidden, activation).cuda()
Rick Ho's avatar
Rick Ho committed
30

Sengxian's avatar
Sengxian committed
31
32
33
34
35
36
37
38
39
40
41
42
    def expert_fn(inp, gate):
        return experts(inp, gate)

    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=1,
        mp_group=None,
        expert_fn=expert_fn,
        top_k=top_k,
    ).cuda()
Rick Ho's avatar
Rick Ho committed
43

Sengxian's avatar
Sengxian committed
44
    moe_dp = torch.nn.DataParallel(moe, device_ids=list(range(n_devices)))
Rick Ho's avatar
Rick Ho committed
45

Sengxian's avatar
Sengxian committed
46
47
    for i in range(5):
        output = moe_dp(torch.rand(batch_size, d_model).cuda())