test_moe_fused_gate.py 3.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import pytest
import torch
from sgl_kernel import moe_fused_gate

from sglang.srt.layers.moe.topk import biased_grouped_topk


@pytest.mark.parametrize(
    "seq_length",
    list(range(1, 10))
    + [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536],
)
@pytest.mark.parametrize(
    "params",
    [
        (128, 4, 2, 4),
        (256, 8, 4, 8),  # deepseek v3
        (512, 16, 8, 16),
    ],
)
21
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1, 2])
22
23
24
25
@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [True, False])
def test_moe_fused_gate_combined(
    seq_length, params, num_fused_shared_experts, apply_routed_scaling_factor_on_output
):
26
    num_experts, num_expert_group, topk_group, topk = params
Ke Bao's avatar
Ke Bao committed
27
    dtype = torch.float32
28
29

    torch.manual_seed(seq_length)
Ke Bao's avatar
Ke Bao committed
30
    tensor = torch.rand((seq_length, num_experts), dtype=dtype, device="cuda")
31
    scores = tensor.clone()
Ke Bao's avatar
Ke Bao committed
32
    bias = torch.rand(num_experts, dtype=dtype, device="cuda")
33
    topk = topk + num_fused_shared_experts
34
35
36
37
38
39
40

    output, indices = moe_fused_gate(
        tensor,
        bias,
        num_expert_group=num_expert_group,
        topk_group=topk_group,
        topk=topk,
41
        num_fused_shared_experts=num_fused_shared_experts,
42
        routed_scaling_factor=2.5,
43
        apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
44
45
46
47
48
49
50
51
52
    )
    ref_output, ref_indices = biased_grouped_topk(
        scores,
        scores,
        bias,
        topk=topk,
        renormalize=True,
        num_expert_group=num_expert_group,
        topk_group=topk_group,
53
        num_fused_shared_experts=num_fused_shared_experts,
54
        routed_scaling_factor=2.5,
55
        apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
56
57
    )

58
59
    # When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
    if num_fused_shared_experts > 0:
60
61
62
63
64
65
66
        original_indices = indices.clone()
        original_ref_indices = ref_indices.clone()

        indices = indices[:, :-1]
        ref_indices = ref_indices[:, :-1]

        valid_min = num_experts
67
        valid_max = num_experts + num_fused_shared_experts
68
69
70
71
72
73
74
75
76
77
78
        shared_indices = original_indices[:, -1]
        shared_ref_indices = original_ref_indices[:, -1]
        if shared_indices is not None:
            assert torch.all(
                (shared_indices >= valid_min) & (shared_indices < valid_max)
            ), f"Shared expert indices out of range: found values outside [{valid_min}, {valid_max})"
        if shared_ref_indices is not None:
            assert torch.all(
                (shared_ref_indices >= valid_min) & (shared_ref_indices < valid_max)
            ), f"Shared expert reference indices out of range: found values outside [{valid_min}, {valid_max})"

79
80
81
82
83
84
85
86
87
    idx_check = torch.allclose(
        ref_indices.sort()[0].to(torch.int32),
        indices.sort()[0].to(torch.int32),
        rtol=1e-04,
        atol=1e-05,
    )
    output_check = torch.allclose(
        ref_output.sort()[0].to(torch.float32),
        output.sort()[0].to(torch.float32),
88
89
        rtol=1e-02,
        atol=1e-03,
90
91
92
93
    )

    assert idx_check, (
        f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, "
94
        f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
95
96
97
    )
    assert output_check, (
        f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
98
        f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
99
100
101
102
103
    )


if __name__ == "__main__":
    pytest.main([__file__])