test_cublas_grouped_gemm.py 1.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import unittest

import torch
from sgl_kernel import cublas_grouped_gemm


def torch_grouped_gemm(a_array, b_array, out_dtype):
    c_array = []
    for a, b in zip(a_array, b_array):
        c_array.append(torch.matmul(a, b.t()).to(out_dtype))
    return c_array


class TestGroupedGemm(unittest.TestCase):
    def _test_accuracy(self, Ms, Ns, Ks, out_dtype):
        group_count = len(Ms)
        a_array = []
        b_array = []
        c_array_cublas = []
        for i in range(group_count):
            M, N, K = Ms[i], Ns[i], Ks[i]
            a_array.append(torch.randn((M, K), device="cuda", dtype=out_dtype) * 5)
            b_array.append(torch.randn((N, K), device="cuda", dtype=out_dtype) * 5)
            c_array_cublas.append(torch.empty((M, N), device="cuda", dtype=out_dtype))

        c_array_torch = torch_grouped_gemm(a_array, b_array, out_dtype)
        cublas_grouped_gemm(a_array, b_array, c_array_cublas, out_dtype)

        for i in range(group_count):
            M, N, K = Ms[i], Ns[i], Ks[i]
            torch.testing.assert_close(c_array_torch[i], c_array_cublas[i])
            print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")

    def test_accuracy(self):
        Ms = [1, 16, 32, 256, 1024]
        Ns = [2, 16, 128, 256, 4096]
        Ks = [3, 16, 32, 512, 8192]
        out_dtypes = [torch.float16, torch.bfloat16]
        for out_dtype in out_dtypes:
            self._test_accuracy(Ms, Ns, Ks, out_dtype)


if __name__ == "__main__":
    if torch.cuda.is_available():
        cuda_version = tuple(map(int, torch.version.cuda.split(".")))
        if cuda_version >= (12, 5):
            unittest.main()
        else:
            print(f"Cuda version {cuda_version} lower than 12.5, not executing tests.")