"test/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "53ae92cc2b16679e120d4e164c66a4c9f40761d0"
test_tilelang_language_let.py 647 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import tilelang.testing
from tilelang import tvm as tvm
from tilelang import language as T


def test_let_vectorize_load():

    @T.prim_func
    def main(A_ptr: T.handle):
        A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)

        for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
            for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
14
                b = A[0, 0:4]
15
16
17
18
                A[0, 4:8] = b

    mod = tvm.IRModule({"main": main})
    mod = tvm.compile(mod, target="cuda")
19
    assert "float4 b" in mod.mod.imports[0].inspect_source()
20
21
22
23


if __name__ == "__main__":
    tilelang.testing.main()