Unverified Commit 9c21586b authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Feat] Integrate Z3 in TVM Arith Analyzer (#1367)

parent 899f7bd5
Subproject commit 2b1ead1a375704c75af563cc800aa9347583ba2b
Subproject commit 4d3ec9253e346b2281513700e692124aefaff347
......@@ -222,6 +222,14 @@ elseif(USE_CUDA)
list(APPEND TILE_LANG_INCLUDES ${CUDAToolkit_INCLUDE_DIRS})
endif()
set(USE_Z3 ON CACHE STRING "Use Z3 SMT solver for TileLang optimizations")
set(USE_PYPI_Z3 ON CACHE BOOL "Use Z3 provided by PyPI z3-solver package")
if(USE_Z3 AND USE_PYPI_Z3)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/pypi-z3")
find_package(Z3 REQUIRED)
endif()
# Include tvm after configs have been populated
add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL)
......@@ -288,19 +296,32 @@ install(TARGETS tilelang_cython_wrapper
RUNTIME DESTINATION tilelang/lib
ARCHIVE DESTINATION tilelang/lib)
# let libtilelang to search tvm/tvm_runtime in same dir
# add python z3 lib path to rpath for running tests and dev in current folder
if(USE_Z3 AND USE_PYPI_Z3)
set_property(TARGET tvm APPEND PROPERTY BUILD_RPATH ${Python3_SITELIB}/z3/lib)
set_property(TARGET tvm APPEND PROPERTY BUILD_RPATH ${Python3_SITELIB}/z3/bin)
endif()
if(APPLE)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set(TILELANG_INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
if(USE_Z3 AND USE_PYPI_Z3)
# some z3 is placed in lib/ and some in bin/, we add both in rpath
list(APPEND TILELANG_INSTALL_RPATH "@loader_path/../../z3/lib" "@loader_path/../../z3/bin")
endif()
elseif(UNIX)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set(TILELANG_INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
if(USE_Z3 AND USE_PYPI_Z3)
# cmake uses ; by default, we explicitly use : for linux
string(APPEND TILELANG_INSTALL_RPATH ":\$ORIGIN/../../z3/lib")
endif()
endif()
# let libtilelang to search tvm/tvm_runtime in same dir
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
install(
TARGETS tvm tvm_runtime tilelang_module tilelang
LIBRARY DESTINATION tilelang/lib
......
if(Z3_FOUND)
return()
endif()
find_package(Python3 COMPONENTS Interpreter REQUIRED)
execute_process(
COMMAND "${Python3_EXECUTABLE}" -c "import z3; print(z3.__path__[0])"
OUTPUT_VARIABLE Z3_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE Z3_PYTHON_RESULT
)
if(NOT Z3_PYTHON_RESULT EQUAL 0 OR Z3_PATH STREQUAL "")
message(FATAL_ERROR "Failed to locate z3 Python package. Ensure z3-solver>=4.13.0 is installed.")
endif()
message("-- Find Z3 in path: ${Z3_PATH}")
find_path(Z3_INCLUDE_DIR NO_DEFAULT_PATH NAMES z3++.h PATHS ${Z3_PATH}/include)
find_library(Z3_LIBRARY NO_DEFAULT_PATH NAMES z3 libz3 PATHS ${Z3_PATH}/bin ${Z3_PATH}/lib ${Z3_PATH}/lib64)
message("-- Found Z3 include dir: ${Z3_INCLUDE_DIR}")
message("-- Found Z3 library: ${Z3_LIBRARY}")
add_library(z3::libz3 SHARED IMPORTED GLOBAL)
set_target_properties(z3::libz3
PROPERTIES
IMPORTED_LOCATION ${Z3_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${Z3_INCLUDE_DIR}
)
if(NOT Z3_INCLUDE_DIR OR NOT Z3_LIBRARY)
message(FATAL_ERROR "Could not find Z3 library or include directory")
endif()
set(Z3_CXX_INCLUDE_DIRS ${Z3_INCLUDE_DIR})
set(Z3_C_INCLUDE_DIRS ${Z3_INCLUDE_DIR})
set(Z3_FOUND TRUE)
......@@ -43,6 +43,7 @@ dependencies = [
"torch>=2.7; platform_system == 'Darwin'",
"tqdm>=4.62.3",
"typing-extensions>=4.10.0",
"z3-solver>=4.13.0",
]
[project.optional-dependencies]
......@@ -53,7 +54,14 @@ fp4 = ["ml-dtypes>=0.5.1"]
vis = ["matplotlib"]
[build-system]
requires = ["cython>=3.0.0", "scikit-build-core"]
requires = [
"cython>=3.0.0",
"scikit-build-core",
"z3-solver>=4.13.0",
# Not for auditwheel, explicitly add patchelf for repairing libz3.so.
# See tvm's CMakeLists.txt for more information.
"patchelf>=0.17.2; platform_system == 'Linux'",
]
build-backend = "scikit_build_core.build"
[tool.scikit-build]
......@@ -227,7 +235,7 @@ environment.PYTHONUNBUFFERED = "1"
environment.PATH = "/usr/local/cuda/bin:$PATH"
environment.LD_LIBRARY_PATH = "/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
manylinux-x86_64-image = "manylinux_2_28" # AlmaLinux 8
manylinux-aarch64-image = "manylinux_2_28" # AlmaLinux 8
manylinux-aarch64-image = "manylinux_2_34" # Z3 requires
# Install CUDA runtime and stub driver library
# manylinux_2_28 uses gcc 14, which needs CUDA >=12.8
before-all = """
......@@ -256,7 +264,7 @@ yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-de
yum clean all
"""
repair-wheel-command = [
"auditwheel -v repair --exclude libtvm_ffi.so --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}",
"auditwheel -v repair --exclude libtvm_ffi.so --exclude libz3.so --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}",
"pipx run abi3audit --verbose --strict {wheel}",
]
......
......@@ -10,6 +10,7 @@ scikit-build-core
setuptools>=61
torch
wheel
z3-solver>=4.13.0
auditwheel; platform_system == 'Linux'
patchelf; platform_system == 'Linux'
......
......@@ -30,3 +30,4 @@ scipy
tabulate
tornado
wheel
z3-solver>=4.13.0
\ No newline at end of file
......@@ -9,3 +9,4 @@ torch
torch>=2.7; platform_system == 'Darwin'
tqdm>=4.62.3
typing-extensions>=4.10.0
z3-solver>=4.13.0
\ No newline at end of file
import tilelang.testing
import tilelang.language as T
from tvm.arith import Analyzer
from tvm.ir.expr import Range
from tvm.tir.expr import Not, Or
def implies(x, y):
return Or(Not(x), y)
def test_hard_prove():
a = T.Var("a", T.int32)
b = T.Var("b", T.int32)
c = T.Var("c", T.int32)
d = T.Var("d", T.int32)
def check_expr(expr):
analyzer = Analyzer()
result = analyzer.can_prove(expr, 1)
if not result:
smtlib2 = analyzer.get_smtlib2(expr)
raise AssertionError(f"Failed to prove: {expr}\nSMT-LIB2:\n{smtlib2}")
# assert result, f"Failed to prove: {expr}"
@T.macro
def complex_expr_1():
return implies(a > 0 and b > 0 and c > 0, ((b - a) // c) * c + a <= b)
check_expr(complex_expr_1())
@T.macro
def complex_expr_2():
return implies(a < b and b < c and a * d < b * d, b * d < c * d)
check_expr(complex_expr_2())
@T.macro
def complex_expr_3():
return implies(a >= 0 and a < 128, a // 128 == (a // 64 * 32 + a % 32 // 16 * 8) // 64)
check_expr(complex_expr_3())
@T.macro
def complex_expr_4():
return implies(
a >= 0 and a < 128,
(a % 16 * 64 + a // 64 * 32 + a % 8 // 4 * 32 + (a % 32 // 16 + a % 2) % 2 * 8 + 16 - (a // 64 + a % 8 // 4) // 2 * 64) // 512
== (a % 16 * 64 + a // 64 * 32 + a % 8 // 4 * 32 + (a % 32 // 16 + a % 2) % 2 * 8 - (a // 64 + a % 8 // 4) // 2 * 64) // 512,
)
check_expr(complex_expr_4())
def test_smtlib2():
import z3
a = T.Var("a", T.int32)
b = T.Var("b", T.int32)
c = T.Var("c", T.int32)
@T.macro
def complex_expr_1():
return implies(a > 0 and b > 0 and c > 0, ((b - a) // c) * c + a <= b)
e = complex_expr_1()
analyzer = Analyzer()
analyzer.set_z3_timeout_ms(1000)
smtlib2 = analyzer.get_smtlib2(e)
solver = z3.Solver()
solver.from_string(smtlib2)
assert solver.check() == z3.unsat, f"Expected unsat, got {solver.check()}"
def test_bind():
a = T.Var("a", T.int32)
b = T.Var("b", T.int32)
c = T.Var("c", T.int32)
analyzer = Analyzer()
analyzer.bind(a, Range(1, 100000))
analyzer.bind(b, Range(1, 100000))
analyzer.bind(c, Range(1, 100000))
expr = ((b - a) // c) * c + a <= b
smtlib2 = analyzer.get_smtlib2(expr)
try:
result = analyzer.can_prove(expr, 1)
assert result, f"Failed to prove with bindings: {expr}"
except Exception as e:
print(smtlib2)
raise e
if __name__ == "__main__":
tilelang.testing.main()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tilelang import tvm
import tvm.testing
from tvm import te
from tvm import tir
from tvm.arith.analyzer import Analyzer
class IntSetChecker:
def __init__(self):
self.analyzer = tvm.arith.Analyzer()
def verify(self, data, dmap, expected):
res = self.analyzer.int_set(data, dmap)
def err_msg():
return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected)
assert self.analyzer.can_prove_equal(res.min_value, expected[0]), err_msg()
assert self.analyzer.can_prove_equal(res.max_value, expected[1]), err_msg()
def test_basic():
s = tvm.arith.IntervalSet(2, 3)
assert s.min_value.value == 2
assert s.max_value.value == 3
s = tvm.arith.IntSet.single_point(2)
assert s.min_value.value == 2
assert s.max_value.value == 2
def test_vector():
base = 10
stride = 3
lanes = 2
s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, stride, lanes))
assert s.min_value.value == base
assert s.max_value.value == base + stride * (lanes - 1)
def test_scalable_vector():
base = 5
s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, 2, tvm.tir.vscale() * 4))
assert s.min_value.value == base
assert s.max_value.same_as(tvm.arith.int_set.pos_inf())
def test_add_sub():
ck = IntSetChecker()
x, y = te.var("x"), te.var("y")
ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10)}, (y, 10 + y))
ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10), y: tvm.arith.IntervalSet(1, 11)}, (1, 21))
ck.verify(x - y, {x: tvm.arith.IntervalSet(0, 10), y: tvm.arith.IntervalSet(1, 11)}, (-11, 9))
def test_mul_div():
ck = IntSetChecker()
x, y = te.var("x"), te.var("y")
tdiv = tvm.tir.truncdiv
ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
ck.verify(x * y, {x: tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
ck.verify(x * 2, {x: tvm.arith.IntervalSet(1, 10)}, (2, 20))
ck.verify(x * -2, {x: tvm.arith.IntervalSet(1, 10)}, (-20, -2))
ck.verify(tdiv(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, tdiv(10, y)))
ck.verify(tdiv(x, 2), {x: tvm.arith.IntervalSet(1, 10)}, (0, 5))
fld = tvm.te.floordiv
ck.verify(fld(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y)))
ck.verify(fld(x, 2), {x: tvm.arith.IntervalSet(-1, 10)}, (-1, 5))
def test_mod():
ck = IntSetChecker()
x, y = te.var("x"), te.var("y")
tmod = tvm.tir.truncmod
ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
ck.verify(tmod(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
ck.verify(tmod(x, 10), {x: tvm.arith.IntervalSet(1, 10)}, (0, 9))
flm = tvm.te.floormod
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(-10, 10)}, (0, 9))
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 5)}, (3, 5))
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(13, 15)}, (3, 5))
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 15)}, (0, 9))
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 11)}, (0, 9))
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(1, 21)}, (0, 9))
fld = tvm.te.floordiv
z = te.var("z")
ck.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 3))
ck.verify(
flm(y, 8),
{y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)},
(
z * 8 + x * 4 - 8 * fld(z * 8 + x * 4, 8),
z * 8 + x * 4 + 3 - 8 * fld(z * 8 + x * 4, 8),
),
)
ck1 = IntSetChecker()
ck1.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 2))
ck1.verify(flm(y, 8), {y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)}, (x * 4, x * 4 + 3))
def test_max_min():
ck = IntSetChecker()
x, y = te.var("x"), te.var("y")
ck.verify(tvm.te.max(x, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (1, 11))
ck.verify(tvm.te.min(x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 9))
ck.verify(tvm.te.min(x, y), {}, (tvm.te.min(x, y), tvm.te.min(x, y)))
ck.verify(tvm.te.max(x, y), {}, (tvm.te.max(x, y), tvm.te.max(x, y)))
def test_select():
ck = IntSetChecker()
# x, y = te.var("x"), te.var("y")
x = te.var("x")
ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 11))
def check_region_bound(expect_region, var_dom, mode, predicate=None):
"""Helper to check region bound estimation.
Parameters
----------
expect_region: dict
The keys are of form (begin, end) or PrimExpr as a single point. The values are
expected estimated region or region dict on different bindings.
var_dom: dict
Map var to iteration domain range.
mode: str
Specify "lowerbound", "upperbound" or else use strict bound estimation.
predicate: PrimExpr
Extra predicate, defaults to True.
"""
if predicate is None:
predicate = tvm.tir.IntImm("bool", 1)
region = []
expect = []
for k, v in expect_region.items():
if not isinstance(k, (tuple, list)):
k = (k, k + 1)
region.append(tvm.ir.Range.from_min_extent(k[0], Analyzer().simplify(k[1] - k[0])))
expect.append(v)
if mode == "lowerbound":
result = tvm.arith.estimate_region_lower_bound(region=region, var_dom=var_dom, predicate=predicate)
elif mode == "upperbound":
result = tvm.arith.estimate_region_upper_bound(region=region, var_dom=var_dom, predicate=predicate)
else:
result = tvm.arith.estimate_region_strict_bound(region=region, var_dom=var_dom, predicate=predicate)
if result is None:
assert all([_ is None for _ in expect])
return
assert len(result) == len(expect)
for intset, expect_desc in zip(result, expect):
if isinstance(expect_desc, dict):
# check range on different free var bindings
for binding in expect_desc:
analyzer = Analyzer()
for k, v in binding:
analyzer.bind(k, v)
expect_begin, expect_end = expect_desc[binding]
result_begin = analyzer.simplify(intset.min_value, 3)
result_end = analyzer.simplify(intset.max_value + 1, 3)
assert analyzer.can_prove_equal(result_begin - expect_begin, 0), f"{result_begin} vs {expect_begin}"
assert analyzer.can_prove_equal(result_end - expect_end, 0), f"{result_end} vs {expect_end}"
else:
# check range
expect_begin, expect_end = expect_desc
analyzer = Analyzer()
assert analyzer.can_prove_equal(intset.min_value - expect_begin, 0), f"{intset.min_value} vs {expect_begin}"
assert analyzer.can_prove_equal(intset.max_value - expect_end + 1, 0), f"{intset.max_value} vs {expect_end - 1}"
def test_region_bound_not_independent():
# (i, i+2) and (i+2, i+4) are dependent, this the lowerbound is not available
i = tvm.tir.Var("i", "int32")
var_dom = {
i: tvm.ir.Range(begin=0, end=64),
}
check_region_bound({(i, i + 2): None, (i + 2, i + 4): None}, var_dom, mode="lowerbound")
check_region_bound({(i, i + 2): (0, 65), (i + 2, i + 4): (2, 67)}, var_dom, mode="upperbound")
# when only a subset of access indices are affine
i, j, k = tvm.tir.Var("i", "int32"), tvm.tir.Var("j", "int32"), tvm.tir.Var("k", "int32")
var_dom = {
i: tvm.ir.Range(begin=0, end=16),
j: tvm.ir.Range(begin=0, end=16),
k: tvm.ir.Range(begin=0, end=16),
}
check_region_bound(
{i // 4: None, j * 4 + i % 4: None, tir.truncdiv(k, 2): None},
var_dom,
predicate=j * 4 + i % 4 > 3,
mode="lowerbound",
)
check_region_bound(
{i // 4: (0, 4), j * 4 + i % 4: (4, 64), tir.truncdiv(k, 2): (0, 8)},
var_dom,
predicate=j * 4 + i % 4 > 3,
mode="upperbound",
)
def test_region_bound_stride_too_wide():
i = tvm.tir.Var("i", "int32")
var_dom = {i: tvm.ir.Range(begin=0, end=64)}
check_region_bound({(i * 4, i * 4 + 2): None}, var_dom, mode="lowerbound")
check_region_bound({(i * 4, i * 4 + 2): (0, 254)}, var_dom, mode="upperbound")
def test_region_bound_small_stride():
i = tvm.tir.Var("i", "int32")
var_dom = {
i: tvm.ir.Range(begin=0, end=64),
}
check_region_bound({(i * 4, i * 4 + 8): (0, 260)}, var_dom, mode="lowerbound")
def test_region_lower_bound_split_predicate():
x_o = tvm.tir.Var("xo", "int32")
x_i = tvm.tir.Var("xi", "int32")
x = x_o * 4 + x_i
var_dom = {
x_o: tvm.ir.Range(begin=0, end=16),
x_i: tvm.ir.Range(begin=0, end=4),
}
check_region_bound({(x * 4, x * 4 + 8): (0, 256)}, var_dom, predicate=x < 63, mode="lowerbound")
check_region_bound(
{(x * 4, x * 4 + 8): (0, 256), (x * 3, x * 3 + 5): (0, 191)},
var_dom,
predicate=x < 63,
mode="upperbound",
)
def test_region_lower_bound_multiple_variables():
div = tvm.tir.floordiv
mod = tvm.tir.floormod
x = tvm.tir.Var("x", "int32")
wid = tvm.tir.Var("wid", "int32")
i = div(x, 16)
j = div(mod(x, 16), 4) * 8 + mod(x, 4) + div(wid, 32) * 4
k = wid % 32
var_dom = {
x: tvm.ir.Range(begin=0, end=32),
wid: tvm.ir.Range(begin=0, end=64),
}
check_region_bound({i: (0, 2), j: (0, 32), k: (0, 32)}, var_dom, mode="lowerbound")
def test_region_lower_bound_negative_scale():
i = tvm.tir.Var("i", "int32")
j = tvm.tir.Var("j", "int32")
var_dom = {
i: tvm.ir.Range(begin=0, end=4),
j: tvm.ir.Range(begin=0, end=4),
}
check_region_bound({(1 - i, 5 - i): (-2, 5), (20 - j * 4, 36 - j * 4): (8, 36)}, var_dom, mode="lowerbound")
def test_region_lower_bound_for_non_perfect_tile():
h1 = tvm.tir.Var("h1", "int32")
h2 = tvm.tir.Var("h2", "int32")
h3 = tvm.tir.Var("h3", "int32")
# non-uniform tiling, single inner variable
var_dom = {
h2: tvm.ir.Range(begin=0, end=10),
}
check_region_bound(
{
h3 * 8 + h2: {
(): (
tvm.tir.max(h3 * 8, 1),
tvm.tir.min(0, h3 * 8 - 214) + 224,
),
((h3, 0),): (1, 10), # h3 == 0: region is [1, 10)
((h3, 10),): (h3 * 8, h3 * 8 + 10), # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 10)
((h3, 27),): (h3 * 8, 224), # h3 > 26: region is [h3 * 8, 224)
}
},
var_dom,
predicate=tvm.tir.all(h3 * 8 + h2 >= 1, h3 * 8 + h2 < 224),
mode="lowerbound",
)
# non-uniform tiling, two inner variables
var_dom = {
h1: tvm.ir.Range(begin=0, end=5),
h2: tvm.ir.Range(begin=0, end=2),
}
check_region_bound(
{
h3 * 8 + h2 * 5 + h1: {
(): (
tvm.tir.max(h3 * 8, 1),
tvm.tir.min(0, h3 * 8 - 214) + 224,
),
((h3, 0),): (1, 10),
((h3, 10),): (h3 * 8, h3 * 8 + 10),
((h3, 27),): (h3 * 8, 224),
}
},
var_dom,
predicate=tvm.tir.all(h3 * 8 + h2 * 5 + h1 >= 1, h3 * 8 + h2 * 5 + h1 < 224),
mode="lowerbound",
)
# lowerbound should fail on incompatible predicates
check_region_bound(
{h3 * 8 + h2 * 5 + h1: None},
var_dom,
predicate=tvm.tir.all(h3 * 8 + h2 * 5 + h1 >= 1, h3 * 8 + h1 * 2 + h2 < 224),
mode="lowerbound",
)
check_region_bound(
{h3 * 8 + h2 * 5 + h1: (h3 * 8, h3 * 8 + 10)},
var_dom,
predicate=tvm.tir.all(h3 * 8 + h2 * 5 + h1 >= 1, h3 * 8 + h1 * 2 + h2 < 224),
mode="upperbound",
)
def test_region_lower_bound_unfusable():
var_dom = {
tvm.tir.Var("i", "int32"): tvm.ir.Range(8),
tvm.tir.Var("j", "int32"): tvm.ir.Range(4),
}
i, j = var_dom
check_region_bound({(i + j) // 2: (0, 6)}, var_dom, mode="lowerbound")
def test_union_lower_bound():
neg_inf = tvm.arith.int_set.neg_inf()
pos_inf = tvm.arith.int_set.pos_inf()
set_0 = tvm.arith.IntervalSet(min_value=neg_inf, max_value=0)
set_1 = tvm.arith.IntervalSet(min_value=1, max_value=pos_inf)
result = tvm.arith.int_set.union_lower_bound([set_0, set_1])
assert result.min_value.same_as(neg_inf)
assert result.max_value.same_as(pos_inf)
set_2 = tvm.arith.IntervalSet(min_value=pos_inf, max_value=neg_inf)
result = tvm.arith.int_set.union_lower_bound([set_0, set_1, set_2])
assert result.min_value.same_as(neg_inf)
assert result.max_value.same_as(pos_inf)
def test_modular_set():
ck = IntSetChecker()
x = tvm.te.var("x", dtype="int32")
y = tvm.te.var("y", dtype="int32")
expr = (x * 2048 + y * 16) % 7168
ck.verify(expr, {x: tvm.arith.IntervalSet(0, 128), y: tvm.arith.IntervalSet(0, 3584)}, (0, 7152))
if __name__ == "__main__":
tvm.testing.main()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tilelang import tvm
import tilelang.testing
from tvm.tir import floordiv, floormod
from tvm.script import tir as T
def ifuse(inputs, pred_extent=None):
"""Fuse iterators"""
value, extent = 0, 1
for i, ext in inputs:
value = value * ext + i
extent = extent * ext
return value, extent if pred_extent is None else pred_extent
def isplit(axis, factor):
"""Split iterators"""
fld = tvm.tir.floordiv
flm = tvm.tir.floormod
return [
(fld(axis[0], factor), fld(axis[1] + (factor - 1), factor)),
(flm(axis[0], factor), factor),
]
def var_dom(iters):
"""Get domains of iterators"""
return {var: tvm.ir.Range(0, ext) for var, ext in iters}
def convert_iter_expr(expr):
return tvm.arith.normalize_iter_map_to_expr(expr)
def assert_iter_sum_pattern(expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True):
keys = list(expect_dict.keys())
res = tvm.arith.detect_iter_map(
keys,
dom_map,
predicate=predicate,
check_level=check_level,
simplify_trivial_iterators=simplify_trivial_iterators,
)
indices = res.indices
assert len(indices) == len(keys), res.errors
for i, input_iter in enumerate(keys):
spec = expect_dict[input_iter]
(
extent,
base,
) = spec[0:2]
scale = spec[2] if len(spec) > 2 else 1
expect_iter = spec[3] if len(spec) > 3 else None
sum_expr = indices[i]
assert isinstance(sum_expr, tvm.arith.IterSumExpr)
if extent == 1:
assert len(sum_expr.args) == 0
else:
assert len(sum_expr.args) == 1
tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent)
tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale)
tvm.testing.assert_prim_expr_equal(sum_expr.base, base)
if expect_iter is not None:
if not isinstance(expect_iter, tvm.arith.IterMapExpr):
sum_expr = convert_iter_expr(sum_expr)
tvm.ir.assert_structural_equal(sum_expr, expect_iter)
def assert_iter_map_simplify(expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True):
keys = list(expect_dict.keys())
_imap = tvm.arith.detect_iter_map(
keys,
dom_map,
predicate=predicate,
check_level=check_level,
simplify_trivial_iterators=simplify_trivial_iterators,
)
res = tvm.arith.iter_map_simplify(
keys,
dom_map,
predicate=predicate,
check_level=check_level,
simplify_trivial_iterators=simplify_trivial_iterators,
)
for i, input_expr in enumerate(keys):
expected_expr = expect_dict[input_expr]
tvm.ir.assert_structural_equal(res[i], expected_expr)
def assert_iter_sum_failure(iters, dom_map, predicate=True, check_level="surjective"):
res = tvm.arith.detect_iter_map(list(iters), dom_map, predicate=predicate, check_level=check_level).indices
assert len(res) == 0
def test_trivial():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")
dom_map = var_dom([(x, 3), (y, 4), (z, 1)])
assert_iter_sum_pattern({x: (3, 0), y: (4, 0), 3: (1, 3)}, dom_map)
assert_iter_sum_pattern({x: (3, 0), 3: (1, 3)}, dom_map)
# not independent
assert_iter_sum_failure([x, x, 3], dom_map)
assert_iter_sum_pattern({x: (3, 0), y: (4, 0)}, dom_map, check_level="bijective", simplify_trivial_iterators=True)
assert_iter_sum_pattern({x: (3, 0), y: (4, 0)}, dom_map, check_level="bijective", simplify_trivial_iterators=False)
assert_iter_sum_failure([x, z], dom_map, check_level="bijective")
def test_fuse():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
c = tvm.tir.SizeVar("c", "int32")
c0 = tvm.tir.SizeVar("c0", "int32")
assert_iter_sum_pattern({y * 3 + 1 + c + x: (12, 1 + c)}, var_dom([(x, 3), (y, 4)]))
assert_iter_sum_pattern({ifuse([(x, 3), (y, 4)])[0]: (12, 0)}, var_dom([(x, 3), (y, 4)]))
# fuse with symbolic factor
assert_iter_sum_pattern({(y + 1) * c + x: (4 * c, c)}, var_dom([(x, c), (y, 4)]))
# duplication
assert_iter_sum_failure([y * 3 + x, y], var_dom([(x, 3), (y, 4)]))
assert_iter_sum_failure([y, x + 1, y], var_dom([(x, 3), (y, 4)]))
# factor mismatch
assert_iter_sum_failure([y * 4 + x], var_dom([(x, 3), (y, 4)]))
# simple stride pattern
assert_iter_sum_pattern({x * 4 + y * 2: (6, 0, 2, (x * 2 + y) * 2)}, var_dom([(x, 3), (y, 2)]))
# simple stride pattern with symbolic
assert_iter_sum_pattern({x * 2 * c0 + y * 2: (3 * c0, 0, 2, (x * c0 + y) * 2)}, var_dom([(x, 3), (y, c0)]))
def test_split():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
c0 = tvm.tir.SizeVar("c0", "int32")
c1 = tvm.tir.SizeVar("c1", "int32")
fld = tvm.tir.floordiv
flm = tvm.tir.floormod
assert_iter_sum_pattern({fld(x, 3): (8, 0), flm(x, 3) * 2 + c1: (3, c1, 2)}, var_dom([(x, 24)]))
assert_iter_sum_pattern({fld(x, 6): (4, 0), fld(flm(x, 6), 2): (3, 0), flm(x, 2): (2, 0)}, var_dom([(x, 24)]))
# simple symbolic bound
# TODO(tvm-team) improve symbolic divisible check to enable
# more complicated symbolic bound
assert_iter_sum_pattern({fld(x, c0): (c1, 0), flm(x, c0): (c0, 0)}, var_dom([(x, c1 * c0)]))
assert_iter_sum_pattern({fld(x * 2, 4): (4, 0, 1), flm(x * 2, 4): (2, 0, 2)}, var_dom([(x, 8)]))
assert_iter_sum_pattern(
{
fld(x * 2, 4) * 4 + flm(x * 2, 4): (8, 0, 2),
},
var_dom([(x, 8)]),
)
assert_iter_sum_failure([fld(x, flm(flm(y, 8), 6))], var_dom([(x, 24), (y, 8)]))
# domain of x is undefined
assert_iter_sum_pattern({fld(flm(x, 49) + y, 49): (1, fld(flm(x, 49) + y, 49))}, var_dom([(y, 1)]))
def test_compound():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
xo, xi = isplit((x, 10), 5)
yo, yi = isplit((y, 9), 3)
z = ifuse([yo, xo, yi])
# reconstruct the pattern manually
mx = tvm.arith.IterMark(x, 10)
my = tvm.arith.IterMark(y, 9)
xoscale = 3
yoscale = 6
yiscale = 1
mxo = tvm.arith.IterSplitExpr(mx, 5, 2, xoscale)
myo = tvm.arith.IterSplitExpr(my, 3, 3, yoscale)
myi = tvm.arith.IterSplitExpr(my, 1, 3, yiscale)
mz = tvm.arith.IterMark(tvm.arith.IterSumExpr([myo, mxo, myi], 0), 18)
sz = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(mz, 1, 18, 1)], 0)
assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)]))
def test_compound_floormod_two_regression():
x = tvm.tir.Var("x", "int32")
fld = tvm.tir.floordiv
flm = tvm.tir.floormod
# regression
# extent of 2 of negative scale cannot be normalized
assert_iter_sum_failure(
[fld(x, 2) * 2 - flm(x, 2) + 1],
dom_map=var_dom([(x, 8)]),
)
def test_predicate():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")
# available constraints
# upper bound only
assert_iter_sum_pattern({x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 128)
assert_iter_sum_pattern({x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y <= 127)
# lower bound only
assert_iter_sum_pattern({x * 10 + y: (124, 6)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y > 5)
assert_iter_sum_pattern({x * 10 + y: (124, 6)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y >= 6)
# lower bound + upper bound
assert_iter_sum_pattern(
{x * 10 + y: (122, 6)},
var_dom([(x, 13), (y, 10)]),
predicate=tvm.tir.And(x * 10 + y > 5, x * 10 + y < 128),
)
assert_iter_sum_pattern(
{x * 10 + y: (122, 6)},
var_dom([(x, 13), (y, 10)]),
predicate=tvm.tir.And(x * 10 + y >= 6, x * 10 + y <= 127),
)
assert_iter_sum_pattern(
{x * 64 + y * 4 + z: (16, 16)},
var_dom([(x, 16), (y, 16), (z, 4)]),
predicate=tvm.tir.And(x * 64 + y * 4 + z < 32, x * 16 + y >= 4),
)
# constraints on one fused iter
i = tvm.tir.Var("i", "int32")
j = tvm.tir.Var("j", "int32")
k = tvm.tir.Var("k", "int32")
assert_iter_sum_pattern(
{i * 8 + j * 2 + k: (88, 1)},
var_dom([(i, 11), (j, 5), (k, 2)]),
predicate=tvm.tir.all(j * 2 + k >= 1, j * 2 + k < 9),
)
# constraints on single var
assert_iter_sum_pattern({i: (10, 0)}, var_dom([(i, 48)]), predicate=i < 10)
# iterations are subparts of constraint, invalid case 1
assert_iter_sum_failure(
[i, j, k],
var_dom([(i, 128), (j, 128), (k, 128)]),
predicate=tvm.tir.all(i * 16384 + j * 128 + k < 100),
)
# iterations are subparts of constraint, invalid case 2
assert_iter_sum_failure(
[i * 128 + j, k],
var_dom([(i, 128), (j, 128), (k, 128)]),
predicate=i * 16384 + j * 128 + k < 100,
)
# irrelevant predicate
assert_iter_sum_pattern({i + j: (1, j)}, var_dom([(i, 1)]), predicate=j <= 24)
# constraint on nested fused iters
assert_iter_sum_pattern(
{i * 8 + j * 2 + k: (22, 3)},
var_dom([(i, 11), (j, 5), (k, 2)]),
predicate=tvm.tir.all(j * 2 + k >= 1, j * 2 + k < 9, i * 8 + j * 2 + k >= 3, i * 8 + j * 2 + k < 25),
)
# duplicate constraint on one fused iter
assert_iter_sum_pattern(
{i * 6 + j * 2 + k: (66, 2)},
var_dom([(i, 11), (j, 5), (k, 2)]),
predicate=tvm.tir.all(j * 2 + k >= 1, j * 2 + k >= 2, j * 2 + k < 8, j * 2 + k < 9),
)
# duplicate constraint on nested fused iters
assert_iter_sum_pattern(
{i * 6 + j * 2 + k: (15, 3)},
var_dom([(i, 11), (j, 5), (k, 2)]),
predicate=tvm.tir.all(
j * 2 + k >= 1,
j * 2 + k >= 2,
j * 2 + k < 8,
j * 2 + k < 9,
i * 6 + j * 2 + k >= 3,
i * 6 + j * 2 + k < 25,
i * 6 + j * 2 + k >= 1,
i * 6 + j * 2 + k < 18,
),
)
# constraint on non-disjoint fused iters should fail
assert_iter_sum_failure(
[i * 8 + j * 2 + k],
var_dom([(i, 11), (j, 5), (k, 2)]),
predicate=tvm.tir.all(j * 2 + k >= 2, i * 4 + j >= 0),
)
# constraints with different lower bound
assert_iter_sum_pattern(
{
(i * 16 + j) // 23 * 8 + (i * 16 + j) % 23 - 15: (
64,
0,
1,
(i * 16 + j) // 23 * 8 + ((i * 16 + j) % 23 + tvm.tir.IntImm("int32", -15)),
)
},
var_dom([(i, 12), (j, 16)]),
predicate=tvm.tir.And(
tvm.tir.And(i * 16 + j < 184, tvm.tir.LE(tvm.tir.IntImm("int32", 8), (i * 16 + j) % 23)),
tvm.tir.LE(tvm.tir.IntImm("int32", 15), (i * 16 + j) % 23),
),
)
# constraint on many disjoint fused iters, case 1
# i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2)
# i2 * 30 + i3 * 15 in [30, 90), extent=60 (= scale of i1)
# i1 * 60 in [60, 240), extent=180 (= scale of i0)
i0 = tvm.tir.Var("i0", "int32")
i1 = tvm.tir.Var("i1", "int32")
i2 = tvm.tir.Var("i2", "int32")
i3 = tvm.tir.Var("i3", "int32")
i4 = tvm.tir.Var("i4", "int32")
i5 = tvm.tir.Var("i5", "int32")
assert_iter_sum_pattern(
{i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5: (540, 93)},
var_dom([(i0, 3), (i1, 4), (i2, 3), (i3, 2), (i4, 3), (i5, 6)]),
predicate=tvm.tir.all(i1 >= 1, i2 * 2 + i3 >= 2, i4 * 6 + i5 >= 3),
)
# constraint on many disjoint fused iters, case 2
assert_iter_sum_pattern(
{i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4: (135, 28)},
var_dom([(i0, 3), (i1, 2), (i2, 5), (i3, 3), (i4, 4)]),
predicate=tvm.tir.all(i1 * 5 + i2 >= 3, i1 * 5 + i2 < 8, i3 * 4 + i4 >= 1, i3 * 4 + i4 < 10),
)
# constraint on split iters
assert_iter_sum_pattern(
{i % 16: (7, 3), i // 16: (8, 4)},
var_dom([(i, 1024)]),
predicate=tvm.tir.all(i % 16 >= 3, i % 16 < 10, i // 16 >= 4, i // 16 < 12),
check_level="bijective",
)
# constraint on split iters, nested case 1
assert_iter_sum_pattern(
{(i * 32 + j) % 16: (7, 3)},
var_dom([(i, 5), (j, 32)]),
predicate=tvm.tir.all((i * 32 + j) % 16 >= 3, (i * 32 + j) % 16 < 10),
)
# constraint on split iters, nested case 2
assert_iter_sum_failure(
[
(i * 32 + j) % 16,
],
var_dom([(i, 5), (j, 32)]),
predicate=tvm.tir.all(i * 32 + j >= 1, i * 32 + j <= 32),
check_level="bijective",
)
assert_iter_sum_pattern(
{(i * 32 + j) % 16: (16, 0)},
var_dom([(i, 5), (j, 32)]),
predicate=tvm.tir.all(i * 32 + j >= 1, i * 32 + j <= 32),
)
assert_iter_sum_pattern(
{(i * 32 + j - 1) % 16: (16, 0), (i * 32 + j - 1) // 16: (4, 0)},
var_dom([(i, 5), (j, 32)]),
predicate=tvm.tir.all(i * 32 + j >= 1, i * 32 + j <= 64),
)
# non-standard form of predicate
assert_iter_sum_pattern({x * 10 + y: (128, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 < 128 - y)
# duplicate constraint
assert_iter_sum_pattern(
{x * 10 + y: (64, 0)},
var_dom([(x, 13), (y, 10)]),
predicate=tvm.tir.all(x * 10 + y < 128, x * 10 + y < 64),
)
# useless constraint
assert_iter_sum_pattern({x * 10 + y: (130, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 140)
i1 = tvm.tir.Var("i1", "int32")
i2 = tvm.tir.Var("i2", "int32")
i3 = tvm.tir.Var("i3", "int32")
i4 = tvm.tir.Var("i4", "int32")
assert_iter_sum_pattern(
{i1 * 20 + i2 * 10 + i3 * 3 + i4: (128, 0)},
var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]),
predicate=(
tvm.tir.all(
i1 * 2 + i2 < 13,
i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128,
i3 * 3 + i4 < 10,
)
),
)
# wrong constraint
assert_iter_sum_failure(
[i1 * 20 + i2 * 10 + i3 * 3 + i4],
var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]),
predicate=(
tvm.tir.all(
i1 * 2 + i2 < 13,
i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128,
i3 * 3 + i4 < 7,
)
),
)
# incompatible constraint
assert_iter_sum_failure(
[i1 * 20 + i2 * 10 + i3 * 3 + i4],
var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]),
predicate=(
tvm.tir.all(
i1 * 2 + i2 < 13,
i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128,
i3 * 3 + i4 < 10,
i1 * 4 + i3 < 20,
)
),
)
assert_iter_sum_failure(
[i1 * 20 + i2 * 10 + i3 * 3 + i4],
var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]),
predicate=(
tvm.tir.all(
i1 * 2 + i2 < 13,
i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128,
i1 * 4 + i3 < 20,
)
),
)
# zero iter
xo = tvm.tir.Var("xo", "int32")
xi = tvm.tir.Var("xi", "int32")
y = tvm.tir.Var("y", "int32")
assert_iter_sum_pattern(
{xo * 129 + xi: (128, 0), y: (128, 0)},
var_dom([(xo, 1), (xi, 129), (y, 128)]),
predicate=xo * 129 + xi < 128,
)
# strided iteration predicate
assert_iter_sum_pattern(
{xo * 16 + xi * 4: (10, 0, 4)},
var_dom([(xo, 3), (xi, 4)]),
predicate=xo * 4 + xi < 10,
)
def convert_division(divisions):
if divisions is None or len(divisions) == 0:
return []
res = []
for division in divisions[:-1]:
res.append(
[
tvm.arith.normalize_iter_map_to_expr(division[0].source),
tvm.arith.normalize_iter_map_to_expr(division[1].source),
]
)
res.append([divisions[-1][0].extent, divisions[-1][1].extent])
return res
def create_iter(name, extent):
return tvm.tir.Var(name, "int32"), extent
def test_subspace_division():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")
c = tvm.tir.SizeVar("c", "int32")
# simple 1.1
res = tvm.arith.subspace_divide([z * 12 + y * 3 + x + c], var_dom([(x, 3), (y, 4), (z, 5)]), [x])
res = convert_division(res)
assert len(res) == 2
tvm.ir.assert_structural_equal(res[0][0], z * 4 + y)
tvm.ir.assert_structural_equal(res[0][1], x + c)
# simple 1.2
res = tvm.arith.subspace_divide([z * 12 + y * 3 + x + c], var_dom([(x, 3), (y, 4), (z, 5)]), [x], z * 4 + y < 18)
res = convert_division(res)
assert len(res) == 2
tvm.ir.assert_structural_equal(res[0][0], z * 4 + y)
tvm.ir.assert_structural_equal(res[0][1], x + c)
tvm.ir.assert_structural_equal(res[1][0], z * 4 + y < 18)
tvm.ir.assert_structural_equal(res[1][1], T.bool(True))
# compound 1
i0 = create_iter("i0", 4)
j0 = create_iter("j0", 8)
i3 = create_iter("i3", 2)
i1, i2 = isplit(j0, 4)
k0 = ifuse([i0, i1])
k1 = ifuse([i2, i3])
# compound 1.1
res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]])
res = convert_division(res)
assert len(res) == 3
tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4))
tvm.ir.assert_structural_equal(res[0][1], T.int32(0))
tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4))
tvm.ir.assert_structural_equal(res[1][1], i3[0])
# assert_iter_sum_pattern
res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices
assert len(res1) == 2
res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])).indices
assert len(res2) == 2
# compound 1.2
res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [j0[0], i3[0]])
res = convert_division(res)
assert len(res) == 3
tvm.ir.assert_structural_equal(res[0][0], i0[0])
tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4))
tvm.ir.assert_structural_equal(res[1][0], T.int32(0))
tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0])
res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices
assert len(res1) == 2
res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])).indices
assert len(res2) == 2
# compound 1.3
res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i0[0], i3[0]])
res = convert_division(res)
assert len(res) == 0
# compound 1.4
res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], k0[0] < 7)
res = convert_division(res)
assert len(res) == 3
tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4))
tvm.ir.assert_structural_equal(res[0][1], T.int32(0))
tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4))
tvm.ir.assert_structural_equal(res[1][1], i3[0])
tvm.ir.assert_structural_equal(res[2][0], (i0[0] * 2) + floordiv(j0[0], 4) < 7)
tvm.ir.assert_structural_equal(res[2][1], T.bool(True))
res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices
assert len(res1) == 2
res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])).indices
assert len(res2) == 2
# compound 1.5
res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [j0[0], i3[0]], k1[0] < 7)
res = convert_division(res)
assert len(res) == 3
tvm.ir.assert_structural_equal(res[0][0], i0[0])
tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4))
tvm.ir.assert_structural_equal(res[1][0], T.int32(0))
tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0])
tvm.ir.assert_structural_equal(res[2][0], T.bool(True))
tvm.ir.assert_structural_equal(res[2][1], (floormod(j0[0], 4) * 2) + i3[0] < 7)
res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices
assert len(res1) == 2
res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])).indices
assert len(res2) == 2
# compound 1.6
res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], tvm.tir.all(k0[0] < 7, k1[0] < 7))
res = convert_division(res)
assert len(res) == 0
# compound 2
j0 = create_iter("j0", 4)
l0 = create_iter("l0", 2)
l1 = create_iter("l1", 6)
j3 = create_iter("j3", 3)
k0 = ifuse([l0, l1])
i1, j2 = isplit(k0, 3)
j1, i1 = isplit(i1, 2)
i0 = ifuse([j0, j1])
i2 = ifuse([j2, j3])
# compound 2.1
res = tvm.arith.subspace_divide([i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l1[0], j3[0]])
res = convert_division(res)
assert len(res) == 4
tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0])
tvm.ir.assert_structural_equal(res[0][1], T.int32(0))
tvm.ir.assert_structural_equal(res[1][0], T.int32(0))
tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3))
tvm.ir.assert_structural_equal(res[2][0], T.int32(0))
tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0])
res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices
assert len(res1) == 3
res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])).indices
assert len(res2) == 3
# compound 2.2
res = tvm.arith.subspace_divide([i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l0[0], l1[0], j3[0]])
res = convert_division(res)
assert len(res) == 4
tvm.ir.assert_structural_equal(res[0][0], j0[0])
tvm.ir.assert_structural_equal(res[0][1], floordiv(l0[0] * 6 + l1[0], 6))
tvm.ir.assert_structural_equal(res[1][0], T.int32(0))
tvm.ir.assert_structural_equal(res[1][1], floordiv(floormod(l0[0] * 6 + l1[0], 6), 3))
tvm.ir.assert_structural_equal(res[2][0], T.int32(0))
tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0])
res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l0, l1, j3])).indices
assert len(res1) == 3
res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0])).indices
assert len(res2) == 3
# compound 2.3
res = tvm.arith.subspace_divide([i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l0[0], j3[0]])
res = convert_division(res)
assert len(res) == 0
# compound 2.4
res = tvm.arith.subspace_divide(
[i0[0], i1[0], i2[0]],
var_dom([j0, l0, l1, j3]),
[l1[0], j3[0]],
tvm.tir.all(i0[0] < 7, i2[0] < 8),
)
res = convert_division(res)
assert len(res) == 4
tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0])
tvm.ir.assert_structural_equal(res[0][1], T.int32(0))
tvm.ir.assert_structural_equal(res[1][0], T.int32(0))
tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3))
tvm.ir.assert_structural_equal(res[2][0], T.int32(0))
tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0])
tvm.ir.assert_structural_equal(res[3][0], (j0[0] * 2) + l0[0] < 7)
tvm.ir.assert_structural_equal(res[3][1], (floormod(l1[0], 3) * 3) + j3[0] < 8)
res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices
assert len(res1) == 3
res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])).indices
assert len(res2) == 3
# compound 2.5
res = tvm.arith.subspace_divide([i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [j3[0]], i2[0] < 8)
res = convert_division(res)
assert len(res) == 0
def test_subspace_divide_trivial_iters():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
# z = tvm.tir.Var("z", "int32")
# trivial 1.1
res = tvm.arith.subspace_divide([x * 16 + y], var_dom([(x, 1), (y, 16)]), [y], simplify_trivial_iterators=False)
res = convert_division(res)
assert len(res) == 2
tvm.ir.assert_structural_equal(res[0][0], x)
tvm.ir.assert_structural_equal(res[0][1], y)
# trivial 1.2
res = tvm.arith.subspace_divide(
[x, y],
var_dom([(x, 1), (y, 1)]),
[y],
simplify_trivial_iterators=False,
)
res = convert_division(res)
assert len(res) == 3
tvm.ir.assert_structural_equal(res[0][0], x)
tvm.ir.assert_structural_equal(res[0][1], T.int32(0))
tvm.ir.assert_structural_equal(res[1][0], T.int32(0))
tvm.ir.assert_structural_equal(res[1][1], y)
def test_complex():
n0 = create_iter("n0", 2)
n1 = create_iter("n1", 4)
m0 = ifuse([n0, n1], 6)
m1 = create_iter("m1", 3)
l0 = create_iter("l0", 4)
l1 = create_iter("l1", 8)
l2 = ifuse([m0, m1], 16)
l3 = create_iter("l3", 32)
k0, k4 = isplit(l0, 2)
k1, k5 = isplit(l1, 2)
k2, k6 = isplit(l2, 4)
k3, k7 = isplit(l3, 4)
j0 = ifuse([k0, k1], 7)
j1 = ifuse([k2, k3])
j2 = ifuse([k4, k5])
j3 = ifuse([k6, k7], 15)
i0 = ifuse([j0, j1], 200)
i1 = ifuse([j2, j3], 50)
n0_mark = tvm.arith.IterMark(n0[0], n0[1])
n1_mark = tvm.arith.IterMark(n1[0], n1[1])
l0_mark = tvm.arith.IterMark(l0[0], l0[1])
l1_mark = tvm.arith.IterMark(l1[0], l1[1])
m1_mark = tvm.arith.IterMark(m1[0], m1[1])
l3_mark = tvm.arith.IterMark(l3[0], l3[1])
m0_expr = tvm.arith.IterSumExpr(
[
tvm.arith.IterSplitExpr(n0_mark, 1, n0[1], 4),
tvm.arith.IterSplitExpr(n1_mark, 1, n1[1], 1),
],
0,
)
m0_mark = tvm.arith.IterMark(m0_expr, 6)
l2_expr = tvm.arith.IterSumExpr(
[tvm.arith.IterSplitExpr(m0_mark, 1, 6, 3), tvm.arith.IterSplitExpr(m1_mark, 1, m1[1], 1)],
0,
)
l2_mark = tvm.arith.IterMark(l2_expr, 16)
k0_expr = tvm.arith.IterSplitExpr(l0_mark, 2, 2, 4)
k1_expr = tvm.arith.IterSplitExpr(l1_mark, 2, 4, 1)
k2_expr = tvm.arith.IterSplitExpr(l2_mark, 4, 4, 8)
k3_expr = tvm.arith.IterSplitExpr(l3_mark, 4, 8, 1)
k4_expr = tvm.arith.IterSplitExpr(l0_mark, 1, 2, 30)
k5_expr = tvm.arith.IterSplitExpr(l1_mark, 1, 2, 15)
k6_expr = tvm.arith.IterSplitExpr(l2_mark, 1, 4, 4)
k7_expr = tvm.arith.IterSplitExpr(l3_mark, 1, 4, 1)
j0_expr = tvm.arith.IterSumExpr([k0_expr, k1_expr], 0)
j0_mark = tvm.arith.IterMark(j0_expr, 7)
i0_expr = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(j0_mark, 1, 7, 32), k2_expr, k3_expr], 0)
j3_expr = tvm.arith.IterSumExpr([k6_expr, k7_expr], 0)
j3_mark = tvm.arith.IterMark(j3_expr, 15)
i1_expr = tvm.arith.IterSumExpr([k4_expr, k5_expr, tvm.arith.IterSplitExpr(j3_mark, 1, 15, 1)], 0)
i0_mark = tvm.arith.IterMark(i0_expr, i0[1])
i1_mark = tvm.arith.IterMark(i1_expr, i1[1])
i0_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i0_mark, 1, i0[1], 1)], 0)
i1_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i1_mark, 1, i1[1], 1)], 0)
assert_iter_sum_pattern(
{i0[0]: (200, 0, 1, i0_final), i1[0]: (50, 0, 1, i1_final)},
var_dom([l0, l1, n0, n1, m1, l3]),
predicate=tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15),
)
# wrong constraint
assert_iter_sum_failure(
[i0[0], i1[0]],
var_dom([l0, l1, n0, n1, m1, l3]),
tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 9, l2[0] < 16, j0[0] < 7, j3[0] < 14),
)
# subspace_division
res = tvm.arith.subspace_divide(
[i0[0], i1[0]],
var_dom([l0, l1, n0, n1, m1, l3]),
[n0[0], n1[0], m1[0], l3[0]],
tvm.tir.all(m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15),
)
res = convert_division(res)
assert len(res) == 3
tvm.ir.assert_structural_equal(res[0][0], floordiv(l0[0], 2) * 4 + floordiv(l1[0], 2))
tvm.ir.assert_structural_equal(res[0][1], (floordiv((n0[0] * 4 + n1[0]) * 3 + m1[0], 4) * 8) + floordiv(l3[0], 4))
tvm.ir.assert_structural_equal(res[1][0], ((floormod(l0[0], 2) * 2) + floormod(l1[0], 2)))
tvm.ir.assert_structural_equal(res[1][1], ((floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4) + floormod(l3[0], 4)))
tvm.ir.assert_structural_equal(res[2][0], (floordiv(l0[0], 2) * 4) + floordiv(l1[0], 2) < 7)
tvm.ir.assert_structural_equal(
res[2][1],
tvm.tir.all(
n0[0] * 4 + n1[0] < 6,
(n0[0] * 4 + n1[0]) * 3 + m1[0] < 16,
floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4 + floormod(l3[0], 4) < 15,
),
)
assert_iter_sum_pattern({res[0][1]: (32, 0), res[1][1]: (15, 0)}, var_dom([n0, n1, m1, l3]), res[2][1])
assert_iter_sum_pattern({res[0][0]: (8, 0), res[1][0]: (4, 0)}, var_dom([l0, l1]))
def test_normalize_iter_map_to_expr():
fld = tvm.tir.floordiv
flm = tvm.tir.floormod
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
xo, xi = isplit((x, 10), 5)
yo, yi = isplit((y, 9), 3)
z = ifuse([yo, xo, yi])
res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([(x, 10), (y, 9)]))
tvm.ir.assert_structural_equal(
tvm.arith.normalize_iter_map_to_expr(res.indices[0]),
fld(y, 3) * 6 + fld(x, 5) * 3 + flm(y, 3),
)
tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res.indices[1]), flm(x, 5))
# iter mark wrap a complex expr
split = tvm.arith.IterSplitExpr(tvm.arith.IterMark(x * y + 1, 1024), 1, 1024, 1)
tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(split), x * y + 1)
def test_inverse_affine_iter_map():
analyzer = tvm.arith.Analyzer()
l0 = create_iter("l0", 64)
l1 = create_iter("l1", 64)
l2 = create_iter("l2", 64)
# simple case
l0_0, l0_1 = isplit(l0, 16)
l1_0, l1_1 = isplit(l1, 4)
l0_1_l1_1_fused = ifuse([l0_1, l1_1])
iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1])).indices
outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))]
res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
assert len(res) == 2
l0_inverse = floordiv(outputs[0], 4) + outputs[1] * 16
l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4
assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
assert analyzer.can_prove_equal(res[l1[0]], l1_inverse)
# compound case
l0_0, l0_1 = isplit(l0, 16)
l1_0, l1_1 = isplit(l1, 4)
l2_1, l2_2 = isplit(l2, 4)
l2_0, l2_1 = isplit(l2_1, 4)
l0_1_l2_1_l1_1_l2_0_fused = ifuse([l0_1, l2_1, l1_1, l2_0])
iter_map = tvm.arith.detect_iter_map([l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], var_dom([l0, l1, l2])).indices
outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))]
res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
assert len(res) == 3
l0_inverse = floordiv(outputs[0], 64) + outputs[1] * 16
l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4
l2_inverse = floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 + outputs[2]
assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
assert analyzer.can_prove_equal(res[l1[0]], l1_inverse)
assert analyzer.can_prove_equal(res[l2[0]], l2_inverse)
# diamond-shape DAG
l0_0, l0_1 = isplit(l0, 16)
l1 = ifuse([l0_1, l0_0])
l1_0, l1_1 = isplit(l1, 8)
l2 = ifuse([l1_1, l1_0])
iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])).indices
outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))]
res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
assert len(res) == 1
l1_inverse = floormod(outputs[0], 8) * 8 + floordiv(outputs[0], 8)
l0_inverse = floormod(l1_inverse, 4) * 16 + floordiv(l1_inverse, 4)
assert analyzer.can_prove_equal(res[l0[0]], l0_inverse)
def test_inverse_affine_map_trivial_iter():
analyzer = tvm.arith.Analyzer()
l0 = create_iter("l0", 64)
l1 = create_iter("l1", 64)
iter_map = tvm.arith.detect_iter_map([0, l0[0], l1[0]], var_dom([l0, l1])).indices
outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))]
res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
# output_0 is expected to be constant and it is not included in the inverse map
assert len(res) == 2
assert analyzer.can_prove_equal(res[l0[0]], outputs[1])
assert analyzer.can_prove_equal(res[l1[0]], outputs[2])
def test_free_variables():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")
# illegal iter if z is within dom
assert_iter_sum_failure([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)]))
# iter is valid if z is free, even there are linear forms of z
assert_iter_sum_pattern(
{z * 19 + y * 3 + x: (9, z * 19)},
var_dom(
[
(x, 3),
(y, 3),
]
),
)
assert_iter_sum_pattern(
{z * z + y * 3 + x: (9, z * z)},
var_dom(
[
(x, 3),
(y, 3),
]
),
)
class TestPadding:
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
fld = tvm.tir.floordiv
flm = tvm.tir.floormod
positive_test_case = tvm.testing.parameter(
# left padding only, offset divisible
({y: 192}, {fld(64 + y, 32): (6, 2, 1), flm(64 + y, 32): (32, 0, 1)}, "bijective"),
# left padding only, offset non-divisible
({y: 176}, {fld(80 + y, 32): (6, 2, 1)}),
({y: 176}, {flm(fld(80 + y, 2), 16): (16, 0, 1), flm(80 + y, 2): (2, 0, 1)}),
# right padding only, offset divisible
({x: 5, y: 4}, {fld(x * 32 + y * 8, 16): (10, 0, 1), flm(x * 32 + y * 8, 16): (2, 0, 8)}),
# right padding only, offset non-divisible
({x: 26}, {fld(x, 15): (2, 0, 1)}),
({x: 26}, {flm(fld(x, 3), 5): (5, 0, 1), flm(x, 3): (3, 0, 1)}),
# padding constants on both side
({x: 45}, {fld(x + 71, 32): (2, 2, 1)}),
({x: 45}, {flm(fld(x, 4), 8): (8, 0, 1), flm(x, 4): (4, 0, 1)}),
# padding for free iteration part
({y: 360}, {fld(x * 360 + y, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}),
({y: 360}, {flm(x * 360 + y, 16): (16, 0, 1)}),
# multiple split with same mark offset, could
# be surjective on missing (padded // LCM)
(
{x: 240},
{
flm(x + 10, 3): (3, 0),
flm(fld(x + 10, 3), 4): (4, 0),
flm(fld(fld(x + 10, 3), 4), 5): (5, 0),
},
),
# different offsets on splits
(
{x: 240},
{
flm(x + 1, 3): (3, 0),
flm(fld(x + 10, 3) + 2, 4): (4, 0),
flm(fld(fld(x + 10, 3), 4) + 3, 5): (5, 0),
},
),
)
negative_test_case = tvm.testing.parameter(
# left padding only, offset non-divisible
({y: 176}, {fld(80 + y, 32), flm(80 + y, 32)}),
({y: 176}, {fld(80 + y, 32), fld(80 + y, 4)}),
# right padding only, offset divisible
({x: 5, y: 4}, {fld(x * 32 + y * 8, 5)}),
# multiple split with same mark offset, could
# be surjective on missing (padded // LCM)
(
{x: 240},
{
flm(x + 10, 3),
flm(fld(x + 10, 3), 4),
flm(fld(fld(x + 10, 3), 4), 5),
fld(fld(fld(x + 10, 3), 4), 5),
},
),
# original extent is smaller than the divident
# it is not surjective wrt to the region [0, 16)
({x: 3}, {flm(x, 16)}),
# (x % c1) // c2 is not proved as surjective if c1 % c2 != 0
({x: 255}, {fld(flm(x, 255), 16)}),
)
def test_padding(self, positive_test_case):
iter_extent, mapped_iterators, *args = positive_test_case
check_level = args[0] if args else "surjective"
dom_map = {var: tvm.ir.Range(0, ext) for var, ext in iter_extent.items()}
assert_iter_sum_pattern(mapped_iterators, dom_map, check_level=check_level)
def test_padding_error(self, negative_test_case):
iter_extent, mapped_iterators, *args = negative_test_case
check_level = args[0] if args else "surjective"
dom_map = {var: tvm.ir.Range(0, ext) for var, ext in iter_extent.items()}
assert_iter_sum_failure(mapped_iterators, dom_map, check_level=check_level)
def test_overlapped_fuse():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")
a = tvm.tir.Var("x", "int32")
b = tvm.tir.Var("y", "int32")
# non-bijective fuse of two
assert_iter_sum_pattern(
{
x * 7 + y: (22, 0, 1),
},
var_dom([(x, 3), (y, 8)]),
check_level="surjective",
)
assert_iter_sum_failure([x * 7 + y], var_dom([(x, 3), (y, 8)]), check_level="bijective")
# non-bijective fuse of three
assert_iter_sum_pattern(
{
x * 18 + y * 7 + z: (40, 0, 1),
},
var_dom([(x, 2), (y, 3), (z, 8)]),
check_level="surjective",
)
assert_iter_sum_failure([x * 7 + y], var_dom([(x, 2), (y, 3), (z, 8)]), check_level="bijective")
# negative scale fusion is not allowed
assert_iter_sum_failure([x * -7 + y], var_dom([(x, 3), (y, 8)]), check_level="surjective")
assert_iter_sum_failure([x * 7 - y], var_dom([(x, 3), (y, 8)]), check_level="surjective")
# with predicate
assert_iter_sum_pattern(
{
a * 40 + b * 20 + x * 18 + y * 3 + z: (125, 6, 1),
},
var_dom([(a, 3), (b, 2), (x, 2), (y, 6), (z, 8)]),
predicate=tvm.tir.all(z < 4, x * 6 + y > 1, x * 6 + y < 10),
check_level="surjective",
)
# stride=1 kernel
assert_iter_sum_pattern({x + a: (230, 0, 1)}, var_dom([(x, 224), (a, 7)]), check_level="surjective")
# do not allow both strided and overlapped
assert_iter_sum_failure([5 * x + 2 * y], var_dom([(x, 4), (y, 3)]), check_level="surjective")
def test_iter_map_simplify_symbolic_case():
"""Test itermap simplify"""
x = tvm.tir.Var("x", "int64")
y = tvm.tir.Var("y", "int64")
z = x * 32 + y
n = tvm.tir.SizeVar("n", "int64")
def simple_fuse0(x):
return (x // n) * n + x % n
assert_iter_map_simplify({simple_fuse0(x): x}, var_dom([(x, n * 32)]))
assert_iter_map_simplify({simple_fuse0(z): z}, var_dom([(x, n), (y, 32)]))
def fsymbolic_fuse0(x):
return ((x // (n * n)) % 32) * (n * n) + ((x // n) % n) * n + x % n
assert_iter_map_simplify({fsymbolic_fuse0(x): x}, var_dom([(x, n * n * 32)]))
assert_iter_map_simplify({fsymbolic_fuse0(z): z}, var_dom([(x, n * n), (y, 32)]))
def fsymbolic_fuse1(x):
return ((x % (n * n * 32)) // (n * n) * n + (x % (n * n) // n)) * n + x % n
assert_iter_map_simplify({fsymbolic_fuse1(x): x}, var_dom([(x, n * n * 32)]))
assert_iter_map_simplify({fsymbolic_fuse1(z): z}, var_dom([(x, n * n), (y, 32)]))
def fsymbolic_fuse2(i):
return (i // (n * n) * n + i % (n * n) // n) * n + i % n
assert_iter_map_simplify({fsymbolic_fuse2(x): x}, var_dom([(x, n * n * 32)]))
def test_iter_map_simplify_symbolic_predicate():
"""Test itermap simplify"""
x = tvm.tir.Var("x", "int64")
y = tvm.tir.Var("y", "int64")
n = tvm.tir.SizeVar("n", "int64")
def simple_fuse0(x):
return (x // n) * n + x % n
z = x * 32 + y
assert_iter_map_simplify({simple_fuse0(z): z}, var_dom([(x, (n + 1) // 2), (y, 32)]), predicate=(z < n * 16))
def fsymbolic_fuse2(i):
return (i // (n * n) * n + i % (n * n) // n) * n + i % n
z = x * 64 + y
assert_iter_map_simplify(
{fsymbolic_fuse2(z): z},
var_dom([(x, (n * n + 1) // 2), (y, 64)]),
predicate=(z < n * n * 32),
)
def test_iter_map_simplify_symbolic_reshape():
n = tvm.tir.Var("n", "int64")
fused = tvm.tir.Var("fused", "int64")
ax0 = (fused // 4096) // n
ax1 = (fused // 4096) % n
ax2 = fused % 4096
rhs_index = ((ax2 // 4096 + ax0 * n + ax1) % n) * 4096 + ax2 % 4096
assert_iter_map_simplify({rhs_index: fused}, var_dom([(fused, n * 4096)]))
def test_iter_map_simplify_unit_loop_order():
"""Test itermap simplify"""
x = tvm.tir.Var("x", "int64")
y = tvm.tir.Var("y", "int64")
z = tvm.tir.Var("z", "int64")
# trivial iterators can be found at any when comparing via scale
# ensure order unchange
assert_iter_map_simplify({x + y + z: x + y + z}, var_dom([(x, 1), (y, 1), (z, 1)]), simplify_trivial_iterators=False)
# Even with simplification, it should follow the original order
assert_iter_map_simplify(
{x + y + (z // 4) * 4 + z % 4: z + x + y},
var_dom([(x, 1), (y, 1), (z, 32)]),
simplify_trivial_iterators=False,
)
assert_iter_map_simplify(
{y + 64 - x % 2 * 64: y + 64 - x % 2 * 64},
var_dom([(x, 6), (y, 64)]),
simplify_trivial_iterators=False,
)
# When we have iterators that have same scale but one of them come
# with unit extent, we should prioritize unit extent
assert_iter_map_simplify(
{x // 128 + y + z: y + z},
var_dom([(x, 128), (y, 128), (z, 1)]),
simplify_trivial_iterators=False,
)
def assert_normalize_to_iter_sum(index, input_iters, args, base):
"""Assert the result of arith.normalize_to_iter_sum is correct
Parameters
----------
index : tvm.tir.PrimExpr
The index to be normalized
input_iters : Mapping[Var, Range]
The input iterators
args : List[Union[tvm.arith.IterSplitExpr, Tuple[PrimExpr, PrimExpr]]]
The expected result. Ordered list of args of the expected IterSumExpr. Each arg can be
either IterSplitExpr or a tuple of (PrimExpr, PrimExpr) where the first element is the
iterator normalized to PrimExpr and the second element is the scale.
base : tvm.tir.PrimExpr
The expected base
"""
res = tvm.arith.normalize_to_iter_sum(index, input_iters)
assert isinstance(res, tvm.arith.IterSumExpr)
assert len(res.args) == len(args)
for split, item in zip(res.args, args):
if isinstance(item, tvm.arith.IterSplitExpr):
tvm.ir.assert_structural_equal(split, item)
continue
tvm.testing.assert_prim_expr_equal(split.scale, item[1])
tvm.testing.assert_prim_expr_equal(tvm.arith.normalize_iter_map_to_expr(split), item[0] * item[1])
tvm.testing.assert_prim_expr_equal(res.base, base)
def test_normalize_to_iter_sum():
x = tvm.tir.Var("x", "int64")
y = tvm.tir.Var("y", "int64")
z = tvm.tir.Var("z", "int64")
a = tvm.tir.Var("a", "int64")
n = tvm.tir.Var("n", "int64")
# flm = tvm.tir.floormod
assert_normalize_to_iter_sum(
z + ((y + x * 4 + 2) * n) + 3,
var_dom([(x, 9), (y, 4), (z, 3)]),
[(x, n * 4), (y, n), (z, 1)],
2 * n + 3,
)
# max cannot detected so it goes into base
assert_normalize_to_iter_sum(
tvm.tir.max(z, a) + ((y + x * 4 + 2) * n) + 3,
var_dom([(x, 9), (y, 4), (z, 3)]),
[(x, n * 4), (y, n)],
tvm.tir.max(z, a) + 2 * n + 3,
)
# order by symbolic prod
assert_normalize_to_iter_sum(
z + ((y * 4 * a + x * 4 + 2) * n) + 3,
var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
[(y, a * n * 4), (x, n * 4), (z, 1)],
2 * n + 3,
)
# order by cscale
assert_normalize_to_iter_sum(
z + 2 * y * 3 + 4 * x,
var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
[(y, 6), (x, 4), (z, 1)],
0,
)
# split pattern
assert_normalize_to_iter_sum(
z + 2 * y * 3 + 4 * (x // 2),
var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
[(y, 6), (x // 2, 4), (z, 1)],
0,
)
# non-divisible
assert_normalize_to_iter_sum(
x // 5,
var_dom([(x, 4096)]),
[
tvm.arith.IterSplitExpr(
tvm.arith.IterMark(x, 4096),
lower_factor=tvm.tir.const(5, "int64"),
extent=tvm.tir.const(820, "int64"),
scale=tvm.tir.const(1, "int64"),
)
],
0,
)
# iter simplify
assert_normalize_to_iter_sum(
z * 2 + 2 * y * 3 + 4 * (x // 4) + (x % 4),
var_dom([(y, a * n * 4), (x, n * 4), (z, a)]),
[(y, 6), (z, 2), (x, 1)],
0,
)
def test_detect_iter_map_with_bufferload_recursion():
n = tvm.tir.Var("n", "int32")
m = tvm.tir.Var("m", "int32")
divisor = tvm.tir.Var("divisor", "int32")
i = tvm.tir.Var("i", "int32")
j = tvm.tir.Var("j", "int32")
buffer = tvm.tir.decl_buffer((n,), "int32", name="seqlen")
indices = [(buffer[i] + j) // divisor]
iter_vars = {
i: tvm.ir.Range(tvm.tir.const(0, "int32"), n),
j: tvm.ir.Range(tvm.tir.const(0, "int32"), m),
}
result = tvm.arith.detect_iter_map(indices, iter_vars)
assert len(result.indices) == 0
if __name__ == "__main__":
tilelang.testing.main()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tilelang import tvm
import tilelang.testing
from tvm import tir
import tvm.ir
def test_simplify_reshape_flattened_index():
ana = tvm.arith.Analyzer()
i0 = tir.Var("i0", "int64")
i1 = tir.Var("i1", "int64")
ana.bind(i0, tvm.ir.Range(0, 8))
ana.bind(i1, tvm.ir.Range(0, 3))
i_flattened = i0 * 3 + i1
tvm.ir.assert_structural_equal(
ana.simplify((i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + (i_flattened) % 4),
i_flattened,
)
dtype = tvm.testing.parameter(
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"float16",
"float32",
"float64",
)
def test_can_prove_self_identity(dtype):
ana = tvm.arith.Analyzer()
n = tir.Var("n", dtype)
assert ana.can_prove(n == n)
def test_can_prove_self_equal_to_self(dtype):
ana = tvm.arith.Analyzer()
n = tir.Var("n", dtype)
assert ana.can_prove_equal(n, n)
def test_simplify_symbolic_comparison():
ana = tvm.arith.Analyzer()
i0 = tir.Var("i0", "int64")
i1 = tir.Var("i1", "int64")
n, m = tvm.tir.SizeVar("n", "int64"), tvm.tir.SizeVar("m", "int64")
outer = (n + 31) // 32
ana.bind(i0, tvm.ir.Range(0, outer))
ana.bind(i1, tvm.ir.Range(0, 32))
PS = tvm.arith.ProofStrength
assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)
assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32 + m, PS.SYMBOLIC_BOUND)
assert ana.can_prove(i0 * 32 + i1 + 1 <= (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)
assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1 + 1, PS.SYMBOLIC_BOUND)
assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, PS.SYMBOLIC_BOUND)
def test_regression_simplify_inf_recursion():
ana = tvm.arith.Analyzer()
cond = tir.Var("cond", "int32")
res = (tvm.tir.NE(cond, 0).astype("int8") - tvm.tir.NE(cond, 0).astype("int8")).astype("int32") == 0
# regression in a previous case
# try compare and int set recursive call can cause infinite loop
ana.rewrite_simplify(res)
def test_simplify_floor_mod_with_linear_offset():
"""
Test that the floor_mod is simplified correctly when the offset is linear.
"""
ana = tvm.arith.Analyzer()
past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", "int64")
expr1 = (past_decoder_sequence_length + 1) * 64
divisor1 = (past_decoder_sequence_length + 1) * 32
assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor1), 0)
divisor2 = 32 * (past_decoder_sequence_length + 1)
assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor2), 0)
def test_simplify_float_division():
# Test for the discussion:
# https://discuss.tvm.apache.org/t/discuss-is-constant-division-to-multiplication-rewrite-in-tvm-necessary/18615
ana = tvm.arith.Analyzer()
x = tir.Var("x", "float32")
ry = x / 27
# in old version, the division will be rewritten into x * T.float32(1 / 27)
sy = ana.rewrite_simplify(ry)
tvm.ir.assert_structural_equal(ry, sy)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -216,6 +216,8 @@ def run_gemm_sp(
print("pass")
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
def run_gemm_sp_sm90(
M,
N,
......@@ -228,8 +230,8 @@ def run_gemm_sp_sm90(
block_K,
num_stages,
num_threads,
trans_A=False,
trans_B=False,
trans_A,
trans_B,
):
kernel = matmul_sp_sm90(
M,
......@@ -259,6 +261,9 @@ def run_gemm_sp_sm90(
)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(8, 0)
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def run_gemm_sp_sm80(
M,
N,
......@@ -271,8 +276,8 @@ def run_gemm_sp_sm80(
block_K,
num_stages,
num_threads,
trans_A=False,
trans_B=False,
trans_A,
trans_B,
):
kernel = matmul_sp_sm80(
M,
......
......@@ -41,34 +41,35 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
def issue_1013_buggy_kernel():
# NOTE: This kernel is mainly to test some corner cases in boundary check
num_tokens = T.dynamic("num_tokens")
num_threads = 128
@T.prim_func
def main(x: T.Tensor((num_tokens,), dtype="int64")):
with T.Kernel(1, threads=num_threads) as _:
count = T.alloc_var("int")
thread_idx = T.get_thread_binding()
for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
idx = thread_idx + i * num_threads
count += x[idx] == 2
# NOTE(chaofan): Ideally, the prover should be able to prove that the access is safe
# and the padding value is not used. However, the current prover cannot handle this case.
# So for now the expected kernel is a if-else statement to check the boundary.
@T.prim_func
def expected(x: T.Tensor((num_tokens,), dtype="int64")):
with T.Kernel(1, threads=num_threads) as _:
count = T.alloc_var("int")
thread_idx = T.get_thread_binding()
for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
idx = thread_idx + i * num_threads
count += T.Cast("int32", T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2))
return main, expected
# def issue_1013_buggy_kernel():
# # NOTE: This kernel is mainly to test some corner cases in boundary check
# num_tokens = T.dynamic('num_tokens')
# num_threads = 128
# @T.prim_func
# def main(x: T.Tensor((num_tokens,), dtype="int64")):
# with T.Kernel(1, threads=num_threads) as _:
# count = T.alloc_var('int')
# thread_idx = T.get_thread_binding()
# for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
# idx = thread_idx + i * num_threads
# count += x[idx] == 2
# # NOTE(chaofan): Ideally, the prover should be able to prove that the access is safe
# # and the padding value is not used. However, the current prover cannot handle this case.
# # So for now the expected kernel is a if-else statement to check the boundary.
# @T.prim_func
# def expected(x: T.Tensor((num_tokens,), dtype="int64")):
# with T.Kernel(1, threads=num_threads) as _:
# count = T.alloc_var('int')
# thread_idx = T.get_thread_binding()
# for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
# idx = thread_idx + i * num_threads
# count += T.Cast("int32",
# value=T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2))
# return main, expected
def vectorize_access_with_atmoic_add_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2):
......@@ -151,11 +152,11 @@ def test_vectorize_access():
assert_vectorize_access(64, 64)
def test_issue_1013():
func, expected = issue_1013_buggy_kernel()
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
transformed = tl.transform.LegalizeSafeMemoryAccess()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
# def test_issue_1013():
# func, expected = issue_1013_buggy_kernel()
# mod = tvm.IRModule({func.attrs["global_symbol"]: func})
# transformed = tl.transform.LegalizeSafeMemoryAccess()(mod)
# tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
def test_vectorize_access_with_atmoic_add():
......
......@@ -243,8 +243,8 @@ def reduce_bitxor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: boo
@macro
def cumsum_fragment(
src: tir.Buffer | tir.BufferRegion | tir.BufferLoad,
dst: tir.Buffer | tir.BufferRegion | tir.BufferLoad,
src: tir.Buffer,
dst: tir.Buffer,
dim: int,
reverse: bool,
) -> tir.PrimExpr:
......
......@@ -21,7 +21,6 @@ def _get_cached_lib():
try:
return _import_module_from_library(name, _CACHE_DIR, is_python_module=True)
except Exception:
# If loading fails, recompile
pass
# Set TORCH_CUDA_ARCH_LIST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment