example_gemm_persistent.py 6.11 KB
Newer Older
1
2
3
4
5
6
import tilelang
import tilelang.language as T
from tilelang.carver.arch import driver
import argparse


7
@tilelang.jit(out_idx=[-1])
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def matmul_non_persistent(M,
                          N,
                          K,
                          block_M,
                          block_N,
                          block_K,
                          threads,
                          num_stages,
                          dtype="float16",
                          accum_dtype="float"):

    @T.prim_func
    def main(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), dtype),
    ):
        with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            C_shared = T.alloc_shared((block_M, block_N), dtype)

            T.use_swizzle(10)

            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(A[bx * block_M, k * block_K], A_shared)
                T.copy(B[k * block_K, by * block_N], B_shared)
                T.gemm(A_shared, B_shared, C_local)

            T.copy(C_local, C_shared)
            T.copy(C_shared, C[bx * block_M, by * block_N])

    return main


45
@tilelang.jit(out_idx=[-1])
46
47
48
49
50
51
52
53
54
def matmul_persistent(M,
                      N,
                      K,
                      block_M,
                      block_N,
                      block_K,
                      threads,
                      num_stages,
                      dtype="float16",
55
56
                      accum_dtype="float",
                      use_persistent_primitive=True):
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

    sm_num = driver.get_num_sms()
    m_blocks = T.ceildiv(M, block_M)
    n_blocks = T.ceildiv(N, block_N)
    waves = T.ceildiv(m_blocks * n_blocks, sm_num)
    group_size = 8

    @T.prim_func
    def main(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), dtype),
    ):
        with T.Kernel(sm_num, threads=threads) as (block_id):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            C_shared = T.alloc_shared((block_M, block_N), dtype)

            for w in T.serial(waves):
                tile_id = sm_num * w + block_id
                bx = (tile_id // group_size) % m_blocks
                by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size

                if bx * block_M < M and by * block_N < N:
                    T.clear(C_local)
                    for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                        T.copy(A[bx * block_M, k * block_K], A_shared)
                        T.copy(B[k * block_K, by * block_N], B_shared)
                        T.gemm(A_shared, B_shared, C_local)

                    T.copy(C_local, C_shared)
                    T.copy(C_shared, C[bx * block_M, by * block_N])

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    @T.prim_func
    def main_persistent_primitive(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), dtype),
    ):
        with T.Kernel(sm_num, threads=threads) as (block_id):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            C_shared = T.alloc_shared((block_M, block_N), dtype)

            for bx, by in T.Persistent(
                [T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id):
                T.clear(C_local)
                for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                    T.copy(A[bx * block_M, k * block_K], A_shared)
                    T.copy(B[k * block_K, by * block_N], B_shared)
                    T.gemm(A_shared, B_shared, C_local)

                T.copy(C_local, C_shared)
                T.copy(C_shared, C[bx * block_M, by * block_N])

    return main_persistent_primitive if use_persistent_primitive else main
115
116
117
118
119
120


def ref_program(A, B):
    return A @ B


121
def main(M=4096, N=4096, K=4096):
122
123
124
125
126
127
128
129
    total_flops = 2 * M * N * K

    BLOCK_M = 128
    BLOCK_N = 256
    BLOCK_K = 64
    threads = 256
    num_stages = 3

130
    persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages)
131
132
133
134
135
136
137
138
    persistent_profiler = persistent_kernel.get_profiler(
        tensor_supply_type=tilelang.TensorSupplyType.Randn)
    persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
    print("Persistent GEMM: All check passed.")
    persistent_latency = persistent_profiler.do_bench(warmup=500)
    print(f"Persistent GEMM Latency: {persistent_latency} ms")
    print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops")

139
140
    non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads,
                                                  num_stages)
141
142
143
144
145
146
147
148
149
150
151
152
    non_persistent_profiler = non_persistent_kernel.get_profiler(
        tensor_supply_type=tilelang.TensorSupplyType.Randn)
    non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
    print("Non-Persistent GEMM: All check passed.")
    non_persistent_latency = non_persistent_profiler.do_bench(warmup=500)
    print(f"Non-Persistent GEMM Latency: {non_persistent_latency} ms")
    print(f"Non-Persistent GEMM TFlops: {total_flops / non_persistent_latency * 1e-9} TFlops")

    print(f"Persistent GEMM Speedup: {non_persistent_latency / persistent_latency}")


if __name__ == "__main__":
153
154
155
156
157
158
159
    parser = argparse.ArgumentParser()
    parser.add_argument('--M', type=int, default=8192, help='M dimension')
    parser.add_argument('--N', type=int, default=8192, help='N dimension')
    parser.add_argument('--K', type=int, default=8192, help='K dimension')
    args = parser.parse_args()
    M, N, K = args.M, args.N, args.K
    main(M, N, K)