Unverified Commit bccb6485 authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Feat] Add support for using `T.Tensor(n * 2 + 1)` in function annotation (#1285)



* [Feature] Add support for A: T.Tensor(n + 1) and A: T.Tensor(2*n)

* issue fix

* fix

* fix

* decreate nproc for debugging

---------
Co-authored-by: default avatarLei Wang <leiwang1999@outlook.com>
parent bef7e52e
...@@ -352,7 +352,7 @@ jobs: ...@@ -352,7 +352,7 @@ jobs:
uv run --no-project -m -- uv run --no-project -m --
pytest --verbose --color=yes --durations=0 --showlocals --cache-clear pytest --verbose --color=yes --durations=0 --showlocals --cache-clear
) )
"${PYTEST[@]}" --maxfail=3 --numprocesses=4 \ "${PYTEST[@]}" --maxfail=3 --numprocesses=1 \
../examples ../examples
# NVIDIA CUDA tests # NVIDIA CUDA tests
......
# ruff: noqa # ruff: noqa
import tilelang
import tilelang.testing import tilelang.testing
import topk_selector import topk_selector
......
...@@ -29,8 +29,14 @@ ...@@ -29,8 +29,14 @@
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <sstream> #include <sstream>
#include <unordered_set>
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
#include "tvm/arith/int_solver.h"
#include "tvm/ffi/cast.h"
#include "tvm/ffi/container/array.h"
#include "tvm/tir/stmt.h"
#include "tvm/tir/stmt_functor.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -51,6 +57,26 @@ void BinderAddAssert(arith::Analyzer *ana, PrimExpr cond, ...@@ -51,6 +57,26 @@ void BinderAddAssert(arith::Analyzer *ana, PrimExpr cond,
} }
} }
std::vector<Var> ArgBinder::getUndefVars(const std::vector<PrimExpr> &args) {
std::unordered_set<const VarNode *> visit;
std::vector<Var> res;
for (const auto &arg : args) {
PostOrderVisit(arg, [&](ObjectRef r) {
if (auto var = r.as<VarNode>()) {
if (!visit.count(var)) {
visit.insert(var);
}
auto it = def_map_->find(var);
if (it == def_map_->end()) {
// res.push_back(var);
res.push_back(ffi::GetRef<Var>(var));
}
}
});
}
return res;
}
bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value, bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value,
const std::string &arg_name, bool with_lets, const std::string &arg_name, bool with_lets,
const PrimExpr &nullable_guard) { const PrimExpr &nullable_guard) {
...@@ -60,20 +86,23 @@ bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value, ...@@ -60,20 +86,23 @@ bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value,
// is_null || basic // is_null || basic
return Or(nullable_guard, basic); return Or(nullable_guard, basic);
}; };
ICHECK_EQ(arg.dtype(), value.dtype()) << "arg " << arg << " value " << value; ICHECK_EQ(arg.dtype(), value.dtype()) << "arg " << arg << " value " << value;
auto BindVar = [&](const VarNode *v, PrimExpr value) {
auto v_arg = ffi::GetRef<Var>(v);
defs_.emplace_back(v_arg);
if (with_lets) {
(*def_map_)[v] = value;
init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0)));
} else {
(*def_map_)[v] = value;
}
};
// 1. simple binding var = value
if (const VarNode *v = arg.as<VarNode>()) { if (const VarNode *v = arg.as<VarNode>()) {
auto it = def_map_->find(v); auto it = def_map_->find(v);
if (it == def_map_->end()) { if (it == def_map_->end()) {
BindVar(v, value);
// First time binding: identical behavior as Bind_ // First time binding: identical behavior as Bind_
Var v_arg = Downcast<Var>(arg);
defs_.emplace_back(v_arg);
if (with_lets) {
(*def_map_)[v] = arg;
init_nest_.emplace_back(LetStmt(v_arg, value, Evaluate(0)));
} else {
(*def_map_)[v] = value;
}
return true; return true;
} else { } else {
// Second or later binding: add is_null short-circuit // Second or later binding: add is_null short-circuit
...@@ -81,7 +110,34 @@ bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value, ...@@ -81,7 +110,34 @@ bool ArgBinder::BindNullable(const PrimExpr &arg, const PrimExpr &value,
BinderAddAssert(&analyzer_, cond, arg_name, &asserts_); BinderAddAssert(&analyzer_, cond, arg_name, &asserts_);
} }
} else { } else {
// For non-Var expressions, also add is_null short-circuit // 2. complex binding expr = value
// get undefined variables
auto undefs = ffi::Array<Var>(getUndefVars({arg}));
if (!undefs.empty()) {
// if value is not integer, such as float, we are unable to solve it
if (!value.dtype().is_int() && !value.dtype().is_uint()) {
LOG(FATAL) << "Unable to solve non-integer variables " << undefs
<< " from equation `" << value << "`";
}
arith::IntConstraints constraints(undefs, {}, {arg == value});
auto sol = arith::SolveLinearEquations(constraints);
if (!sol->dst->variables.empty()) {
LOG(FATAL) << "TVM is unable to solve variables " << undefs
<< " from equation " << constraints;
}
for (const auto &v : undefs) {
auto value_opt = sol->src_to_dst.Get(v);
ICHECK(value_opt->defined())
<< "Unable to solve variable `" << v << "` from expression `"
<< (arg == value) << "`";
auto value = ffi::GetRef<PrimExpr>(sol->src_to_dst.Get(v)->get());
BindVar(v.as<VarNode>(), value);
}
}
// we must add the assert again
// because the solved expression may contain floordiv (e.g. 3 * m == n
// ==> m = n // 3) we re-compute the constraint to verify the solution
// is correct
PrimExpr cond = MakeGuarded(arg == value); PrimExpr cond = MakeGuarded(arg == value);
BinderAddAssert(&analyzer_, cond, arg_name, &asserts_); BinderAddAssert(&analyzer_, cond, arg_name, &asserts_);
} }
......
...@@ -159,6 +159,7 @@ public: ...@@ -159,6 +159,7 @@ public:
const PrimExpr &nullable_guard); const PrimExpr &nullable_guard);
private: private:
std::vector<Var> getUndefVars(const std::vector<PrimExpr> &arg);
// Internal bind function // Internal bind function
bool Bind_(const PrimExpr &arg, const PrimExpr &value, bool Bind_(const PrimExpr &arg, const PrimExpr &value,
const std::string &arg_name, bool with_lets); const std::string &arg_name, bool with_lets);
......
...@@ -91,7 +91,9 @@ def run_gemm( ...@@ -91,7 +91,9 @@ def run_gemm(
code = f"// {stramp}\n" + code code = f"// {stramp}\n" + code
return code return code
tilelang.disable_cache()
matmul_kernel = tilelang.compile(program, out_idx=-1) matmul_kernel = tilelang.compile(program, out_idx=-1)
tilelang.enable_cache()
kernel_source = matmul_kernel.get_kernel_source() kernel_source = matmul_kernel.get_kernel_source()
......
...@@ -52,68 +52,6 @@ def matmul( ...@@ -52,68 +52,6 @@ def matmul(
return main return main
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
stramp = "&*(XS)"
@tvm.register_global_func("tilelang_callback_cuda_postproc", override=True)
def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="tvm_ffi")
kernel_source = matmul_kernel.get_kernel_source()
assert stramp in kernel_source, f"Expected {stramp} in the kernel source"
def test_gemm_f16f16f16_nn():
run_gemm(
512,
1024,
768,
False,
False,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def matmu_jit_kernel( def matmu_jit_kernel(
M, M,
N, N,
......
import tilelang
import tilelang.language as T
import tilelang.testing
import torch
def test_tensor_annot_mul():
@tilelang.jit
def example_tensor_annot():
n = T.symbolic('n')
@T.prim_func
def kernel(A: T.Tensor((n * 4,), T.int32),):
with T.Kernel(1) as _:
for i in range(n * 4):
A[i] = 0
return kernel
ker = example_tensor_annot()
A = torch.arange(16, dtype=torch.int32, device='cuda')
ker(A)
expected = torch.zeros(16, dtype=torch.int32, device='cuda')
assert torch.equal(A, expected)
def test_tensor_annot_add():
@tilelang.jit
def example_tensor_annot():
n = T.symbolic('n')
@T.prim_func
def kernel(A: T.Tensor((n + 1,), T.int32),):
with T.Kernel(1) as _:
for i in range(n + 1):
A[i] = 0
return kernel
ker = example_tensor_annot()
A = torch.arange(16, dtype=torch.int32, device='cuda')
ker(A)
expected = torch.zeros(16, dtype=torch.int32, device='cuda')
assert torch.equal(A, expected)
def test_tensor_annot_mul_add():
@tilelang.jit
def example_tensor_annot():
n = T.symbolic('n')
@T.prim_func
def kernel(A: T.Tensor((n * 3 + 1,), T.int32),):
with T.Kernel(1) as _:
for i in range(n * 3 + 1):
A[i] = 0
return kernel
ker = example_tensor_annot()
A = torch.arange(16, dtype=torch.int32, device='cuda')
ker(A)
expected = torch.zeros(16, dtype=torch.int32, device='cuda')
assert torch.equal(A, expected)
if __name__ == '__main__':
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment