performance.py 2.95 KB
Newer Older
1
2
import argparse
import tilelang.language as T
3
from tilelang.autotuner import AutoTuner
4
5
6
7
8
9


def ref_program(A, B):
    return A @ B.T


10
def get_configs():
11
12
13
14
15
16
17
18
19
20
    configs = [
        {
            "block_M": 128,
            "block_N": 128,
            "block_K": 64,
            "num_stages": 2,
            "thread_num": 256,
            "enable_rasteration": True,  # keep param name for backward-compat
        }
    ]
21
22
23
    return configs


24
def run(M, N, K):
25
26
27
28
29
30
31
32
    def kernel(
        block_M=None,
        block_N=None,
        block_K=None,
        num_stages=None,
        thread_num=None,
        enable_rasteration=None,
    ):
33
34
        dtype = T.float16
        accum_dtype = T.float32
35
36
37

        @T.prim_func
        def main(
38
39
40
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((N, K), dtype),
            C: T.Tensor((M, N), dtype),
41
        ):
42
            with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
                A_shared = T.alloc_shared((block_M, block_K), dtype)
                B_shared = T.alloc_shared((block_N, block_K), dtype)
                C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
                C_shared = T.alloc_shared((block_M, block_N), dtype)
                T.use_swizzle(panel_size=10, enable=enable_rasteration)
                T.clear(C_local)
                for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                    T.copy(A[by * block_M, k * block_K], A_shared)
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                    T.gemm(
                        A_shared,
                        B_shared,
                        C_local,
                        transpose_B=True,
                    )
                T.copy(C_local, C_shared)
                T.copy(C_shared, C[by * block_M, bx * block_N])

        return main

63
64
65
    autotuner = (
        AutoTuner.from_kernel(kernel=kernel, configs=get_configs())
        .set_compile_args(
66
67
            out_idx=[-1],
            target="auto",
68
69
70
71
72
        )
        .set_profile_args(
            ref_prog=ref_program,
        )
    )
73
    return autotuner.run(warmup=3, rep=20)
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88


if __name__ == "__main__":
    # Parse command-line arguments for matrix dimensions
    parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
    parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
    parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
    parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
    args = parser.parse_args()

    M, N, K = args.m, args.n, args.k

    # Compute total floating-point operations to measure throughput
    total_flops = 2 * M * N * K

89
    result = run(M, N, K)
90

91
92
93
    print(f"Latency: {result.latency}")
    print(f"TFlops: {total_flops / result.latency * 1e-9:.3f}")
    print(f"Config: {result.config}")
94

95
    print(f"Reference TFlops: {total_flops / result.ref_latency * 1e-9:.3f}")