test_dsv3_router_gemm.py 1.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import dsv3_router_gemm


@pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)])
def test_dsv3_router_gemm(num_tokens):
    num_experts = 256
    hidden_dim = 7168

    mat_a = torch.randn(
        (num_tokens, hidden_dim), dtype=torch.bfloat16, device="cuda"
    ).contiguous()
    mat_b = torch.randn(
        (num_experts, hidden_dim), dtype=torch.bfloat16, device="cuda"
    ).contiguous()

19
20
21
22
23
    bf16_ref = F.linear(mat_a, mat_b)
    float_ref = bf16_ref.to(torch.float32)

    bf16_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
    float_output = dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
24

25
26
27
    assert torch.allclose(
        bf16_output, bf16_ref, rtol=1e-2, atol=1e-3
    ), "Router GEMM output in bf16 dtype mismatch with torch.nn.functional.linear reference"
28
29

    assert torch.allclose(
30
31
        float_output, float_ref, rtol=1e-2, atol=1e-3
    ), "Router GEMM output in float32 dtype mismatch with torch.nn.functional.linear reference"
32
33
34
35


if __name__ == "__main__":
    pytest.main([__file__])