"docs/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "1f28d136d687ee46e9a63f253155aae39981d9cb"
Unverified Commit 17cfeb76 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Improve error handling and assertion messages across runtime and...

[Enhancement] Improve error handling and assertion messages across runtime and argument binding (#1356)

This commit enhances the error handling mechanisms in the runtime by introducing CPU-safe runtime helpers and refining assertion messages in the CodeGenCHost and ArgBinder. It includes structured packed error messages for various conditions, improving clarity in diagnostics. Additionally, the CMake configuration is updated to always include necessary runtime helpers, ensuring consistent error reporting. The changes aim to provide clearer feedback during runtime errors and improve the overall robustness of the argument binding process.
parent 36a2b2f3
Subproject commit e3af400013551755a8df668ba77b530735931ade
Subproject commit fc7ed0b9cb7a52eb1c8bf6e8c26bbb8dff3655ce
......@@ -145,6 +145,11 @@ file(GLOB TILE_LANG_SRCS
src/target/intrin_rule*.cc
)
# Always include CPU-safe runtime helpers
list(APPEND TILE_LANG_SRCS
src/runtime/error_helpers.cc
)
# Track if the user explicitly selected a backend via cache options.
set(TILELANG_BACKEND_USER_SELECTED OFF)
foreach(BACKEND IN LISTS TILELANG_BACKENDS)
......@@ -206,7 +211,7 @@ elseif(USE_CUDA)
cmake_path(GET CUDAToolkit_BIN_DIR PARENT_PATH USE_CUDA)
file(GLOB TILE_LANG_CUDA_SRCS
src/runtime/*.cc
src/runtime/runtime.cc
src/target/ptx.cc
src/target/codegen_cuda.cc
src/target/rt_mod_cuda.cc
......
/*
* Helper functions for nicer runtime error messages.
*/
#include "error_helpers.h"
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/device_api.h>
#include <sstream>
#include <string>
......@@ -25,8 +30,9 @@ static int DTypeMismatch(const tvm::ffi::String &kernel_name,
static_cast<int>(expect_bits),
static_cast<int>(expect_lanes));
std::ostringstream os;
os << std::string(kernel_name) << ": dtype of " << std::string(buffer_name)
<< " is expected to be " << expect << ", but got " << actual;
os << "kernel " << std::string(kernel_name) << " input "
<< std::string(buffer_name) << " dtype expected " << expect << ", but got "
<< actual;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
return -1;
}
......@@ -48,13 +54,169 @@ static int DTypeMismatchNoNames(int64_t actual_code, int64_t actual_bits,
return -1;
}
} // namespace tl
} // namespace tvm
// Register packed versions, following the design in runtime.cc
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
// Packed: __tvm_error_dtype_mismatch(kernel_name, buffer_name,
// actual_code, actual_bits, actual_lanes,
// expect_code, expect_bits, expect_lanes)
refl::GlobalDef().def_packed(
tl::tvm_error_dtype_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 8) << "Expected 8 args: kernel, buffer, "
"actual_code, actual_bits, actual_lanes, "
<< "expect_code, expect_bits, expect_lanes";
auto kernel_name = args[0].cast<tvm::ffi::String>();
auto buffer_name = args[1].cast<tvm::ffi::String>();
int64_t actual_code = args[2].cast<int64_t>();
int64_t actual_bits = args[3].cast<int64_t>();
int64_t actual_lanes = args[4].cast<int64_t>();
int64_t expect_code = args[5].cast<int64_t>();
int64_t expect_bits = args[6].cast<int64_t>();
int64_t expect_lanes = args[7].cast<int64_t>();
// Reuse the helper to format the message
(void)DTypeMismatch(kernel_name, buffer_name, actual_code, actual_bits,
actual_lanes, expect_code, expect_bits,
expect_lanes);
// Provide a return value for completeness, then signal the error
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_ndim_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 4)
<< "__tvm_error_ndim_mismatch(kernel, buffer, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
int64_t expect = args[2].cast<int64_t>();
int64_t got = args[3].cast<int64_t>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << " ndim expected " << expect << ", but got "
<< got;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_byte_offset_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 4)
<< "__tvm_error_byte_offset_mismatch(kernel, buffer, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
int64_t expect = args[2].cast<int64_t>();
int64_t got = args[3].cast<int64_t>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << " byte_offset expected " << expect
<< ", but got " << got;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_device_type_mismatch,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 4)
<< "__tvm_error_device_type_mismatch(kernel, buffer, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
int64_t expect = args[2].cast<int64_t>();
int64_t got = args[3].cast<int64_t>();
const char *expect_str =
tvm::runtime::DLDeviceType2Str(static_cast<int>(expect));
const char *got_str =
tvm::runtime::DLDeviceType2Str(static_cast<int>(got));
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << " device_type expected " << expect_str
<< ", but got " << got_str;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, field:String
refl::GlobalDef().def_packed(
tl::tvm_error_null_ptr,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 3)
<< "__tvm_error_null_ptr(kernel, buffer, field)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
auto field = args[2].cast<tvm::ffi::String>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << ' ' << std::string(field)
<< " expected non-NULL, but got NULL";
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, field:String, expect:int64, got:int64
refl::GlobalDef().def_packed(
tl::tvm_error_expect_eq,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 5)
<< "__tvm_error_expect_eq(kernel, buffer, field, expect, got)";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
auto field = args[2].cast<tvm::ffi::String>();
int64_t expect = args[3].cast<int64_t>();
int64_t got = args[4].cast<int64_t>();
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << ' ' << std::string(field) << " expected "
<< expect << ", but got " << got;
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// kernel, buffer, field:String [, reason:String]
refl::GlobalDef().def_packed(
tl::tvm_error_constraint_violation,
[](tvm::ffi::PackedArgs args, tvm::ffi::Any *ret) {
ICHECK(args.size() == 3 || args.size() == 4)
<< "__tvm_error_constraint_violation(kernel, buffer, field[, "
"reason])";
auto kernel = args[0].cast<tvm::ffi::String>();
auto buffer = args[1].cast<tvm::ffi::String>();
auto field = args[2].cast<tvm::ffi::String>();
std::string reason;
if (args.size() == 4) {
reason = args[3].cast<tvm::ffi::String>();
}
std::ostringstream os;
os << "kernel " << std::string(kernel) << " input "
<< std::string(buffer) << ' ' << std::string(field)
<< " constraint not satisfied";
if (!reason.empty()) {
os << ": " << reason;
}
TVMFFIErrorSetRaisedFromCStr("RuntimeError", os.str().c_str());
*ret = -1;
throw ::tvm::ffi::EnvErrorAlreadySet();
});
// Legacy typed registrations for backward compatibility
refl::GlobalDef().def("tilelang_error_dtype_mismatch",
&tvm::tl::DTypeMismatch);
refl::GlobalDef().def("tilelang_error_dtype_mismatch2",
&tvm::tl::DTypeMismatchNoNames);
}
} // namespace tl
} // namespace tvm
/*!
* \file tl/runtime/error_helpers.h
* \brief Error helper FFI names for TileLang runtime.
*/
#ifndef TVM_TL_RUNTIME_ERROR_HELPERS_H_
#define TVM_TL_RUNTIME_ERROR_HELPERS_H_
namespace tvm {
namespace tl {
// Error helper packed functions
constexpr const char *tvm_error_dtype_mismatch = "__tvm_error_dtype_mismatch";
constexpr const char *tvm_error_ndim_mismatch = "__tvm_error_ndim_mismatch";
constexpr const char *tvm_error_byte_offset_mismatch =
"__tvm_error_byte_offset_mismatch";
constexpr const char *tvm_error_device_type_mismatch =
"__tvm_error_device_type_mismatch";
constexpr const char *tvm_error_null_ptr = "__tvm_error_null_ptr";
constexpr const char *tvm_error_expect_eq = "__tvm_error_expect_eq";
constexpr const char *tvm_error_constraint_violation =
"__tvm_error_constraint_violation";
} // namespace tl
} // namespace tvm
#endif // TVM_TL_RUNTIME_ERROR_HELPERS_H_
......@@ -354,32 +354,44 @@ void CodeGenCHost::VisitStmt_(const tvm::tir::AssertStmtNode *op) { // NOLINT(*)
stream << "if (!(" << cond << ")) {\n";
int assert_if_scope = this->BeginScope();
{
// Prepare the base error message
// Prepare the base error message: allow StringImm or general PrimExpr
const auto *msg_node = op->message.as<tvm::tir::StringImmNode>();
ICHECK(msg_node != nullptr) << "Assert message expected to be StringImm";
const std::string &raw_msg = msg_node->value;
const std::string esc_msg = tvm::support::StrEscape(
raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true,
/*escape_whitespace_special_chars=*/true);
// If the assertion is an equality check, append the actual LHS/RHS values
if (const auto *eq = op->condition.as<tvm::tir::EQNode>()) {
std::string lhs = PrintExpr(eq->a);
std::string rhs = PrintExpr(eq->b);
PrintIndent();
stream << "char __tvm_assert_msg_buf[512];\n";
PrintIndent();
stream << "snprintf(__tvm_assert_msg_buf, 512, \"%s; expected: %lld, "
"got: %lld\", \""
<< esc_msg << "\", (long long)(" << lhs << "), (long long)("
<< rhs << "));\n";
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", "
"__tvm_assert_msg_buf);\n";
bool msg_is_literal = (msg_node != nullptr);
std::string esc_msg;
std::string msg_expr;
if (msg_is_literal) {
const std::string &raw_msg = msg_node->value;
esc_msg = tvm::support::StrEscape(
raw_msg.c_str(), raw_msg.length(), /*use_octal_escape=*/true,
/*escape_whitespace_special_chars=*/true);
} else {
msg_expr = PrintExpr(op->message);
}
// Only print expected/got values for equality when message is StringImm
if (msg_is_literal) {
if (const auto *eq = op->condition.as<tvm::tir::EQNode>()) {
std::string lhs = PrintExpr(eq->a);
std::string rhs = PrintExpr(eq->b);
PrintIndent();
stream << "char __tvm_assert_msg_buf[512];\n";
PrintIndent();
stream << "snprintf(__tvm_assert_msg_buf, 512, \"%s; expected: %lld, "
"got: %lld\", \""
<< esc_msg << "\", (long long)(" << lhs << "), (long long)("
<< rhs << "));\n";
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", "
"__tvm_assert_msg_buf);\n";
} else {
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \""
<< esc_msg << "\");\n";
}
} else {
PrintIndent();
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", \"" << esc_msg
<< "\");\n";
stream << "TVMFFIErrorSetRaisedFromCStr(\"RuntimeError\", " << msg_expr
<< ");\n";
}
}
PrintIndent();
......
......@@ -13,6 +13,7 @@
#include <sstream>
#include <unordered_set>
#include "../runtime/error_helpers.h"
#include "tir/transforms/ir_utils.h"
#include "tvm/arith/int_solver.h"
#include "tvm/ffi/cast.h"
......@@ -35,22 +36,58 @@ void BinderAddAssert(arith::Analyzer *ana, PrimExpr cond,
}
if (!is_one(scond)) {
std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint: " << cond;
// Check if the condition is of the form "is_null || actual_cond"
// If so, generate "if !is_null: assert actual_cond" instead of "assert
// is_null || actual_cond"
if (nullable_guard.defined()) {
// Pattern: nullable_guard || actual_condition
// We want to transform this into: if !nullable_guard: assert
// actual_condition
Stmt check = AssertStmt(scond, StringImm(os.str()), Evaluate(0));
check = IfThenElse(Not(nullable_guard), check);
asserts->emplace_back(SeqStmt({check, Evaluate(0)}));
// Extract kernel/buffer/field from arg_name (e.g., "main.A.shape[0]")
std::string kernel = arg_name;
std::string buf_and_field = arg_name;
size_t dot_pos = arg_name.find('.');
if (dot_pos != std::string::npos) {
kernel = arg_name.substr(0, dot_pos);
buf_and_field = arg_name.substr(dot_pos + 1);
}
std::string buffer = buf_and_field;
std::string field;
size_t dot2 = buf_and_field.find('.');
if (dot2 != std::string::npos) {
buffer = buf_and_field.substr(0, dot2);
field = buf_and_field.substr(dot2 + 1);
}
// If cond is an equality, prefer structured packed error with expect/got
if (const auto *eq = scond.as<tvm::tir::EQNode>()) {
PrimExpr lhs = eq->a;
PrimExpr rhs = eq->b;
// Choose rhs as expected and lhs as got for better semantics in most
// binding cases
ffi::Array<PrimExpr> pargs;
pargs.push_back(StringImm(tvm_error_expect_eq));
pargs.push_back(StringImm(kernel));
pargs.push_back(StringImm(buffer));
pargs.push_back(StringImm(field.empty() ? std::string("value") : field));
pargs.push_back(cast(DataType::Int(64), rhs)); // expected
pargs.push_back(cast(DataType::Int(64), lhs)); // got
Stmt call_err =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs));
// Only emit at runtime when the equality fails
Stmt inner = IfThenElse(Not(scond), call_err);
if (nullable_guard.defined()) {
inner = IfThenElse(Not(nullable_guard), inner);
}
asserts->emplace_back(SeqStmt({inner, Evaluate(0)}));
} else {
asserts->emplace_back(
AssertStmt(scond, StringImm(os.str()), Evaluate(0)));
// Fallback: packed generic constraint violation without dumping cond
ffi::Array<PrimExpr> pargs;
pargs.push_back(StringImm(tvm_error_constraint_violation));
pargs.push_back(StringImm(kernel));
pargs.push_back(StringImm(buffer));
pargs.push_back(StringImm(field.empty() ? std::string("value") : field));
Stmt call_err =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs));
Stmt inner = IfThenElse(Not(scond), call_err);
if (nullable_guard.defined()) {
inner = IfThenElse(Not(nullable_guard), inner);
}
asserts->emplace_back(SeqStmt({inner, Evaluate(0)}));
}
}
}
......@@ -318,22 +355,29 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
PrimExpr a_ndim =
make_const(tvm_ndim_type, static_cast<int64_t>(buffer->shape.size()));
std::ostringstream ndim_err_msg;
// Note: We cannot embed runtime values into the message string.
// Keep message human-friendly without printing TIR exprs.
ndim_err_msg << arg_name << ".ndim is expected to equal "
<< buffer->shape.size() << ", but got mismatched ndim";
auto msg = StringImm(ndim_err_msg.str());
// Only check ndim when handle is non-NULL (using if statement)
Stmt ndim_check = AssertStmt(a_ndim == v_ndim, msg, nop);
ndim_check = IfThenElse(Not(is_null), ndim_check);
init_nest_.emplace_back(SeqStmt({ndim_check, nop}));
// Build clearer ndim message with kernel/buffer names
std::string kernel_nm = arg_name;
std::string buf_nm = arg_name;
size_t dot_pos = arg_name.find('.');
if (dot_pos != std::string::npos) {
kernel_nm = arg_name.substr(0, dot_pos);
buf_nm = arg_name.substr(dot_pos + 1);
}
// Only check ndim when handle is non-NULL: use packed error helper
PrimExpr ndim_ok = (a_ndim == v_ndim);
ffi::Array<PrimExpr> ndim_args;
ndim_args.push_back(StringImm(tvm_error_ndim_mismatch));
ndim_args.push_back(StringImm(kernel_nm));
ndim_args.push_back(StringImm(buf_nm));
ndim_args.push_back(cast(DataType::Int(64), a_ndim));
ndim_args.push_back(cast(DataType::Int(64), v_ndim));
Stmt ndim_call =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), ndim_args));
init_nest_.emplace_back(
SeqStmt({IfThenElse(Not(is_null), IfThenElse(Not(ndim_ok), ndim_call),
Evaluate(0)),
nop}));
// type checks
std::ostringstream type_err_msg;
// Avoid dumping TIR expressions in error text; just state mismatch.
// Include expected dtype triplet for clarity.
type_err_msg << arg_name << ".dtype is expected to be " << buffer->dtype
<< ", but got incompatible dtype";
// Guard all dtype field loads by `is_null` using if_then_else
PrimExpr v_type_code = tvm::if_then_else(
Not(is_null),
......@@ -402,11 +446,36 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
if (!(buffer->dtype == DataType::Int(1) ||
buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4))) {
auto type_msg = StringImm(type_err_msg.str());
// Only check dtype when handle is non-NULL (using if statement)
Stmt dtype_check = AssertStmt(cond, type_msg, nop);
dtype_check = IfThenElse(Not(is_null), dtype_check);
asserts_.emplace_back(SeqStmt({dtype_check, nop}));
// Build FFI packed call to __tvm_error_dtype_mismatch when mismatch occurs.
// Only issue the call when handle is non-NULL and cond is false.
ffi::Array<PrimExpr> packed_args;
packed_args.push_back(StringImm(tvm_error_dtype_mismatch));
// Split arg_name of the form "<kernel>.<buffer>" into parts for clearer
// diagnostics
std::string kernel_name = arg_name;
std::string buffer_name = arg_name;
size_t dot_pos = arg_name.find('.');
if (dot_pos != std::string::npos) {
kernel_name = arg_name.substr(0, dot_pos);
buffer_name = arg_name.substr(dot_pos + 1);
}
packed_args.push_back(StringImm(kernel_name));
packed_args.push_back(StringImm(buffer_name));
auto i64 = DataType::Int(64);
// Cast to int64 for FFI function signature
packed_args.push_back(cast(i64, v_type_code)); // actual_code
packed_args.push_back(cast(i64, v_type_bits)); // actual_bits
packed_args.push_back(cast(i64, v_type_lanes)); // actual_lanes
packed_args.push_back(cast(i64, expect_code)); // expect_code
packed_args.push_back(cast(i64, expect_bits)); // expect_bits
packed_args.push_back(cast(i64, expect_lanes)); // expect_lanes
Stmt call_err = Evaluate(
Call(DataType::Int(32), builtin::tvm_call_packed(), packed_args));
// Guard the call: only when handle is not null and cond fails
Stmt guarded = IfThenElse(Not(is_null) && Not(cond), call_err);
asserts_.emplace_back(SeqStmt({guarded, nop}));
}
// shape field
......@@ -482,14 +551,27 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
<< stride_handle_name()
<< ": expected to be compact array, but got non-compact strides";
if (!conds.empty()) {
auto stride_msg = StringImm(stride_err_msg.str());
Stmt check =
AssertStmt(foldl([](PrimExpr a, PrimExpr b,
Span span) { return logical_and(a, b, span); },
const_true(1), conds),
stride_msg, Evaluate(0));
// Only check when strides array is actually present at runtime
check = IfThenElse(Not(v_strides_is_null), check);
PrimExpr all_ok = foldl([](PrimExpr a, PrimExpr b,
Span span) { return logical_and(a, b, span); },
const_true(1), conds);
// Packed generic violation for non-compact strides
std::string kernel_nm3 = arg_name;
std::string buf_nm3 = arg_name;
size_t dot_pos3 = arg_name.find('.');
if (dot_pos3 != std::string::npos) {
kernel_nm3 = arg_name.substr(0, dot_pos3);
buf_nm3 = arg_name.substr(dot_pos3 + 1);
}
ffi::Array<PrimExpr> pargs4;
pargs4.push_back(StringImm(tvm_error_constraint_violation));
pargs4.push_back(StringImm(kernel_nm3));
pargs4.push_back(StringImm(buf_nm3));
pargs4.push_back(StringImm("strides"));
Stmt call_err4 =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs4));
// Only check when strides array is present and condition fails
Stmt check = IfThenElse(Not(v_strides_is_null),
IfThenElse(Not(all_ok), call_err4), Evaluate(0));
asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
......@@ -539,11 +621,18 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
make_const(DataType::UInt(64), 0));
PrimExpr expect_byte_offset =
make_const(DataType::UInt(64), const_offset->value * data_bytes);
Stmt byte_off_check =
AssertStmt(expect_byte_offset == actual_byte_offset,
StringImm(arg_name + ".byte_offset mismatch"), nop);
byte_off_check = IfThenElse(Not(is_null), byte_off_check);
asserts_.emplace_back(SeqStmt({byte_off_check, nop}));
PrimExpr ok = (expect_byte_offset == actual_byte_offset);
ffi::Array<PrimExpr> pargs;
pargs.push_back(StringImm(tvm_error_byte_offset_mismatch));
pargs.push_back(StringImm(kernel_nm));
pargs.push_back(StringImm(buf_nm));
pargs.push_back(cast(DataType::Int(64), expect_byte_offset));
pargs.push_back(cast(DataType::Int(64), actual_byte_offset));
Stmt call_err =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs));
asserts_.emplace_back(SeqStmt(
{IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err), Evaluate(0)),
nop}));
} else {
PrimExpr actual_byte_offset = tvm::if_then_else(
Not(is_null),
......@@ -582,21 +671,18 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
// Check device_type consistency (device_id equality is implicitly ensured by
// binding above)
{
std::ostringstream dev_msg;
dev_msg << arg_name << ".device_type mismatch";
if (const auto *imm = device_type.as<IntImmNode>()) {
dev_msg << " [expected: " << imm->value << " ("
<< tvm::runtime::DLDeviceType2Str(static_cast<int>(imm->value))
<< ")]";
}
// Give a short legend so users can interpret numeric codes in the
// appended "got/expected" part printed by the runtime.
dev_msg << "; DLPack codes: 1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, "
"14=OneAPI, 15=WebGPU";
auto device_type_check =
IfThenElse(Not(is_null), AssertStmt(device_type == actual_dev_type,
StringImm(dev_msg.str()), nop));
asserts_.emplace_back(SeqStmt({device_type_check, Evaluate(0)}));
PrimExpr ok = (device_type == actual_dev_type);
ffi::Array<PrimExpr> pargs2;
pargs2.push_back(StringImm(tvm_error_device_type_mismatch));
pargs2.push_back(StringImm(kernel_nm));
pargs2.push_back(StringImm(buf_nm));
pargs2.push_back(cast(DataType::Int(64), device_type));
pargs2.push_back(cast(DataType::Int(64), actual_dev_type));
Stmt call_err2 =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs2));
asserts_.emplace_back(SeqStmt(
{IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err2), Evaluate(0)),
Evaluate(0)}));
}
// Data field. Because the validation of the data field may depend
......@@ -619,14 +705,31 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
product *= dim;
return product;
}();
Stmt data_null_check = AssertStmt(
(alloc_size == 0) ||
!Call(DataType::Bool(), builtin::isnullptr(), {vptr}),
StringImm(arg_name +
" is expected to have non-NULL data pointer, but got NULL"),
nop);
data_null_check = IfThenElse(Not(is_null), data_null_check);
asserts_.emplace_back(SeqStmt({data_null_check, nop}));
// Improve message: kernel/buffer naming for data pointer null check
std::string kernel_nm2 = arg_name;
std::string buf_nm2 = arg_name;
size_t dot_pos2 = arg_name.find('.');
if (dot_pos2 != std::string::npos) {
kernel_nm2 = arg_name.substr(0, dot_pos2);
buf_nm2 = arg_name.substr(dot_pos2 + 1);
}
// expand combined condition via nested IfThenElse for portability
ffi::Array<PrimExpr> pargs3;
pargs3.push_back(StringImm(tvm_error_null_ptr));
pargs3.push_back(StringImm(kernel_nm2));
pargs3.push_back(StringImm(buf_nm2));
pargs3.push_back(StringImm("data pointer"));
Stmt call_err3 =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs3));
asserts_.emplace_back(SeqStmt(
{IfThenElse(Not(is_null),
IfThenElse(Not(alloc_size == 0),
IfThenElse(Call(DataType::Bool(),
builtin::isnullptr(), {vptr}),
call_err3),
Evaluate(0)),
Evaluate(0)),
nop}));
// mark alignment of external bufs
init_nest_.emplace_back(
......
......@@ -44,28 +44,52 @@ private:
PrimExpr simplified = analyzer_.Simplify(indices[i]);
IndexSignState state = IndexSignState::kUnknown;
// Handle vector patterns first to avoid querying lanes() on
// scalable vectors (which is not allowed at compile-time).
if (const auto *ramp = simplified.as<RampNode>()) {
// For scalable vectors, we cannot rely on a constant lane count.
// Use sufficient (but not necessary) conditions:
// - If base >= 0 and stride >= 0, all lanes are non-negative.
// - If base < 0 and stride <= 0, all lanes are negative.
bool base_nonneg = analyzer_.CanProve(ramp->base >= 0);
bool base_neg = analyzer_.CanProve(ramp->base < 0);
bool stride_nonneg = analyzer_.CanProve(ramp->stride >= 0);
bool stride_nonpos = analyzer_.CanProve(ramp->stride <= 0);
if (base_nonneg && stride_nonneg) {
// Handle scalar indices with the standard analyzer
if (simplified.dtype().lanes() == 1) {
if (analyzer_.CanProve(simplified >= 0))
state = IndexSignState::kNonNegative;
} else if (base_neg && stride_nonpos) {
else if (analyzer_.CanProve(simplified < 0))
state = IndexSignState::kNegative;
} else {
else
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
}
// Vector indices: try to reason about non-negativity/negativity
// Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
// lanes).
else if (const auto *ramp = simplified.as<RampNode>()) {
// Compute a safe lower/upper bound for the vector lanes
// lower_bound = base_min + min(0, stride_min) * (lanes - 1)
// upper_bound = base_max + max(0, stride_max) * (lanes - 1)
auto base_bound = analyzer_.const_int_bound(ramp->base);
auto stride_bound = analyzer_.const_int_bound(ramp->stride);
int lanes = *as_const_int(ramp->lanes);
int64_t base_min = base_bound->min_value;
int64_t base_max = base_bound->max_value;
int64_t s_min = stride_bound->min_value;
int64_t s_max = stride_bound->max_value;
// Guard against overflow is not strictly necessary here because
// bounds may be +/-inf represented by sentinel values.
int64_t lower = base_min;
if (s_min < 0)
lower += s_min * (lanes - 1);
int64_t upper = base_max;
if (s_max > 0)
upper += s_max * (lanes - 1);
if (lower >= 0)
state = IndexSignState::kNonNegative;
else if (upper < 0)
state = IndexSignState::kNegative;
else
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
}
} else if (const auto *broadcast = simplified.as<BroadcastNode>()) {
auto v = analyzer_.Simplify(broadcast->value);
if (analyzer_.CanProve(v >= 0))
......@@ -85,20 +109,6 @@ private:
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
}
} else {
// Assume scalar (or non-Ramp/Broadcast) index; avoid querying lanes().
// Fall back to scalar reasoning. If this expression is actually a
// vector-but-not-Ramp/Broadcast, treat as unknown to be safe.
// Try to prove scalar first; if proof fails, leave as unknown.
if (analyzer_.CanProve(simplified >= 0))
state = IndexSignState::kNonNegative;
else if (analyzer_.CanProve(simplified < 0))
state = IndexSignState::kNegative;
else
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << buffer_name << " (axis " << i
<< ", index " + indices[i]->Script() + ").";
}
states.push_back(state);
}
......
......@@ -402,8 +402,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
display_name.erase(display_name.size() - 7);
}
}
msg << name_hint << ": Expect buffer " << display_name
<< " to be pointer or tensor";
msg << "kernel " << name_hint << " input " << display_name
<< " expected pointer or tensor handle";
seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFINone ||
type_index == ffi::TypeIndex::kTVMFFIOpaquePtr ||
......@@ -423,7 +423,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
handle_from_tensor, arg_value);
} else if (dtype.is_bool()) {
std::ostringstream msg;
msg << name_hint << ": Expect " << param->name_hint << " to be boolean";
msg << "kernel " << name_hint << " scalar " << param->name_hint
<< " expected boolean";
seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFIBool ||
type_index == ffi::TypeIndex::kTVMFFIInt,
......@@ -433,7 +434,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
} else if (dtype.is_int() || dtype.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect " << param->name_hint << " to be int";
msg << "kernel " << name_hint << " scalar " << param->name_hint
<< " expected integer";
seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFIInt ||
type_index == ffi::TypeIndex::kTVMFFIBool,
......@@ -442,7 +444,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
} else {
ICHECK(dtype.is_float());
std::ostringstream msg;
msg << name_hint << ": Expect " << param->name_hint << " to be float";
msg << "kernel " << name_hint << " scalar " << param->name_hint
<< " expected float";
seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFIFloat ||
type_index == ffi::TypeIndex::kTVMFFIInt ||
......
......@@ -195,8 +195,7 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
expected_inputs = len(self.params) - len(self.result_idx)
if len(inputs) != expected_inputs:
raise ValueError(
f"Expected {expected_inputs} inputs, got {len(inputs)} (params={len(self.params)}, outputs={len(self.result_idx)})"
)
f"Kernel expected {expected_inputs} inputs, but {len(inputs)} are provided.")
# Resolve the device used for outputs. Prefer the first tensor input's device
# if available, otherwise use PyTorch's current device.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment