test_dsv3_fused_a_gemm.py 860 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import pytest
import torch
import torch.nn.functional as F
from sgl_kernel import dsv3_fused_a_gemm


@pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)])
def test_dsv3_fused_a_gemm(num_tokens):
    kHdIn = 7168
    kHdOut = 2112

    mat_a = torch.randn(
        (num_tokens, kHdIn), dtype=torch.bfloat16, device="cuda"
    ).contiguous()
    mat_b = torch.randn((kHdOut, kHdIn), dtype=torch.bfloat16, device="cuda").transpose(
        0, 1
    )
    output = torch.empty(
        (num_tokens, kHdOut), dtype=torch.bfloat16, device="cuda"
    ).contiguous()

    ref = F.linear(mat_a, mat_b.T)

    output = dsv3_fused_a_gemm(mat_a, mat_b)

    assert torch.allclose(
        output, ref, rtol=1e-2, atol=1e-3
    ), "Fused GEMM output mismatch with torch.nn.functional.linear reference"


if __name__ == "__main__":
    pytest.main([__file__])