Unverified Commit 17cfeb76 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Improve error handling and assertion messages across runtime and...

[Enhancement] Improve error handling and assertion messages across runtime and argument binding (#1356)

This commit enhances the error handling mechanisms in the runtime by introducing CPU-safe runtime helpers and refining assertion messages in the CodeGenCHost and ArgBinder. It includes structured packed error messages for various conditions, improving clarity in diagnostics. Additionally, the CMake configuration is updated to always include necessary runtime helpers, ensuring consistent error reporting. The changes aim to provide clearer feedback during runtime errors and improve the overall robustness of the argument binding process.
parent 36a2b2f3
Subproject commit e3af400013551755a8df668ba77b530735931ade Subproject commit fc7ed0b9cb7a52eb1c8bf6e8c26bbb8dff3655ce
...@@ -145,6 +145,11 @@ file(GLOB TILE_LANG_SRCS ...@@ -145,6 +145,11 @@ file(GLOB TILE_LANG_SRCS
src/target/intrin_rule*.cc src/target/intrin_rule*.cc
) )
# Always include CPU-safe runtime helpers
list(APPEND TILE_LANG_SRCS
src/runtime/error_helpers.cc
)
# Track if the user explicitly selected a backend via cache options. # Track if the user explicitly selected a backend via cache options.
set(TILELANG_BACKEND_USER_SELECTED OFF) set(TILELANG_BACKEND_USER_SELECTED OFF)
foreach(BACKEND IN LISTS TILELANG_BACKENDS) foreach(BACKEND IN LISTS TILELANG_BACKENDS)
...@@ -206,7 +211,7 @@ elseif(USE_CUDA) ...@@ -206,7 +211,7 @@ elseif(USE_CUDA)
cmake_path(GET CUDAToolkit_BIN_DIR PARENT_PATH USE_CUDA) cmake_path(GET CUDAToolkit_BIN_DIR PARENT_PATH USE_CUDA)
file(GLOB TILE_LANG_CUDA_SRCS file(GLOB TILE_LANG_CUDA_SRCS
src/runtime/*.cc src/runtime/runtime.cc
src/target/ptx.cc src/target/ptx.cc
src/target/codegen_cuda.cc src/target/codegen_cuda.cc
src/target/rt_mod_cuda.cc src/target/rt_mod_cuda.cc
......
/* /*
* Helper functions for nicer runtime error messages. * Helper functions for nicer runtime error messages.
*/ */
#include "error_helpers.h"
#include <tvm/ffi/c_api.h> #include <tvm/ffi/c_api.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/data_type.h> #include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
#include <sstream> #include <sstream>
#include <string> #include <string>
...@@ -25,8 +30,9 @@ static int DTypeMismatch(const tvm::ffi::String &kernel_name, ...@@ -25,8 +30,9 @@ static int DTypeMismatch(const tvm::ffi::String &kernel_name,
static_cast<int>(expect_bits), static_cast<int>(expect_bits),
static_cast<int>(expect_lanes)); static_cast<int>(expect_lanes));
std::ostringstream os; std::ostringstream os;
os << std::string(kernel_name) << ": dtype of " << std::string(buffer_name) os << "kernel " << std::string(kernel_name) << " input "
<< " is expected to be " << expect << ", but got " << actual; << std::string(buffer_name) << " dtype expected " << expect << ", but got "
<< actual;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str()); TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
return -1; return -1;
} }
...@@ -48,13 +54,169 @@ static int DTypeMismatchNoNames(int64_t actual_code, int64_t actual_bits, ...@@ -48,13 +54,169 @@ static int DTypeMismatchNoNames(int64_t actual_code, int64_t actual_bits,
return -1; return -1;
} }
} // namespace tl // Register packed versions, following the design in runtime.cc
} // namespace tvm
TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
// Packed: __tvm_error_dtype_mismatch(kernel_name, buffer_name,
// actual_code, actual_bits, actual_lanes,
// expect_code, expect_bits, expect_lanes)
refl::GlobalDef().def_packed(
tl::tvm_error_dtype_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 8) << "Expected 8 args: kernel, buffer, "
"actual_code, actual_bits, actual_lanes, "
<< "expect_code, expect_bits, expect_lanes";
auto kernel_name = args[0].cast<tvm::ffi::String>();
auto buffer_name = args[1].cast<tvm::ffi::String>();
int64_t actual_code = args[2].cast<int64_t>();
int64_t actual_bits = args[3].cast<int64_t>();
int64_t actual_lanes = args[4].cast<int64_t>();
int64_t expect_code = args[5].cast<int64_t>();
int64_t expect_bits = args[6].cast<int64_t>();
int64_t expect_lanes = args[7].cast<int64_t>();
// Reuse the helper to format the message
(void)DTypeMismatch(kernel_name, buffer_name, actual_code, actual_bits,
actual_lanes, expect_code, expect_bits,
expect_lanes);
// Provide a return value for completeness, then signal the error
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_ndim_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 4)
<< "__tvm_error_ndim_mismatch(kernel, buffer, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
int64_t expect = args[2].cast<int64_t>();
int64_t got = args[3].cast<int64_t>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << " ndim expected " << expect << ", but got "
<< got;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_byte_offset_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 4)
<< "__tvm_error_byte_offset_mismatch(kernel, buffer, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
int64_t expect = args[2].cast<int64_t>();
int64_t got = args[3].cast<int64_t>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << " byte_offset expected " << expect
<< ", but got " << got;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_device_type_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 4)
<< "__tvm_error_device_type_mismatch(kernel, buffer, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
int64_t expect = args[2].cast<int64_t>();
int64_t got = args[3].cast<int64_t>();
const char *expect_str =
tvm::runtime::DLDeviceType2Str(static_cast<int>(expect));
const char *got_str =
tvm::runtime::DLDeviceType2Str(static_cast<int>(got));
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << " device_type expected " << expect_str
<< ", but got " << got_str;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, field:String
refl::GlobalDef().def_packed(
tl::tvm_error_null_ptr,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 3)
<< "__tvm_error_null_ptr(kernel, buffer, field)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
auto field = args[2].cast<tvm::ffi::String>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << ' ' << std::string(field)
<< " expected non-NULL, but got NULL";
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, field:String, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_expect_eq,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 5)
<< "__tvm_error_expect_eq(kernel, buffer, field, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
auto field = args[2].cast<tvm::ffi::String>();
int64_t expect = args[3].cast<int64_t>();
int64_t got = args[4].cast<int64_t>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << ' ' << std::string(field) << " expected "
<< expect << ", but got " << got;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, field:String [, reason:String]
refl::GlobalDef().def_packed(
tl::tvm_error_constraint_violation,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 3 || args.size() == 4)
<< "__tvm_error_constraint_violation(kernel, buffer, field[, "
"reason])";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
auto field = args[2].cast<tvm::ffi::String>();
std::string reason;
if (args.size() == 4) {
reason = args[3].cast<tvm::ffi::String>();
}
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << ' ' << std::string(field)
<< " constraint not satisfied";
if (!reason.empty()) {
os << ": " << reason;
}
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// Legacy typed registrations for backward compatibility
refl::GlobalDef().def("tilelang_error_dtype_mismatch", refl::GlobalDef().def("tilelang_error_dtype_mismatch",
&tvm::tl::DTypeMismatch); &tvm::tl::DTypeMismatch);
refl::GlobalDef().def("tilelang_error_dtype_mismatch2", refl::GlobalDef().def("tilelang_error_dtype_mismatch2",
&tvm::tl::DTypeMismatchNoNames); &tvm::tl::DTypeMismatchNoNames);
} }
} // namespace tl
} // namespace tvm
/*!
* \file tl/runtime/error_helpers.h
* \brief Error helper FFI names for TileLang runtime.
*/
#ifndef TVM_TL_RUNTIME_ERROR_HELPERS_H_
#define TVM_TL_RUNTIME_ERROR_HELPERS_H_
namespace tvm {
namespace tl {
// Error helper packed functions
constexpr const char *tvm_error_dtype_mismatch = "__tvm_error_dtype_mismatch";
constexpr const char *tvm_error_ndim_mismatch = "__tvm_error_ndim_mismatch";
constexpr const char *tvm_error_byte_offset_mismatch =
"__tvm_error_byte_offset_mismatch";
constexpr const char *tvm_error_device_type_mismatch =
"__tvm_error_device_type_mismatch";
constexpr const char *tvm_error_null_ptr = "__tvm_error_null_ptr";
constexpr const char *tvm_error_expect_eq = "__tvm_error_expect_eq";
constexpr const char *tvm_error_constraint_violation =
"__tvm_error_constraint_violation";
} // namespace tl
} // namespace tvm
#endif // TVM_TL_RUNTIME_ERROR_HELPERS_H_
...@@ -354,32 +354,44 @@ void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*) ...@@ -354,32 +354,44 @@ void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*)
stream << "if (!(" << cond << ")) {\n"; stream << "if (!(" << cond << ")) {\n";
int assert_if_scope = this->BeginScope(); int assert_if_scope = this->BeginScope();
{ {
// Prepare the base error message // Prepare the base error message: allow StringImm or general PrimExpr
const auto *msg_node = op->message.as<tvm::tir::StringImmNode>(); const auto *msg_node = op->message.as<tvm::tir::StringImmNode>();
ICHECK(msg_node != nullptr) << "Assert message expected to be StringImm"; bool msg_is_literal = (msg_node != nullptr);
const std::string &raw_msg = msg_node->value; std::string esc_msg;
const std::string esc_msg = tvm::support::StrEscape( std::string msg_expr;
raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true, if (msg_is_literal) {
/*escape_whitespace_special_chars=*/true); const std::string &raw_msg = msg_node->value;
esc_msg = tvm::support::StrEscape(
// If the assertion is an equality check, append the actual LHS/RHS values raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true,
if (const auto *eq = op->condition.as<tvm::tir::EQNode>()) { /*escape_whitespace_special_chars=*/true);
std::string lhs = PrintExpr(eq->a); } else {
std::string rhs = PrintExpr(eq->b); msg_expr = PrintExpr(op->message);
PrintIndent(); }
stream << "char __tvm_assert_msg_buf[512];\n";
PrintIndent(); // Only print expected/got values for equality when message is StringImm
stream << "snprintf(__tvm_assert_msg_buf, 512, \"%s; expected: %lld, " if (msg_is_literal) {
"got: %lld\", \"" if (const auto *eq = op->condition.as<tvm::tir::EQNode>()) {
<< esc_msg << "\", (long long)(" << lhs << "), (long long)(" std::string lhs = PrintExpr(eq->a);
<< rhs << "));\n"; std::string rhs = PrintExpr(eq->b);
PrintIndent(); PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", " stream << "char __tvm_assert_msg_buf[512];\n";
"__tvm_assert_msg_buf);\n"; PrintIndent();
stream << "snprintf(__tvm_assert_msg_buf, 512, \"%s; expected: %lld, "
"got: %lld\", \""
<< esc_msg << "\", (long long)(" << lhs << "), (long long)("
<< rhs << "));\n";
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", "
"__tvm_assert_msg_buf);\n";
} else {
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \""
<< esc_msg << "\");\n";
}
} else { } else {
PrintIndent(); PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" << esc_msg stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", " << msg_expr
<< "\");\n"; << ");\n";
} }
} }
PrintIndent(); PrintIndent();
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <sstream> #include <sstream>
#include <unordered_set> #include <unordered_set>
#include "../runtime/error_helpers.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
#include "tvm/arith/int_solver.h" #include "tvm/arith/int_solver.h"
#include "tvm/ffi/cast.h" #include "tvm/ffi/cast.h"
...@@ -35,22 +36,58 @@ void BinderAddAssert(arith::Analyzer *ana, PrimExpr cond, ...@@ -35,22 +36,58 @@ void BinderAddAssert(arith::Analyzer *ana, PrimExpr cond,
} }
if (!is_one(scond)) { if (!is_one(scond)) {
std::ostringstream os; // Extract kernel/buffer/field from arg_name (e.g., "main.A.shape[0]")
os << "Argument " << arg_name << " has an unsatisfied constraint: " << cond; std::string kernel = arg_name;
std::string buf_and_field = arg_name;
// Check if the condition is of the form "is_null || actual_cond" size_t dot_pos = arg_name.find('.');
// If so, generate "if !is_null: assert actual_cond" instead of "assert if (dot_pos != std::string::npos) {
// is_null || actual_cond" kernel = arg_name.substr(0, dot_pos);
if (nullable_guard.defined()) { buf_and_field = arg_name.substr(dot_pos + 1);
// Pattern: nullable_guard || actual_condition }
// We want to transform this into: if !nullable_guard: assert std::string buffer = buf_and_field;
// actual_condition std::string field;
Stmt check = AssertStmt(scond, StringImm(os.str()), Evaluate(0)); size_t dot2 = buf_and_field.find('.');
check = IfThenElse(Not(nullable_guard), check); if (dot2 != std::string::npos) {
asserts->emplace_back(SeqStmt({check, Evaluate(0)})); buffer = buf_and_field.substr(0, dot2);
field = buf_and_field.substr(dot2 + 1);
}
// If cond is an equality, prefer structured packed error with expect/got
if (const auto *eq = scond.as<tvm::tir::EQNode>()) {
PrimExpr lhs = eq->a;
PrimExpr rhs = eq->b;
// Choose rhs as expected and lhs as got for better semantics in most
// binding cases
ffi::Array<PrimExpr> pargs;
pargs.push_back(StringImm(tvm_error_expect_eq));
pargs.push_back(StringImm(kernel));
pargs.push_back(StringImm(buffer));
pargs.push_back(StringImm(field.empty() ? std::string("value") : field));
pargs.push_back(cast(DataType::Int(64), rhs)); // expected
pargs.push_back(cast(DataType::Int(64), lhs)); // got
Stmt call_err =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs));
// Only emit at runtime when the equality fails
Stmt inner = IfThenElse(Not(scond), call_err);
if (nullable_guard.defined()) {
inner = IfThenElse(Not(nullable_guard), inner);
}
asserts->emplace_back(SeqStmt({inner, Evaluate(0)}));
} else { } else {
asserts->emplace_back( // Fallback: packed generic constraint violation without dumping cond
AssertStmt(scond, StringImm(os.str()), Evaluate(0))); ffi::Array<PrimExpr> pargs;
pargs.push_back(StringImm(tvm_error_constraint_violation));
pargs.push_back(StringImm(kernel));
pargs.push_back(StringImm(buffer));
pargs.push_back(StringImm(field.empty() ? std::string("value") : field));
Stmt call_err =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs));
Stmt inner = IfThenElse(Not(scond), call_err);
if (nullable_guard.defined()) {
inner = IfThenElse(Not(nullable_guard), inner);
}
asserts->emplace_back(SeqStmt({inner, Evaluate(0)}));
} }
} }
} }
...@@ -318,22 +355,29 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, ...@@ -318,22 +355,29 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
PrimExpr a_ndim = PrimExpr a_ndim =
make_const(tvm_ndim_type, static_cast<int64_t>(buffer->shape.size())); make_const(tvm_ndim_type, static_cast<int64_t>(buffer->shape.size()));
std::ostringstream ndim_err_msg; // Build clearer ndim message with kernel/buffer names
// Note: We cannot embed runtime values into the message string. std::string kernel_nm = arg_name;
// Keep message human-friendly without printing TIR exprs. std::string buf_nm = arg_name;
ndim_err_msg << arg_name << ".ndim is expected to equal " size_t dot_pos = arg_name.find('.');
<< buffer->shape.size() << ", but got mismatched ndim"; if (dot_pos != std::string::npos) {
auto msg = StringImm(ndim_err_msg.str()); kernel_nm = arg_name.substr(0, dot_pos);
// Only check ndim when handle is non-NULL (using if statement) buf_nm = arg_name.substr(dot_pos + 1);
Stmt ndim_check = AssertStmt(a_ndim == v_ndim, msg, nop); }
ndim_check = IfThenElse(Not(is_null), ndim_check); // Only check ndim when handle is non-NULL: use packed error helper
init_nest_.emplace_back(SeqStmt({ndim_check, nop})); PrimExpr ndim_ok = (a_ndim == v_ndim);
ffi::Array<PrimExpr> ndim_args;
ndim_args.push_back(StringImm(tvm_error_ndim_mismatch));
ndim_args.push_back(StringImm(kernel_nm));
ndim_args.push_back(StringImm(buf_nm));
ndim_args.push_back(cast(DataType::Int(64), a_ndim));
ndim_args.push_back(cast(DataType::Int(64), v_ndim));
Stmt ndim_call =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), ndim_args));
init_nest_.emplace_back(
SeqStmt({IfThenElse(Not(is_null), IfThenElse(Not(ndim_ok), ndim_call),
Evaluate(0)),
nop}));
// type checks // type checks
std::ostringstream type_err_msg;
// Avoid dumping TIR expressions in error text; just state mismatch.
// Include expected dtype triplet for clarity.
type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype
<< ", but got incompatible dtype";
// Guard all dtype field loads by `is_null` using if_then_else // Guard all dtype field loads by `is_null` using if_then_else
PrimExpr v_type_code = tvm::if_then_else( PrimExpr v_type_code = tvm::if_then_else(
Not(is_null), Not(is_null),
...@@ -402,11 +446,36 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, ...@@ -402,11 +446,36 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
if (!(buffer->dtype == DataType::Int(1) || if (!(buffer->dtype == DataType::Int(1) ||
buffer->dtype == DataType::Int(4) || buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4))) { buffer->dtype == DataType::UInt(4))) {
auto type_msg = StringImm(type_err_msg.str()); // Build FFI packed call to __tvm_error_dtype_mismatch when mismatch occurs.
// Only check dtype when handle is non-NULL (using if statement) // Only issue the call when handle is non-NULL and cond is false.
Stmt dtype_check = AssertStmt(cond, type_msg, nop); ffi::Array<PrimExpr> packed_args;
dtype_check = IfThenElse(Not(is_null), dtype_check); packed_args.push_back(StringImm(tvm_error_dtype_mismatch));
asserts_.emplace_back(SeqStmt({dtype_check, nop})); // Split arg_name of the form "<kernel>.<buffer>" into parts for clearer
// diagnostics
std::string kernel_name = arg_name;
std::string buffer_name = arg_name;
size_t dot_pos = arg_name.find('.');
if (dot_pos != std::string::npos) {
kernel_name = arg_name.substr(0, dot_pos);
buffer_name = arg_name.substr(dot_pos + 1);
}
packed_args.push_back(StringImm(kernel_name));
packed_args.push_back(StringImm(buffer_name));
auto i64 = DataType::Int(64);
// Cast to int64 for FFI function signature
packed_args.push_back(cast(i64, v_type_code)); // actual_code
packed_args.push_back(cast(i64, v_type_bits)); // actual_bits
packed_args.push_back(cast(i64, v_type_lanes)); // actual_lanes
packed_args.push_back(cast(i64, expect_code)); // expect_code
packed_args.push_back(cast(i64, expect_bits)); // expect_bits
packed_args.push_back(cast(i64, expect_lanes)); // expect_lanes
Stmt call_err = Evaluate(
Call(DataType::Int(32), builtin::tvm_call_packed(), packed_args));
// Guard the call: only when handle is not null and cond fails
Stmt guarded = IfThenElse(Not(is_null) && Not(cond), call_err);
asserts_.emplace_back(SeqStmt({guarded, nop}));
} }
// shape field // shape field
...@@ -482,14 +551,27 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, ...@@ -482,14 +551,27 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
<< stride_handle_name() << stride_handle_name()
<< ": expected to be compact array, but got non-compact strides"; << ": expected to be compact array, but got non-compact strides";
if (!conds.empty()) { if (!conds.empty()) {
auto stride_msg = StringImm(stride_err_msg.str()); PrimExpr all_ok = foldl([](PrimExpr a, PrimExpr b,
Stmt check = Span span) { return logical_and(a, b, span); },
AssertStmt(foldl([](PrimExpr a, PrimExpr b, const_true(1), conds);
Span span) { return logical_and(a, b, span); }, // Packed generic violation for non-compact strides
const_true(1), conds), std::string kernel_nm3 = arg_name;
stride_msg, Evaluate(0)); std::string buf_nm3 = arg_name;
// Only check when strides array is actually present at runtime size_t dot_pos3 = arg_name.find('.');
check = IfThenElse(Not(v_strides_is_null), check); if (dot_pos3 != std::string::npos) {
kernel_nm3 = arg_name.substr(0, dot_pos3);
buf_nm3 = arg_name.substr(dot_pos3 + 1);
}
ffi::Array<PrimExpr> pargs4;
pargs4.push_back(StringImm(tvm_error_constraint_violation));
pargs4.push_back(StringImm(kernel_nm3));
pargs4.push_back(StringImm(buf_nm3));
pargs4.push_back(StringImm("strides"));
Stmt call_err4 =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs4));
// Only check when strides array is present and condition fails
Stmt check = IfThenElse(Not(v_strides_is_null),
IfThenElse(Not(all_ok), call_err4), Evaluate(0));
asserts_.emplace_back(SeqStmt({check, Evaluate(0)})); asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
} }
} else if (buffer->buffer_type == kAutoBroadcast) { } else if (buffer->buffer_type == kAutoBroadcast) {
...@@ -539,11 +621,18 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, ...@@ -539,11 +621,18 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
make_const(DataType::UInt(64), 0)); make_const(DataType::UInt(64), 0));
PrimExpr expect_byte_offset = PrimExpr expect_byte_offset =
make_const(DataType::UInt(64), const_offset->value * data_bytes); make_const(DataType::UInt(64), const_offset->value * data_bytes);
Stmt byte_off_check = PrimExpr ok = (expect_byte_offset == actual_byte_offset);
AssertStmt(expect_byte_offset == actual_byte_offset, ffi::Array<PrimExpr> pargs;
StringImm(arg_name + ".byte_offset mismatch"), nop); pargs.push_back(StringImm(tvm_error_byte_offset_mismatch));
byte_off_check = IfThenElse(Not(is_null), byte_off_check); pargs.push_back(StringImm(kernel_nm));
asserts_.emplace_back(SeqStmt({byte_off_check, nop})); pargs.push_back(StringImm(buf_nm));
pargs.push_back(cast(DataType::Int(64), expect_byte_offset));
pargs.push_back(cast(DataType::Int(64), actual_byte_offset));
Stmt call_err =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs));
asserts_.emplace_back(SeqStmt(
{IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err), Evaluate(0)),
nop}));
} else { } else {
PrimExpr actual_byte_offset = tvm::if_then_else( PrimExpr actual_byte_offset = tvm::if_then_else(
Not(is_null), Not(is_null),
...@@ -582,21 +671,18 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, ...@@ -582,21 +671,18 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
// Check device_type consistency (device_id equality is implicitly ensured by // Check device_type consistency (device_id equality is implicitly ensured by
// binding above) // binding above)
{ {
std::ostringstream dev_msg; PrimExpr ok = (device_type == actual_dev_type);
dev_msg << arg_name << ".device_type mismatch"; ffi::Array<PrimExpr> pargs2;
if (const auto *imm = device_type.as<IntImmNode>()) { pargs2.push_back(StringImm(tvm_error_device_type_mismatch));
dev_msg << " [expected: " << imm->value << " (" pargs2.push_back(StringImm(kernel_nm));
<< tvm::runtime::DLDeviceType2Str(static_cast<int>(imm->value)) pargs2.push_back(StringImm(buf_nm));
<< ")]"; pargs2.push_back(cast(DataType::Int(64), device_type));
} pargs2.push_back(cast(DataType::Int(64), actual_dev_type));
// Give a short legend so users can interpret numeric codes in the Stmt call_err2 =
// appended "got/expected" part printed by the runtime. Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs2));
dev_msg << "; DLPack codes: 1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, " asserts_.emplace_back(SeqStmt(
"14=OneAPI, 15=WebGPU"; {IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err2), Evaluate(0)),
auto device_type_check = Evaluate(0)}));
IfThenElse(Not(is_null), AssertStmt(device_type == actual_dev_type,
StringImm(dev_msg.str()), nop));
asserts_.emplace_back(SeqStmt({device_type_check, Evaluate(0)}));
} }
// Data field. Because the validation of the data field may depend // Data field. Because the validation of the data field may depend
...@@ -619,14 +705,31 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, ...@@ -619,14 +705,31 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
product *= dim; product *= dim;
return product; return product;
}(); }();
Stmt data_null_check = AssertStmt( // Improve message: kernel/buffer naming for data pointer null check
(alloc_size == 0) || std::string kernel_nm2 = arg_name;
!Call(DataType::Bool(), builtin::isnullptr(), {vptr}), std::string buf_nm2 = arg_name;
StringImm(arg_name + size_t dot_pos2 = arg_name.find('.');
" is expected to have non-NULL data pointer, but got NULL"), if (dot_pos2 != std::string::npos) {
nop); kernel_nm2 = arg_name.substr(0, dot_pos2);
data_null_check = IfThenElse(Not(is_null), data_null_check); buf_nm2 = arg_name.substr(dot_pos2 + 1);
asserts_.emplace_back(SeqStmt({data_null_check, nop})); }
// expand combined condition via nested IfThenElse for portability
ffi::Array<PrimExpr> pargs3;
pargs3.push_back(StringImm(tvm_error_null_ptr));
pargs3.push_back(StringImm(kernel_nm2));
pargs3.push_back(StringImm(buf_nm2));
pargs3.push_back(StringImm("data pointer"));
Stmt call_err3 =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs3));
asserts_.emplace_back(SeqStmt(
{IfThenElse(Not(is_null),
IfThenElse(Not(alloc_size == 0),
IfThenElse(Call(DataType::Bool(),
builtin::isnullptr(), {vptr}),
call_err3),
Evaluate(0)),
Evaluate(0)),
nop}));
// mark alignment of external bufs // mark alignment of external bufs
init_nest_.emplace_back( init_nest_.emplace_back(
......
...@@ -44,28 +44,52 @@ private: ...@@ -44,28 +44,52 @@ private:
PrimExpr simplified = analyzer_.Simplify(indices[i]); PrimExpr simplified = analyzer_.Simplify(indices[i]);
IndexSignState state = IndexSignState::kUnknown; IndexSignState state = IndexSignState::kUnknown;
// Handle vector patterns first to avoid querying lanes() on // Handle scalar indices with the standard analyzer
// scalable vectors (which is not allowed at compile-time). if (simplified.dtype().lanes() == 1) {
if (const auto *ramp = simplified.as<RampNode>()) { if (analyzer_.CanProve(simplified >= 0))
// For scalable vectors, we cannot rely on a constant lane count.
// Use sufficient (but not necessary) conditions:
// - If base >= 0 and stride >= 0, all lanes are non-negative.
// - If base < 0 and stride <= 0, all lanes are negative.
bool base_nonneg = analyzer_.CanProve(ramp->base >= 0);
bool base_neg = analyzer_.CanProve(ramp->base < 0);
bool stride_nonneg = analyzer_.CanProve(ramp->stride >= 0);
bool stride_nonpos = analyzer_.CanProve(ramp->stride <= 0);
if (base_nonneg && stride_nonneg) {
state = IndexSignState::kNonNegative; state = IndexSignState::kNonNegative;
} else if (base_neg && stride_nonpos) { else if (analyzer_.CanProve(simplified < 0))
state = IndexSignState::kNegative; state = IndexSignState::kNegative;
} else { else
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
}
// Vector indices: try to reason about non-negativity/negativity
// Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
// lanes).
else if (const auto *ramp = simplified.as<RampNode>()) {
// Compute a safe lower/upper bound for the vector lanes
// lower_bound = base_min + min(0, stride_min) * (lanes - 1)
// upper_bound = base_max + max(0, stride_max) * (lanes - 1)
auto base_bound = analyzer_.const_int_bound(ramp->base);
auto stride_bound = analyzer_.const_int_bound(ramp->stride);
int lanes = *as_const_int(ramp->lanes);
int64_t base_min = base_bound->min_value;
int64_t base_max = base_bound->max_value;
int64_t s_min = stride_bound->min_value;
int64_t s_max = stride_bound->max_value;
// Guard against overflow is not strictly necessary here because
// bounds may be +/-inf represented by sentinel values.
int64_t lower = base_min;
if (s_min < 0)
lower += s_min * (lanes - 1);
int64_t upper = base_max;
if (s_max > 0)
upper += s_max * (lanes - 1);
if (lower >= 0)
state = IndexSignState::kNonNegative;
else if (upper < 0)
state = IndexSignState::kNegative;
else
DLOG(WARNING) DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index " << "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i << simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ")."; << ", index " + indices[i]->Script() + ").";
}
} else if (const auto *broadcast = simplified.as<BroadcastNode>()) { } else if (const auto *broadcast = simplified.as<BroadcastNode>()) {
auto v = analyzer_.Simplify(broadcast->value); auto v = analyzer_.Simplify(broadcast->value);
if (analyzer_.CanProve(v >= 0)) if (analyzer_.CanProve(v >= 0))
...@@ -85,20 +109,6 @@ private: ...@@ -85,20 +109,6 @@ private:
<< simplified << " for buffer " << buffer_name << " (axis " << i << simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ")."; << ", index " + indices[i]->Script() + ").";
} }
} else {
// Assume scalar (or non-Ramp/Broadcast) index; avoid querying lanes().
// Fall back to scalar reasoning. If this expression is actually a
// vector-but-not-Ramp/Broadcast, treat as unknown to be safe.
// Try to prove scalar first; if proof fails, leave as unknown.
if (analyzer_.CanProve(simplified >= 0))
state = IndexSignState::kNonNegative;
else if (analyzer_.CanProve(simplified < 0))
state = IndexSignState::kNegative;
else
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
} }
states.push_back(state); states.push_back(state);
} }
......
...@@ -402,8 +402,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -402,8 +402,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
display_name.erase(display_name.size() - 7); display_name.erase(display_name.size() - 7);
} }
} }
msg << name_hint << ": Expect buffer " << display_name msg << "kernel " << name_hint << " input " << display_name
<< " to be pointer or tensor"; << " expected pointer or tensor handle";
seq_init.emplace_back( seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFINone || AssertStmt(type_index == ffi::TypeIndex::kTVMFFINone ||
type_index == ffi::TypeIndex::kTVMFFIOpaquePtr || type_index == ffi::TypeIndex::kTVMFFIOpaquePtr ||
...@@ -423,7 +423,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -423,7 +423,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
handle_from_tensor, arg_value); handle_from_tensor, arg_value);
} else if (dtype.is_bool()) { } else if (dtype.is_bool()) {
std::ostringstream msg; std::ostringstream msg;
msg << name_hint << ": Expect " << param->name_hint << " to be boolean"; msg << "kernel " << name_hint << " scalar " << param->name_hint
<< " expected boolean";
seq_init.emplace_back( seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFIBool || AssertStmt(type_index == ffi::TypeIndex::kTVMFFIBool ||
type_index == ffi::TypeIndex::kTVMFFIInt, type_index == ffi::TypeIndex::kTVMFFIInt,
...@@ -433,7 +434,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -433,7 +434,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
} else if (dtype.is_int() || dtype.is_uint()) { } else if (dtype.is_int() || dtype.is_uint()) {
std::ostringstream msg; std::ostringstream msg;
msg << name_hint << ": Expect " << param->name_hint << " to be int"; msg << "kernel " << name_hint << " scalar " << param->name_hint
<< " expected integer";
seq_init.emplace_back( seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFIInt || AssertStmt(type_index == ffi::TypeIndex::kTVMFFIInt ||
type_index == ffi::TypeIndex::kTVMFFIBool, type_index == ffi::TypeIndex::kTVMFFIBool,
...@@ -442,7 +444,8 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -442,7 +444,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
} else { } else {
ICHECK(dtype.is_float()); ICHECK(dtype.is_float());
std::ostringstream msg; std::ostringstream msg;
msg << name_hint << ": Expect " << param->name_hint << " to be float"; msg << "kernel " << name_hint << " scalar " << param->name_hint
<< " expected float";
seq_init.emplace_back( seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFIFloat || AssertStmt(type_index == ffi::TypeIndex::kTVMFFIFloat ||
type_index == ffi::TypeIndex::kTVMFFIInt || type_index == ffi::TypeIndex::kTVMFFIInt ||
......
...@@ -195,8 +195,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): ...@@ -195,8 +195,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
expected_inputs = len(self.params) - len(self.result_idx) expected_inputs = len(self.params) - len(self.result_idx)
if len(inputs) != expected_inputs: if len(inputs) != expected_inputs:
raise ValueError( raise ValueError(
f"Expected {expected_inputs} inputs, got {len(inputs)} (params={len(self.params)}, outputs={len(self.result_idx)})" f"Kernel expected {expected_inputs} inputs, but {len(inputs)} are provided.")
)
# Resolve the device used for outputs. Prefer the first tensor input's device # Resolve the device used for outputs. Prefer the first tensor input's device
# if available, otherwise use PyTorch's current device. # if available, otherwise use PyTorch's current device.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment