"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "b7ca02c7b339b18333d77caf08e356b69e9cde11"
test_tilelang_nested_loop_checker.py 19.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
import pytest

tilelang.testing.set_random_seed()


def _require_cuda_tensor(shape, dtype=torch.float32):
    if not torch.cuda.is_available():
        pytest.skip("CUDA not available")
    try:
        return torch.randn(*shape, device="cuda", dtype=dtype)
    except RuntimeError as err:
        pytest.skip(f"CUDA runtime unavailable: {err}")


"""
Nested Parallel cases:

T.Parallel
    T.Parallel

Rule:
    - continuous parallels is allowed and will be merged into one T.Parallel.
    - Non-continuous (e.g. with some statements in the outer-loop) are forbidden.
"""


@tilelang.jit(out_idx=[1])
def nested_continuous_parallels(length=256, block=16, dtype="float32"):
    @T.prim_func
    def main(
35
36
        A: T.Tensor((length,), dtype),
        B: T.Tensor((length,), dtype),
37
38
39
40
41
42
43
44
45
46
47
48
49
    ):
        with T.Kernel(1, threads=length) as _:
            for i in T.Parallel(length // block):
                for j in T.Parallel(block):
                    B[i * block + j] = A[i * block + j] + 1.0

    return main


@tilelang.jit(out_idx=[1])
def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype="float32"):
    @T.prim_func
    def main(
50
51
        A: T.Tensor((length,), dtype),
        B: T.Tensor((length,), dtype),
52
53
54
55
56
    ):
        with T.Kernel(1, threads=length) as _:
            for i in T.Parallel(length // block1 // block2):
                for j in T.Parallel(block1):
                    for k in T.Parallel(block2):
57
                        B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0
58
59
60
61
62
63
64
65

    return main


@tilelang.jit(out_idx=[1])
def nested_noncontinuous_parallels(length=256, block=16, dtype="float32"):
    @T.prim_func
    def main(
66
67
        A: T.Tensor((length,), dtype),
        B: T.Tensor((length,), dtype),
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    ):
        with T.Kernel(1, threads=length) as _:
            for i in T.Parallel(length // block):
                B[i] = 0
                for j in T.Parallel(block):
                    B[i * block + j] = A[i * block + j] + 1.0

    return main


def test_nested_parallels():
    kernel1 = nested_continuous_parallels(length=256, block=16)
    kernel2 = nested_triple_continuous_parallels(length=256, block1=8, block2=2)
    data = _require_cuda_tensor((256,), torch.float32)
    result1 = kernel1(data)
    result2 = kernel2(data)
    torch.testing.assert_close(result1, data + 1.0, atol=1e-5, rtol=1e-5)
    torch.testing.assert_close(result2, data + 1.0, atol=1e-5, rtol=1e-5)

    # This is invalid
    with pytest.raises(ValueError):
        nested_noncontinuous_parallels(length=256, block=16)


"""
Nested Pipeline cases:

T.Pipeline
    T.Pipeline

is OK.
"""


102
103
104
def matmul_nested_pipelines(
    M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, threads, order, stage, extra_pipeline_repeats
):
105
106
107
108
109
110
111
112
113
    A_shape = (K, M) if trans_A else (M, K)
    B_shape = (N, K) if trans_B else (K, N)
    A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
    B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

    import tilelang.language as T

    @T.prim_func
    def main(
114
115
116
        A: T.Tensor(A_shape, in_dtype),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            for _ in T.Pipelined(extra_pipeline_repeats):
                T.clear(C_local)
                for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage):
                    if trans_A:
                        T.copy(A[k * block_K, by * block_M], A_shared)
                    else:
                        T.copy(A[by * block_M, k * block_K], A_shared)
                    if trans_B:
                        T.copy(B[bx * block_N, k * block_K], B_shared)
                    else:
                        T.copy(B[k * block_K, bx * block_N], B_shared)
                    T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
                T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_gemm_nested_pipelines(
    order,
    stage,
    extra_pipeline_repeats,
):
    M = 1024
    N = 1024
    K = 1024
    block_M = 128
    block_N = 128
    block_K = 32
    trans_A = False
    trans_B = False
    in_dtype = "float16"
    out_dtype = "float16"
    dtypeAccum = "float32"
    num_threads = 128
    program = matmul_nested_pipelines(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        trans_A,
        trans_B,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_threads,
        order,
        stage,
        extra_pipeline_repeats,
    )

    kernel = tilelang.compile(
        program,
        out_idx=[2],
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
180
181
        },
    )
182
183
184
185
186
187
188
189
190
191
192
193
    profiler = kernel.get_profiler()

    def ref_program(A, B):
        import torch

        if trans_A:
            A = A.T
        if trans_B:
            B = B.T
        if in_dtype == "float32":
            # Convert float32 to tfloat32 because tfloat32 mma cannot truncate
            # float32 automatically, -0x1000 meas
194
195
            A = (A.view(torch.int32) - 0x1000).view(torch.float32)
            B = (B.view(torch.int32) - 0x1000).view(torch.float32)
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        C = torch.matmul(A.to(torch.float), B.to(torch.float))
        C = C.to(torch.__getattribute__(out_dtype))
        return C

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


def test_nested_pipelines():
    run_gemm_nested_pipelines(order=[0, 1, 2], stage=[0, 0, 1], extra_pipeline_repeats=3)


"""
Nested serial cases:

T.serial
    T.serial

is OK.
"""


@tilelang.jit(out_idx=[1])
def nested_continuous_serials(length=256, block=16, dtype="float32"):
    @T.prim_func
    def main(
221
222
        A: T.Tensor((length,), dtype),
        B: T.Tensor((length,), dtype),
223
224
225
226
227
228
229
230
231
232
233
234
235
    ):
        with T.Kernel(1, threads=length) as _:
            for i in T.serial(length // block):
                for j in T.serial(block):
                    B[i * block + j] = A[i * block + j] + 1.0

    return main


@tilelang.jit(out_idx=[1])
def nested_noncontinuous_serials(length=256, block=16, dtype="float32"):
    @T.prim_func
    def main(
236
237
        A: T.Tensor((length,), dtype),
        B: T.Tensor((length,), dtype),
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
    ):
        with T.Kernel(1, threads=length) as _:
            for i in T.serial(length // block):
                B[i] = 0
                for j in T.serial(block):
                    B[i * block + j] = A[i * block + j] + 1.0

    return main


def test_nested_serials():
    kernel1 = nested_continuous_serials(length=256, block=16)
    data = _require_cuda_tensor((256,), torch.float32)
    result1 = kernel1(data)
    torch.testing.assert_close(result1, data + 1.0, atol=1e-5, rtol=1e-5)

    # This is valid
    nested_noncontinuous_serials(length=256, block=16)


"""
Mixed serial and Parallel loops:

(S-P)
T.serial
    T.Parallel

(P-S)
T.Parallel
    T.serial

Rule:
    - No Parallel - * - Parallel
"""


@tilelang.jit(out_idx=[1])
def nested_continuous_sp(length=256, block=16, dtype="float32"):
    @T.prim_func
    def main(
278
279
        A: T.Tensor((length,), dtype),
        B: T.Tensor((length,), dtype),
280
281
282
283
284
285
286
287
288
289
290
291
292
    ):
        with T.Kernel(1, threads=length) as _:
            for i in T.serial(length // block):
                for j in T.Parallel(block):
                    B[i * block + j] = A[i * block + j] + 1.0

    return main


@tilelang.jit(out_idx=[1])
def nested_continuous_ps(length=256, block=16, dtype="float32"):
    @T.prim_func
    def main(
293
294
        A: T.Tensor((length,), dtype),
        B: T.Tensor((length,), dtype),
295
296
297
298
299
300
301
302
303
304
305
306
307
    ):
        with T.Kernel(1, threads=length) as _:
            for i in T.Parallel(length // block):
                for j in T.serial(block):
                    B[i * block + j] = A[i * block + j] + 1.0

    return main


@tilelang.jit(out_idx=[1])
def nested_continuous_psp(length=256, block1=8, block2=2, dtype="float32"):
    @T.prim_func
    def main(
308
309
        A: T.Tensor((length,), dtype),
        B: T.Tensor((length,), dtype),
310
311
312
313
314
    ):
        with T.Kernel(1, threads=length) as _:
            for i in T.Parallel(length // block1 // block2):
                for j in T.serial(block1):
                    for k in T.Parallel(block2):
315
                        B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0
316
317
318
319
320
321
322
323

    return main


@tilelang.jit(out_idx=[1])
def nested_continuous_sps(length=256, block1=8, block2=2, dtype="float32"):
    @T.prim_func
    def main(
324
325
        A: T.Tensor((length,), dtype),
        B: T.Tensor((length,), dtype),
326
327
328
329
330
    ):
        with T.Kernel(1, threads=length) as _:
            for i in T.serial(length // block1 // block2):
                for j in T.Parallel(block1):
                    for k in T.serial(block2):
331
                        B[i * block1 * block2 + j * block2 + k] = A[i * block1 * block2 + j * block2 + k] + 1.0
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391

    return main


def test_mixed_sp():
    kernel1 = nested_continuous_sp(length=256, block=16)
    kernel2 = nested_continuous_ps(length=256, block=16)
    data = _require_cuda_tensor((256,), torch.float32)
    result1 = kernel1(data)
    result2 = kernel2(data)
    torch.testing.assert_close(result1, data + 1.0, atol=1e-5, rtol=1e-5)
    torch.testing.assert_close(result2, data + 1.0, atol=1e-5, rtol=1e-5)

    # This should be invalid (Undefined behaviour)
    with pytest.raises(ValueError):
        nested_continuous_psp(length=256, block1=16, block2=8)

    kernel3 = nested_continuous_sps(length=256, block1=8, block2=2)
    result3 = kernel3(data)
    torch.testing.assert_close(result3, data + 1.0, atol=1e-5, rtol=1e-5)


"""
Mixed Pipelined and Parallel loops:

(Pi-Pa)
T.Pipelined
    T.Parallel

(Pa-Pi)
T.Parallel
    T.Pipelined

Rule:
    - Pi-Pa is ok where Pa-Pi is not allowed.
    - For more nested cases, refer to the rule of T.Parallel.
"""


def matmul_nested_pipa(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    in_dtype,
    out_dtype,
    accum_dtype,
    threads,
    order,
    stage,
):
    A_shape = (M, K)
    B_shape = (K, N)
    A_shared_shape = (block_M, block_K)
    B_shared_shape = (block_K, block_N)

    @T.prim_func
    def main(
392
393
394
        A: T.Tensor(A_shape, in_dtype),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage):
                for i, j in T.Parallel(block_M, block_K):
                    A_shared[i, j] = A[by * block_M + i, k * block_K + j]
                for i, j in T.Parallel(block_K, block_N):
                    B_shared[i, j] = B[k * block_K + i, bx * block_N + j]

                # T.copy(A[by * block_M, k * block_K], A_shared)
                # T.copy(B[k * block_K, bx * block_N], B_shared)

                T.gemm(A_shared, B_shared, C_local, False, False)
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def matmul_nested_papipa(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    in_dtype,
    out_dtype,
    accum_dtype,
    threads,
    order,
    stage,
):
    A_shape = (M, K)
    B_shape = (K, N)
    A_shared_shape = (block_M, block_K)
    B_shared_shape = (block_K, block_N)

    @T.prim_func
    def main(
437
438
439
        A: T.Tensor(A_shape, in_dtype),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.clear(C_local)
            for _ in T.Parallel(1):
                for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage):
                    for i, j in T.Parallel(block_M, block_K):
                        A_shared[i, j] = A[by * block_M + i, k * block_K + j]
                    for i, j in T.Parallel(block_K, block_N):
                        B_shared[i, j] = B[k * block_K + i, bx * block_N + j]

                    # T.copy(A[by * block_M, k * block_K], A_shared)
                    # T.copy(B[k * block_K, bx * block_N], B_shared)

                    T.gemm(A_shared, B_shared, C_local, False, False)
                T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_gemm_mixed_pp(
    order,
    stage,
):
    M = 1024
    N = 1024
    K = 1024
    block_M = 128
    block_N = 128
    block_K = 32
    in_dtype = "float16"
    out_dtype = "float16"
    dtypeAccum = "float32"
    num_threads = 128

    program = matmul_nested_pipa(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_threads,
        order,
        stage,
    )

    kernel = tilelang.compile(
        program,
        out_idx=[2],
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
498
499
        },
    )
500
501
502
503
504
505
506
507
    profiler = kernel.get_profiler()

    def ref_program(A, B):
        import torch

        if in_dtype == "float32":
            # Convert float32 to tfloat32 because tfloat32 mma cannot truncate
            # float32 automatically, -0x1000 meas
508
509
            A = (A.view(torch.int32) - 0x1000).view(torch.float32)
            B = (B.view(torch.int32) - 0x1000).view(torch.float32)
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        C = torch.matmul(A.to(torch.float), B.to(torch.float))
        C = C.to(torch.__getattribute__(out_dtype))
        return C

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)

    program1 = matmul_nested_papipa(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_threads,
        order,
        stage,
    )
    with pytest.raises(ValueError):
        tilelang.compile(
            program1,
            out_idx=[2],
            pass_configs={
                tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
                tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
537
538
            },
        )
539
540
541
542
543
544


def test_mixed_pp():
    run_gemm_mixed_pp(order=[0, 1, 2], stage=[0, 0, 1])


545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
"""
TiledOp in a T.Parallel is also not permitted.
"""


def matmul_with_parallel(
    M,
    N,
    K,
    block_M,
    block_N,
    block_K,
    in_dtype,
    out_dtype,
    accum_dtype,
    threads,
    order,
    stage,
):
    A_shape = (M, K)
    B_shape = (K, N)
    A_shared_shape = (block_M, block_K)
    B_shared_shape = (block_K, block_N)

    @T.prim_func
    def main(
571
572
573
        A: T.Tensor(A_shape, in_dtype),
        B: T.Tensor(B_shape, in_dtype),
        C: T.Tensor((M, N), out_dtype),
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
            A_shared = T.alloc_shared(A_shared_shape, in_dtype)
            B_shared = T.alloc_shared(B_shared_shape, in_dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage):
                for i, j in T.Parallel(block_M, block_K):
                    A_shared[i, j] = A[by * block_M + i, k * block_K + j]
                for i, j in T.Parallel(block_K, block_N):
                    B_shared[i, j] = B[k * block_K + i, bx * block_N + j]

                # T.copy(A[by * block_M, k * block_K], A_shared)
                # T.copy(B[k * block_K, bx * block_N], B_shared)

                for _ in T.Parallel(1):
                    T.gemm(A_shared, B_shared, C_local, False, False)
            T.copy(C_local, C[by * block_M, bx * block_N])

    return main


def run_gemm_tiled_op_with_parallel(
    order,
    stage,
):
    M = 1024
    N = 1024
    K = 1024
    block_M = 128
    block_N = 128
    block_K = 32
    in_dtype = "float16"
    out_dtype = "float16"
    dtypeAccum = "float32"
    num_threads = 128

    program = matmul_nested_pipa(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_threads,
        order,
        stage,
    )

    kernel = tilelang.compile(
        program,
        out_idx=[2],
        pass_configs={
            tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
            tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
632
633
        },
    )
634
635
636
637
638
639
640
641
    profiler = kernel.get_profiler()

    def ref_program(A, B):
        import torch

        if in_dtype == "float32":
            # Convert float32 to tfloat32 because tfloat32 mma cannot truncate
            # float32 automatically, -0x1000 meas
642
643
            A = (A.view(torch.int32) - 0x1000).view(torch.float32)
            B = (B.view(torch.int32) - 0x1000).view(torch.float32)
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
        C = torch.matmul(A.to(torch.float), B.to(torch.float))
        C = C.to(torch.__getattribute__(out_dtype))
        return C

    profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)

    program1 = matmul_with_parallel(
        M,
        N,
        K,
        block_M,
        block_N,
        block_K,
        in_dtype,
        out_dtype,
        dtypeAccum,
        num_threads,
        order,
        stage,
    )
    with pytest.raises(ValueError):
        tilelang.compile(
            program1,
            out_idx=[2],
            pass_configs={
                tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
                tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
671
672
            },
        )
673
674
675
676
677
678


@tilelang.jit(out_idx=[1])
def tir_op_with_parallel(length=256, block=16, dtype="float32"):
    @T.prim_func
    def main(
679
680
        A: T.Tensor((length,), dtype),
        B: T.Tensor((length,), dtype),
681
682
683
684
685
686
687
688
689
690
691
692
693
    ):
        with T.Kernel(1, threads=length) as _:
            for i in T.Parallel(length // block):
                for j in T.Parallel(block):
                    B[i * block + j] = T.max(A[i * block + j], 0.0)

    return main


@tilelang.jit(out_idx=[1])
def customize_op_with_parallel(length=256, block=16, dtype="float32"):
    @T.prim_func
    def main(
694
695
        A: T.Tensor((length,), dtype),
        B: T.Tensor((length,), dtype),
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
    ):
        with T.Kernel(1, threads=length) as _:
            for i in T.Parallel(length // block):
                for j in T.Parallel(block):
                    B[i * block + j] = A[i * block + j]
                    T.atomic_add(B[i * block + j], 1.0)

    return main


def test_tiled_op_with_parallel():
    run_gemm_tiled_op_with_parallel(order=[0, 1, 2], stage=[0, 0, 1])

    kernel1 = tir_op_with_parallel(length=256, block=16)
    data = _require_cuda_tensor((256,), torch.float32)
    result1 = kernel1(data)
    torch.testing.assert_close(result1, torch.relu(data), atol=1e-5, rtol=1e-5)
    kernel2 = customize_op_with_parallel(length=256, block=16)
    result2 = kernel2(data)
    torch.testing.assert_close(result2, data + 1, atol=1e-5, rtol=1e-5)


718
719
if __name__ == "__main__":
    tilelang.testing.main()