example_gemm_persistent.py 5.55 KB
Newer Older
1
2
3
4
5
6
import tilelang
import tilelang.language as T
from tilelang.carver.arch import driver
import argparse


7
@tilelang.jit(out_idx=[-1])
8
def matmul_non_persistent(M, N, K, block_M, block_N, block_K, threads, num_stages, dtype="float16", accum_dtype="float"):
9
10
    @T.prim_func
    def main(
11
12
13
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((K, N), dtype),
        C: T.Tensor((M, N), dtype),
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    ):
        with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            C_shared = T.alloc_shared((block_M, block_N), dtype)

            T.use_swizzle(10)

            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(A[bx * block_M, k * block_K], A_shared)
                T.copy(B[k * block_K, by * block_N], B_shared)
                T.gemm(A_shared, B_shared, C_local)

            T.copy(C_local, C_shared)
            T.copy(C_shared, C[bx * block_M, by * block_N])

    return main


35
@tilelang.jit(out_idx=[-1])
36
37
38
def matmul_persistent(
    M, N, K, block_M, block_N, block_K, threads, num_stages, dtype="float16", accum_dtype="float", use_persistent_primitive=True
):
39
40
41
42
43
44
45
46
    sm_num = driver.get_num_sms()
    m_blocks = T.ceildiv(M, block_M)
    n_blocks = T.ceildiv(N, block_N)
    waves = T.ceildiv(m_blocks * n_blocks, sm_num)
    group_size = 8

    @T.prim_func
    def main(
47
48
49
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((K, N), dtype),
        C: T.Tensor((M, N), dtype),
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    ):
        with T.Kernel(sm_num, threads=threads) as (block_id):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            C_shared = T.alloc_shared((block_M, block_N), dtype)

            for w in T.serial(waves):
                tile_id = sm_num * w + block_id
                bx = (tile_id // group_size) % m_blocks
                by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size

                if bx * block_M < M and by * block_N < N:
                    T.clear(C_local)
                    for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                        T.copy(A[bx * block_M, k * block_K], A_shared)
                        T.copy(B[k * block_K, by * block_N], B_shared)
                        T.gemm(A_shared, B_shared, C_local)

                    T.copy(C_local, C_shared)
                    T.copy(C_shared, C[bx * block_M, by * block_N])

72
73
    @T.prim_func
    def main_persistent_primitive(
74
75
76
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((K, N), dtype),
        C: T.Tensor((M, N), dtype),
77
78
79
80
81
82
83
    ):
        with T.Kernel(sm_num, threads=threads) as (block_id):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            C_shared = T.alloc_shared((block_M, block_N), dtype)

84
            for bx, by in T.Persistent([T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id):
85
86
87
88
89
90
91
92
93
94
                T.clear(C_local)
                for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                    T.copy(A[bx * block_M, k * block_K], A_shared)
                    T.copy(B[k * block_K, by * block_N], B_shared)
                    T.gemm(A_shared, B_shared, C_local)

                T.copy(C_local, C_shared)
                T.copy(C_shared, C[bx * block_M, by * block_N])

    return main_persistent_primitive if use_persistent_primitive else main
95
96
97
98
99
100


def ref_program(A, B):
    return A @ B


101
def main(M=4096, N=4096, K=4096):
102
103
104
105
106
107
108
109
    total_flops = 2 * M * N * K

    BLOCK_M = 128
    BLOCK_N = 256
    BLOCK_K = 64
    threads = 256
    num_stages = 3

110
    persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages)
111
    persistent_profiler = persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
112
113
114
115
116
117
    persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
    print("Persistent GEMM: All check passed.")
    persistent_latency = persistent_profiler.do_bench(warmup=500)
    print(f"Persistent GEMM Latency: {persistent_latency} ms")
    print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops")

118
119
    non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages)
    non_persistent_profiler = non_persistent_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
120
121
122
123
124
125
126
127
128
129
    non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
    print("Non-Persistent GEMM: All check passed.")
    non_persistent_latency = non_persistent_profiler.do_bench(warmup=500)
    print(f"Non-Persistent GEMM Latency: {non_persistent_latency} ms")
    print(f"Non-Persistent GEMM TFlops: {total_flops / non_persistent_latency * 1e-9} TFlops")

    print(f"Persistent GEMM Speedup: {non_persistent_latency / persistent_latency}")


if __name__ == "__main__":
130
    parser = argparse.ArgumentParser()
131
132
133
    parser.add_argument("--M", type=int, default=8192, help="M dimension")
    parser.add_argument("--N", type=int, default=8192, help="N dimension")
    parser.add_argument("--K", type=int, default=8192, help="K dimension")
134
135
136
    args = parser.parse_args()
    M, N, K = args.M, args.N, args.K
    main(M, N, K)