"library/vscode:/vscode.git/clone" did not exist on "0f912e205eec6e349060f2203a8eeabc5e7ba075"
Unverified Commit 9c21586b authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Feat] Integrate Z3 in TVM Arith Analyzer (#1367)

parent 899f7bd5
Subproject commit 2b1ead1a375704c75af563cc800aa9347583ba2b
Subproject commit 4d3ec9253e346b2281513700e692124aefaff347
......@@ -222,6 +222,14 @@ elseif(USE_CUDA)
list(APPEND TILE_LANG_INCLUDES ${CUDAToolkit_INCLUDE_DIRS})
endif()
set(USE_Z3 ON CACHE STRING "Use Z3 SMT solver for TileLang optimizations")
set(USE_PYPI_Z3 ON CACHE BOOL "Use Z3 provided by PyPI z3-solver package")
if(USE_Z3 AND USE_PYPI_Z3)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/pypi-z3")
find_package(Z3 REQUIRED)
endif()
# Include tvm after configs have been populated
add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL)
......@@ -288,19 +296,32 @@ install(TARGETS tilelang_cython_wrapper
RUNTIME DESTINATION tilelang/lib
ARCHIVE DESTINATION tilelang/lib)
# let libtilelang to search tvm/tvm_runtime in same dir
# add python z3 lib path to rpath for running tests and dev in current folder
if(USE_Z3 AND USE_PYPI_Z3)
set_property(TARGET tvm APPEND PROPERTY BUILD_RPATH ${Python3_SITELIB}/z3/lib)
set_property(TARGET tvm APPEND PROPERTY BUILD_RPATH ${Python3_SITELIB}/z3/bin)
endif()
if(APPLE)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
set(TILELANG_INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib")
if(USE_Z3 AND USE_PYPI_Z3)
# some z3 is placed in lib/ and some in bin/, we add both in rpath
list(APPEND TILELANG_INSTALL_RPATH "@loader_path/../../z3/lib" "@loader_path/../../z3/bin")
endif()
elseif(UNIX)
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
set(TILELANG_INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib")
if(USE_Z3 AND USE_PYPI_Z3)
# cmake uses ; by default, we explicitly use : for linux
string(APPEND TILELANG_INSTALL_RPATH ":\$ORIGIN/../../z3/lib")
endif()
endif()
# let libtilelang to search tvm/tvm_runtime in same dir
set_target_properties(tilelang PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
set_target_properties(tilelang_module PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
set_target_properties(tvm PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
set_target_properties(tvm_runtime PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}")
install(
TARGETS tvm tvm_runtime tilelang_module tilelang
LIBRARY DESTINATION tilelang/lib
......
if(Z3_FOUND)
return()
endif()
find_package(Python3 COMPONENTS Interpreter REQUIRED)
execute_process(
COMMAND "${Python3_EXECUTABLE}" -c "import z3; print(z3.__path__[0])"
OUTPUT_VARIABLE Z3_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE Z3_PYTHON_RESULT
)
if(NOT Z3_PYTHON_RESULT EQUAL 0 OR Z3_PATH STREQUAL "")
message(FATAL_ERROR "Failed to locate z3 Python package. Ensure z3-solver>=4.13.0 is installed.")
endif()
message("-- Find Z3 in path: ${Z3_PATH}")
find_path(Z3_INCLUDE_DIR NO_DEFAULT_PATH NAMES z3++.h PATHS ${Z3_PATH}/include)
find_library(Z3_LIBRARY NO_DEFAULT_PATH NAMES z3 libz3 PATHS ${Z3_PATH}/bin ${Z3_PATH}/lib ${Z3_PATH}/lib64)
message("-- Found Z3 include dir: ${Z3_INCLUDE_DIR}")
message("-- Found Z3 library: ${Z3_LIBRARY}")
add_library(z3::libz3 SHARED IMPORTED GLOBAL)
set_target_properties(z3::libz3
PROPERTIES
IMPORTED_LOCATION ${Z3_LIBRARY}
INTERFACE_INCLUDE_DIRECTORIES ${Z3_INCLUDE_DIR}
)
if(NOT Z3_INCLUDE_DIR OR NOT Z3_LIBRARY)
message(FATAL_ERROR "Could not find Z3 library or include directory")
endif()
set(Z3_CXX_INCLUDE_DIRS ${Z3_INCLUDE_DIR})
set(Z3_C_INCLUDE_DIRS ${Z3_INCLUDE_DIR})
set(Z3_FOUND TRUE)
......@@ -43,6 +43,7 @@ dependencies = [
"torch>=2.7; platform_system == 'Darwin'",
"tqdm>=4.62.3",
"typing-extensions>=4.10.0",
"z3-solver>=4.13.0",
]
[project.optional-dependencies]
......@@ -53,7 +54,14 @@ fp4 = ["ml-dtypes>=0.5.1"]
vis = ["matplotlib"]
[build-system]
requires = ["cython>=3.0.0", "scikit-build-core"]
requires = [
"cython>=3.0.0",
"scikit-build-core",
"z3-solver>=4.13.0",
# Not for auditwheel, explicitly add patchelf for repairing libz3.so.
# See tvm's CMakeLists.txt for more information.
"patchelf>=0.17.2; platform_system == 'Linux'",
]
build-backend = "scikit_build_core.build"
[tool.scikit-build]
......@@ -227,7 +235,7 @@ environment.PYTHONUNBUFFERED = "1"
environment.PATH = "/usr/local/cuda/bin:$PATH"
environment.LD_LIBRARY_PATH = "/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:$LD_LIBRARY_PATH"
manylinux-x86_64-image = "manylinux_2_28" # AlmaLinux 8
manylinux-aarch64-image = "manylinux_2_28" # AlmaLinux 8
manylinux-aarch64-image = "manylinux_2_34" # Z3 requires
# Install CUDA runtime and stub driver library
# manylinux_2_28 uses gcc 14, which needs CUDA >=12.8
before-all = """
......@@ -256,7 +264,7 @@ yum install -y "cuda-minimal-build-${v}" "cuda-driver-devel-${v}" "cuda-nvrtc-de
yum clean all
"""
repair-wheel-command = [
"auditwheel -v repair --exclude libtvm_ffi.so --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}",
"auditwheel -v repair --exclude libtvm_ffi.so --exclude libz3.so --exclude libcuda.so.1 --exclude '/usr/local/cuda*' -w {dest_dir} {wheel}",
"pipx run abi3audit --verbose --strict {wheel}",
]
......
......@@ -10,6 +10,7 @@ scikit-build-core
setuptools>=61
torch
wheel
z3-solver>=4.13.0
auditwheel; platform_system == 'Linux'
patchelf; platform_system == 'Linux'
......
......@@ -30,3 +30,4 @@ scipy
tabulate
tornado
wheel
z3-solver>=4.13.0
\ No newline at end of file
......@@ -9,3 +9,4 @@ torch
torch>=2.7; platform_system == 'Darwin'
tqdm>=4.62.3
typing-extensions>=4.10.0
z3-solver>=4.13.0
\ No newline at end of file
import tilelang.testing
import tilelang.language as T
from tvm.arith import Analyzer
from tvm.ir.expr import Range
from tvm.tir.expr import Not, Or
def implies(x, y):
return Or(Not(x), y)
def test_hard_prove():
a = T.Var("a", T.int32)
b = T.Var("b", T.int32)
c = T.Var("c", T.int32)
d = T.Var("d", T.int32)
def check_expr(expr):
analyzer = Analyzer()
result = analyzer.can_prove(expr, 1)
if not result:
smtlib2 = analyzer.get_smtlib2(expr)
raise AssertionError(f"Failed to prove: {expr}\nSMT-LIB2:\n{smtlib2}")
# assert result, f"Failed to prove: {expr}"
@T.macro
def complex_expr_1():
return implies(a > 0 and b > 0 and c > 0, ((b - a) // c) * c + a <= b)
check_expr(complex_expr_1())
@T.macro
def complex_expr_2():
return implies(a < b and b < c and a * d < b * d, b * d < c * d)
check_expr(complex_expr_2())
@T.macro
def complex_expr_3():
return implies(a >= 0 and a < 128, a // 128 == (a // 64 * 32 + a % 32 // 16 * 8) // 64)
check_expr(complex_expr_3())
@T.macro
def complex_expr_4():
return implies(
a >= 0 and a < 128,
(a % 16 * 64 + a // 64 * 32 + a % 8 // 4 * 32 + (a % 32 // 16 + a % 2) % 2 * 8 + 16 - (a // 64 + a % 8 // 4) // 2 * 64) // 512
== (a % 16 * 64 + a // 64 * 32 + a % 8 // 4 * 32 + (a % 32 // 16 + a % 2) % 2 * 8 - (a // 64 + a % 8 // 4) // 2 * 64) // 512,
)
check_expr(complex_expr_4())
def test_smtlib2():
import z3
a = T.Var("a", T.int32)
b = T.Var("b", T.int32)
c = T.Var("c", T.int32)
@T.macro
def complex_expr_1():
return implies(a > 0 and b > 0 and c > 0, ((b - a) // c) * c + a <= b)
e = complex_expr_1()
analyzer = Analyzer()
analyzer.set_z3_timeout_ms(1000)
smtlib2 = analyzer.get_smtlib2(e)
solver = z3.Solver()
solver.from_string(smtlib2)
assert solver.check() == z3.unsat, f"Expected unsat, got {solver.check()}"
def test_bind():
a = T.Var("a", T.int32)
b = T.Var("b", T.int32)
c = T.Var("c", T.int32)
analyzer = Analyzer()
analyzer.bind(a, Range(1, 100000))
analyzer.bind(b, Range(1, 100000))
analyzer.bind(c, Range(1, 100000))
expr = ((b - a) // c) * c + a <= b
smtlib2 = analyzer.get_smtlib2(expr)
try:
result = analyzer.can_prove(expr, 1)
assert result, f"Failed to prove with bindings: {expr}"
except Exception as e:
print(smtlib2)
raise e
if __name__ == "__main__":
tilelang.testing.main()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tilelang import tvm
import tvm.testing
from tvm import te
from tvm import tir
from tvm.arith.analyzer import Analyzer
class IntSetChecker:
def __init__(self):
self.analyzer = tvm.arith.Analyzer()
def verify(self, data, dmap, expected):
res = self.analyzer.int_set(data, dmap)
def err_msg():
return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected)
assert self.analyzer.can_prove_equal(res.min_value, expected[0]), err_msg()
assert self.analyzer.can_prove_equal(res.max_value, expected[1]), err_msg()
def test_basic():
s = tvm.arith.IntervalSet(2, 3)
assert s.min_value.value == 2
assert s.max_value.value == 3
s = tvm.arith.IntSet.single_point(2)
assert s.min_value.value == 2
assert s.max_value.value == 2
def test_vector():
base = 10
stride = 3
lanes = 2
s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, stride, lanes))
assert s.min_value.value == base
assert s.max_value.value == base + stride * (lanes - 1)
def test_scalable_vector():
base = 5
s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, 2, tvm.tir.vscale() * 4))
assert s.min_value.value == base
assert s.max_value.same_as(tvm.arith.int_set.pos_inf())
def test_add_sub():
ck = IntSetChecker()
x, y = te.var("x"), te.var("y")
ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10)}, (y, 10 + y))
ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10), y: tvm.arith.IntervalSet(1, 11)}, (1, 21))
ck.verify(x - y, {x: tvm.arith.IntervalSet(0, 10), y: tvm.arith.IntervalSet(1, 11)}, (-11, 9))
def test_mul_div():
ck = IntSetChecker()
x, y = te.var("x"), te.var("y")
tdiv = tvm.tir.truncdiv
ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
ck.verify(x * y, {x: tvm.arith.IntervalSet(0, 10)}, (0, 10 * y))
ck.verify(x * 2, {x: tvm.arith.IntervalSet(1, 10)}, (2, 20))
ck.verify(x * -2, {x: tvm.arith.IntervalSet(1, 10)}, (-20, -2))
ck.verify(tdiv(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, tdiv(10, y)))
ck.verify(tdiv(x, 2), {x: tvm.arith.IntervalSet(1, 10)}, (0, 5))
fld = tvm.te.floordiv
ck.verify(fld(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y)))
ck.verify(fld(x, 2), {x: tvm.arith.IntervalSet(-1, 10)}, (-1, 5))
def test_mod():
ck = IntSetChecker()
x, y = te.var("x"), te.var("y")
tmod = tvm.tir.truncmod
ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True)
ck.verify(tmod(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, y - 1))
ck.verify(tmod(x, 10), {x: tvm.arith.IntervalSet(1, 10)}, (0, 9))
flm = tvm.te.floormod
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(-10, 10)}, (0, 9))
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 5)}, (3, 5))
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(13, 15)}, (3, 5))
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 15)}, (0, 9))
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 11)}, (0, 9))
ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(1, 21)}, (0, 9))
fld = tvm.te.floordiv
z = te.var("z")
ck.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 3))
ck.verify(
flm(y, 8),
{y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)},
(
z * 8 + x * 4 - 8 * fld(z * 8 + x * 4, 8),
z * 8 + x * 4 + 3 - 8 * fld(z * 8 + x * 4, 8),
),
)
ck1 = IntSetChecker()
ck1.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 2))
ck1.verify(flm(y, 8), {y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)}, (x * 4, x * 4 + 3))
def test_max_min():
ck = IntSetChecker()
x, y = te.var("x"), te.var("y")
ck.verify(tvm.te.max(x, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (1, 11))
ck.verify(tvm.te.min(x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 9))
ck.verify(tvm.te.min(x, y), {}, (tvm.te.min(x, y), tvm.te.min(x, y)))
ck.verify(tvm.te.max(x, y), {}, (tvm.te.max(x, y), tvm.te.max(x, y)))
def test_select():
ck = IntSetChecker()
# x, y = te.var("x"), te.var("y")
x = te.var("x")
ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 11))
def check_region_bound(expect_region, var_dom, mode, predicate=None):
"""Helper to check region bound estimation.
Parameters
----------
expect_region: dict
The keys are of form (begin, end) or PrimExpr as a single point. The values are
expected estimated region or region dict on different bindings.
var_dom: dict
Map var to iteration domain range.
mode: str
Specify "lowerbound", "upperbound" or else use strict bound estimation.
predicate: PrimExpr
Extra predicate, defaults to True.
"""
if predicate is None:
predicate = tvm.tir.IntImm("bool", 1)
region = []
expect = []
for k, v in expect_region.items():
if not isinstance(k, (tuple, list)):
k = (k, k + 1)
region.append(tvm.ir.Range.from_min_extent(k[0], Analyzer().simplify(k[1] - k[0])))
expect.append(v)
if mode == "lowerbound":
result = tvm.arith.estimate_region_lower_bound(region=region, var_dom=var_dom, predicate=predicate)
elif mode == "upperbound":
result = tvm.arith.estimate_region_upper_bound(region=region, var_dom=var_dom, predicate=predicate)
else:
result = tvm.arith.estimate_region_strict_bound(region=region, var_dom=var_dom, predicate=predicate)
if result is None:
assert all([_ is None for _ in expect])
return
assert len(result) == len(expect)
for intset, expect_desc in zip(result, expect):
if isinstance(expect_desc, dict):
# check range on different free var bindings
for binding in expect_desc:
analyzer = Analyzer()
for k, v in binding:
analyzer.bind(k, v)
expect_begin, expect_end = expect_desc[binding]
result_begin = analyzer.simplify(intset.min_value, 3)
result_end = analyzer.simplify(intset.max_value + 1, 3)
assert analyzer.can_prove_equal(result_begin - expect_begin, 0), f"{result_begin} vs {expect_begin}"
assert analyzer.can_prove_equal(result_end - expect_end, 0), f"{result_end} vs {expect_end}"
else:
# check range
expect_begin, expect_end = expect_desc
analyzer = Analyzer()
assert analyzer.can_prove_equal(intset.min_value - expect_begin, 0), f"{intset.min_value} vs {expect_begin}"
assert analyzer.can_prove_equal(intset.max_value - expect_end + 1, 0), f"{intset.max_value} vs {expect_end - 1}"
def test_region_bound_not_independent():
# (i, i+2) and (i+2, i+4) are dependent, this the lowerbound is not available
i = tvm.tir.Var("i", "int32")
var_dom = {
i: tvm.ir.Range(begin=0, end=64),
}
check_region_bound({(i, i + 2): None, (i + 2, i + 4): None}, var_dom, mode="lowerbound")
check_region_bound({(i, i + 2): (0, 65), (i + 2, i + 4): (2, 67)}, var_dom, mode="upperbound")
# when only a subset of access indices are affine
i, j, k = tvm.tir.Var("i", "int32"), tvm.tir.Var("j", "int32"), tvm.tir.Var("k", "int32")
var_dom = {
i: tvm.ir.Range(begin=0, end=16),
j: tvm.ir.Range(begin=0, end=16),
k: tvm.ir.Range(begin=0, end=16),
}
check_region_bound(
{i // 4: None, j * 4 + i % 4: None, tir.truncdiv(k, 2): None},
var_dom,
predicate=j * 4 + i % 4 > 3,
mode="lowerbound",
)
check_region_bound(
{i // 4: (0, 4), j * 4 + i % 4: (4, 64), tir.truncdiv(k, 2): (0, 8)},
var_dom,
predicate=j * 4 + i % 4 > 3,
mode="upperbound",
)
def test_region_bound_stride_too_wide():
i = tvm.tir.Var("i", "int32")
var_dom = {i: tvm.ir.Range(begin=0, end=64)}
check_region_bound({(i * 4, i * 4 + 2): None}, var_dom, mode="lowerbound")
check_region_bound({(i * 4, i * 4 + 2): (0, 254)}, var_dom, mode="upperbound")
def test_region_bound_small_stride():
i = tvm.tir.Var("i", "int32")
var_dom = {
i: tvm.ir.Range(begin=0, end=64),
}
check_region_bound({(i * 4, i * 4 + 8): (0, 260)}, var_dom, mode="lowerbound")
def test_region_lower_bound_split_predicate():
x_o = tvm.tir.Var("xo", "int32")
x_i = tvm.tir.Var("xi", "int32")
x = x_o * 4 + x_i
var_dom = {
x_o: tvm.ir.Range(begin=0, end=16),
x_i: tvm.ir.Range(begin=0, end=4),
}
check_region_bound({(x * 4, x * 4 + 8): (0, 256)}, var_dom, predicate=x < 63, mode="lowerbound")
check_region_bound(
{(x * 4, x * 4 + 8): (0, 256), (x * 3, x * 3 + 5): (0, 191)},
var_dom,
predicate=x < 63,
mode="upperbound",
)
def test_region_lower_bound_multiple_variables():
div = tvm.tir.floordiv
mod = tvm.tir.floormod
x = tvm.tir.Var("x", "int32")
wid = tvm.tir.Var("wid", "int32")
i = div(x, 16)
j = div(mod(x, 16), 4) * 8 + mod(x, 4) + div(wid, 32) * 4
k = wid % 32
var_dom = {
x: tvm.ir.Range(begin=0, end=32),
wid: tvm.ir.Range(begin=0, end=64),
}
check_region_bound({i: (0, 2), j: (0, 32), k: (0, 32)}, var_dom, mode="lowerbound")
def test_region_lower_bound_negative_scale():
i = tvm.tir.Var("i", "int32")
j = tvm.tir.Var("j", "int32")
var_dom = {
i: tvm.ir.Range(begin=0, end=4),
j: tvm.ir.Range(begin=0, end=4),
}
check_region_bound({(1 - i, 5 - i): (-2, 5), (20 - j * 4, 36 - j * 4): (8, 36)}, var_dom, mode="lowerbound")
def test_region_lower_bound_for_non_perfect_tile():
h1 = tvm.tir.Var("h1", "int32")
h2 = tvm.tir.Var("h2", "int32")
h3 = tvm.tir.Var("h3", "int32")
# non-uniform tiling, single inner variable
var_dom = {
h2: tvm.ir.Range(begin=0, end=10),
}
check_region_bound(
{
h3 * 8 + h2: {
(): (
tvm.tir.max(h3 * 8, 1),
tvm.tir.min(0, h3 * 8 - 214) + 224,
),
((h3, 0),): (1, 10), # h3 == 0: region is [1, 10)
((h3, 10),): (h3 * 8, h3 * 8 + 10), # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 10)
((h3, 27),): (h3 * 8, 224), # h3 > 26: region is [h3 * 8, 224)
}
},
var_dom,
predicate=tvm.tir.all(h3 * 8 + h2 >= 1, h3 * 8 + h2 < 224),
mode="lowerbound",
)
# non-uniform tiling, two inner variables
var_dom = {
h1: tvm.ir.Range(begin=0, end=5),
h2: tvm.ir.Range(begin=0, end=2),
}
check_region_bound(
{
h3 * 8 + h2 * 5 + h1: {
(): (
tvm.tir.max(h3 * 8, 1),
tvm.tir.min(0, h3 * 8 - 214) + 224,
),
((h3, 0),): (1, 10),
((h3, 10),): (h3 * 8, h3 * 8 + 10),
((h3, 27),): (h3 * 8, 224),
}
},
var_dom,
predicate=tvm.tir.all(h3 * 8 + h2 * 5 + h1 >= 1, h3 * 8 + h2 * 5 + h1 < 224),
mode="lowerbound",
)
# lowerbound should fail on incompatible predicates
check_region_bound(
{h3 * 8 + h2 * 5 + h1: None},
var_dom,
predicate=tvm.tir.all(h3 * 8 + h2 * 5 + h1 >= 1, h3 * 8 + h1 * 2 + h2 < 224),
mode="lowerbound",
)
check_region_bound(
{h3 * 8 + h2 * 5 + h1: (h3 * 8, h3 * 8 + 10)},
var_dom,
predicate=tvm.tir.all(h3 * 8 + h2 * 5 + h1 >= 1, h3 * 8 + h1 * 2 + h2 < 224),
mode="upperbound",
)
def test_region_lower_bound_unfusable():
var_dom = {
tvm.tir.Var("i", "int32"): tvm.ir.Range(8),
tvm.tir.Var("j", "int32"): tvm.ir.Range(4),
}
i, j = var_dom
check_region_bound({(i + j) // 2: (0, 6)}, var_dom, mode="lowerbound")
def test_union_lower_bound():
neg_inf = tvm.arith.int_set.neg_inf()
pos_inf = tvm.arith.int_set.pos_inf()
set_0 = tvm.arith.IntervalSet(min_value=neg_inf, max_value=0)
set_1 = tvm.arith.IntervalSet(min_value=1, max_value=pos_inf)
result = tvm.arith.int_set.union_lower_bound([set_0, set_1])
assert result.min_value.same_as(neg_inf)
assert result.max_value.same_as(pos_inf)
set_2 = tvm.arith.IntervalSet(min_value=pos_inf, max_value=neg_inf)
result = tvm.arith.int_set.union_lower_bound([set_0, set_1, set_2])
assert result.min_value.same_as(neg_inf)
assert result.max_value.same_as(pos_inf)
def test_modular_set():
ck = IntSetChecker()
x = tvm.te.var("x", dtype="int32")
y = tvm.te.var("y", dtype="int32")
expr = (x * 2048 + y * 16) % 7168
ck.verify(expr, {x: tvm.arith.IntervalSet(0, 128), y: tvm.arith.IntervalSet(0, 3584)}, (0, 7152))
if __name__ == "__main__":
tvm.testing.main()
This diff is collapsed.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from tilelang import tvm
import tilelang.testing
from tvm import tir
import tvm.ir
def test_simplify_reshape_flattened_index():
ana = tvm.arith.Analyzer()
i0 = tir.Var("i0", "int64")
i1 = tir.Var("i1", "int64")
ana.bind(i0, tvm.ir.Range(0, 8))
ana.bind(i1, tvm.ir.Range(0, 3))
i_flattened = i0 * 3 + i1
tvm.ir.assert_structural_equal(
ana.simplify((i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + (i_flattened) % 4),
i_flattened,
)
dtype = tvm.testing.parameter(
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"float16",
"float32",
"float64",
)
def test_can_prove_self_identity(dtype):
ana = tvm.arith.Analyzer()
n = tir.Var("n", dtype)
assert ana.can_prove(n == n)
def test_can_prove_self_equal_to_self(dtype):
ana = tvm.arith.Analyzer()
n = tir.Var("n", dtype)
assert ana.can_prove_equal(n, n)
def test_simplify_symbolic_comparison():
ana = tvm.arith.Analyzer()
i0 = tir.Var("i0", "int64")
i1 = tir.Var("i1", "int64")
n, m = tvm.tir.SizeVar("n", "int64"), tvm.tir.SizeVar("m", "int64")
outer = (n + 31) // 32
ana.bind(i0, tvm.ir.Range(0, outer))
ana.bind(i1, tvm.ir.Range(0, 32))
PS = tvm.arith.ProofStrength
assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)
assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32 + m, PS.SYMBOLIC_BOUND)
assert ana.can_prove(i0 * 32 + i1 + 1 <= (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)
assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1 + 1, PS.SYMBOLIC_BOUND)
assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, PS.SYMBOLIC_BOUND)
def test_regression_simplify_inf_recursion():
ana = tvm.arith.Analyzer()
cond = tir.Var("cond", "int32")
res = (tvm.tir.NE(cond, 0).astype("int8") - tvm.tir.NE(cond, 0).astype("int8")).astype("int32") == 0
# regression in a previous case
# try compare and int set recursive call can cause infinite loop
ana.rewrite_simplify(res)
def test_simplify_floor_mod_with_linear_offset():
"""
Test that the floor_mod is simplified correctly when the offset is linear.
"""
ana = tvm.arith.Analyzer()
past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", "int64")
expr1 = (past_decoder_sequence_length + 1) * 64
divisor1 = (past_decoder_sequence_length + 1) * 32
assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor1), 0)
divisor2 = 32 * (past_decoder_sequence_length + 1)
assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor2), 0)
def test_simplify_float_division():
# Test for the discussion:
# https://discuss.tvm.apache.org/t/discuss-is-constant-division-to-multiplication-rewrite-in-tvm-necessary/18615
ana = tvm.arith.Analyzer()
x = tir.Var("x", "float32")
ry = x / 27
# in old version, the division will be rewritten into x * T.float32(1 / 27)
sy = ana.rewrite_simplify(ry)
tvm.ir.assert_structural_equal(ry, sy)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -216,6 +216,8 @@ def run_gemm_sp(
print("pass")
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
def run_gemm_sp_sm90(
M,
N,
......@@ -228,8 +230,8 @@ def run_gemm_sp_sm90(
block_K,
num_stages,
num_threads,
trans_A=False,
trans_B=False,
trans_A,
trans_B,
):
kernel = matmul_sp_sm90(
M,
......@@ -259,6 +261,9 @@ def run_gemm_sp_sm90(
)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(8, 0)
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def run_gemm_sp_sm80(
M,
N,
......@@ -271,8 +276,8 @@ def run_gemm_sp_sm80(
block_K,
num_stages,
num_threads,
trans_A=False,
trans_B=False,
trans_A,
trans_B,
):
kernel = matmul_sp_sm80(
M,
......
......@@ -41,34 +41,35 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
def issue_1013_buggy_kernel():
# NOTE: This kernel is mainly to test some corner cases in boundary check
num_tokens = T.dynamic("num_tokens")
num_threads = 128
@T.prim_func
def main(x: T.Tensor((num_tokens,), dtype="int64")):
with T.Kernel(1, threads=num_threads) as _:
count = T.alloc_var("int")
thread_idx = T.get_thread_binding()
for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
idx = thread_idx + i * num_threads
count += x[idx] == 2
# NOTE(chaofan): Ideally, the prover should be able to prove that the access is safe
# and the padding value is not used. However, the current prover cannot handle this case.
# So for now the expected kernel is a if-else statement to check the boundary.
@T.prim_func
def expected(x: T.Tensor((num_tokens,), dtype="int64")):
with T.Kernel(1, threads=num_threads) as _:
count = T.alloc_var("int")
thread_idx = T.get_thread_binding()
for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
idx = thread_idx + i * num_threads
count += T.Cast("int32", T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2))
return main, expected
# def issue_1013_buggy_kernel():
# # NOTE: This kernel is mainly to test some corner cases in boundary check
# num_tokens = T.dynamic('num_tokens')
# num_threads = 128
# @T.prim_func
# def main(x: T.Tensor((num_tokens,), dtype="int64")):
# with T.Kernel(1, threads=num_threads) as _:
# count = T.alloc_var('int')
# thread_idx = T.get_thread_binding()
# for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
# idx = thread_idx + i * num_threads
# count += x[idx] == 2
# # NOTE(chaofan): Ideally, the prover should be able to prove that the access is safe
# # and the padding value is not used. However, the current prover cannot handle this case.
# # So for now the expected kernel is a if-else statement to check the boundary.
# @T.prim_func
# def expected(x: T.Tensor((num_tokens,), dtype="int64")):
# with T.Kernel(1, threads=num_threads) as _:
# count = T.alloc_var('int')
# thread_idx = T.get_thread_binding()
# for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
# idx = thread_idx + i * num_threads
# count += T.Cast("int32",
# value=T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2))
# return main, expected
def vectorize_access_with_atmoic_add_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2):
......@@ -151,11 +152,11 @@ def test_vectorize_access():
assert_vectorize_access(64, 64)
def test_issue_1013():
func, expected = issue_1013_buggy_kernel()
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
transformed = tl.transform.LegalizeSafeMemoryAccess()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
# def test_issue_1013():
# func, expected = issue_1013_buggy_kernel()
# mod = tvm.IRModule({func.attrs["global_symbol"]: func})
# transformed = tl.transform.LegalizeSafeMemoryAccess()(mod)
# tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
def test_vectorize_access_with_atmoic_add():
......
......@@ -243,8 +243,8 @@ def reduce_bitxor(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: boo
@macro
def cumsum_fragment(
src: tir.Buffer | tir.BufferRegion | tir.BufferLoad,
dst: tir.Buffer | tir.BufferRegion | tir.BufferLoad,
src: tir.Buffer,
dst: tir.Buffer,
dim: int,
reverse: bool,
) -> tir.PrimExpr:
......
......@@ -21,7 +21,6 @@ def _get_cached_lib():
try:
return _import_module_from_library(name, _CACHE_DIR, is_python_module=True)
except Exception:
# If loading fails, recompile
pass
# Set TORCH_CUDA_ARCH_LIST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment