"docs/vscode:/vscode.git/clone" did not exist on "d110d0871c364577427e946954ece65880fa034f"
test_tilelang_language_let.py 646 Bytes
Newer Older
1
2
3
4
5
6
7
8
import tilelang.testing
from tilelang import tvm as tvm
from tilelang import language as T


def test_let_vectorize_load():
    @T.prim_func
    def main(A_ptr: T.handle):
9
        A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16)
10
11
12

        for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
            for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
13
                b = A[0, 0:4]
14
15
16
17
                A[0, 4:8] = b

    mod = tvm.IRModule({"main": main})
    mod = tvm.compile(mod, target="cuda")
18
    assert "float4 b" in mod.mod.imports[0].inspect_source()
19
20
21
22


if __name__ == "__main__":
    tilelang.testing.main()