# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from tilelang import tvm import tilelang.testing from tvm.tir import floordiv, floormod from tvm.script import tir as T def ifuse(inputs, pred_extent=None): """Fuse iterators""" value, extent = 0, 1 for i, ext in inputs: value = value * ext + i extent = extent * ext return value, extent if pred_extent is None else pred_extent def isplit(axis, factor): """Split iterators""" fld = tvm.tir.floordiv flm = tvm.tir.floormod return [ (fld(axis[0], factor), fld(axis[1] + (factor - 1), factor)), (flm(axis[0], factor), factor), ] def var_dom(iters): """Get domains of iterators""" return {var: tvm.ir.Range(0, ext) for var, ext in iters} def convert_iter_expr(expr): return tvm.arith.normalize_iter_map_to_expr(expr) def assert_iter_sum_pattern(expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True): keys = list(expect_dict.keys()) res = tvm.arith.detect_iter_map( keys, dom_map, predicate=predicate, check_level=check_level, simplify_trivial_iterators=simplify_trivial_iterators, ) indices = res.indices assert len(indices) == len(keys), res.errors for i, input_iter in enumerate(keys): spec = expect_dict[input_iter] ( extent, base, ) = spec[0:2] scale = spec[2] if len(spec) > 2 else 1 expect_iter = spec[3] if len(spec) > 3 else None sum_expr = indices[i] assert isinstance(sum_expr, tvm.arith.IterSumExpr) if extent == 1: assert len(sum_expr.args) == 0 else: assert len(sum_expr.args) == 1 tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent) tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale) tvm.testing.assert_prim_expr_equal(sum_expr.base, base) if expect_iter is not None: if not isinstance(expect_iter, tvm.arith.IterMapExpr): sum_expr = convert_iter_expr(sum_expr) tvm.ir.assert_structural_equal(sum_expr, expect_iter) def assert_iter_map_simplify(expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True): keys = list(expect_dict.keys()) _imap = tvm.arith.detect_iter_map( keys, dom_map, predicate=predicate, check_level=check_level, simplify_trivial_iterators=simplify_trivial_iterators, ) res = tvm.arith.iter_map_simplify( keys, dom_map, predicate=predicate, check_level=check_level, simplify_trivial_iterators=simplify_trivial_iterators, ) for i, input_expr in enumerate(keys): expected_expr = expect_dict[input_expr] tvm.ir.assert_structural_equal(res[i], expected_expr) def assert_iter_sum_failure(iters, dom_map, predicate=True, check_level="surjective"): res = tvm.arith.detect_iter_map(list(iters), dom_map, predicate=predicate, check_level=check_level).indices assert len(res) == 0 def test_trivial(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") z = tvm.tir.Var("z", "int32") dom_map = var_dom([(x, 3), (y, 4), (z, 1)]) assert_iter_sum_pattern({x: (3, 0), y: (4, 0), 3: (1, 3)}, dom_map) assert_iter_sum_pattern({x: (3, 0), 3: (1, 3)}, dom_map) # not independent assert_iter_sum_failure([x, x, 3], dom_map) assert_iter_sum_pattern({x: (3, 0), y: (4, 0)}, dom_map, check_level="bijective", simplify_trivial_iterators=True) assert_iter_sum_pattern({x: (3, 0), y: (4, 0)}, dom_map, check_level="bijective", simplify_trivial_iterators=False) assert_iter_sum_failure([x, z], dom_map, check_level="bijective") def test_fuse(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") c = tvm.tir.SizeVar("c", "int32") c0 = tvm.tir.SizeVar("c0", "int32") assert_iter_sum_pattern({y * 3 + 1 + c + x: (12, 1 + c)}, var_dom([(x, 3), (y, 4)])) assert_iter_sum_pattern({ifuse([(x, 3), (y, 4)])[0]: (12, 0)}, var_dom([(x, 3), (y, 4)])) # fuse with symbolic factor assert_iter_sum_pattern({(y + 1) * c + x: (4 * c, c)}, var_dom([(x, c), (y, 4)])) # duplication assert_iter_sum_failure([y * 3 + x, y], var_dom([(x, 3), (y, 4)])) assert_iter_sum_failure([y, x + 1, y], var_dom([(x, 3), (y, 4)])) # factor mismatch assert_iter_sum_failure([y * 4 + x], var_dom([(x, 3), (y, 4)])) # simple stride pattern assert_iter_sum_pattern({x * 4 + y * 2: (6, 0, 2, (x * 2 + y) * 2)}, var_dom([(x, 3), (y, 2)])) # simple stride pattern with symbolic assert_iter_sum_pattern({x * 2 * c0 + y * 2: (3 * c0, 0, 2, (x * c0 + y) * 2)}, var_dom([(x, 3), (y, c0)])) def test_split(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") c0 = tvm.tir.SizeVar("c0", "int32") c1 = tvm.tir.SizeVar("c1", "int32") fld = tvm.tir.floordiv flm = tvm.tir.floormod assert_iter_sum_pattern({fld(x, 3): (8, 0), flm(x, 3) * 2 + c1: (3, c1, 2)}, var_dom([(x, 24)])) assert_iter_sum_pattern({fld(x, 6): (4, 0), fld(flm(x, 6), 2): (3, 0), flm(x, 2): (2, 0)}, var_dom([(x, 24)])) # simple symbolic bound # TODO(tvm-team) improve symbolic divisible check to enable # more complicated symbolic bound assert_iter_sum_pattern({fld(x, c0): (c1, 0), flm(x, c0): (c0, 0)}, var_dom([(x, c1 * c0)])) assert_iter_sum_pattern({fld(x * 2, 4): (4, 0, 1), flm(x * 2, 4): (2, 0, 2)}, var_dom([(x, 8)])) assert_iter_sum_pattern( { fld(x * 2, 4) * 4 + flm(x * 2, 4): (8, 0, 2), }, var_dom([(x, 8)]), ) assert_iter_sum_failure([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)])) # domain of x is undefined assert_iter_sum_pattern({fld(flm(x, 49) + y, 49): (1, fld(flm(x, 49) + y, 49))}, var_dom([(y, 1)])) def test_compound(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") xo, xi = isplit((x, 10), 5) yo, yi = isplit((y, 9), 3) z = ifuse([yo, xo, yi]) # reconstruct the pattern manually mx = tvm.arith.IterMark(x, 10) my = tvm.arith.IterMark(y, 9) xoscale = 3 yoscale = 6 yiscale = 1 mxo = tvm.arith.IterSplitExpr(mx, 5, 2, xoscale) myo = tvm.arith.IterSplitExpr(my, 3, 3, yoscale) myi = tvm.arith.IterSplitExpr(my, 1, 3, yiscale) mz = tvm.arith.IterMark(tvm.arith.IterSumExpr([myo, mxo, myi], 0), 18) sz = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(mz, 1, 18, 1)], 0) assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)])) def test_compound_floormod_two_regression(): x = tvm.tir.Var("x", "int32") fld = tvm.tir.floordiv flm = tvm.tir.floormod # regression # extent of 2 of negative scale cannot be normalized assert_iter_sum_failure( [fld(x, 2) * 2 - flm(x, 2) + 1], dom_map=var_dom([(x, 8)]), ) def test_predicate(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") z = tvm.tir.Var("z", "int32") # available constraints # upper bound only assert_iter_sum_pattern({x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 128) assert_iter_sum_pattern({x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y <= 127) # lower bound only assert_iter_sum_pattern({x * 10 + y: (124, 6)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y > 5) assert_iter_sum_pattern({x * 10 + y: (124, 6)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y >= 6) # lower bound + upper bound assert_iter_sum_pattern( {x * 10 + y: (122, 6)}, var_dom([(x, 13), (y, 10)]), predicate=tvm.tir.And(x * 10 + y > 5, x * 10 + y < 128), ) assert_iter_sum_pattern( {x * 10 + y: (122, 6)}, var_dom([(x, 13), (y, 10)]), predicate=tvm.tir.And(x * 10 + y >= 6, x * 10 + y <= 127), ) assert_iter_sum_pattern( {x * 64 + y * 4 + z: (16, 16)}, var_dom([(x, 16), (y, 16), (z, 4)]), predicate=tvm.tir.And(x * 64 + y * 4 + z < 32, x * 16 + y >= 4), ) # constraints on one fused iter i = tvm.tir.Var("i", "int32") j = tvm.tir.Var("j", "int32") k = tvm.tir.Var("k", "int32") assert_iter_sum_pattern( {i * 8 + j * 2 + k: (88, 1)}, var_dom([(i, 11), (j, 5), (k, 2)]), predicate=tvm.tir.all(j * 2 + k >= 1, j * 2 + k < 9), ) # constraints on single var assert_iter_sum_pattern({i: (10, 0)}, var_dom([(i, 48)]), predicate=i < 10) # iterations are subparts of constraint, invalid case 1 assert_iter_sum_failure( [i, j, k], var_dom([(i, 128), (j, 128), (k, 128)]), predicate=tvm.tir.all(i * 16384 + j * 128 + k < 100), ) # iterations are subparts of constraint, invalid case 2 assert_iter_sum_failure( [i * 128 + j, k], var_dom([(i, 128), (j, 128), (k, 128)]), predicate=i * 16384 + j * 128 + k < 100, ) # irrelevant predicate assert_iter_sum_pattern({i + j: (1, j)}, var_dom([(i, 1)]), predicate=j <= 24) # constraint on nested fused iters assert_iter_sum_pattern( {i * 8 + j * 2 + k: (22, 3)}, var_dom([(i, 11), (j, 5), (k, 2)]), predicate=tvm.tir.all(j * 2 + k >= 1, j * 2 + k < 9, i * 8 + j * 2 + k >= 3, i * 8 + j * 2 + k < 25), ) # duplicate constraint on one fused iter assert_iter_sum_pattern( {i * 6 + j * 2 + k: (66, 2)}, var_dom([(i, 11), (j, 5), (k, 2)]), predicate=tvm.tir.all(j * 2 + k >= 1, j * 2 + k >= 2, j * 2 + k < 8, j * 2 + k < 9), ) # duplicate constraint on nested fused iters assert_iter_sum_pattern( {i * 6 + j * 2 + k: (15, 3)}, var_dom([(i, 11), (j, 5), (k, 2)]), predicate=tvm.tir.all( j * 2 + k >= 1, j * 2 + k >= 2, j * 2 + k < 8, j * 2 + k < 9, i * 6 + j * 2 + k >= 3, i * 6 + j * 2 + k < 25, i * 6 + j * 2 + k >= 1, i * 6 + j * 2 + k < 18, ), ) # constraint on non-disjoint fused iters should fail assert_iter_sum_failure( [i * 8 + j * 2 + k], var_dom([(i, 11), (j, 5), (k, 2)]), predicate=tvm.tir.all(j * 2 + k >= 2, i * 4 + j >= 0), ) # constraints with different lower bound assert_iter_sum_pattern( { (i * 16 + j) // 23 * 8 + (i * 16 + j) % 23 - 15: ( 64, 0, 1, (i * 16 + j) // 23 * 8 + ((i * 16 + j) % 23 + tvm.tir.IntImm("int32", -15)), ) }, var_dom([(i, 12), (j, 16)]), predicate=tvm.tir.And( tvm.tir.And(i * 16 + j < 184, tvm.tir.LE(tvm.tir.IntImm("int32", 8), (i * 16 + j) % 23)), tvm.tir.LE(tvm.tir.IntImm("int32", 15), (i * 16 + j) % 23), ), ) # constraint on many disjoint fused iters, case 1 # i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2) # i2 * 30 + i3 * 15 in [30, 90), extent=60 (= scale of i1) # i1 * 60 in [60, 240), extent=180 (= scale of i0) i0 = tvm.tir.Var("i0", "int32") i1 = tvm.tir.Var("i1", "int32") i2 = tvm.tir.Var("i2", "int32") i3 = tvm.tir.Var("i3", "int32") i4 = tvm.tir.Var("i4", "int32") i5 = tvm.tir.Var("i5", "int32") assert_iter_sum_pattern( {i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5: (540, 93)}, var_dom([(i0, 3), (i1, 4), (i2, 3), (i3, 2), (i4, 3), (i5, 6)]), predicate=tvm.tir.all(i1 >= 1, i2 * 2 + i3 >= 2, i4 * 6 + i5 >= 3), ) # constraint on many disjoint fused iters, case 2 assert_iter_sum_pattern( {i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4: (135, 28)}, var_dom([(i0, 3), (i1, 2), (i2, 5), (i3, 3), (i4, 4)]), predicate=tvm.tir.all(i1 * 5 + i2 >= 3, i1 * 5 + i2 < 8, i3 * 4 + i4 >= 1, i3 * 4 + i4 < 10), ) # constraint on split iters assert_iter_sum_pattern( {i % 16: (7, 3), i // 16: (8, 4)}, var_dom([(i, 1024)]), predicate=tvm.tir.all(i % 16 >= 3, i % 16 < 10, i // 16 >= 4, i // 16 < 12), check_level="bijective", ) # constraint on split iters, nested case 1 assert_iter_sum_pattern( {(i * 32 + j) % 16: (7, 3)}, var_dom([(i, 5), (j, 32)]), predicate=tvm.tir.all((i * 32 + j) % 16 >= 3, (i * 32 + j) % 16 < 10), ) # constraint on split iters, nested case 2 assert_iter_sum_failure( [ (i * 32 + j) % 16, ], var_dom([(i, 5), (j, 32)]), predicate=tvm.tir.all(i * 32 + j >= 1, i * 32 + j <= 32), check_level="bijective", ) assert_iter_sum_pattern( {(i * 32 + j) % 16: (16, 0)}, var_dom([(i, 5), (j, 32)]), predicate=tvm.tir.all(i * 32 + j >= 1, i * 32 + j <= 32), ) assert_iter_sum_pattern( {(i * 32 + j - 1) % 16: (16, 0), (i * 32 + j - 1) // 16: (4, 0)}, var_dom([(i, 5), (j, 32)]), predicate=tvm.tir.all(i * 32 + j >= 1, i * 32 + j <= 64), ) # non-standard form of predicate assert_iter_sum_pattern({x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 < 128 - y) # duplicate constraint assert_iter_sum_pattern( {x * 10 + y: (64, 0)}, var_dom([(x, 13), (y, 10)]), predicate=tvm.tir.all(x * 10 + y < 128, x * 10 + y < 64), ) # useless constraint assert_iter_sum_pattern({x * 10 + y: (130, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 140) i1 = tvm.tir.Var("i1", "int32") i2 = tvm.tir.Var("i2", "int32") i3 = tvm.tir.Var("i3", "int32") i4 = tvm.tir.Var("i4", "int32") assert_iter_sum_pattern( {i1 * 20 + i2 * 10 + i3 * 3 + i4: (128, 0)}, var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), predicate=( tvm.tir.all( i1 * 2 + i2 < 13, i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, i3 * 3 + i4 < 10, ) ), ) # wrong constraint assert_iter_sum_failure( [i1 * 20 + i2 * 10 + i3 * 3 + i4], var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), predicate=( tvm.tir.all( i1 * 2 + i2 < 13, i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, i3 * 3 + i4 < 7, ) ), ) # incompatible constraint assert_iter_sum_failure( [i1 * 20 + i2 * 10 + i3 * 3 + i4], var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), predicate=( tvm.tir.all( i1 * 2 + i2 < 13, i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, i3 * 3 + i4 < 10, i1 * 4 + i3 < 20, ) ), ) assert_iter_sum_failure( [i1 * 20 + i2 * 10 + i3 * 3 + i4], var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), predicate=( tvm.tir.all( i1 * 2 + i2 < 13, i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, i1 * 4 + i3 < 20, ) ), ) # zero iter xo = tvm.tir.Var("xo", "int32") xi = tvm.tir.Var("xi", "int32") y = tvm.tir.Var("y", "int32") assert_iter_sum_pattern( {xo * 129 + xi: (128, 0), y: (128, 0)}, var_dom([(xo, 1), (xi, 129), (y, 128)]), predicate=xo * 129 + xi < 128, ) # strided iteration predicate assert_iter_sum_pattern( {xo * 16 + xi * 4: (10, 0, 4)}, var_dom([(xo, 3), (xi, 4)]), predicate=xo * 4 + xi < 10, ) def convert_division(divisions): if divisions is None or len(divisions) == 0: return [] res = [] for division in divisions[:-1]: res.append( [ tvm.arith.normalize_iter_map_to_expr(division[0].source), tvm.arith.normalize_iter_map_to_expr(division[1].source), ] ) res.append([divisions[-1][0].extent, divisions[-1][1].extent]) return res def create_iter(name, extent): return tvm.tir.Var(name, "int32"), extent def test_subspace_division(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") z = tvm.tir.Var("z", "int32") c = tvm.tir.SizeVar("c", "int32") # simple 1.1 res = tvm.arith.subspace_divide([z * 12 + y * 3 + x + c], var_dom([(x, 3), (y, 4), (z, 5)]), [x]) res = convert_division(res) assert len(res) == 2 tvm.ir.assert_structural_equal(res[0][0], z * 4 + y) tvm.ir.assert_structural_equal(res[0][1], x + c) # simple 1.2 res = tvm.arith.subspace_divide([z * 12 + y * 3 + x + c], var_dom([(x, 3), (y, 4), (z, 5)]), [x], z * 4 + y < 18) res = convert_division(res) assert len(res) == 2 tvm.ir.assert_structural_equal(res[0][0], z * 4 + y) tvm.ir.assert_structural_equal(res[0][1], x + c) tvm.ir.assert_structural_equal(res[1][0], z * 4 + y < 18) tvm.ir.assert_structural_equal(res[1][1], T.bool(True)) # compound 1 i0 = create_iter("i0", 4) j0 = create_iter("j0", 8) i3 = create_iter("i3", 2) i1, i2 = isplit(j0, 4) k0 = ifuse([i0, i1]) k1 = ifuse([i2, i3]) # compound 1.1 res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]]) res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) # assert_iter_sum_pattern res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices assert len(res1) == 2 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])).indices assert len(res2) == 2 # compound 1.2 res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [j0[0], i3[0]]) res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], i0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices assert len(res1) == 2 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])).indices assert len(res2) == 2 # compound 1.3 res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i0[0], i3[0]]) res = convert_division(res) assert len(res) == 0 # compound 1.4 res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], k0[0] < 7) res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) tvm.ir.assert_structural_equal(res[2][0], (i0[0] * 2) + floordiv(j0[0], 4) < 7) tvm.ir.assert_structural_equal(res[2][1], T.bool(True)) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices assert len(res1) == 2 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])).indices assert len(res2) == 2 # compound 1.5 res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [j0[0], i3[0]], k1[0] < 7) res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], i0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) tvm.ir.assert_structural_equal(res[2][0], T.bool(True)) tvm.ir.assert_structural_equal(res[2][1], (floormod(j0[0], 4) * 2) + i3[0] < 7) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices assert len(res1) == 2 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])).indices assert len(res2) == 2 # compound 1.6 res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], tvm.tir.all(k0[0] < 7, k1[0] < 7)) res = convert_division(res) assert len(res) == 0 # compound 2 j0 = create_iter("j0", 4) l0 = create_iter("l0", 2) l1 = create_iter("l1", 6) j3 = create_iter("j3", 3) k0 = ifuse([l0, l1]) i1, j2 = isplit(k0, 3) j1, i1 = isplit(i1, 2) i0 = ifuse([j0, j1]) i2 = ifuse([j2, j3]) # compound 2.1 res = tvm.arith.subspace_divide([i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l1[0], j3[0]]) res = convert_division(res) assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices assert len(res1) == 3 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])).indices assert len(res2) == 3 # compound 2.2 res = tvm.arith.subspace_divide([i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l0[0], l1[0], j3[0]]) res = convert_division(res) assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], j0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(l0[0] * 6 + l1[0], 6)) tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], floordiv(floormod(l0[0] * 6 + l1[0], 6), 3)) tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l0, l1, j3])).indices assert len(res1) == 3 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0])).indices assert len(res2) == 3 # compound 2.3 res = tvm.arith.subspace_divide([i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l0[0], j3[0]]) res = convert_division(res) assert len(res) == 0 # compound 2.4 res = tvm.arith.subspace_divide( [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l1[0], j3[0]], tvm.tir.all(i0[0] < 7, i2[0] < 8), ) res = convert_division(res) assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) tvm.ir.assert_structural_equal(res[3][0], (j0[0] * 2) + l0[0] < 7) tvm.ir.assert_structural_equal(res[3][1], (floormod(l1[0], 3) * 3) + j3[0] < 8) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices assert len(res1) == 3 res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])).indices assert len(res2) == 3 # compound 2.5 res = tvm.arith.subspace_divide([i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [j3[0]], i2[0] < 8) res = convert_division(res) assert len(res) == 0 def test_subspace_divide_trivial_iters(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") # z = tvm.tir.Var("z", "int32") # trivial 1.1 res = tvm.arith.subspace_divide([x * 16 + y], var_dom([(x, 1), (y, 16)]), [y], simplify_trivial_iterators=False) res = convert_division(res) assert len(res) == 2 tvm.ir.assert_structural_equal(res[0][0], x) tvm.ir.assert_structural_equal(res[0][1], y) # trivial 1.2 res = tvm.arith.subspace_divide( [x, y], var_dom([(x, 1), (y, 1)]), [y], simplify_trivial_iterators=False, ) res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], x) tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], y) def test_complex(): n0 = create_iter("n0", 2) n1 = create_iter("n1", 4) m0 = ifuse([n0, n1], 6) m1 = create_iter("m1", 3) l0 = create_iter("l0", 4) l1 = create_iter("l1", 8) l2 = ifuse([m0, m1], 16) l3 = create_iter("l3", 32) k0, k4 = isplit(l0, 2) k1, k5 = isplit(l1, 2) k2, k6 = isplit(l2, 4) k3, k7 = isplit(l3, 4) j0 = ifuse([k0, k1], 7) j1 = ifuse([k2, k3]) j2 = ifuse([k4, k5]) j3 = ifuse([k6, k7], 15) i0 = ifuse([j0, j1], 200) i1 = ifuse([j2, j3], 50) n0_mark = tvm.arith.IterMark(n0[0], n0[1]) n1_mark = tvm.arith.IterMark(n1[0], n1[1]) l0_mark = tvm.arith.IterMark(l0[0], l0[1]) l1_mark = tvm.arith.IterMark(l1[0], l1[1]) m1_mark = tvm.arith.IterMark(m1[0], m1[1]) l3_mark = tvm.arith.IterMark(l3[0], l3[1]) m0_expr = tvm.arith.IterSumExpr( [ tvm.arith.IterSplitExpr(n0_mark, 1, n0[1], 4), tvm.arith.IterSplitExpr(n1_mark, 1, n1[1], 1), ], 0, ) m0_mark = tvm.arith.IterMark(m0_expr, 6) l2_expr = tvm.arith.IterSumExpr( [tvm.arith.IterSplitExpr(m0_mark, 1, 6, 3), tvm.arith.IterSplitExpr(m1_mark, 1, m1[1], 1)], 0, ) l2_mark = tvm.arith.IterMark(l2_expr, 16) k0_expr = tvm.arith.IterSplitExpr(l0_mark, 2, 2, 4) k1_expr = tvm.arith.IterSplitExpr(l1_mark, 2, 4, 1) k2_expr = tvm.arith.IterSplitExpr(l2_mark, 4, 4, 8) k3_expr = tvm.arith.IterSplitExpr(l3_mark, 4, 8, 1) k4_expr = tvm.arith.IterSplitExpr(l0_mark, 1, 2, 30) k5_expr = tvm.arith.IterSplitExpr(l1_mark, 1, 2, 15) k6_expr = tvm.arith.IterSplitExpr(l2_mark, 1, 4, 4) k7_expr = tvm.arith.IterSplitExpr(l3_mark, 1, 4, 1) j0_expr = tvm.arith.IterSumExpr([k0_expr, k1_expr], 0) j0_mark = tvm.arith.IterMark(j0_expr, 7) i0_expr = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(j0_mark, 1, 7, 32), k2_expr, k3_expr], 0) j3_expr = tvm.arith.IterSumExpr([k6_expr, k7_expr], 0) j3_mark = tvm.arith.IterMark(j3_expr, 15) i1_expr = tvm.arith.IterSumExpr([k4_expr, k5_expr, tvm.arith.IterSplitExpr(j3_mark, 1, 15, 1)], 0) i0_mark = tvm.arith.IterMark(i0_expr, i0[1]) i1_mark = tvm.arith.IterMark(i1_expr, i1[1]) i0_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i0_mark, 1, i0[1], 1)], 0) i1_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i1_mark, 1, i1[1], 1)], 0) assert_iter_sum_pattern( {i0[0]: (200, 0, 1, i0_final), i1[0]: (50, 0, 1, i1_final)}, var_dom([l0, l1, n0, n1, m1, l3]), predicate=tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15), ) # wrong constraint assert_iter_sum_failure( [i0[0], i1[0]], var_dom([l0, l1, n0, n1, m1, l3]), tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 9, l2[0] < 16, j0[0] < 7, j3[0] < 14), ) # subspace_division res = tvm.arith.subspace_divide( [i0[0], i1[0]], var_dom([l0, l1, n0, n1, m1, l3]), [n0[0], n1[0], m1[0], l3[0]], tvm.tir.all(m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15), ) res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], floordiv(l0[0], 2) * 4 + floordiv(l1[0], 2)) tvm.ir.assert_structural_equal(res[0][1], (floordiv((n0[0] * 4 + n1[0]) * 3 + m1[0], 4) * 8) + floordiv(l3[0], 4)) tvm.ir.assert_structural_equal(res[1][0], ((floormod(l0[0], 2) * 2) + floormod(l1[0], 2))) tvm.ir.assert_structural_equal(res[1][1], ((floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4) + floormod(l3[0], 4))) tvm.ir.assert_structural_equal(res[2][0], (floordiv(l0[0], 2) * 4) + floordiv(l1[0], 2) < 7) tvm.ir.assert_structural_equal( res[2][1], tvm.tir.all( n0[0] * 4 + n1[0] < 6, (n0[0] * 4 + n1[0]) * 3 + m1[0] < 16, floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4 + floormod(l3[0], 4) < 15, ), ) assert_iter_sum_pattern({res[0][1]: (32, 0), res[1][1]: (15, 0)}, var_dom([n0, n1, m1, l3]), res[2][1]) assert_iter_sum_pattern({res[0][0]: (8, 0), res[1][0]: (4, 0)}, var_dom([l0, l1])) def test_normalize_iter_map_to_expr(): fld = tvm.tir.floordiv flm = tvm.tir.floormod x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") xo, xi = isplit((x, 10), 5) yo, yi = isplit((y, 9), 3) z = ifuse([yo, xo, yi]) res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([(x, 10), (y, 9)])) tvm.ir.assert_structural_equal( tvm.arith.normalize_iter_map_to_expr(res.indices[0]), fld(y, 3) * 6 + fld(x, 5) * 3 + flm(y, 3), ) tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res.indices[1]), flm(x, 5)) # iter mark wrap a complex expr split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x * y + 1, 1024), 1, 1024, 1) tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x * y + 1) def test_inverse_affine_iter_map(): analyzer = tvm.arith.Analyzer() l0 = create_iter("l0", 64) l1 = create_iter("l1", 64) l2 = create_iter("l2", 64) # simple case l0_0, l0_1 = isplit(l0, 16) l1_0, l1_1 = isplit(l1, 4) l0_1_l1_1_fused = ifuse([l0_1, l1_1]) iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1])).indices outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 2 l0_inverse = floordiv(outputs[0], 4) + outputs[1] * 16 l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4 assert analyzer.can_prove_equal(res[l0[0]], l0_inverse) assert analyzer.can_prove_equal(res[l1[0]], l1_inverse) # compound case l0_0, l0_1 = isplit(l0, 16) l1_0, l1_1 = isplit(l1, 4) l2_1, l2_2 = isplit(l2, 4) l2_0, l2_1 = isplit(l2_1, 4) l0_1_l2_1_l1_1_l2_0_fused = ifuse([l0_1, l2_1, l1_1, l2_0]) iter_map = tvm.arith.detect_iter_map([l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], var_dom([l0, l1, l2])).indices outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 3 l0_inverse = floordiv(outputs[0], 64) + outputs[1] * 16 l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4 l2_inverse = floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 + outputs[2] assert analyzer.can_prove_equal(res[l0[0]], l0_inverse) assert analyzer.can_prove_equal(res[l1[0]], l1_inverse) assert analyzer.can_prove_equal(res[l2[0]], l2_inverse) # diamond-shape DAG l0_0, l0_1 = isplit(l0, 16) l1 = ifuse([l0_1, l0_0]) l1_0, l1_1 = isplit(l1, 8) l2 = ifuse([l1_1, l1_0]) iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])).indices outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 1 l1_inverse = floormod(outputs[0], 8) * 8 + floordiv(outputs[0], 8) l0_inverse = floormod(l1_inverse, 4) * 16 + floordiv(l1_inverse, 4) assert analyzer.can_prove_equal(res[l0[0]], l0_inverse) def test_inverse_affine_map_trivial_iter(): analyzer = tvm.arith.Analyzer() l0 = create_iter("l0", 64) l1 = create_iter("l1", 64) iter_map = tvm.arith.detect_iter_map([0, l0[0], l1[0]], var_dom([l0, l1])).indices outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) # output_0 is expected to be constant and it is not included in the inverse map assert len(res) == 2 assert analyzer.can_prove_equal(res[l0[0]], outputs[1]) assert analyzer.can_prove_equal(res[l1[0]], outputs[2]) def test_free_variables(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") z = tvm.tir.Var("z", "int32") # illegal iter if z is within dom assert_iter_sum_failure([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)])) # iter is valid if z is free, even there are linear forms of z assert_iter_sum_pattern( {z * 19 + y * 3 + x: (9, z * 19)}, var_dom( [ (x, 3), (y, 3), ] ), ) assert_iter_sum_pattern( {z * z + y * 3 + x: (9, z * z)}, var_dom( [ (x, 3), (y, 3), ] ), ) class TestPadding: x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") fld = tvm.tir.floordiv flm = tvm.tir.floormod positive_test_case = tvm.testing.parameter( # left padding only, offset divisible ({y: 192}, {fld(64 + y, 32): (6, 2, 1), flm(64 + y, 32): (32, 0, 1)}, "bijective"), # left padding only, offset non-divisible ({y: 176}, {fld(80 + y, 32): (6, 2, 1)}), ({y: 176}, {flm(fld(80 + y, 2), 16): (16, 0, 1), flm(80 + y, 2): (2, 0, 1)}), # right padding only, offset divisible ({x: 5, y: 4}, {fld(x * 32 + y * 8, 16): (10, 0, 1), flm(x * 32 + y * 8, 16): (2, 0, 8)}), # right padding only, offset non-divisible ({x: 26}, {fld(x, 15): (2, 0, 1)}), ({x: 26}, {flm(fld(x, 3), 5): (5, 0, 1), flm(x, 3): (3, 0, 1)}), # padding constants on both side ({x: 45}, {fld(x + 71, 32): (2, 2, 1)}), ({x: 45}, {flm(fld(x, 4), 8): (8, 0, 1), flm(x, 4): (4, 0, 1)}), # padding for free iteration part ({y: 360}, {fld(x * 360 + y, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}), ({y: 360}, {flm(x * 360 + y, 16): (16, 0, 1)}), # multiple split with same mark offset, could # be surjective on missing (padded // LCM) ( {x: 240}, { flm(x + 10, 3): (3, 0), flm(fld(x + 10, 3), 4): (4, 0), flm(fld(fld(x + 10, 3), 4), 5): (5, 0), }, ), # different offsets on splits ( {x: 240}, { flm(x + 1, 3): (3, 0), flm(fld(x + 10, 3) + 2, 4): (4, 0), flm(fld(fld(x + 10, 3), 4) + 3, 5): (5, 0), }, ), ) negative_test_case = tvm.testing.parameter( # left padding only, offset non-divisible ({y: 176}, {fld(80 + y, 32), flm(80 + y, 32)}), ({y: 176}, {fld(80 + y, 32), fld(80 + y, 4)}), # right padding only, offset divisible ({x: 5, y: 4}, {fld(x * 32 + y * 8, 5)}), # multiple split with same mark offset, could # be surjective on missing (padded // LCM) ( {x: 240}, { flm(x + 10, 3), flm(fld(x + 10, 3), 4), flm(fld(fld(x + 10, 3), 4), 5), fld(fld(fld(x + 10, 3), 4), 5), }, ), # original extent is smaller than the divident # it is not surjective wrt to the region [0, 16) ({x: 3}, {flm(x, 16)}), # (x % c1) // c2 is not proved as surjective if c1 % c2 != 0 ({x: 255}, {fld(flm(x, 255), 16)}), ) def test_padding(self, positive_test_case): iter_extent, mapped_iterators, *args = positive_test_case check_level = args[0] if args else "surjective" dom_map = {var: tvm.ir.Range(0, ext) for var, ext in iter_extent.items()} assert_iter_sum_pattern(mapped_iterators, dom_map, check_level=check_level) def test_padding_error(self, negative_test_case): iter_extent, mapped_iterators, *args = negative_test_case check_level = args[0] if args else "surjective" dom_map = {var: tvm.ir.Range(0, ext) for var, ext in iter_extent.items()} assert_iter_sum_failure(mapped_iterators, dom_map, check_level=check_level) def test_overlapped_fuse(): x = tvm.tir.Var("x", "int32") y = tvm.tir.Var("y", "int32") z = tvm.tir.Var("z", "int32") a = tvm.tir.Var("x", "int32") b = tvm.tir.Var("y", "int32") # non-bijective fuse of two assert_iter_sum_pattern( { x * 7 + y: (22, 0, 1), }, var_dom([(x, 3), (y, 8)]), check_level="surjective", ) assert_iter_sum_failure([x * 7 + y], var_dom([(x, 3), (y, 8)]), check_level="bijective") # non-bijective fuse of three assert_iter_sum_pattern( { x * 18 + y * 7 + z: (40, 0, 1), }, var_dom([(x, 2), (y, 3), (z, 8)]), check_level="surjective", ) assert_iter_sum_failure([x * 7 + y], var_dom([(x, 2), (y, 3), (z, 8)]), check_level="bijective") # negative scale fusion is not allowed assert_iter_sum_failure([x * -7 + y], var_dom([(x, 3), (y, 8)]), check_level="surjective") assert_iter_sum_failure([x * 7 - y], var_dom([(x, 3), (y, 8)]), check_level="surjective") # with predicate assert_iter_sum_pattern( { a * 40 + b * 20 + x * 18 + y * 3 + z: (125, 6, 1), }, var_dom([(a, 3), (b, 2), (x, 2), (y, 6), (z, 8)]), predicate=tvm.tir.all(z < 4, x * 6 + y > 1, x * 6 + y < 10), check_level="surjective", ) # stride=1 kernel assert_iter_sum_pattern({x + a: (230, 0, 1)}, var_dom([(x, 224), (a, 7)]), check_level="surjective") # do not allow both strided and overlapped assert_iter_sum_failure([5 * x + 2 * y], var_dom([(x, 4), (y, 3)]), check_level="surjective") def test_iter_map_simplify_symbolic_case(): """Test itermap simplify""" x = tvm.tir.Var("x", "int64") y = tvm.tir.Var("y", "int64") z = x * 32 + y n = tvm.tir.SizeVar("n", "int64") def simple_fuse0(x): return (x // n) * n + x % n assert_iter_map_simplify({simple_fuse0(x): x}, var_dom([(x, n * 32)])) assert_iter_map_simplify({simple_fuse0(z): z}, var_dom([(x, n), (y, 32)])) def fsymbolic_fuse0(x): return ((x // (n * n)) % 32) * (n * n) + ((x // n) % n) * n + x % n assert_iter_map_simplify({fsymbolic_fuse0(x): x}, var_dom([(x, n * n * 32)])) assert_iter_map_simplify({fsymbolic_fuse0(z): z}, var_dom([(x, n * n), (y, 32)])) def fsymbolic_fuse1(x): return ((x % (n * n * 32)) // (n * n) * n + (x % (n * n) // n)) * n + x % n assert_iter_map_simplify({fsymbolic_fuse1(x): x}, var_dom([(x, n * n * 32)])) assert_iter_map_simplify({fsymbolic_fuse1(z): z}, var_dom([(x, n * n), (y, 32)])) def fsymbolic_fuse2(i): return (i // (n * n) * n + i % (n * n) // n) * n + i % n assert_iter_map_simplify({fsymbolic_fuse2(x): x}, var_dom([(x, n * n * 32)])) def test_iter_map_simplify_symbolic_predicate(): """Test itermap simplify""" x = tvm.tir.Var("x", "int64") y = tvm.tir.Var("y", "int64") n = tvm.tir.SizeVar("n", "int64") def simple_fuse0(x): return (x // n) * n + x % n z = x * 32 + y assert_iter_map_simplify({simple_fuse0(z): z}, var_dom([(x, (n + 1) // 2), (y, 32)]), predicate=(z < n * 16)) def fsymbolic_fuse2(i): return (i // (n * n) * n + i % (n * n) // n) * n + i % n z = x * 64 + y assert_iter_map_simplify( {fsymbolic_fuse2(z): z}, var_dom([(x, (n * n + 1) // 2), (y, 64)]), predicate=(z < n * n * 32), ) def test_iter_map_simplify_symbolic_reshape(): n = tvm.tir.Var("n", "int64") fused = tvm.tir.Var("fused", "int64") ax0 = (fused // 4096) // n ax1 = (fused // 4096) % n ax2 = fused % 4096 rhs_index = ((ax2 // 4096 + ax0 * n + ax1) % n) * 4096 + ax2 % 4096 assert_iter_map_simplify({rhs_index: fused}, var_dom([(fused, n * 4096)])) def test_iter_map_simplify_unit_loop_order(): """Test itermap simplify""" x = tvm.tir.Var("x", "int64") y = tvm.tir.Var("y", "int64") z = tvm.tir.Var("z", "int64") # trivial iterators can be found at any when comparing via scale # ensure order unchange assert_iter_map_simplify({x + y + z: x + y + z}, var_dom([(x, 1), (y, 1), (z, 1)]), simplify_trivial_iterators=False) # Even with simplification, it should follow the original order assert_iter_map_simplify( {x + y + (z // 4) * 4 + z % 4: z + x + y}, var_dom([(x, 1), (y, 1), (z, 32)]), simplify_trivial_iterators=False, ) assert_iter_map_simplify( {y + 64 - x % 2 * 64: y + 64 - x % 2 * 64}, var_dom([(x, 6), (y, 64)]), simplify_trivial_iterators=False, ) # When we have iterators that have same scale but one of them come # with unit extent, we should prioritize unit extent assert_iter_map_simplify( {x // 128 + y + z: y + z}, var_dom([(x, 128), (y, 128), (z, 1)]), simplify_trivial_iterators=False, ) def assert_normalize_to_iter_sum(index, input_iters, args, base): """Assert the result of arith.normalize_to_iter_sum is correct Parameters ---------- index : tvm.tir.PrimExpr The index to be normalized input_iters : Mapping[Var, Range] The input iterators args : List[Union[tvm.arith.IterSplitExpr, Tuple[PrimExpr, PrimExpr]]] The expected result. Ordered list of args of the expected IterSumExpr. Each arg can be either IterSplitExpr or a tuple of (PrimExpr, PrimExpr) where the first element is the iterator normalized to PrimExpr and the second element is the scale. base : tvm.tir.PrimExpr The expected base """ res = tvm.arith.normalize_to_iter_sum(index, input_iters) assert isinstance(res, tvm.arith.IterSumExpr) assert len(res.args) == len(args) for split, item in zip(res.args, args): if isinstance(item, tvm.arith.IterSplitExpr): tvm.ir.assert_structural_equal(split, item) continue tvm.testing.assert_prim_expr_equal(split.scale, item[1]) tvm.testing.assert_prim_expr_equal(tvm.arith.normalize_iter_map_to_expr(split), item[0] * item[1]) tvm.testing.assert_prim_expr_equal(res.base, base) def test_normalize_to_iter_sum(): x = tvm.tir.Var("x", "int64") y = tvm.tir.Var("y", "int64") z = tvm.tir.Var("z", "int64") a = tvm.tir.Var("a", "int64") n = tvm.tir.Var("n", "int64") # flm = tvm.tir.floormod assert_normalize_to_iter_sum( z + ((y + x * 4 + 2) * n) + 3, var_dom([(x, 9), (y, 4), (z, 3)]), [(x, n * 4), (y, n), (z, 1)], 2 * n + 3, ) # max cannot detected so it goes into base assert_normalize_to_iter_sum( tvm.tir.max(z, a) + ((y + x * 4 + 2) * n) + 3, var_dom([(x, 9), (y, 4), (z, 3)]), [(x, n * 4), (y, n)], tvm.tir.max(z, a) + 2 * n + 3, ) # order by symbolic prod assert_normalize_to_iter_sum( z + ((y * 4 * a + x * 4 + 2) * n) + 3, var_dom([(y, a * n * 4), (x, n * 4), (z, a)]), [(y, a * n * 4), (x, n * 4), (z, 1)], 2 * n + 3, ) # order by cscale assert_normalize_to_iter_sum( z + 2 * y * 3 + 4 * x, var_dom([(y, a * n * 4), (x, n * 4), (z, a)]), [(y, 6), (x, 4), (z, 1)], 0, ) # split pattern assert_normalize_to_iter_sum( z + 2 * y * 3 + 4 * (x // 2), var_dom([(y, a * n * 4), (x, n * 4), (z, a)]), [(y, 6), (x // 2, 4), (z, 1)], 0, ) # non-divisible assert_normalize_to_iter_sum( x // 5, var_dom([(x, 4096)]), [ tvm.arith.IterSplitExpr( tvm.arith.IterMark(x, 4096), lower_factor=tvm.tir.const(5, "int64"), extent=tvm.tir.const(820, "int64"), scale=tvm.tir.const(1, "int64"), ) ], 0, ) # iter simplify assert_normalize_to_iter_sum( z * 2 + 2 * y * 3 + 4 * (x // 4) + (x % 4), var_dom([(y, a * n * 4), (x, n * 4), (z, a)]), [(y, 6), (z, 2), (x, 1)], 0, ) def test_detect_iter_map_with_bufferload_recursion(): n = tvm.tir.Var("n", "int32") m = tvm.tir.Var("m", "int32") divisor = tvm.tir.Var("divisor", "int32") i = tvm.tir.Var("i", "int32") j = tvm.tir.Var("j", "int32") buffer = tvm.tir.decl_buffer((n,), "int32", name="seqlen") indices = [(buffer[i] + j) // divisor] iter_vars = { i: tvm.ir.Range(tvm.tir.const(0, "int32"), n), j: tvm.ir.Range(tvm.tir.const(0, "int32"), m), } result = tvm.arith.detect_iter_map(indices, iter_vars) assert len(result.indices) == 0 if __name__ == "__main__": tilelang.testing.main()