# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from tilelang import tvm import tvm.testing from tvm import te from tvm import tir from tvm.arith.analyzer import Analyzer class IntSetChecker: def __init__(self): self.analyzer = tvm.arith.Analyzer() def verify(self, data, dmap, expected): res = self.analyzer.int_set(data, dmap) def err_msg(): return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected) assert self.analyzer.can_prove_equal(res.min_value, expected[0]), err_msg() assert self.analyzer.can_prove_equal(res.max_value, expected[1]), err_msg() def test_basic(): s = tvm.arith.IntervalSet(2, 3) assert s.min_value.value == 2 assert s.max_value.value == 3 s = tvm.arith.IntSet.single_point(2) assert s.min_value.value == 2 assert s.max_value.value == 2 def test_vector(): base = 10 stride = 3 lanes = 2 s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, stride, lanes)) assert s.min_value.value == base assert s.max_value.value == base + stride * (lanes - 1) def test_scalable_vector(): base = 5 s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, 2, tvm.tir.vscale() * 4)) assert s.min_value.value == base assert s.max_value.same_as(tvm.arith.int_set.pos_inf()) def test_add_sub(): ck = IntSetChecker() x, y = te.var("x"), te.var("y") ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10)}, (y, 10 + y)) ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10), y: tvm.arith.IntervalSet(1, 11)}, (1, 21)) ck.verify(x - y, {x: tvm.arith.IntervalSet(0, 10), y: tvm.arith.IntervalSet(1, 11)}, (-11, 9)) def test_mul_div(): ck = IntSetChecker() x, y = te.var("x"), te.var("y") tdiv = tvm.tir.truncdiv ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) ck.verify(x * y, {x: tvm.arith.IntervalSet(0, 10)}, (0, 10 * y)) ck.verify(x * 2, {x: tvm.arith.IntervalSet(1, 10)}, (2, 20)) ck.verify(x * -2, {x: tvm.arith.IntervalSet(1, 10)}, (-20, -2)) ck.verify(tdiv(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, tdiv(10, y))) ck.verify(tdiv(x, 2), {x: tvm.arith.IntervalSet(1, 10)}, (0, 5)) fld = tvm.te.floordiv ck.verify(fld(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y))) ck.verify(fld(x, 2), {x: tvm.arith.IntervalSet(-1, 10)}, (-1, 5)) def test_mod(): ck = IntSetChecker() x, y = te.var("x"), te.var("y") tmod = tvm.tir.truncmod ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) ck.verify(tmod(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, y - 1)) ck.verify(tmod(x, 10), {x: tvm.arith.IntervalSet(1, 10)}, (0, 9)) flm = tvm.te.floormod ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(-10, 10)}, (0, 9)) ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 5)}, (3, 5)) ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(13, 15)}, (3, 5)) ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 15)}, (0, 9)) ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 11)}, (0, 9)) ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(1, 21)}, (0, 9)) fld = tvm.te.floordiv z = te.var("z") ck.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 3)) ck.verify( flm(y, 8), {y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)}, ( z * 8 + x * 4 - 8 * fld(z * 8 + x * 4, 8), z * 8 + x * 4 + 3 - 8 * fld(z * 8 + x * 4, 8), ), ) ck1 = IntSetChecker() ck1.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 2)) ck1.verify(flm(y, 8), {y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)}, (x * 4, x * 4 + 3)) def test_max_min(): ck = IntSetChecker() x, y = te.var("x"), te.var("y") ck.verify(tvm.te.max(x, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (1, 11)) ck.verify(tvm.te.min(x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 9)) ck.verify(tvm.te.min(x, y), {}, (tvm.te.min(x, y), tvm.te.min(x, y))) ck.verify(tvm.te.max(x, y), {}, (tvm.te.max(x, y), tvm.te.max(x, y))) def test_select(): ck = IntSetChecker() # x, y = te.var("x"), te.var("y") x = te.var("x") ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 11)) def check_region_bound(expect_region, var_dom, mode, predicate=None): """Helper to check region bound estimation. Parameters ---------- expect_region: dict The keys are of form (begin, end) or PrimExpr as a single point. The values are expected estimated region or region dict on different bindings. var_dom: dict Map var to iteration domain range. mode: str Specify "lowerbound", "upperbound" or else use strict bound estimation. predicate: PrimExpr Extra predicate, defaults to True. """ if predicate is None: predicate = tvm.tir.IntImm("bool", 1) region = [] expect = [] for k, v in expect_region.items(): if not isinstance(k, (tuple, list)): k = (k, k + 1) region.append(tvm.ir.Range.from_min_extent(k[0], Analyzer().simplify(k[1] - k[0]))) expect.append(v) if mode == "lowerbound": result = tvm.arith.estimate_region_lower_bound(region=region, var_dom=var_dom, predicate=predicate) elif mode == "upperbound": result = tvm.arith.estimate_region_upper_bound(region=region, var_dom=var_dom, predicate=predicate) else: result = tvm.arith.estimate_region_strict_bound(region=region, var_dom=var_dom, predicate=predicate) if result is None: assert all([_ is None for _ in expect]) return assert len(result) == len(expect) for intset, expect_desc in zip(result, expect): if isinstance(expect_desc, dict): # check range on different free var bindings for binding in expect_desc: analyzer = Analyzer() for k, v in binding: analyzer.bind(k, v) expect_begin, expect_end = expect_desc[binding] result_begin = analyzer.simplify(intset.min_value, 3) result_end = analyzer.simplify(intset.max_value + 1, 3) assert analyzer.can_prove_equal(result_begin - expect_begin, 0), f"{result_begin} vs {expect_begin}" assert analyzer.can_prove_equal(result_end - expect_end, 0), f"{result_end} vs {expect_end}" else: # check range expect_begin, expect_end = expect_desc analyzer = Analyzer() assert analyzer.can_prove_equal(intset.min_value - expect_begin, 0), f"{intset.min_value} vs {expect_begin}" assert analyzer.can_prove_equal(intset.max_value - expect_end + 1, 0), f"{intset.max_value} vs {expect_end - 1}" def test_region_bound_not_independent(): # (i, i+2) and (i+2, i+4) are dependent, this the lowerbound is not available i = tvm.tir.Var("i", "int32") var_dom = { i: tvm.ir.Range(begin=0, end=64), } check_region_bound({(i, i + 2): None, (i + 2, i + 4): None}, var_dom, mode="lowerbound") check_region_bound({(i, i + 2): (0, 65), (i + 2, i + 4): (2, 67)}, var_dom, mode="upperbound") # when only a subset of access indices are affine i, j, k = tvm.tir.Var("i", "int32"), tvm.tir.Var("j", "int32"), tvm.tir.Var("k", "int32") var_dom = { i: tvm.ir.Range(begin=0, end=16), j: tvm.ir.Range(begin=0, end=16), k: tvm.ir.Range(begin=0, end=16), } check_region_bound( {i // 4: None, j * 4 + i % 4: None, tir.truncdiv(k, 2): None}, var_dom, predicate=j * 4 + i % 4 > 3, mode="lowerbound", ) check_region_bound( {i // 4: (0, 4), j * 4 + i % 4: (4, 64), tir.truncdiv(k, 2): (0, 8)}, var_dom, predicate=j * 4 + i % 4 > 3, mode="upperbound", ) def test_region_bound_stride_too_wide(): i = tvm.tir.Var("i", "int32") var_dom = {i: tvm.ir.Range(begin=0, end=64)} check_region_bound({(i * 4, i * 4 + 2): None}, var_dom, mode="lowerbound") check_region_bound({(i * 4, i * 4 + 2): (0, 254)}, var_dom, mode="upperbound") def test_region_bound_small_stride(): i = tvm.tir.Var("i", "int32") var_dom = { i: tvm.ir.Range(begin=0, end=64), } check_region_bound({(i * 4, i * 4 + 8): (0, 260)}, var_dom, mode="lowerbound") def test_region_lower_bound_split_predicate(): x_o = tvm.tir.Var("xo", "int32") x_i = tvm.tir.Var("xi", "int32") x = x_o * 4 + x_i var_dom = { x_o: tvm.ir.Range(begin=0, end=16), x_i: tvm.ir.Range(begin=0, end=4), } check_region_bound({(x * 4, x * 4 + 8): (0, 256)}, var_dom, predicate=x < 63, mode="lowerbound") check_region_bound( {(x * 4, x * 4 + 8): (0, 256), (x * 3, x * 3 + 5): (0, 191)}, var_dom, predicate=x < 63, mode="upperbound", ) def test_region_lower_bound_multiple_variables(): div = tvm.tir.floordiv mod = tvm.tir.floormod x = tvm.tir.Var("x", "int32") wid = tvm.tir.Var("wid", "int32") i = div(x, 16) j = div(mod(x, 16), 4) * 8 + mod(x, 4) + div(wid, 32) * 4 k = wid % 32 var_dom = { x: tvm.ir.Range(begin=0, end=32), wid: tvm.ir.Range(begin=0, end=64), } check_region_bound({i: (0, 2), j: (0, 32), k: (0, 32)}, var_dom, mode="lowerbound") def test_region_lower_bound_negative_scale(): i = tvm.tir.Var("i", "int32") j = tvm.tir.Var("j", "int32") var_dom = { i: tvm.ir.Range(begin=0, end=4), j: tvm.ir.Range(begin=0, end=4), } check_region_bound({(1 - i, 5 - i): (-2, 5), (20 - j * 4, 36 - j * 4): (8, 36)}, var_dom, mode="lowerbound") def test_region_lower_bound_for_non_perfect_tile(): h1 = tvm.tir.Var("h1", "int32") h2 = tvm.tir.Var("h2", "int32") h3 = tvm.tir.Var("h3", "int32") # non-uniform tiling, single inner variable var_dom = { h2: tvm.ir.Range(begin=0, end=10), } check_region_bound( { h3 * 8 + h2: { (): ( tvm.tir.max(h3 * 8, 1), tvm.tir.min(0, h3 * 8 - 214) + 224, ), ((h3, 0),): (1, 10), # h3 == 0: region is [1, 10) ((h3, 10),): (h3 * 8, h3 * 8 + 10), # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 10) ((h3, 27),): (h3 * 8, 224), # h3 > 26: region is [h3 * 8, 224) } }, var_dom, predicate=tvm.tir.all(h3 * 8 + h2 >= 1, h3 * 8 + h2 < 224), mode="lowerbound", ) # non-uniform tiling, two inner variables var_dom = { h1: tvm.ir.Range(begin=0, end=5), h2: tvm.ir.Range(begin=0, end=2), } check_region_bound( { h3 * 8 + h2 * 5 + h1: { (): ( tvm.tir.max(h3 * 8, 1), tvm.tir.min(0, h3 * 8 - 214) + 224, ), ((h3, 0),): (1, 10), ((h3, 10),): (h3 * 8, h3 * 8 + 10), ((h3, 27),): (h3 * 8, 224), } }, var_dom, predicate=tvm.tir.all(h3 * 8 + h2 * 5 + h1 >= 1, h3 * 8 + h2 * 5 + h1 < 224), mode="lowerbound", ) # lowerbound should fail on incompatible predicates check_region_bound( {h3 * 8 + h2 * 5 + h1: None}, var_dom, predicate=tvm.tir.all(h3 * 8 + h2 * 5 + h1 >= 1, h3 * 8 + h1 * 2 + h2 < 224), mode="lowerbound", ) check_region_bound( {h3 * 8 + h2 * 5 + h1: (h3 * 8, h3 * 8 + 10)}, var_dom, predicate=tvm.tir.all(h3 * 8 + h2 * 5 + h1 >= 1, h3 * 8 + h1 * 2 + h2 < 224), mode="upperbound", ) def test_region_lower_bound_unfusable(): var_dom = { tvm.tir.Var("i", "int32"): tvm.ir.Range(8), tvm.tir.Var("j", "int32"): tvm.ir.Range(4), } i, j = var_dom check_region_bound({(i + j) // 2: (0, 6)}, var_dom, mode="lowerbound") def test_union_lower_bound(): neg_inf = tvm.arith.int_set.neg_inf() pos_inf = tvm.arith.int_set.pos_inf() set_0 = tvm.arith.IntervalSet(min_value=neg_inf, max_value=0) set_1 = tvm.arith.IntervalSet(min_value=1, max_value=pos_inf) result = tvm.arith.int_set.union_lower_bound([set_0, set_1]) assert result.min_value.same_as(neg_inf) assert result.max_value.same_as(pos_inf) set_2 = tvm.arith.IntervalSet(min_value=pos_inf, max_value=neg_inf) result = tvm.arith.int_set.union_lower_bound([set_0, set_1, set_2]) assert result.min_value.same_as(neg_inf) assert result.max_value.same_as(pos_inf) def test_modular_set(): ck = IntSetChecker() x = tvm.te.var("x", dtype="int32") y = tvm.te.var("y", dtype="int32") expr = (x * 2048 + y * 16) % 7168 ck.verify(expr, {x: tvm.arith.IntervalSet(0, 128), y: tvm.arith.IntervalSet(0, 3584)}, (0, 7152)) if __name__ == "__main__": tvm.testing.main()