"src/vscode:/vscode.git/clone" did not exist on "8fd3cf0012965e53c4bc322d9bcb28fc91ddfbdd"
test_tilelang_issue_1210.py 1.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import tilelang
import tilelang.language as T
import tilelang.testing


def _make_kernel(M, N):
    dtype = "bfloat16"

    @T.prim_func
    def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), "int32")):
        with T.Kernel(4, threads=1):
            A = T.alloc_shared([N], dtype)
            B = T.alloc_shared([N], dtype)

            # Regression for a bug where InjectSoftwarePipeline left the loop
            # variable as a free var, causing MakePackedAPI to fail
            for i in T.Pipelined(4, num_stages=1):
                _id = ids[i]
                T.copy(KV[_id, :], A)
                T.clear(B)

    return fwd_main


def test_make_packed_api_no_free_loop_var():
    func = _make_kernel(4, 4)
    # Keep warp-specialization/TMA disabled to match the original repro
    cfg = {
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    }
    tilelang.compile(func, pass_configs=cfg)


if __name__ == "__main__":
    tilelang.testing.main()