"testing/vscode:/vscode.git/clone" did not exist on "8e1845d279172db9ee07f18bb0cccfd1862c4d78"
test_tilelang_language_unroll.py 1.12 KB
Newer Older
1
2
3
4
5
6
7
8
import tilelang.testing
from tilelang import tvm as tvm
from tilelang import language as T


def test_unroll_with_step():
    @T.prim_func
    def main(A_ptr: T.handle):
9
        A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16)
10
11
12
13
14
15
16
17
18
19
20
21
22

        for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
            for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
                for i in T.unroll(0, 16, step=4):
                    A[0, i] = 1.0

    kernel = tilelang.compile(main, target="cuda")
    assert "#pragma unroll" in kernel.get_kernel_source()


def test_unroll_with_unroll_factor():
    @T.prim_func
    def main(A_ptr: T.handle):
23
        A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16)
24
25
26
27
28
29
30
31
32
33
34
35

        for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
            for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
                for i in T.unroll(0, 16, unroll_factor=4):
                    A[0, i] = 1.0

    kernel = tilelang.compile(main, target="cuda")
    assert "#pragma unroll 4" in kernel.get_kernel_source()


if __name__ == "__main__":
    tilelang.testing.main()