test_tilelang_primitives_mma.py 9.61 KB
Newer Older
1
2
3
4
from tilelang import tvm as tvm
import tilelang.testing
from tilelang import primitives as P

5

6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def matmul_ssr(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
):
    A_shape = (K, M) if trans_A else (M, K)
    B_shape = (N, K) if trans_B else (K, N)
    A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
    B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
LeiWang1999's avatar
LeiWang1999 committed
25
    shared_scope = "shared"  # or "shared.dyn" for dynamic shared memory
26
27
28
29
    import tilelang.language as T

    @T.prim_func
    def main(
30
31
32
        A: T.Tensor(A_shape, in_dtype),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
33
    ):
34
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
LeiWang1999's avatar
LeiWang1999 committed
35
36
            A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
37
38
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.clear(C_local)
39
            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
40
                if trans_A:
41
                    T.copy(A[ko * block_K, by * block_M], A_shared)
42
                else:
43
                    T.copy(A[by * block_M, ko * block_K], A_shared)
44
                if trans_B:
45
                    T.copy(B[bx * block_N, ko * block_K], B_shared)
46
                else:
47
                    T.copy(B[ko * block_K, bx * block_N], B_shared)
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
                P.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_matmul_ssr(
    M,
    N,
    K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    dtypeAccum,
    block_M,
    block_N,
    block_K,
    num_stages=3,
    num_threads=128,
):
    program = matmul_ssr(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_stages,
        num_threads,
    )
84
85
86
87
88
89
90
    # TODO(lei): gemm_v2 with tma is not fully tested.
    kernel = tilelang.compile(
        program,
        out_idx=[2],
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
91
92
        },
    )
93
    profiler = kernel.get_profiler()
94
95
96
97
98
99
100
101
102
103
104
105

    def ref_program(A, B):
        import torch

        if trans_A:
            A = A.T
        if trans_B:
            B = B.T
        C = torch.matmul(A.to(torch.float), B.to(torch.float))
        C = C.to(torch.__getattribute__(out_dtype))
        return C

106
    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
107
108
109


def test_gemm_f16f16f16_nt_ssr():
110
111
112
    run_matmul_ssr(16, 16, 16, False, True, "float16", "float16", "float16", 16, 16, 16, 0, num_threads=32)
    run_matmul_ssr(128, 128, 128, False, True, "float16", "float16", "float16", 32, 32, 32, 0, num_threads=64)
    run_matmul_ssr(1024, 1024, 1024, False, True, "float16", "float16", "float16", 128, 128, 32, 2, num_threads=128)
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134


def matmul_rsr(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
):
    A_shape = (K, M) if trans_A else (M, K)
    B_shape = (N, K) if trans_B else (K, N)
    A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
    B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
    A_local_shape = A_shared_shape
LeiWang1999's avatar
LeiWang1999 committed
135
    shared_scope = "shared"  # or "shared.dyn" for dynamic shared memory
136
137
138
139
    import tilelang.language as T

    @T.prim_func
    def main(
140
141
142
        A: T.Tensor(A_shape, in_dtype),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
143
    ):
144
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
LeiWang1999's avatar
LeiWang1999 committed
145
146
            A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
147
148
149
            A_local = T.alloc_fragment(A_local_shape, in_dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.clear(C_local)
150
            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
151
                if trans_A:
152
                    T.copy(A[ko * block_K, by * block_M], A_shared)
153
                else:
154
                    T.copy(A[by * block_M, ko * block_K], A_shared)
155
                if trans_B:
156
                    T.copy(B[bx * block_N, ko * block_K], B_shared)
157
                else:
158
                    T.copy(B[ko * block_K, bx * block_N], B_shared)
LeiWang1999's avatar
LeiWang1999 committed
159
                T.copy(A_shared, A_local)
160
                P.gemm(A_local, B_shared, C_local, trans_A, trans_B)
LeiWang1999's avatar
LeiWang1999 committed
161
                # T.gemm(A_local, B_shared, C_local, trans_A, trans_B)
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_matmul_rsr(
    M,
    N,
    K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    dtypeAccum,
    block_M,
    block_N,
    block_K,
    num_stages=3,
    num_threads=128,
):
    program = matmul_rsr(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_stages,
        num_threads,
    )
197
198
199
200
201
202
    kernel = tilelang.compile(
        program,
        out_idx=[2],
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
203
204
        },
    )
205
    profiler = kernel.get_profiler()
206
207
208
209
210
211
212
213
214
215
216
217

    def ref_program(A, B):
        import torch

        if trans_A:
            A = A.T
        if trans_B:
            B = B.T
        C = torch.matmul(A.to(torch.float), B.to(torch.float))
        C = C.to(torch.__getattribute__(out_dtype))
        return C

218
    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
219
220


LeiWang1999's avatar
LeiWang1999 committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# TODO(lei): Fix the test case in future release
# Now it has some bugs related to is_m_first
# def test_gemm_f16f16f16_nt_rsr():
#     run_matmul_rsr(
#         1024,
#         1024,
#         1024,
#         False,
#         True,
#         "float16",
#         "float16",
#         "float16",
#         128,
#         128,
#         32,
#         0,
#         num_threads=128,
#     )
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265


def matmul_rrr(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
):
    A_shape = (K, M) if trans_A else (M, K)
    B_shape = (N, K) if trans_B else (K, N)
    A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
    B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
    A_local_shape = A_shared_shape
    B_local_shape = B_shared_shape
    import tilelang.language as T

    @T.prim_func
    def main(
266
267
268
        A: T.Tensor(A_shape, in_dtype),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
269
    ):
270
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            A_local = T.alloc_fragment(A_local_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            B_local = T.alloc_fragment(B_local_shape, in_dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                if trans_A:
                    T.copy(A[k * block_K, by * block_M], A_shared)
                    T.copy(A_shared, A_local)
                else:
                    T.copy(A[by * block_M, k * block_K], A_shared)
                    T.copy(A_shared, A_local)
                if trans_B:
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                    T.copy(B_shared, B_local)
                else:
                    T.copy(B[k * block_K, bx * block_N], B_shared)
                    T.copy(B_shared, B_local)
                P.gemm(A_local, B_local, C_local, trans_A, trans_B)
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_matmul_rrr(
    M,
    N,
    K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    dtypeAccum,
    block_M,
    block_N,
    block_K,
    num_stages=3,
    num_threads=128,
):
    program = matmul_rrr(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_stages,
        num_threads,
    )
326
327
328
329
330
331
    kernel = tilelang.compile(
        program,
        out_idx=[2],
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
332
333
        },
    )
334
    profiler = kernel.get_profiler()
335
336
337
338
339
340
341
342
343
344
345
346

    def ref_program(A, B):
        import torch

        if trans_A:
            A = A.T
        if trans_B:
            B = B.T
        C = torch.matmul(A.to(torch.float), B.to(torch.float))
        C = C.to(torch.__getattribute__(out_dtype))
        return C

347
    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
348
349


LeiWang1999's avatar
LeiWang1999 committed
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
# def test_gemm_f16f16f16_nt_rrr():
#     run_matmul_rrr(
#         1024,
#         1024,
#         1024,
#         False,
#         True,
#         "float16",
#         "float16",
#         "float16",
#         128,
#         128,
#         32,
#         2,
#     )

if __name__ == "__main__":
367
    tilelang.testing.main()