"...resnet50_tensorflow.git" did not exist on "1d8b2263f00bde1730b52f2e78c95cb14e9899b9"
test_tilelang_issue_1210.py 1.02 KB
Newer Older
1
2
3
4
5
6
import tilelang
import tilelang.language as T
import tilelang.testing


def _make_kernel(M, N):
7
    dtype = T.bfloat16
8
9

    @T.prim_func
10
    def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), T.int32)):
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
        with T.Kernel(4, threads=1):
            A = T.alloc_shared([N], dtype)
            B = T.alloc_shared([N], dtype)

            # Regression for a bug where InjectSoftwarePipeline left the loop
            # variable as a free var, causing MakePackedAPI to fail
            for i in T.Pipelined(4, num_stages=1):
                _id = ids[i]
                T.copy(KV[_id, :], A)
                T.clear(B)

    return fwd_main


def test_make_packed_api_no_free_loop_var():
    func = _make_kernel(4, 4)
    # Keep warp-specialization/TMA disabled to match the original repro
    cfg = {
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    }
    tilelang.compile(func, pass_configs=cfg)


if __name__ == "__main__":
    tilelang.testing.main()