test_tilelang_language_assume.py 2.21 KB
Newer Older
LJC00118's avatar
LJC00118 committed
1
2
3
4
5
6
7
8
import tilelang
import tilelang.language as T
import tilelang.testing


def test_assume_remove_boundary_check():
    @tilelang.jit
    def kernel_with_assume():
9
        N = T.dynamic("N")
LJC00118's avatar
LJC00118 committed
10
11
12
13
14
15
16
17
18
19
20
21
22

        @T.prim_func
        def main(A: T.Tensor((N,), "float32"), l: T.int32, r: T.int32):
            with T.Kernel(1, threads=32) as _:
                for i in T.serial(r - l + 1):
                    T.assume(l + i >= 0 and l + i < N)
                    A[l + i] = 0

        return main

    jit_kernel = kernel_with_assume()
    source = jit_kernel.get_kernel_source()

23
    assert "if (" not in source
LJC00118's avatar
LJC00118 committed
24
25
26
27
28


def test_assume_enable_vectorization():
    @tilelang.jit
    def kernel_vectorize(M):
29
        N = T.dynamic("N")
LJC00118's avatar
LJC00118 committed
30
31
32
33
        vectorize_size = 4

        @T.prim_func
        def main(
34
35
            A: T.Tensor((M, N), "float32"),
            B: T.Tensor((M, N), "float32"),
LJC00118's avatar
LJC00118 committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        ):
            with T.Kernel(1, threads=32) as _:
                tid = T.get_thread_binding()

                base_idx = tid * 4
                T.assume(N % vectorize_size == 0)

                for i in T.vectorized(vectorize_size):
                    T.assume(base_idx + i < N)
                    B[tid, base_idx + i] = A[tid, base_idx + i]

        return main

    jit_kernel = kernel_vectorize(128)
    source = jit_kernel.get_kernel_source()

    assert ("float4" in source) and ("if (" not in source)


def test_assume_complex_indexing():
    @tilelang.jit
    def kernel_complex():
58
59
        M = T.dynamic("M")
        N = T.dynamic("N")
LJC00118's avatar
LJC00118 committed
60
61
62

        @T.prim_func
        def main(
63
64
            A: T.Tensor((M, N), "float32"),
            B: T.Tensor((M, N), "float32"),
LJC00118's avatar
LJC00118 committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        ):
            with T.Kernel(1, threads=32) as _:
                tid = T.get_thread_binding()
                for j in T.serial(N):
                    i_src = T.min(j + 233, tid + 2)
                    j_src = j * T.ceildiv(j, i_src) * j - 1

                    T.assume(i_src >= 0 and i_src < M)
                    T.assume(j_src >= 0 and j_src < N)

                    B[tid, j] = A[i_src, j_src]

        return main

    jit_kernel = kernel_complex()
    source = jit_kernel.get_kernel_source()

82
    assert "if (" not in source
LJC00118's avatar
LJC00118 committed
83
84


85
if __name__ == "__main__":
LJC00118's avatar
LJC00118 committed
86
    tilelang.testing.main()