test_arith_simplify.py 3.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from tilelang import tvm
import tilelang.testing
from tvm import tir
import tvm.ir


def test_simplify_reshape_flattened_index():
    ana = tvm.arith.Analyzer()

    i0 = tir.Var("i0", "int64")
    i1 = tir.Var("i1", "int64")
    ana.bind(i0, tvm.ir.Range(0, 8))
    ana.bind(i1, tvm.ir.Range(0, 3))

    i_flattened = i0 * 3 + i1
    tvm.ir.assert_structural_equal(
        ana.simplify((i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + (i_flattened) % 4),
        i_flattened,
    )


dtype = tvm.testing.parameter(
    "uint8",
    "uint16",
    "uint32",
    "uint64",
    "int8",
    "int16",
    "int32",
    "int64",
    "float16",
    "float32",
    "float64",
)


def test_can_prove_self_identity(dtype):
    ana = tvm.arith.Analyzer()

    n = tir.Var("n", dtype)
    assert ana.can_prove(n == n)


def test_can_prove_self_equal_to_self(dtype):
    ana = tvm.arith.Analyzer()

    n = tir.Var("n", dtype)
    assert ana.can_prove_equal(n, n)


def test_simplify_symbolic_comparison():
    ana = tvm.arith.Analyzer()

    i0 = tir.Var("i0", "int64")
    i1 = tir.Var("i1", "int64")
    n, m = tvm.tir.SizeVar("n", "int64"), tvm.tir.SizeVar("m", "int64")
    outer = (n + 31) // 32
    ana.bind(i0, tvm.ir.Range(0, outer))
    ana.bind(i1, tvm.ir.Range(0, 32))
    PS = tvm.arith.ProofStrength

    assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)
    assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32 + m, PS.SYMBOLIC_BOUND)
    assert ana.can_prove(i0 * 32 + i1 + 1 <= (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)
    assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1 + 1, PS.SYMBOLIC_BOUND)
    assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, PS.SYMBOLIC_BOUND)


def test_regression_simplify_inf_recursion():
    ana = tvm.arith.Analyzer()
    cond = tir.Var("cond", "int32")

    res = (tvm.tir.NE(cond, 0).astype("int8") - tvm.tir.NE(cond, 0).astype("int8")).astype("int32") == 0
    # regression in a previous case
    # try compare and int set recursive call can cause infinite loop
    ana.rewrite_simplify(res)


def test_simplify_floor_mod_with_linear_offset():
    """
    Test that the floor_mod is simplified correctly when the offset is linear.
    """
    ana = tvm.arith.Analyzer()
    past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", "int64")
    expr1 = (past_decoder_sequence_length + 1) * 64
    divisor1 = (past_decoder_sequence_length + 1) * 32
    assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor1), 0)
    divisor2 = 32 * (past_decoder_sequence_length + 1)
    assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor2), 0)


def test_simplify_float_division():
    # Test for the discussion:
    # https://discuss.tvm.apache.org/t/discuss-is-constant-division-to-multiplication-rewrite-in-tvm-necessary/18615
    ana = tvm.arith.Analyzer()
    x = tir.Var("x", "float32")
    ry = x / 27
    # in old version, the division will be rewritten into x * T.float32(1 / 27)
    sy = ana.rewrite_simplify(ry)
    tvm.ir.assert_structural_equal(ry, sy)


if __name__ == "__main__":
    tilelang.testing.main()