test_tilelang_jit_nvrtc.py 14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from tilelang import tvm as tvm
import tilelang.language as T
import tilelang.testing
import tilelang
import torch
from tilelang.utils.tensor import map_torch_type


def matmul(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
):
    A_shape = (K, M) if trans_A else (M, K)
    B_shape = (N, K) if trans_B else (K, N)
    A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
    B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

    @T.prim_func
    def main(
31
32
33
        A: T.Tensor(A_shape, in_dtype),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                if trans_A:
                    T.copy(A[k * block_K, by * block_M], A_shared)
                else:
                    T.copy(A[by * block_M, k * block_K], A_shared)
                if trans_B:
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                else:
                    T.copy(B[k * block_K, bx * block_N], B_shared)
                T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_gemm(
    M,
    N,
    K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    dtypeAccum,
    block_M,
    block_N,
    block_K,
    num_stages=3,
    num_threads=128,
):
    program = matmul(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_stages,
        num_threads,
    )

    stramp = "&*(XS)"

    @tvm.register_global_func("tilelang_callback_cuda_postproc", override=True)
    def tilelang_callback_cuda_postproc(code, _):
        code = f"// {stramp}\n" + code
        return code

    matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc")

    kernel_source = matmul_kernel.get_kernel_source()

    assert stramp in kernel_source, f"Expected {stramp} in the kernel source"


def test_gemm_f16f16f16_nn():
    run_gemm(
        512,
        1024,
        768,
        False,
        False,
107
108
109
        T.float16,
        T.float16,
        T.float16,
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        128,
        256,
        32,
        2,
    )


def matmu_jit_kernel(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    accum_dtype,
    num_stages,
    threads,
):
    A_shape = (K, M) if trans_A else (M, K)
    B_shape = (N, K) if trans_B else (K, N)
    A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
    B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

    @T.prim_func
    def main(
139
140
141
        A: T.Tensor(A_shape, in_dtype),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                if trans_A:
                    T.copy(A[k * block_K, by * block_M], A_shared)
                else:
                    T.copy(A[by * block_M, k * block_K], A_shared)
                if trans_B:
                    T.copy(B[bx * block_N, k * block_K], B_shared)
                else:
                    T.copy(B[k * block_K, bx * block_N], B_shared)
                T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_gemm_jit_kernel(
    M,
    N,
    K,
    trans_A,
    trans_B,
    in_dtype,
    out_dtype,
    dtypeAccum,
    block_M,
    block_N,
    block_K,
    num_stages=3,
    num_threads=128,
):
    program = matmu_jit_kernel(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_stages,
        num_threads,
    )

    matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc")

    in_dtype = map_torch_type(in_dtype)
    out_dtype = map_torch_type(out_dtype)

    A = torch.randn(M, K, dtype=in_dtype).cuda()
    B = torch.randn(K, N, dtype=in_dtype).cuda()

    if trans_A:
        A = A.T
    if trans_B:
        B = B.T

    def ref_program(A, B):
        import torch
209

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        C = torch.matmul(A.to(torch.float), B.to(torch.float))
        C = C.to(out_dtype)
        return C

    ref_C = ref_program(A, B)
    C = matmul_kernel(A, B)

    tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)


def test_gemm_jit_kernel():
    run_gemm_jit_kernel(
        512,
        1024,
        768,
        False,
        False,
227
228
229
        T.float16,
        T.float16,
        T.float16,
230
231
232
233
234
235
236
        128,
        256,
        32,
        2,
    )


237
238
239
def run_nvrtc_kernel_do_bench(
    M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    program = matmul(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_stages,
        num_threads,
    )

    matmul_kernel = tilelang.compile(program, execution_backend="nvrtc")

    profiler = matmul_kernel.get_profiler()

    nvrtc_latency = profiler.do_bench(func=matmul_kernel)
    print(f"NVRTC Latency: {nvrtc_latency} ms")

    assert nvrtc_latency is not None

    tvm_latency = profiler.do_bench()
    print(f"TVM Latency: {tvm_latency} ms")

    assert tvm_latency is not None


def test_nvrtc_kernel_do_bench():
272
    run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
273
274
275
276
277


def run_nvrtc_kernel_multi_stream(
    M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    program = matmul(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_stages,
        num_threads,
    )

    matmul_kernel = tilelang.compile(program, execution_backend="nvrtc")
    in_dtype = map_torch_type(in_dtype)
    out_dtype = map_torch_type(out_dtype)
    tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
    tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()

    if trans_A:
        tensor_a = tensor_a.T
    if trans_B:
        tensor_b = tensor_b.T
    tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()

    num_streams = 4
    for _ in range(num_streams):
        stream = torch.cuda.Stream()
        with torch.cuda.stream(stream):
            matmul_kernel(tensor_a, tensor_b, tensor_c)


def test_nvrtc_kernel_multi_stream():
314
    run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
315
316
317
318
319


def run_nvrtc_dynamic_shape(
    M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages=3, num_threads=128
):
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    program = matmul(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_stages,
        num_threads,
    )

    matmul_kernel = tilelang.compile(program, execution_backend="nvrtc")
    if isinstance(M, T.Var):
        M = 1024
    if isinstance(N, T.Var):
        N = 1024
    if isinstance(K, T.Var):
        K = 768

    in_dtype = map_torch_type(in_dtype)
    out_dtype = map_torch_type(out_dtype)

    tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
    tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()

    if trans_A:
        tensor_a = tensor_a.T
    if trans_B:
        tensor_b = tensor_b.T
    tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()

    matmul_kernel(tensor_a, tensor_b, tensor_c)

    tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
359
    tilelang.testing.torch_assert_close(tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
360
361
362


def test_nvrtc_dynamic_shape():
363
    run_nvrtc_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
364

365
    run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
366

367
    run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
368
369
370
371
372
373
374
375
376
377


def check_hopper():
    if not torch.cuda.is_available():
        return False
    props = torch.cuda.get_device_properties(0)
    compute_capability = props.major, props.minor
    return compute_capability == (9, 0)


378
def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
379
380
381
382
383
384
    KH, KW = K, K
    OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
    OW = (W + 2 * P - D * (K - 1) - 1) // S + 1

    @T.prim_func
    def main(
385
386
387
        data: T.Tensor((N, H, W, C), dtype),
        kernel: T.Tensor((KH, KW, C, F), dtype),
        out: T.Tensor((N, OH, OW, F), dtype),
388
    ):
389
        with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by):
390
391
392
393
394
395
396
397
            data_shared = T.alloc_shared((block_M, block_K), dtype)
            kernel_shared = T.alloc_shared((block_K, block_N), dtype)
            out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            out_shared = T.alloc_shared((block_M, block_N), dtype)

            kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
            out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)

398
399
400
401
402
403
404
            T.annotate_layout(
                {
                    out_shared: tilelang.layout.make_swizzled_layout(out_shared),
                    data_shared: tilelang.layout.make_swizzled_layout(data_shared),
                    kernel_shared: tilelang.layout.make_swizzled_layout(kernel_shared),
                }
            )
405
406
407
408
409
410
411
412
413
414
415
416
417

            T.clear(out_local)
            for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
                T.c2d_im2col(data, data_shared, by, k_iter, KH, S, D, P)
                T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
                T.gemm(data_shared, kernel_shared, out_local)

            T.copy(out_local, out_shared)
            T.copy(out_shared, out_flat[by * block_M, bx * block_N])

    return main


418
def run_nvrtc_im2col_tma_desc(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages=3, num_threads=256):
419
    """Test im2col TMA descriptor functionality in NVRTC backend."""
420
    program = convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, num_threads)
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437

    conv_kernel = tilelang.compile(program, out_idx=-1, execution_backend="nvrtc")

    a = torch.randn(N, H, W, C).cuda().half()
    b = torch.randn(K, K, C, F).cuda().half()

    out_c = conv_kernel(a, b)

    # Reference implementation using torch.conv2d
    def ref_program(A, B):
        A = A.permute(0, 3, 1, 2)  # N, H, W, C -> N, C, H, W
        B = B.permute(3, 2, 0, 1)  # H, W, C, F -> F, C, H, W
        C = torch.conv2d(A, B, stride=S, padding=P, dilation=D)
        C = C.permute(0, 2, 3, 1)  # N, C, H, W -> N, H, W, C
        return C

    ref_c = ref_program(a, b)
438
    tilelang.testing.torch_assert_close(out_c, ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
439
440
441
442
443
444


def test_nvrtc_im2col_tma_desc():
    """Test im2col TMA descriptor with NVRTC backend."""
    if not check_hopper():
        import pytest
445

446
447
448
449
        pytest.skip("Test requires Hopper GPU (compute capability 9.0)")

    # Small test case for im2col TMA descriptor
    run_nvrtc_im2col_tma_desc(
450
451
        N=4, C=64, H=32, W=32, F=64, K=3, S=1, D=1, P=1, block_M=64, block_N=128, block_K=32, num_stages=3, num_threads=256
    )
452
453
454
455
456
457
458
459
460
461
462
463
464
465


def test_nvrtc_l2_persistent_map():
    """Test L2 persistent cache annotation with elementwise add."""
    from tilelang.language import annotate_l2_hit_ratio

    M = 1024
    N = 1024

    @tilelang.jit(out_idx=[-1], execution_backend="nvrtc")
    def elementwise_add_with_l2_cache(
        M,
        N,
        block_size=256,
466
        dtype=T.float32,
467
468
469
    ):
        @T.prim_func
        def kernel(
470
471
472
            A: T.Tensor((M, N), dtype),
            B: T.Tensor((M, N), dtype),
            C: T.Tensor((M, N), dtype),
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
        ):
            with T.Kernel(M * N // block_size, threads=block_size) as bx:
                # Annotate L2 persistent cache for buffer B
                # B will be accessed multiple times and benefit from L2 caching
                annotate_l2_hit_ratio({B: 0.8})

                for i in T.serial(block_size):
                    idx = bx * block_size + i
                    if idx < M * N:
                        row = idx // N
                        col = idx % N
                        C[row, col] = A[row, col] + B[row, col]

        return kernel

    # Compile the kernel
    kernel = elementwise_add_with_l2_cache(M, N)

    # Create test tensors
    a = torch.randn(M, N, dtype=torch.float32).cuda()
    b = torch.randn(M, N, dtype=torch.float32).cuda()

    # Run kernel with out_idx=[-1], C is returned not passed in
    c = kernel(a, b)

    # Verify correctness
    ref_c = a + b
    tilelang.testing.torch_assert_close(c, ref_c, atol=1e-5, rtol=1e-5)

    print("L2 persistent map test passed!")


if __name__ == "__main__":
    tilelang.testing.main()