test_tilelang_language_negative_index.py 1.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from tilelang import tvm
import tilelang as tl
import tilelang.testing
from tvm.script import tir as T


@T.prim_func
def negative_index_before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
    T.func_attr({"tir.noalias": True})
    B[0] = A[T.int32(-1)]


@T.prim_func
def negative_index_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
    T.func_attr({"tir.noalias": True})
    B[0] = A[T.int32(15)]


@T.prim_func
def negative_index_loop_before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
    T.func_attr({"tir.noalias": True})
    for i in T.serial(4):
        B[i] = A[-i - 1]


@T.prim_func
def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
    T.func_attr({"tir.noalias": True})
    for i in T.serial(4):
        B[i] = A[15 - i]


@T.prim_func
def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"),
                                   B: T.Buffer((16,), "float32")):
    T.func_attr({"tir.noalias": True})
    for i in T.serial(16):
        B[i] = A[shift + i]


def test_legalize_negative_index_scalar():
    mod = tvm.IRModule({"main": negative_index_before})
    transformed = tl.transform.LegalizeNegativeIndex()(mod)
    tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_expected.body)


def test_legalize_negative_index_affine_expr():
    mod = tvm.IRModule({"main": negative_index_loop_before})
    transformed = tl.transform.LegalizeNegativeIndex()(mod)
    tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_loop_expected.body)


def test_legalize_negative_index_symbolic_passthrough():
    mod = tvm.IRModule({"main": negative_index_symbolic_before})
    transformed = tl.transform.LegalizeNegativeIndex()(mod)
    tvm.ir.assert_structural_equal(transformed["main"].body, negative_index_symbolic_before.body)


if __name__ == "__main__":
    tilelang.testing.main()