test_tilelang_autotune_with_inputs.py 4.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import itertools
import logging
import tilelang
import tilelang.testing
from tilelang.autotuner import set_autotune_inputs
import tilelang.language as T

# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def ref_program(A, B):
    """
    A reference matrix multiplication program, used to compare performance.

    Parameters
    ----------
    A : numpy.ndarray
        The matrix with shape (M, K).
    B : numpy.ndarray
        The matrix with shape (N, K).

    Returns
    -------
    np.ndarray
        The result of A @ B.T, shape (M, N).
    """
    return A @ B.T


def get_configs():
33
34
    iter_params = dict(block_M=[64], block_N=[64], block_K=[32], num_stages=[0, 1], thread_num=[128], enable_rasterization=[False])
    return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
35
36


37
38
39
@tilelang.autotune(
    configs=get_configs(),
)
40
@tilelang.jit(out_idx=[-1])
41
def matmul(M, N, K, block_M=128, block_N=128, block_K=32, num_stages=0, thread_num=128, enable_rasterization=False):
42
43
    dtype = T.float16
    accum_dtype = T.float32
44
45
46

    @T.prim_func
    def main(
47
48
49
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((N, K), dtype),
        C: T.Tensor((M, N), dtype),
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    ):
        """
        The compiled TVM function for block-level matrix multiplication.

        - We divide the entire (M, N) domain into blocks of shape
            (block_M, block_N).
        - Each block has its own allocated shared memory for sub-blocks
            of A and B.
        - The partial results go into C_local, and then we copy them back
            to global memory C.
        """
        # Bind x-dimension to block index in N,
        #     y-dimension to block index in M.
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
            # Allocate shared memory for A sub-block of shape (block_M, block_K)
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            # Allocate shared memory for B sub-block of shape (block_N, block_K)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            # Allocate a local fragment for intermediate accumulation
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            # Enable (or disable) swizzling optimization
72
            T.use_swizzle(panel_size=10, enable=enable_rasterization)
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

            # Clear out the accumulation buffer
            T.clear(C_local)

            # Loop over sub-blocks in K dimension, pipelined by num_stages
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                # Load a sub-block of A from global memory into A_shared
                T.copy(
                    A[by * block_M, k * block_K],
                    A_shared,
                )
                # Load a sub-block of B from global memory into B_shared
                T.copy(
                    B[bx * block_N, k * block_K],
                    B_shared,
                )
                # Perform a partial matrix multiplication:
                #   C_local += A_shared @ B_shared^T
                T.gemm(
                    A_shared,
                    B_shared,
                    C_local,
                    transpose_B=True,
                )
            # Write back the results from C_local to the global memory C
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


103
def run_autotune(M, N, K, M_value=None, N_value=None, K_value=None):
104
    import torch
105
106
107
108
109
110
111
112
113
114
115
116
117
118

    def _resolve(dim, provided, name):
        if isinstance(dim, T.Var):
            if provided is None:
                raise ValueError(f"Dynamic dimension {name} requires a concrete value.")
            return provided
        return dim

    actual_M = _resolve(M, M_value, "M")
    actual_N = _resolve(N, N_value, "N")
    actual_K = _resolve(K, K_value, "K")

    a = torch.randn(actual_M, actual_K, dtype=torch.float16).cuda()
    b = torch.randn(actual_N, actual_K, dtype=torch.float16).cuda()
119
120
121
122
123
124
125
126
127
128
129

    with set_autotune_inputs([a, b]):
        kernel = matmul(M, N, K)

    c = kernel(a, b)

    ref_c = ref_program(a, b)
    torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)


def test_autotune_matmul():
130
131
    """
    Run the autotuning validation for the matmul kernel on a 1024x1024x1024 problem.
132

133
134
135
    This test constructs random CUDA tensors, autotunes the JIT-compiled block-level matrix-multiplication kernel,
    executes it, and asserts the result matches a reference CPU implementation within tolerances.
    """
136
    run_autotune(1024, 1024, 1024)
137
138


139
140
141
142
def test_autotune_matmul_symbolic_m():
    run_autotune(T.symbolic("m"), 1024, 1024, M_value=1024)


143
144
if __name__ == "__main__":
    tilelang.testing.main()