test_tilelang_autotune_with_inputs.py 4.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import itertools
import logging
import tilelang
import tilelang.testing
from tilelang.autotuner import set_autotune_inputs
import tilelang.language as T

# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def ref_program(A, B):
    """
    A reference matrix multiplication program, used to compare performance.

    Parameters
    ----------
    A : numpy.ndarray
        The matrix with shape (M, K).
    B : numpy.ndarray
        The matrix with shape (N, K).

    Returns
    -------
    np.ndarray
        The result of A @ B.T, shape (M, N).
    """
    return A @ B.T


def get_configs():
    iter_params = dict(
        block_M=[64],
        block_N=[64],
        block_K=[32],
        num_stages=[0, 1],
        thread_num=[128],
        enable_rasterization=[False])
    return [{
        k: v for k, v in zip(iter_params, values)
    } for values in itertools.product(*iter_params.values())]


45
@tilelang.autotune(configs=get_configs(),)
46
47
48
49
50
51
52
53
54
@tilelang.jit(out_idx=[-1])
def matmul(M,
           N,
           K,
           block_M=128,
           block_N=128,
           block_K=32,
           num_stages=0,
           thread_num=128,
55
           enable_rasterization=False):
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    dtype = "float16"
    accum_dtype = "float"

    @T.prim_func
    def main(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((N, K), dtype),
            C: T.Tensor((M, N), dtype),
    ):
        """
        The compiled TVM function for block-level matrix multiplication.

        - We divide the entire (M, N) domain into blocks of shape
            (block_M, block_N).
        - Each block has its own allocated shared memory for sub-blocks
            of A and B.
        - The partial results go into C_local, and then we copy them back
            to global memory C.
        """
        # Bind x-dimension to block index in N,
        #     y-dimension to block index in M.
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):

            # Allocate shared memory for A sub-block of shape (block_M, block_K)
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            # Allocate shared memory for B sub-block of shape (block_N, block_K)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            # Allocate a local fragment for intermediate accumulation
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            # Enable (or disable) swizzling optimization
88
            T.use_swizzle(panel_size=10, enable=enable_rasterization)
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

            # Clear out the accumulation buffer
            T.clear(C_local)

            # Loop over sub-blocks in K dimension, pipelined by num_stages
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                # Load a sub-block of A from global memory into A_shared
                T.copy(
                    A[by * block_M, k * block_K],
                    A_shared,
                )
                # Load a sub-block of B from global memory into B_shared
                T.copy(
                    B[bx * block_N, k * block_K],
                    B_shared,
                )
                # Perform a partial matrix multiplication:
                #   C_local += A_shared @ B_shared^T
                T.gemm(
                    A_shared,
                    B_shared,
                    C_local,
                    transpose_B=True,
                )
            # Write back the results from C_local to the global memory C
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_autotune(M: int, N: int, K: int):
    import torch
    a = torch.randn(M, K, dtype=torch.float16).cuda()
    b = torch.randn(N, K, dtype=torch.float16).cuda()

    with set_autotune_inputs([a, b]):
        kernel = matmul(M, N, K)

    c = kernel(a, b)

    ref_c = ref_program(a, b)
    torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)


def test_autotune_matmul():
134
135
136
137
138
139
    """
    Run the autotuning validation for the matmul kernel on a 1024x1024x1024 problem.
    
    This test constructs random CUDA tensors, autotunes the JIT-compiled block-level matrix-multiplication kernel,
    executes it, and asserts the result matches a reference CPU implementation within tolerances.
    """
140
    run_autotune(1024, 1024, 1024)
141
142
143
144


if __name__ == "__main__":
    tilelang.testing.main()