test_example_warp_specialize.py 1.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import tilelang.testing

import example_warp_specialize_flashmla
import example_warp_specialize_gemm_barrierpipe_stage2
import example_warp_specialize_gemm_copy_0_gemm_1
import example_warp_specialize_gemm_copy_1_gemm_0
import example_warp_specialize_gemm_softpipe_stage2


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_flashmla():
    example_warp_specialize_flashmla.main()


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_barrierpipe_stage2():
19
    example_warp_specialize_gemm_barrierpipe_stage2.main(M=1024, N=1024, K=1024)
20
21
22
23
24


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_copy_0_gemm_1():
25
    example_warp_specialize_gemm_copy_0_gemm_1.main(M=1024, N=1024, K=1024)
26
27
28
29
30


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_copy_1_gemm_0():
31
    example_warp_specialize_gemm_copy_1_gemm_0.main(M=1024, N=1024, K=1024)
32
33
34
35
36


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_eq(9, 0)
def test_example_warp_specialize_gemm_softpipe_stage2():
37
    example_warp_specialize_gemm_softpipe_stage2.main(M=1024, N=1024, K=1024)
38
39
40
41


if __name__ == "__main__":
    tilelang.testing.main()