"ts/webui/src/static/style/common/trialStatus.css" did not exist on "b2b4f4581deafebdd172c18e91ece060a6d62e47"
example_gemm_autotune.py 9.4 KB
Newer Older
1
2
3
4
5
6
7
import argparse
import itertools
import tilelang as tl
import tilelang.language as T
from tilelang.autotuner import AutoTuner
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
8
from tilelang.carver.arch import CDNA
9
from tilelang.carver.roller.rasterization import NoRasterization
10
import torch
11
12
13


def ref_program(A, B):
14
15
    """
    Compute the matrix product of A and the transpose of B.
16

17
18
    A and B are expected to be 2-D tensors where A has shape (M, K) and B has shape (N, K). The result is a tensor with shape (M, N) equal to A @ B.T, using the inputs' dtypes.
    """
19
20
21
22
    return A @ B.T


def get_configs(M, N, K, with_roller=False, topk=20):
23
24
    """
    Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply.
25

26
27
28
29
30
31
    When with_roller is True this queries the MatmulTemplate roller to produce up to `topk` recommended
    configurations (device-specific TensorCore-friendly tilings). Each returned dict contains:
      - block_M, block_N, block_K: tile sizes
      - num_stages: pipeline staging (0 means no explicit staging)
      - thread_num: total threads used for the block
      - enable_rasteration: whether a rasterization/swizzle layout was recommended (note spelling)
32

33
34
    When with_roller is False this returns the Cartesian product of a fixed set of candidate
    parameters; the returned dicts use the backward-compatible key name "enable_rasteration" for that flag.
35

36
37
38
39
40
    Parameters:
        M, N, K (int): GEMM dimensions used to generate valid tile sizes.
        with_roller (bool): If True, use MatmulTemplate's roller to generate device-aware hints;
            otherwise use a predefined candidate grid.
        topk (int): Maximum number of roller hints to request when with_roller is True.
41

42
43
    Returns:
        List[dict]: A list of configuration dictionaries as described above.
44

45
46
47
    Raises:
        ValueError: if with_roller is True but the roller returns no hints.
    """
48
    if with_roller:
49
        arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        carve_template = MatmulTemplate(
            M=M,
            N=N,
            K=K,
            in_dtype="float16",
            out_dtype="float16",
            accum_dtype="float",
        ).with_arch(arch)

        func = carve_template.equivalent_function()
        assert func is not None, "Function is None"
        roller_hints = carve_template.recommend_hints(topk=topk)
        if roller_hints is None:
            raise ValueError("No Roller Hints Found for TensorCore Scheduling")
        configs = []
        for hint in roller_hints:
            config = {}
            block_m, block_n = hint.block
            warp_m, warp_n = hint.warp
            # block_rows, block_cols represents warp partitioning
            block_rows, block_cols = block_m // warp_m, block_n // warp_n
            config["block_M"] = block_m
            config["block_N"] = block_n
            config["block_K"] = hint.rstep[0]
            config["num_stages"] = hint.pipeline_stage if hint.pipeline_stage > 1 else 0
            config["thread_num"] = block_rows * block_cols * 32
            config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
            configs.append(config)
    else:
        block_M = [64, 128, 256]
        block_N = [64, 128, 256]
        block_K = [32, 64]
        num_stages = [0, 1, 2, 3]
        thread_num = [128, 256]
        enable_rasterization = [True, False]
        _configs = list(
            itertools.product(
                block_M,
                block_N,
                block_K,
                num_stages,
                thread_num,
                enable_rasterization,
93
94
            )
        )
95
96
97
98
99
100
101
102
103

        configs = [
            {
                "block_M": c[0],
                "block_N": c[1],
                "block_K": c[2],
                "num_stages": c[3],
                "thread_num": c[4],
                "enable_rasteration": c[5],  # keep param name for backward-compat
104
105
            }
            for c in _configs
106
107
108
109
110
111
112
113
114
115
116
117
118
        ]
    return configs


def get_best_config(M, N, K, with_roller=False):
    def kernel(
        block_M=None,
        block_N=None,
        block_K=None,
        num_stages=None,
        thread_num=None,
        enable_rasteration=None,
    ):
119
        dtype = "bfloat16"
120
121
122
123
        accum_dtype = "float"

        @T.prim_func
        def main(
124
125
126
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((N, K), dtype),
            C: T.Tensor((M, N), dtype),
127
        ):
128
            with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
                A_shared = T.alloc_shared((block_M, block_K), dtype)
                B_shared = T.alloc_shared((block_N, block_K), dtype)
                C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
                C_shared = T.alloc_shared((block_M, block_N), dtype)
                T.use_swizzle(panel_size=10, enable=enable_rasteration)
                T.clear(C_local)
                for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                    T.copy(A[by * block_M, k * block_K], A_shared)
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                    T.gemm(
                        A_shared,
                        B_shared,
                        C_local,
                        transpose_B=True,
                    )
                T.copy(C_local, C_shared)
                T.copy(C_shared, C[by * block_M, bx * block_N])

        return main

149
150
151
    autotuner = (
        AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller))
        .set_compile_args(
152
            out_idx=[-1],
153
            target="auto",
154
155
        )
        .set_profile_args(
156
157
158
159
            supply_type=tl.TensorSupplyType.Integer,
            ref_prog=ref_program,
            skip_check=False,
        )
160
    )
161
162
163
164
165
166
167
168
169
170
171
172
    return autotuner.run(warmup=3, rep=20)


def get_heuristic_config() -> dict:
    # Get CUDA device properties
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available")
    device = torch.cuda.current_device()
    sm_major, sm_minor = torch.cuda.get_device_capability(device)
    sm_version = sm_major * 10 + sm_minor
    print(f"CUDA device capability: {sm_version}")
    if sm_version in {80}:
173
        return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True}
174
    elif sm_version in {90}:
175
        return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True}
176
    else:
177
        return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True}
178
179


180
@tl.jit(out_idx=[-1])
181
def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype="float16", accum_dtype="float"):
182
    @T.prim_func
183
    def gemm_autotune(
184
185
186
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((N, K), dtype),
        C: T.Tensor((M, N), dtype),
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            C_shared = T.alloc_shared((block_M, block_N), dtype)
            T.use_swizzle(panel_size=10, enable=enable_rasteration)
            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(A[by * block_M, k * block_K], A_shared)
                T.copy(B[bx * block_N, k * block_K], B_shared)
                T.gemm(
                    A_shared,
                    B_shared,
                    C_local,
                    transpose_B=True,
                )
            T.copy(C_local, C_shared)
            T.copy(C_shared, C[by * block_M, bx * block_N])

207
    return gemm_autotune
208
209


210
def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False, with_roller: bool = False):
211
212
213
214
215
216
217
    use_autotune = True
    if use_autotune:
        result = get_best_config(M, N, K, with_roller)
        print(result.config)
        kernel = result.kernel
    else:
        config = get_heuristic_config()
218
        kernel = matmul(M, N, K, **config)
219
220
221
222
223
224
225
226
227
228

    # benchmark
    profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto)
    tilelang_latency = profiler.do_bench()
    ref_latency = profiler.do_bench(ref_program)
    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
    print(f"TileLang latency: {tilelang_latency}")
    print(f"Ref latency: {ref_latency}")
    print(f"TileLang TFlops: {2 * M * N * K / tilelang_latency * 1e-9}")
    print(f"Ref TFlops: {2 * M * N * K / ref_latency * 1e-9}")
229
230
231
232


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
233
234
235
    parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M")
    parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N")
    parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K")
236
237
    parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs")
    parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space")
238
239
    args = parser.parse_args()
    main(args.m, args.n, args.k, args.use_autotune, args.with_roller)