test_example_warp_specialize.py 1.45 KB
Newer Older
1
2
3
4
5
6
7
import tilelang.testing

import example_warp_specialize_gemm_barrierpipe_stage2
import example_warp_specialize_gemm_copy_0_gemm_1
import example_warp_specialize_gemm_copy_1_gemm_0
import example_warp_specialize_gemm_softpipe_stage2

8
9
10
11
12
13
14
# TODO: skip for now as non-deterministic on H20
# CC @cunxiao
# @tilelang.testing.requires_cuda
# @tilelang.testing.requires_cuda_compute_version_eq(9, 0)
# def test_example_warp_specialize_flashmla():
#     import example_warp_specialize_flashmla
#     example_warp_specialize_flashmla.main()
15
16
17
18
19


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_barrierpipe_stage2():
20
    example_warp_specialize_gemm_barrierpipe_stage2.main(M=1024, N=1024, K=1024)
21
22
23
24
25


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_copy_0_gemm_1():
26
    example_warp_specialize_gemm_copy_0_gemm_1.main(M=1024, N=1024, K=1024)
27
28
29
30
31


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_copy_1_gemm_0():
32
    example_warp_specialize_gemm_copy_1_gemm_0.main(M=1024, N=1024, K=1024)
33
34
35
36
37


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_softpipe_stage2():
38
    example_warp_specialize_gemm_softpipe_stage2.main(M=1024, N=1024, K=1024)
39
40
41
42


if __name__ == "__main__":
    tilelang.testing.main()