Unverified Commit 1e92d11c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Improve assertion handling in CodeGenCHost and ArgBinder (#1352)

* [Refactor] Improve assertion handling in CodeGenCHost and ArgBinder

This commit refines the assertion message generation in CodeGenCHost by optimizing the handling of equality checks and reducing buffer size for error messages. Additionally, it enhances the ArgBinder by introducing a nullable guard mechanism for assertions, allowing for more precise error handling when binding arguments. The changes improve the clarity and efficiency of assertion handling across the codebase.

* [Enhancement] Update matmul kernel and optimize argument binding

This commit enhances the matmul kernel by introducing additional tensor parameters and refining the pipeline stages for improved performance. It also updates the argument binding mechanism to include a flag indicating whether buffers are used, enhancing the efficiency of buffer management. Furthermore, the optimization phase in the engine is improved by adding a simplification step, ensuring better performance and clarity in the generated code.

* lint fix

* [Enhancement] Add tensor checks documentation and improve argument binding assertions

This commit introduces a new documentation page for host-side tensor checks, detailing the automatic validations performed by TileLang on kernel arguments. It enhances the ArgBinder by adding assertions for non-null pointers when arguments are used, improving error handling. Additionally, the optimization phase in the engine is updated to include a simplification step, ensuring better performance and clarity in the generated code.

* [Enhancement] Update .gitignore and refine matmul kernel for improved performance

This commit adds host checks logs to the .gitignore file to prevent unnecessary log files from being tracked. Additionally, it refines the matmul kernel by adjusting pipeline stages, updating tensor parameters, and enhancing argument handling for better performance. The changes also include improved error messages in the argument binding process, ensuring clearer diagnostics for users.

* lint fix

* lint fix

* [Refactor] Simplify tensor_null_test function and remove ptr_null_test

This commit refactors the tensor_null_test function by adding a with_bias parameter and removing the ptr_null_test function, which was previously unused. The run_test function is updated to reflect these changes, streamlining the testing process for tensor operations.

* lint fix

* fix
parent b8240b7a
...@@ -105,7 +105,7 @@ public: ...@@ -105,7 +105,7 @@ public:
*/ */
void BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, void BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
const PrimExpr &device_id, const Var &handle, const PrimExpr &device_id, const Var &handle,
const std::string &arg_name); const std::string &arg_name, bool is_used);
/*! \return The defs generated in binding. */ /*! \return The defs generated in binding. */
const std::vector<Var> &defs() const { return defs_; } const std::vector<Var> &defs() const { return defs_; }
......
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
#include "../op/builtin.h" #include "../op/builtin.h"
#include "arg_binder.h" #include "arg_binder.h"
#include "merge_if_stmt.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
namespace tvm { namespace tvm {
...@@ -297,6 +298,81 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -297,6 +298,81 @@ PrimFunc MakePackedAPI(PrimFunc func) {
std::vector<std::pair<PrimExpr, Var>> var_def; std::vector<std::pair<PrimExpr, Var>> var_def;
std::vector<std::pair<Var, Buffer>> buffer_def; std::vector<std::pair<Var, Buffer>> buffer_def;
// First, collect a reverse map from Buffer->data var to parameter var so we
// can detect whether a buffer is actually used by the function body. In
// addition, collect variables that appear in the buffer's shape/stride so we
// can consider uses of those symbols as a use of the buffer itself.
std::unordered_map<const VarNode *, const VarNode *> data_var2param;
std::unordered_map<const VarNode *, std::vector<const VarNode *>>
shape_var2params;
for (const auto &kv : func_ptr->buffer_map) {
const Var &param = kv.first;
const Buffer &buf = kv.second;
data_var2param[buf->data.get()] = param.get();
auto record_shape_vars = [&](const PrimExpr &e) {
PostOrderVisit(e, [&](const ObjectRef &n) {
if (const auto *v = n.as<VarNode>()) {
shape_var2params[v].push_back(param.get());
}
});
};
for (const PrimExpr &e : buf->shape)
record_shape_vars(e);
for (const PrimExpr &e : buf->strides)
record_shape_vars(e);
if (buf->elem_offset.defined())
record_shape_vars(buf->elem_offset);
}
// A visitor that marks a buffer as used when its underlying data var is
// referenced (e.g. BufferLoad/BufferStore or any direct var usage).
struct UsedBufferDetector : public StmtExprVisitor {
UsedBufferDetector(
const std::unordered_map<const VarNode *, const VarNode *> &data2param,
const std::unordered_map<const VarNode *, std::vector<const VarNode *>>
&shape2params)
: data2param(data2param), shape2params(shape2params) {}
void VisitExpr_(const VarNode *op) override {
auto it = data2param.find(op);
if (it != data2param.end()) {
used_params.insert(it->second);
}
auto it2 = shape2params.find(op);
if (it2 != shape2params.end()) {
for (const VarNode *p : it2->second)
used_params.insert(p);
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode *op) override {
auto it = data2param.find(op->buffer->data.get());
if (it != data2param.end()) {
used_params.insert(it->second);
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const BufferLoadNode *op) override {
auto it = data2param.find(op->buffer->data.get());
if (it != data2param.end()) {
used_params.insert(it->second);
}
StmtExprVisitor::VisitExpr_(op);
}
const std::unordered_map<const VarNode *, const VarNode *> &data2param;
const std::unordered_map<const VarNode *, std::vector<const VarNode *>>
&shape2params;
std::unordered_set<const VarNode *> used_params;
};
UsedBufferDetector detector(data_var2param, shape_var2params);
detector(func_ptr->body);
// Build the packed argument handling. While doing so, keep track of whether
// each parameter buffer is actually used. Unused input buffers can be
// nullable and do not require DLTensor field dereferences.
std::unordered_set<const VarNode *> used_param_buffers = detector.used_params;
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) { for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i]; Var param = func_ptr->params[i];
PrimExpr arg_value; PrimExpr arg_value;
...@@ -311,7 +387,23 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -311,7 +387,23 @@ PrimFunc MakePackedAPI(PrimFunc func) {
DataType dtype = param.dtype(); DataType dtype = param.dtype();
if (dtype.is_handle()) { if (dtype.is_handle()) {
std::ostringstream msg; std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be pointer"; // Prefer the Buffer name if available; otherwise, fall back to param name
// (trim _handle).
std::string display_name;
auto it_buf = func_ptr->buffer_map.find(param);
if (it_buf != func_ptr->buffer_map.end()) {
const auto &kv = *it_buf;
display_name = kv.second->data->name_hint;
} else {
display_name = param->name_hint;
const char *suffix = "_handle";
if (display_name.size() >= 7 &&
display_name.compare(display_name.size() - 7, 7, suffix) == 0) {
display_name.erase(display_name.size() - 7);
}
}
msg << name_hint << ": Expect buffer " << display_name
<< " to be pointer or tensor";
seq_init.emplace_back( seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFINone || AssertStmt(type_index == ffi::TypeIndex::kTVMFFINone ||
type_index == ffi::TypeIndex::kTVMFFIOpaquePtr || type_index == ffi::TypeIndex::kTVMFFIOpaquePtr ||
...@@ -331,7 +423,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -331,7 +423,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
handle_from_tensor, arg_value); handle_from_tensor, arg_value);
} else if (dtype.is_bool()) { } else if (dtype.is_bool()) {
std::ostringstream msg; std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be boolean"; msg << name_hint << ": Expect " << param->name_hint << " to be boolean";
seq_init.emplace_back( seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFIBool || AssertStmt(type_index == ffi::TypeIndex::kTVMFFIBool ||
type_index == ffi::TypeIndex::kTVMFFIInt, type_index == ffi::TypeIndex::kTVMFFIInt,
...@@ -341,7 +433,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -341,7 +433,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
} else if (dtype.is_int() || dtype.is_uint()) { } else if (dtype.is_int() || dtype.is_uint()) {
std::ostringstream msg; std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int"; msg << name_hint << ": Expect " << param->name_hint << " to be int";
seq_init.emplace_back( seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFIInt || AssertStmt(type_index == ffi::TypeIndex::kTVMFFIInt ||
type_index == ffi::TypeIndex::kTVMFFIBool, type_index == ffi::TypeIndex::kTVMFFIBool,
...@@ -350,7 +442,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -350,7 +442,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
} else { } else {
ICHECK(dtype.is_float()); ICHECK(dtype.is_float());
std::ostringstream msg; std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be float"; msg << name_hint << ": Expect " << param->name_hint << " to be float";
seq_init.emplace_back( seq_init.emplace_back(
AssertStmt(type_index == ffi::TypeIndex::kTVMFFIFloat || AssertStmt(type_index == ffi::TypeIndex::kTVMFFIFloat ||
type_index == ffi::TypeIndex::kTVMFFIInt || type_index == ffi::TypeIndex::kTVMFFIInt ||
...@@ -388,8 +480,11 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -388,8 +480,11 @@ PrimFunc MakePackedAPI(PrimFunc func) {
} }
for (const auto &[var, buffer] : buffer_def) { for (const auto &[var, buffer] : buffer_def) {
binder.BindDLTensor(buffer, device_type, device_id, var, // Prefer buffer data var name in diagnostics to avoid exposing low-level
name_hint + "." + var->name_hint); // handle vars
std::string display = name_hint + "." + buffer->data->name_hint;
binder.BindDLTensor(buffer, device_type, device_id, var, display,
used_param_buffers.count(var.get()));
arg_buffer_declarations.push_back(DeclBuffer(buffer, nop)); arg_buffer_declarations.push_back(DeclBuffer(buffer, nop));
} }
// reset global symbol to attach prefix // reset global symbol to attach prefix
...@@ -436,7 +531,6 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -436,7 +531,6 @@ PrimFunc MakePackedAPI(PrimFunc func) {
func_ptr->buffer_map = ffi::Map<Var, Buffer>(); func_ptr->buffer_map = ffi::Map<Var, Buffer>();
func_ptr->ret_type = PrimType(DataType::Int(32)); func_ptr->ret_type = PrimType(DataType::Int(32));
// return the function. // return the function.
return func; return func;
} }
...@@ -467,6 +561,7 @@ tvm::transform::Pass MakePackedAPI() { ...@@ -467,6 +561,7 @@ tvm::transform::Pass MakePackedAPI() {
func.CopyOnWrite()->body = body.value(); func.CopyOnWrite()->body = body.value();
} }
func = MakePackedAPI(std::move(func)); func = MakePackedAPI(std::move(func));
func = MergeIfStmtSubstitute(func);
if (!func.same_as(orig_func)) { if (!func.same_as(orig_func)) {
updates->Add(gvar, func); updates->Add(gvar, func);
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
* \brief Merge the If Stmt in SeqStmt * \brief Merge the If Stmt in SeqStmt
*/ */
#include "merge_if_stmt.h"
#include <tvm/ffi/reflection/registry.h> #include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
...@@ -20,23 +22,46 @@ using namespace tir; ...@@ -20,23 +22,46 @@ using namespace tir;
class MergeIfStmtRewriter : public StmtExprMutator { class MergeIfStmtRewriter : public StmtExprMutator {
public: public:
static PrimFunc Substitute(PrimFunc &f) { static PrimFunc Substitute(PrimFunc &f) {
auto rewriter = MergeIfStmtRewriter(); f.CopyOnWrite()->body = MergeIfStmtRewriter::Apply(f->body);
f.CopyOnWrite()->body = rewriter(f->body);
return f; return f;
} }
static Stmt Apply(Stmt stmt) {
auto rewriter = MergeIfStmtRewriter();
return rewriter(stmt);
}
private: private:
MergeIfStmtRewriter() = default; MergeIfStmtRewriter() = default;
void FlattenAppend(const Stmt &s, Array<Stmt> *out) {
if (const auto *seq = s.as<SeqStmtNode>()) {
for (const Stmt &e : seq->seq) {
FlattenAppend(e, out);
}
} else {
out->push_back(s);
}
}
Stmt VisitStmt_(const SeqStmtNode *op) final { Stmt VisitStmt_(const SeqStmtNode *op) final {
Array<Stmt> new_seq; // First, recursively flatten nested SeqStmt so that
// SeqStmt{ if, SeqStmt{ if, SeqStmt{ if } } }
// becomes a single-level sequence of [if, if, if].
Array<Stmt> flat_seq;
for (const Stmt &stmt : op->seq) {
Stmt new_stmt = this->VisitStmt(stmt);
FlattenAppend(new_stmt, &flat_seq);
}
// Then, merge consecutive IfThenElse (without else) that share the same
// condition.
Array<Stmt> new_seq;
PrimExpr current_condition; PrimExpr current_condition;
Array<Stmt> current_if_bodies; Array<Stmt> current_if_bodies;
for (const Stmt &stmt : op->seq) { for (const Stmt &stmt : flat_seq) {
Stmt new_stmt = this->VisitStmt(stmt); if (const auto *if_node = stmt.as<IfThenElseNode>()) {
if (const IfThenElseNode *if_node = new_stmt.as<IfThenElseNode>()) {
if (!if_node->else_case.defined()) { if (!if_node->else_case.defined()) {
if (current_condition.defined() && if (current_condition.defined() &&
ExprDeepEqual()(current_condition, if_node->condition)) { ExprDeepEqual()(current_condition, if_node->condition)) {
...@@ -73,7 +98,7 @@ private: ...@@ -73,7 +98,7 @@ private:
current_if_bodies.clear(); current_if_bodies.clear();
} }
new_seq.push_back(new_stmt); new_seq.push_back(stmt);
} }
if (!current_if_bodies.empty()) { if (!current_if_bodies.empty()) {
...@@ -90,6 +115,12 @@ private: ...@@ -90,6 +115,12 @@ private:
} }
}; };
PrimFunc MergeIfStmtSubstitute(PrimFunc &f) {
return MergeIfStmtRewriter::Substitute(f);
}
Stmt ApplyMergeIfStmt(Stmt stmt) { return MergeIfStmtRewriter::Apply(stmt); }
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass MergeIfStmt() { tvm::transform::Pass MergeIfStmt() {
auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) { auto pass_func = [=](PrimFunc f, const IRModule &m, const PassContext &ctx) {
......
/*!
* \file merge_if_stmt.h
* \brief Merge consecutive If statements with the same condition
*/
#ifndef TVM_TL_TRANSFORM_MERGE_IF_STMT_H_
#define TVM_TL_TRANSFORM_MERGE_IF_STMT_H_
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>
namespace tvm {
namespace tl {
using namespace tir;
// Forward declaration
class MergeIfStmtRewriter;
/*!
* \brief Apply MergeIfStmt transformation to a PrimFunc
*
* This function merges consecutive IfThenElse statements that have the same
* condition into a single if statement with a SeqStmt body.
*
* Example:
* if (cond) { stmt1 }
* if (cond) { stmt2 }
* if (cond) { stmt3 }
*
* Becomes:
* if (cond) {
* stmt1
* stmt2
* stmt3
* }
*
* \param f The PrimFunc to transform
* \return Transformed PrimFunc with merged if statements
*/
PrimFunc MergeIfStmtSubstitute(PrimFunc &f);
/*!
* \brief Apply MergeIfStmt transformation to a statement
* \param stmt The statement to transform
* \return Transformed statement with merged if statements
*/
Stmt ApplyMergeIfStmt(Stmt stmt);
} // namespace tl
} // namespace tvm
#endif // TVM_TL_TRANSFORM_MERGE_IF_STMT_H_
...@@ -7,49 +7,15 @@ from tilelang.utils import map_torch_type ...@@ -7,49 +7,15 @@ from tilelang.utils import map_torch_type
@tl.jit @tl.jit
def ptr_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def tensor_null_test(M,
N,
@T.prim_func K,
def main( block_M,
a_ptr: T.ptr, block_N,
b_ptr: T.ptr, block_K,
c_ptr: T.ptr, dtype="float16",
bias_ptr: T.ptr, accum_dtype="float",
m: T.int32, with_bias=False):
n: T.int32,
k: T.int32,
with_bias: T.bool,
):
A = T.make_tensor(a_ptr, (m, k), dtype)
B = T.make_tensor(b_ptr, (k, n), dtype)
C = T.make_tensor(c_ptr, (m, n), accum_dtype)
Bias = T.make_tensor(bias_ptr, (n), accum_dtype)
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(k, block_K), num_stages=3):
# Copy tile of A
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[bx * block_N, ko * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
if with_bias:
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] += Bias[bx * block_N + j]
T.copy(C_local, C[by * block_M, bx * block_N])
return main
@tl.jit
def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func @T.prim_func
def main( def main(
...@@ -57,7 +23,6 @@ def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_ ...@@ -57,7 +23,6 @@ def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), accum_dtype), C: T.Tensor((M, N), accum_dtype),
Bias: T.Tensor((N), accum_dtype), Bias: T.Tensor((N), accum_dtype),
with_bias: T.bool,
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
...@@ -83,28 +48,13 @@ def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_ ...@@ -83,28 +48,13 @@ def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_
def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
kernel = ptr_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype)) a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype))
b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype)) b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype))
c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype))
d = torch.randn(N, device="cuda", dtype=map_torch_type(accum_dtype)) kernel = tensor_null_test(
kernel(a, b, c, None, M, N, K, False) M, N, K, block_M, block_N, block_K, dtype, accum_dtype, with_bias=False)
kernel(a, b, c, None)
ref_no_bias = (a @ b.T).to(map_torch_type(accum_dtype))
ref_with_bias = ref_no_bias + d
torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2)
kernel(a, b, c, d, M, N, K, True)
torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)
kernel = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
kernel(a, b, c, None, False)
torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2)
kernel(a, b, c, d, True)
torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)
def test_nullptr(): def test_nullptr():
......
...@@ -225,6 +225,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -225,6 +225,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target): if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod) mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
mod = tilelang.transform.MakePackedAPI()(mod) mod = tilelang.transform.MakePackedAPI()(mod)
mod = tilelang.transform.Simplify()(mod)
mod = tilelang.transform.LowerDeviceKernelLaunch()(mod) mod = tilelang.transform.LowerDeviceKernelLaunch()(mod)
# Transform threadblock to persistent threadblock # Transform threadblock to persistent threadblock
......
...@@ -166,8 +166,6 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): ...@@ -166,8 +166,6 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
else: else:
expected_dtype_strs.append(None) expected_dtype_strs.append(None)
is_buffer_param.append(False) is_buffer_param.append(False)
# Global function name used in error messages
global_symbol = str(prim_func.attrs.get("global_symbol", "main"))
# Map torch dtype to TVM-style dtype string # Map torch dtype to TVM-style dtype string
def torch_dtype_to_tvm_str(dtype: torch.dtype) -> str: def torch_dtype_to_tvm_str(dtype: torch.dtype) -> str:
...@@ -236,21 +234,6 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): ...@@ -236,21 +234,6 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
tensor = torch.empty(*shape, dtype=dtype, device=out_device) tensor = torch.empty(*shape, dtype=dtype, device=out_device)
else: else:
tensor = inputs[ins_idx] tensor = inputs[ins_idx]
# Input dtype validation with clear error message
if is_buffer_param[i]:
expected_dtype_str = expected_dtype_strs[i]
expected_torch_dtype = param_dtypes[i]
# Only check when the argument is a tensor and expected dtype is known
if isinstance(
tensor, torch.Tensor
) and expected_dtype_str is not None and tensor.dtype != expected_torch_dtype:
param_var = params[i]
# Reconstruct TVM-like handle name A_handle for error clarity
handle_name = f"{param_var.name}_handle"
actual_dtype_str = torch_dtype_to_tvm_str(tensor.dtype)
raise RuntimeError(
f"{global_symbol}.{handle_name}.dtype is expected to be {expected_dtype_str}, but got {actual_dtype_str}"
)
ins_idx += 1 ins_idx += 1
tensor_list.append(tensor) tensor_list.append(tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment