test_cublas_grouped_gemm.py 1.3 KB
Newer Older
1
import pytest
2
3
4
5
6
import torch
from sgl_kernel import cublas_grouped_gemm


def torch_grouped_gemm(a_array, b_array, out_dtype):
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    return [torch.matmul(a, b.t()).to(out_dtype) for a, b in zip(a_array, b_array)]


skip_condition = not torch.cuda.is_available() or (
    torch.version.cuda is None
    or tuple(map(int, torch.version.cuda.split("."))) < (12, 5)
)


@pytest.mark.skipif(
    skip_condition, reason="CUDA not available or CUDA version lower than 12.5"
)
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("M", [1, 16, 32, 256, 1024])
@pytest.mark.parametrize("N", [2, 16, 128, 256, 4096])
@pytest.mark.parametrize("K", [3, 16, 32, 512, 8192])
def test_grouped_gemm_accuracy(out_dtype, M, N, K):
    a = torch.randn((M, K), device="cuda", dtype=out_dtype) * 5
    b = torch.randn((N, K), device="cuda", dtype=out_dtype) * 5
    expected = torch.matmul(a, b.t()).to(out_dtype)

    a_array = [a]
    b_array = [b]
    c_array = [torch.empty((M, N), device="cuda", dtype=out_dtype)]

    result_torch = torch_grouped_gemm(a_array, b_array, out_dtype)[0]
    cublas_grouped_gemm(a_array, b_array, c_array, out_dtype)

    torch.testing.assert_close(result_torch, expected)
    torch.testing.assert_close(c_array[0], expected)
37
38
39


if __name__ == "__main__":
40
    pytest.main([__file__])