test_tilelang_autotune_with_inputs.py 4.84 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import itertools
import logging
import tilelang
import tilelang.testing
from tilelang.autotuner import set_autotune_inputs
import tilelang.language as T

# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def ref_program(A, B):
    """
    A reference matrix multiplication program, used to compare performance.

    Parameters
    ----------
    A : numpy.ndarray
        The matrix with shape (M, K).
    B : numpy.ndarray
        The matrix with shape (N, K).

    Returns
    -------
    np.ndarray
        The result of A @ B.T, shape (M, N).
    """
    return A @ B.T


def get_configs():
    iter_params = dict(
        block_M=[64],
        block_N=[64],
        block_K=[32],
        num_stages=[0, 1],
        thread_num=[128],
        enable_rasterization=[False])
    return [{
        k: v for k, v in zip(iter_params, values)
    } for values in itertools.product(*iter_params.values())]


@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(out_idx=[-1])
def matmul(M,
           N,
           K,
           block_M=128,
           block_N=128,
           block_K=32,
           num_stages=0,
           thread_num=128,
           enable_rasterization=False):

    dtype = "float16"
    accum_dtype = "float"

    @T.prim_func
    def main(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((N, K), dtype),
            C: T.Tensor((M, N), dtype),
    ):
        """
        The compiled TVM function for block-level matrix multiplication.

        - We divide the entire (M, N) domain into blocks of shape
            (block_M, block_N).
        - Each block has its own allocated shared memory for sub-blocks
            of A and B.
        - The partial results go into C_local, and then we copy them back
            to global memory C.
        """
        # Bind x-dimension to block index in N,
        #     y-dimension to block index in M.
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):

            # Allocate shared memory for A sub-block of shape (block_M, block_K)
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            # Allocate shared memory for B sub-block of shape (block_N, block_K)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            # Allocate a local fragment for intermediate accumulation
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            # Enable (or disable) swizzling optimization
            T.use_swizzle(panel_size=10, enable=enable_rasterization)

            # Clear out the accumulation buffer
            T.clear(C_local)

            # Loop over sub-blocks in K dimension, pipelined by num_stages
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                # Load a sub-block of A from global memory into A_shared
                T.copy(
                    A[by * block_M, k * block_K],
                    A_shared,
                )
                # Load a sub-block of B from global memory into B_shared
                T.copy(
                    B[bx * block_N, k * block_K],
                    B_shared,
                )
                # Perform a partial matrix multiplication:
                #   C_local += A_shared @ B_shared^T
                T.gemm(
                    A_shared,
                    B_shared,
                    C_local,
                    transpose_B=True,
                )
            # Write back the results from C_local to the global memory C
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_autotune(M, N, K, M_value=None, N_value=None, K_value=None):
    import torch

    def _resolve(dim, provided, name):
        if isinstance(dim, T.Var):
            if provided is None:
                raise ValueError(f"Dynamic dimension {name} requires a concrete value.")
            return provided
        return dim

    actual_M = _resolve(M, M_value, "M")
    actual_N = _resolve(N, N_value, "N")
    actual_K = _resolve(K, K_value, "K")

    a = torch.randn(actual_M, actual_K, dtype=torch.float16).cuda()
    b = torch.randn(actual_N, actual_K, dtype=torch.float16).cuda()

    with set_autotune_inputs([a, b]):
        kernel = matmul(M, N, K)

    c = kernel(a, b)

    ref_c = ref_program(a, b)
    torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)


def test_autotune_matmul():
    """
    Run the autotuning validation for the matmul kernel on a 1024x1024x1024 problem.

    This test constructs random CUDA tensors, autotunes the JIT-compiled block-level matrix-multiplication kernel,
    executes it, and asserts the result matches a reference CPU implementation within tolerances.
    """
    run_autotune(1024, 1024, 1024)


def test_autotune_matmul_symbolic_m():
    run_autotune(T.symbolic("m"), 1024, 1024, M_value=1024)


if __name__ == "__main__":
    tilelang.testing.main()