test_tilelang_language_var_init.py 783 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
import tilelang
import tilelang.language as T
import tilelang.testing


def test_var_assign() -> None:
    @tilelang.jit(out_idx=-1)
    def jit_kernel():
        @T.prim_func
10
        def test_var_assign(A: T.Tensor((2,), "int32")):
11
            with T.Kernel(1) as _:
12
13
                a = T.alloc_var("int32", init=1)
                b = T.alloc_var("int32", init=a)  # b gets value of a
14
                a = 2
15
                d = T.alloc_var("int32", init=a)  # c gets new value of a
16
17
18
19
20
21
22
23
24
25
26
27
28
                A[0] = b
                A[1] = d

        print(test_var_assign)
        return test_var_assign

    kernel = jit_kernel()
    print(kernel.get_kernel_source())
    res = kernel()
    assert res[0] == 1
    assert res[1] == 2


29
if __name__ == "__main__":
30
    tilelang.testing.main()