test_tilelang_cache_matmul.py 3.46 KB
Newer Older
1
2
from tilelang import tvm as tvm
import tilelang.testing
3
from tilelang.cache import cached
4
import tilelang.language as T
5
6


7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
    """
    Defines a matrix multiplication primitive function using tilelang.

    This function constructs a tilelang primitive function for matrix multiplication,
    optimized for execution on hardware accelerators. It utilizes shared memory and
    fragment memory for performance.

    Args:
        M (int): Number of rows in matrix A and C.
        N (int): Number of columns in matrix B and C.
        K (int): Number of columns in matrix A and rows in matrix B.
        block_M (int): Block size for M dimension in shared memory and fragment.
        block_N (int): Block size for N dimension in shared memory and fragment.
        block_K (int): Block size for K dimension in shared memory.
        dtype (str, optional): Data type for input matrices A and B, and output C. Defaults to "float16".
        accum_dtype (str, optional): Accumulation data type for internal computations. Defaults to "float".

    Returns:
        T.PrimFunc: A tilelang primitive function representing the matrix multiplication.
    """
28
29
30

    @T.prim_func
    def main(
31
32
33
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), dtype),
34
    ):
35
36
37
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
38
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
39

40
            T.clear(C_local)
41
42
43
44
45
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
                T.copy(A[by * block_M, k * block_K], A_shared)
                T.copy(B[k * block_K, bx * block_N], B_shared)
                T.gemm(A_shared, B_shared, C_local)

46
47
48
49
50
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


51
52
53
54
55
56
57
58
59
def run_cache_matmul():
    """
    Demonstrates the usage of the cached matrix multiplication kernel.

    This function defines a reference PyTorch matrix multiplication,
    creates a cached kernel from the tilelang matmul function,
    runs the kernel with random input tensors, compares the output with the reference,
    and prints the CUDA kernel source code.
    """
60
61

    def ref_program(A, B):
62
63
64
        """
        Reference PyTorch matrix multiplication for comparison.
        """
65
66
        import torch
        C = torch.matmul(A.to(torch.float), B.to(torch.float))
67
        C = C.to(torch.half)  # Assuming dtype="float16" in matmul
68
69
        return C

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    func = matmul(1024, 1024, 1024, 128, 128, 32)
    kernel = cached(func, [2], execution_backend="cython")
    import torch

    a = torch.randn(1024, 1024).cuda().half()
    b = torch.randn(1024, 1024).cuda().half()

    c = kernel(a, b)
    print("\nOutput from Cached Kernel:")
    print(c)

    ref_c = ref_program(a, b)
    print("\nReference PyTorch Output:")
    print(ref_c)

    torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
    print("\nOutputs are close (within tolerance).")

    # Get CUDA Source
    print("\nCUDA Kernel Source:")
    print(kernel.get_kernel_source())
91
92
93


def test_cache_matmul_f16f16f16_nn():
94
95
96
97
    """
    Test function for cached matrix multiplication (float16 inputs, float16 output, no transpose).
    """
    run_cache_matmul()
98
99
100
101


if __name__ == "__main__":
    tilelang.testing.main()