from tilelang import tvm import tilelang as tl import tilelang.testing from tvm.script import tir as T @T.prim_func def negative_index_before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): T.func_attr({"tir.noalias": True}) B[0] = A[T.int32(-1)] @T.prim_func def negative_index_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): T.func_attr({"tir.noalias": True}) B[0] = A[T.int32(15)] @T.prim_func def negative_index_loop_before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")): T.func_attr({"tir.noalias": True}) for i in T.serial(4): B[i] = A[-i - 1] @T.prim_func def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")): T.func_attr({"tir.noalias": True}) for i in T.serial(4): B[i] = A[15 - i] @T.prim_func def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): T.func_attr({"tir.noalias": True}) for i in T.serial(16): B[i] = A[shift + i] def test_legalize_negative_index_scalar(): mod = tvm.IRModule({"main": negative_index_before}) transformed = tl.transform.LegalizeNegativeIndex()(mod) tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_expected.body) def test_legalize_negative_index_affine_expr(): mod = tvm.IRModule({"main": negative_index_loop_before}) transformed = tl.transform.LegalizeNegativeIndex()(mod) tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_loop_expected.body) def test_legalize_negative_index_symbolic_passthrough(): mod = tvm.IRModule({"main": negative_index_symbolic_before}) transformed = tl.transform.LegalizeNegativeIndex()(mod) tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_symbolic_before.body) if __name__ == "__main__": tilelang.testing.main()