test_tilelang_jit_gemm_cython.py 12.7 KB
Newer Older
1
2
3
4
5
6
7
8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from tilelang import tvm as tvm
import tilelang.language as T
import tilelang.testing
import tilelang
import torch
9
from tilelang.utils.tensor import map_torch_type
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200


def matmul(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
):
    A_shape = (K, M) if trans_A else (M, K)
    B_shape = (N, K) if trans_B else (K, N)
    A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
    B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

    @T.prim_func
    def main(
            A: T.Buffer(A_shape, in_dtype),
            B: T.Buffer(B_shape, in_dtype),
            C: T.Buffer((M, N), out_dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                if trans_A:
                    T.copy(A[k * block_K, by * block_M], A_shared)
                else:
                    T.copy(A[by * block_M, k * block_K], A_shared)
                if trans_B:
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                else:
                    T.copy(B[k * block_K, bx * block_N], B_shared)
                T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_gemm(
    M,
    N,
    K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    dtypeAccum,
    block_M,
    block_N,
    block_K,
    num_stages=3,
    num_threads=128,
):
    program = matmul(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_stages,
        num_threads,
    )

    stramp = "&*(XS)"

    @tvm.register_func("tilelang_callback_cuda_postproc", override=True)
    def tilelang_callback_cuda_postproc(code, _):
        code = f"// {stramp}\n" + code
        return code

    matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="cython")

    kernel_source = matmul_kernel.get_kernel_source()

    assert stramp in kernel_source, f"Expected {stramp} in the kernel source"


def test_gemm_f16f16f16_nn():
    run_gemm(
        512,
        1024,
        768,
        False,
        False,
        "float16",
        "float16",
        "float16",
        128,
        256,
        32,
        2,
    )


def matmu_jit_kernel(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
):
    A_shape = (K, M) if trans_A else (M, K)
    B_shape = (N, K) if trans_B else (K, N)
    A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
    B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

    import tilelang.language as T

    @T.prim_func
    def main(
            A: T.Buffer(A_shape, in_dtype),
            B: T.Buffer(B_shape, in_dtype),
            C: T.Buffer((M, N), out_dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                if trans_A:
                    T.copy(A[k * block_K, by * block_M], A_shared)
                else:
                    T.copy(A[by * block_M, k * block_K], A_shared)
                if trans_B:
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                else:
                    T.copy(B[k * block_K, bx * block_N], B_shared)
                T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_gemm_jit_kernel(
    M,
    N,
    K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    dtypeAccum,
    block_M,
    block_N,
    block_K,
    num_stages=3,
    num_threads=128,
):
    program = matmu_jit_kernel(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_stages,
        num_threads,
    )

    matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="cython")

201
202
203
204
205
    in_dtype = map_torch_type(in_dtype)
    out_dtype = map_torch_type(out_dtype)

    A = torch.randn(M, K, dtype=in_dtype).cuda()
    B = torch.randn(K, N, dtype=in_dtype).cuda()
206
207
208
209
210
211
212
213
214

    if trans_A:
        A = A.T
    if trans_B:
        B = B.T

    def ref_program(A, B):
        import torch
        C = torch.matmul(A.to(torch.float), B.to(torch.float))
215
        C = C.to(out_dtype)
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        return C

    ref_C = ref_program(A, B)
    C = matmul_kernel(A, B)

    tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)


def test_gemm_jit_kernel():
    run_gemm_jit_kernel(
        512,
        1024,
        768,
        False,
        False,
        "float16",
        "float16",
        "float16",
        128,
        256,
        32,
        2,
    )


def run_cython_kernel_do_bench(M,
                               N,
                               K,
                               trans_A,
                               trans_B,
                               in_dtype,
                               out_dtype,
                               dtypeAccum,
                               block_M,
                               block_N,
                               block_K,
                               num_stages=3,
                               num_threads=128):
    program = matmul(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_stages,
        num_threads,
    )

    cython_matmul_kernel = tilelang.compile(program, execution_backend="cython")
    ctypes_matmul_kernel = tilelang.compile(program, execution_backend="ctypes")

    cython_profiler = cython_matmul_kernel.get_profiler()
    ctypes_profiler = ctypes_matmul_kernel.get_profiler()

    cython_latency = cython_profiler.do_bench(func=cython_matmul_kernel, profiler="torch")
    print(f"cython Latency: {cython_latency} ms")

    # assert ctypes_latency is not None

    tvm_latency = cython_profiler.do_bench()
    print(f"TVM Latency: {tvm_latency} ms")

    assert tvm_latency is not None

    ctypes_latency = ctypes_profiler.do_bench(func=ctypes_matmul_kernel, profiler="torch")
    print(f"ctypes Latency: {ctypes_latency} ms")

    assert cython_latency is not None


def test_cython_kernel_do_bench():
    run_cython_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128,
                               256, 32, 2)


def run_cython_kernel_multi_stream(M,
                                   N,
                                   K,
                                   trans_A,
                                   trans_B,
                                   in_dtype,
                                   out_dtype,
                                   dtypeAccum,
                                   block_M,
                                   block_N,
                                   block_K,
                                   num_stages=3,
                                   num_threads=128):
    program = matmul(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_stages,
        num_threads,
    )

    matmul_kernel = tilelang.compile(program, execution_backend="cython")

328
329
330
331
332
    in_dtype = map_torch_type(in_dtype)
    out_dtype = map_torch_type(out_dtype)

    tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
    tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
333
334
335
336
337

    if trans_A:
        tensor_a = tensor_a.T
    if trans_B:
        tensor_b = tensor_b.T
338
    tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387

    num_streams = 4
    for _ in range(num_streams):
        stream = torch.cuda.Stream()
        with torch.cuda.stream(stream):
            matmul_kernel(tensor_a, tensor_b, tensor_c)


def test_cython_kernel_multi_stream():
    run_cython_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16",
                                   128, 256, 32, 2)


def run_cython_dynamic_shape(M,
                             N,
                             K,
                             trans_A,
                             trans_B,
                             in_dtype,
                             out_dtype,
                             dtypeAccum,
                             block_M,
                             block_N,
                             block_K,
                             num_stages=3,
                             num_threads=128):
    program = matmul(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_stages,
        num_threads,
    )

    matmul_kernel = tilelang.compile(program, execution_backend="cython")
    if isinstance(M, T.Var):
        M = 1024
    if isinstance(N, T.Var):
        N = 1024
    if isinstance(K, T.Var):
        K = 768
388
389
390
391
392
393

    in_dtype = map_torch_type(in_dtype)
    out_dtype = map_torch_type(out_dtype)

    tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
    tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
394
395
396
397
398

    if trans_A:
        tensor_a = tensor_a.T
    if trans_B:
        tensor_b = tensor_b.T
399
    tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
400
401
402

    matmul_kernel(tensor_a, tensor_b, tensor_c)

403
    tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    tilelang.testing.torch_assert_close(
        tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)


def test_cython_dynamic_shape():
    run_cython_dynamic_shape(
        T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)

    run_cython_dynamic_shape(
        T.symbolic("m"), T.symbolic("n"), 768, False, False, "float16", "float16", "float16", 128,
        256, 32, 2)

    run_cython_dynamic_shape(
        T.symbolic("m"), T.symbolic("n"), T.symbolic("k"), False, False, "float16", "float16",
        "float16", 128, 256, 32, 2)


421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
def run_cython_dynamic_shape_with_out_idx(M,
                                          N,
                                          K,
                                          trans_A,
                                          trans_B,
                                          in_dtype,
                                          out_dtype,
                                          dtypeAccum,
                                          block_M,
                                          block_N,
                                          block_K,
                                          num_stages=3,
                                          num_threads=128):
    program = matmul(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_stages,
        num_threads,
    )

    matmul_kernel = tilelang.compile(program, execution_backend="cython", out_idx=-1)
    if isinstance(M, T.Var):
        M = 1024
    if isinstance(N, T.Var):
        N = 1024
    if isinstance(K, T.Var):
        K = 768
457
458
459
460
461
462

    in_dtype = map_torch_type(in_dtype)
    out_dtype = map_torch_type(out_dtype)

    tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
    tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
463
464
465
466
467
468
469
470

    if trans_A:
        tensor_a = tensor_a.T
    if trans_B:
        tensor_b = tensor_b.T

    tensor_c = matmul_kernel(tensor_a, tensor_b)

471
    tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
472
473
474
475
476
477
478
479
480
481

    tilelang.testing.torch_assert_close(
        tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)


def test_cython_dynamic_shape_with_out_idx():
    run_cython_dynamic_shape_with_out_idx(
        T.symbolic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)


482
483
if __name__ == "__main__":
    tilelang.testing.main()