Unverified Commit f6db2014 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[ArgBinder] Enhance shape variable handling and assertions (#1467)

* feat(arg_binder): enhance shape variable handling and assertions

- Implemented special handling for comparing if_then_else expressions to simplify conditions involving NULL checks.
- Added methods to set shared shape variables and finalize deferred bindings, generating cascading if_then_else expressions and runtime assertions for non-NULL buffers.
- Updated the binding logic to defer shape variable bindings for shared variables, ensuring proper handling across multiple nullable buffers.

* refactor(arg_binder): clean up shape variable handling and remove unused code

- Removed deprecated methods for setting shared shape variables and finalizing deferred bindings, streamlining the argument binding process.
- Simplified the logic for handling shape values in the `BindDLTensor` function, ensuring immediate binding for normal shape variables.
- Enhanced clarity by eliminating unnecessary comments and code related to cascading if_then_else expressions for shared variables.

* refactor(arg_binder): enhance DLTensor binding with improved shape handling

- Replaced the single `BindDLTensor` method with `BindDLTensors` to support multiple buffers, improving flexibility in handling DLTensor bindings.
- Introduced a two-pass approach for shape variable handling, allowing for better management of symbolic dimensions and null checks.
- Updated the logic to assert non-null conditions at runtime and utilize cascaded if_then_else expressions for shape retrieval, enhancing robustness.
- Removed deprecated code and streamlined the binding process for clarity and maintainability.

* fix(test_nullable_buffer_params): improve formatting and consistency in test output

- Updated string formatting for better readability in the `test_nullable_shared_shape` function.
- Ensured consistent use of double quotes for string literals.
- Added a missing newline at the end of the file for proper formatting.

* refactor(arg_binder): simplify allocation size calculation in BindDLTensors

- Streamlined the calculation of allocation size by replacing a lambda function with a direct loop, enhancing readability and maintainability.
- Improved clarity in the null check message for data pointers, ensuring better understanding of the binding process.

* Remove debug prints from phase.py

Removed debug print statements after MakePackedAPI transformation.
parent f0672603
......@@ -311,446 +311,618 @@ inline PrimExpr TVMArrayGet(DataType t, Var arr,
return TVMStructGet(t, arr, 0, kind);
}
void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
const PrimExpr &device_id, const Var &handle,
const std::string &arg_name, bool is_used) {
const DataType tvm_shape_type = DataType::ShapeIndex();
const DataType tvm_ndim_type = DataType::Int(32);
const Stmt nop = Evaluate(0);
void ArgBinder::BindDLTensors(
const std::vector<std::pair<Var, Buffer>> &buffer_def,
const PrimExpr &device_type, const PrimExpr &device_id,
const std::string &func_name,
const std::unordered_set<const VarNode *> &used_param_buffers) {
ffi::Array<Buffer> buffers;
ffi::Array<Var> handles;
// First pass: collect shape var -> list of (buffer_name, dim_idx, handle_ptr)
struct ShapeVarSource {
std::string buf_name;
size_t dim_idx;
const VarNode *handle_ptr; // Raw pointer to check used_param_buffers
};
std::unordered_map<const VarNode *, std::vector<ShapeVarSource>>
shape_var_sources;
for (const auto &[handle, buffer] : buffer_def) {
std::string arg_name = func_name + "." + buffer->data->name_hint;
// Scan buffer shape for symbolic variables
for (size_t k = 0; k < buffer->shape.size(); ++k) {
if (buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4) ||
buffer->dtype == DataType::Int(1)) {
break;
}
// Allow NULL DLTensor* for optional inputs. When the handle is NULL,
// avoid dereferencing it by using expression-level conditionals and
// short-circuiting guards in asserts. Cache the null check in a Let-bound
// boolean so codegen does not repeat `(handle == NULL)` everywhere.
Var is_null_var(arg_name + "_is_null", DataType::Bool());
init_nest_.emplace_back(
LetStmt(is_null_var,
Call(DataType::Bool(), builtin::isnullptr(), {handle}), nop));
const PrimExpr &is_null = is_used ? const_false() : is_null_var;
if (is_used) {
init_nest_.emplace_back(AssertStmt(
!is_null_var,
tvm::tir::StringImm(arg_name + " is expected to have non-NULL pointer"),
nop));
if (const VarNode *v = buffer->shape[k].as<VarNode>()) {
// This dimension is a symbolic variable
shape_var_sources[v].push_back({arg_name, k, handle.get()});
}
}
}
// dimension checks
PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim);
// Second pass: Create is_null vars and shape buffers for all buffers first
std::unordered_map<std::string, Var> is_null_map;
std::unordered_map<std::string, Buffer> shape_buffer_map;
std::unordered_map<std::string, PrimExpr>
is_null_expr_map; // arg_name -> is_null expression (const_false for used
// buffers)
// Helper functions for shape/stride name formatting
auto shape_handle_name = [&]() { return arg_name + ".shape"; };
auto stride_handle_name = [&]() { return arg_name + ".strides"; };
auto array_element_name = [&](const std::string &arr_name, size_t k) {
std::stringstream ss;
ss << arr_name << '[' << k << ']';
return ss.str();
};
auto shape_element_name = [&](size_t k) {
return array_element_name(shape_handle_name(), k);
};
auto stride_element_name = [&](size_t k) {
return array_element_name(stride_handle_name(), k);
};
const DataType tvm_shape_type = DataType::ShapeIndex();
const DataType tvm_ndim_type = DataType::Int(32);
const Stmt nop = Evaluate(0);
PrimExpr a_ndim =
make_const(tvm_ndim_type, static_cast<int64_t>(buffer->shape.size()));
// Build clearer ndim message with kernel/buffer names
std::string kernel_nm = arg_name;
std::string buf_nm = arg_name;
size_t dot_pos = arg_name.find('.');
if (dot_pos != std::string::npos) {
kernel_nm = arg_name.substr(0, dot_pos);
buf_nm = arg_name.substr(dot_pos + 1);
}
// Only check ndim when handle is non-NULL: use packed error helper
PrimExpr ndim_ok = (a_ndim == v_ndim);
ffi::Array<PrimExpr> ndim_args;
ndim_args.push_back(StringImm(tvm_error_ndim_mismatch));
ndim_args.push_back(StringImm(kernel_nm));
ndim_args.push_back(StringImm(buf_nm));
ndim_args.push_back(cast(DataType::Int(64), a_ndim));
ndim_args.push_back(cast(DataType::Int(64), v_ndim));
Stmt ndim_call =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), ndim_args));
init_nest_.emplace_back(
SeqStmt({IfThenElse(Not(is_null), IfThenElse(Not(ndim_ok), ndim_call),
Evaluate(0)),
nop}));
// type checks
// Guard all dtype field loads by `is_null` using if_then_else
PrimExpr v_type_code = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode),
IntImm(DataType::UInt(8), buffer->dtype.code()));
PrimExpr v_type_bits = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits),
IntImm(DataType::UInt(8), buffer->dtype.bits()));
PrimExpr v_type_lanes = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes),
IntImm(DataType::UInt(16), buffer->dtype.lanes()));
PrimExpr expect_code = IntImm(DataType::UInt(8), buffer->dtype.code());
PrimExpr expect_bits = IntImm(DataType::UInt(8), buffer->dtype.bits());
PrimExpr expect_lanes = IntImm(DataType::UInt(16), buffer->dtype.lanes());
PrimExpr cond = (v_type_code == expect_code && v_type_bits == expect_bits &&
v_type_lanes == expect_lanes);
// Allow float8_e4m3 to match float8_e4m3fn/float8_e4m3fnuz at runtime.
if (buffer->dtype.is_float8_e4m3()) {
PrimExpr code_e4m3 = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3);
PrimExpr code_e4m3fn = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fn);
PrimExpr code_e4m3fnuz =
IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fnuz);
PrimExpr code_match =
(v_type_code == code_e4m3 || v_type_code == code_e4m3fn ||
v_type_code == code_e4m3fnuz);
cond = cond || (code_match && v_type_bits == expect_bits &&
v_type_lanes == expect_lanes);
}
// Allow float8_e5m2 to match float8_e5m2fnuz at runtime.
if (buffer->dtype.is_float8_e5m2()) {
PrimExpr code_e5m2 = IntImm(DataType::UInt(8), DataType::kFloat8_e5m2);
PrimExpr code_e5m2fnuz =
IntImm(DataType::UInt(8), DataType::kFloat8_e5m2fnuz);
PrimExpr code_match =
(v_type_code == code_e5m2 || v_type_code == code_e5m2fnuz);
cond = cond || (code_match && v_type_bits == expect_bits &&
v_type_lanes == expect_lanes);
}
// Allow bool to match int8/uint8 at runtime, and also kDLBool(code=6).
if (buffer->dtype.is_bool()) {
PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt);
PrimExpr code_uint = IntImm(DataType::UInt(8), DataType::kUInt);
PrimExpr code_kdlbool = IntImm(DataType::UInt(8), 6);
PrimExpr bits8 = IntImm(DataType::UInt(8), 8);
PrimExpr bits1 = IntImm(DataType::UInt(8), 1);
PrimExpr lanes_ok = (v_type_lanes == expect_lanes);
PrimExpr int8_ok =
(v_type_code == code_int && v_type_bits == bits8 && lanes_ok);
PrimExpr uint8_ok =
(v_type_code == code_uint && v_type_bits == bits8 && lanes_ok);
// Some frontends may tag bool tensors as kDLBool(code=6), commonly with
// bits=8 or bits=1.
PrimExpr kdlbool8_ok =
(v_type_code == code_kdlbool && v_type_bits == bits8 && lanes_ok);
PrimExpr kdlbool1_ok =
(v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok);
// Also accept any dtype whose bitwidth=1, regardless of code, to be
// defensive.
PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok);
cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok;
// Create all is_null vars and shape buffers first
for (const auto &[handle, buffer] : buffer_def) {
bool is_used = used_param_buffers.count(handle.get());
std::string arg_name = func_name + "." + buffer->data->name_hint;
Var is_null_var(arg_name + "_is_null", DataType::Bool());
init_nest_.emplace_back(
LetStmt(is_null_var,
Call(DataType::Bool(), builtin::isnullptr(), {handle}), nop));
const PrimExpr &is_null = is_used ? const_false() : is_null_var;
is_null_map[arg_name] = is_null_var;
is_null_expr_map[arg_name] = is_null;
if (is_used) {
init_nest_.emplace_back(
AssertStmt(!is_null_var,
tvm::tir::StringImm(
arg_name + " is expected to have non-NULL pointer"),
nop));
}
}
// Allow float4 to match int8 at runtime (PyTorch uses int8 as storage for
// FP4).
if (buffer->dtype.is_float4()) {
PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt);
PrimExpr bits8 = IntImm(DataType::UInt(8), 8);
// For FP4, we pack 2 elements per byte, but we still use same lanes at
// storage level Accept int8 with same lanes as the fp4 type
PrimExpr fp4_lanes_ok = (v_type_lanes == expect_lanes);
PrimExpr int8_ok =
(v_type_code == code_int && v_type_bits == bits8 && fp4_lanes_ok);
cond = cond || int8_ok;
// Create all shape buffers before binding any shapes
for (const auto &[handle, buffer] : buffer_def) {
std::string arg_name = func_name + "." + buffer->data->name_hint;
const PrimExpr &is_null = is_null_expr_map[arg_name];
// Helper functions for shape/stride name formatting
auto shape_handle_name = [&]() { return arg_name + ".shape"; };
// shape field
Buffer buf_shape =
decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())},
tvm_shape_type, shape_handle_name());
def_handle_dtype_.Set(buf_shape->data, make_const(tvm_shape_type, 0));
// Use if_then_else for NULL guard on the shape pointer itself, avoiding
// dereferencing TVMStructGet(handle, kArrShape) when handle is NULL.
init_nest_.emplace_back(
LetStmt(buf_shape->data,
tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape),
make_zero(DataType::Handle())),
nop));
init_nest_.emplace_back(DeclBuffer(buf_shape, nop));
// Save for later use in shape binding
shape_buffer_map[arg_name] = buf_shape;
}
if (!(buffer->dtype == DataType::Int(1) ||
buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4) || buffer->dtype.is_float4())) {
// Build FFI packed call to __tvm_error_dtype_mismatch when mismatch occurs.
// Only issue the call when handle is non-NULL and cond is false.
ffi::Array<PrimExpr> packed_args;
packed_args.push_back(StringImm(tvm_error_dtype_mismatch));
// Split arg_name of the form "<kernel>.<buffer>" into parts for clearer
// diagnostics
std::string kernel_name = arg_name;
std::string buffer_name = arg_name;
// Now process each buffer fully
for (const auto &[handle, buffer] : buffer_def) {
bool is_used = used_param_buffers.count(handle.get());
std::string arg_name = func_name + "." + buffer->data->name_hint;
const PrimExpr &is_null = is_null_expr_map[arg_name];
// dimension checks
PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim);
// Helper functions for shape/stride name formatting
auto shape_handle_name = [&]() { return arg_name + ".shape"; };
auto stride_handle_name = [&]() { return arg_name + ".strides"; };
auto array_element_name = [&](const std::string &arr_name, size_t k) {
std::stringstream ss;
ss << arr_name << '[' << k << ']';
return ss.str();
};
auto shape_element_name = [&](size_t k) {
return array_element_name(shape_handle_name(), k);
};
auto stride_element_name = [&](size_t k) {
return array_element_name(stride_handle_name(), k);
};
PrimExpr a_ndim =
make_const(tvm_ndim_type, static_cast<int64_t>(buffer->shape.size()));
// Build clearer ndim message with kernel/buffer names
std::string kernel_nm = arg_name;
std::string buf_nm = arg_name;
size_t dot_pos = arg_name.find('.');
if (dot_pos != std::string::npos) {
kernel_name = arg_name.substr(0, dot_pos);
buffer_name = arg_name.substr(dot_pos + 1);
kernel_nm = arg_name.substr(0, dot_pos);
buf_nm = arg_name.substr(dot_pos + 1);
}
packed_args.push_back(StringImm(kernel_name));
packed_args.push_back(StringImm(buffer_name));
auto i64 = DataType::Int(64);
// Cast to int64 for FFI function signature
packed_args.push_back(cast(i64, v_type_code)); // actual_code
packed_args.push_back(cast(i64, v_type_bits)); // actual_bits
packed_args.push_back(cast(i64, v_type_lanes)); // actual_lanes
packed_args.push_back(cast(i64, expect_code)); // expect_code
packed_args.push_back(cast(i64, expect_bits)); // expect_bits
packed_args.push_back(cast(i64, expect_lanes)); // expect_lanes
Stmt call_err = Evaluate(
Call(DataType::Int(32), builtin::tvm_call_packed(), packed_args));
// Guard the call: only when handle is not null and cond fails
Stmt guarded = IfThenElse(Not(is_null) && Not(cond), call_err);
asserts_.emplace_back(SeqStmt({guarded, nop}));
}
// shape field
Buffer buf_shape =
decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())},
tvm_shape_type, shape_handle_name());
Var v_shape(shape_handle_name(), DataType::Handle());
def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
// Use if_then_else for NULL guard on the shape pointer itself, avoiding
// dereferencing TVMStructGet(handle, kArrShape) when handle is NULL.
init_nest_.emplace_back(
LetStmt(buf_shape->data,
tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape),
make_zero(DataType::Handle())),
nop));
init_nest_.emplace_back(DeclBuffer(buf_shape, nop));
for (size_t k = 0; k < buffer->shape.size(); ++k) {
// These packed-bit dtype shapes were not bound in the original
// implementation, so we just use them as is.
if (buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4) ||
buffer->dtype == DataType::Int(1)) {
break;
// Only check ndim when handle is non-NULL: use packed error helper
PrimExpr ndim_ok = (a_ndim == v_ndim);
ffi::Array<PrimExpr> ndim_args;
ndim_args.push_back(StringImm(tvm_error_ndim_mismatch));
ndim_args.push_back(StringImm(kernel_nm));
ndim_args.push_back(StringImm(buf_nm));
ndim_args.push_back(cast(DataType::Int(64), a_ndim));
ndim_args.push_back(cast(DataType::Int(64), v_ndim));
Stmt ndim_call = Evaluate(
Call(DataType::Int(32), builtin::tvm_call_packed(), ndim_args));
init_nest_.emplace_back(
SeqStmt({IfThenElse(Not(is_null), IfThenElse(Not(ndim_ok), ndim_call),
Evaluate(0)),
nop}));
// type checks
// Guard all dtype field loads by `is_null` using if_then_else
PrimExpr v_type_code = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeCode),
IntImm(DataType::UInt(8), buffer->dtype.code()));
PrimExpr v_type_bits = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(8), handle, builtin::kArrTypeBits),
IntImm(DataType::UInt(8), buffer->dtype.bits()));
PrimExpr v_type_lanes = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(16), handle, builtin::kArrTypeLanes),
IntImm(DataType::UInt(16), buffer->dtype.lanes()));
PrimExpr expect_code = IntImm(DataType::UInt(8), buffer->dtype.code());
PrimExpr expect_bits = IntImm(DataType::UInt(8), buffer->dtype.bits());
PrimExpr expect_lanes = IntImm(DataType::UInt(16), buffer->dtype.lanes());
PrimExpr cond = (v_type_code == expect_code && v_type_bits == expect_bits &&
v_type_lanes == expect_lanes);
// Allow float8_e4m3 to match float8_e4m3fn/float8_e4m3fnuz at runtime.
if (buffer->dtype.is_float8_e4m3()) {
PrimExpr code_e4m3 = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3);
PrimExpr code_e4m3fn =
IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fn);
PrimExpr code_e4m3fnuz =
IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fnuz);
PrimExpr code_match =
(v_type_code == code_e4m3 || v_type_code == code_e4m3fn ||
v_type_code == code_e4m3fnuz);
cond = cond || (code_match && v_type_bits == expect_bits &&
v_type_lanes == expect_lanes);
}
// The "real" runtime shape value read from DLTensor
PrimExpr shape_val =
cast(buffer->shape[k].dtype(),
BufferLoad(buf_shape,
{IntImm(DataType::Int(32), static_cast<int>(k))}));
// When first encountering a Var (e.g., m), this will generate:
// Let(m, bound_shape_val, ...)
// Constant dimensions will only generate consistency assertions.
BindNullable(buffer->shape[k], shape_val, shape_element_name(k), true,
is_null);
}
// strides field
Buffer buf_strides =
decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())},
tvm_shape_type, arg_name + ".strides");
def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type));
init_nest_.emplace_back(
LetStmt(buf_strides->data,
tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides),
make_zero(DataType::Handle())),
nop));
init_nest_.emplace_back(DeclBuffer(buf_strides, nop));
PrimExpr v_strides_is_null =
Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data});
if (buffer->strides.empty()) {
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
PrimExpr expect_stride = make_const(stype, 1);
ffi::Array<PrimExpr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
PrimExpr svalue = cast(
stype, BufferLoad(buf_strides,
{IntImm(DataType::Int(32), static_cast<int>(k))}));
conds.push_back(buffer->shape[k] == 1 || expect_stride == svalue);
expect_stride = expect_stride * buffer->shape[k];
// Allow float8_e5m2 to match float8_e5m2fnuz at runtime.
if (buffer->dtype.is_float8_e5m2()) {
PrimExpr code_e5m2 = IntImm(DataType::UInt(8), DataType::kFloat8_e5m2);
PrimExpr code_e5m2fnuz =
IntImm(DataType::UInt(8), DataType::kFloat8_e5m2fnuz);
PrimExpr code_match =
(v_type_code == code_e5m2 || v_type_code == code_e5m2fnuz);
cond = cond || (code_match && v_type_bits == expect_bits &&
v_type_lanes == expect_lanes);
}
// Allow bool to match int8/uint8 at runtime, and also kDLBool(code=6).
if (buffer->dtype.is_bool()) {
PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt);
PrimExpr code_uint = IntImm(DataType::UInt(8), DataType::kUInt);
PrimExpr code_kdlbool = IntImm(DataType::UInt(8), 6);
PrimExpr bits8 = IntImm(DataType::UInt(8), 8);
PrimExpr bits1 = IntImm(DataType::UInt(8), 1);
PrimExpr lanes_ok = (v_type_lanes == expect_lanes);
PrimExpr int8_ok =
(v_type_code == code_int && v_type_bits == bits8 && lanes_ok);
PrimExpr uint8_ok =
(v_type_code == code_uint && v_type_bits == bits8 && lanes_ok);
// Some frontends may tag bool tensors as kDLBool(code=6), commonly with
// bits=8 or bits=1.
PrimExpr kdlbool8_ok =
(v_type_code == code_kdlbool && v_type_bits == bits8 && lanes_ok);
PrimExpr kdlbool1_ok =
(v_type_code == code_kdlbool && v_type_bits == bits1 && lanes_ok);
// Also accept any dtype whose bitwidth=1, regardless of code, to be
// defensive.
PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok);
cond =
cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok;
}
// Allow float4 to match int8 at runtime (PyTorch uses int8 as storage for
// FP4).
if (buffer->dtype.is_float4()) {
PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt);
PrimExpr bits8 = IntImm(DataType::UInt(8), 8);
// For FP4, we pack 2 elements per byte, but we still use same lanes at
// storage level Accept int8 with same lanes as the fp4 type
PrimExpr fp4_lanes_ok = (v_type_lanes == expect_lanes);
PrimExpr int8_ok =
(v_type_code == code_int && v_type_bits == bits8 && fp4_lanes_ok);
cond = cond || int8_ok;
}
std::ostringstream stride_err_msg;
stride_err_msg
<< stride_handle_name()
<< ": expected to be compact array, but got non-compact strides";
if (!conds.empty()) {
PrimExpr all_ok = foldl([](PrimExpr a, PrimExpr b,
Span span) { return logical_and(a, b, span); },
const_true(1), conds);
// Packed generic violation for non-compact strides
std::string kernel_nm3 = arg_name;
std::string buf_nm3 = arg_name;
size_t dot_pos3 = arg_name.find('.');
if (dot_pos3 != std::string::npos) {
kernel_nm3 = arg_name.substr(0, dot_pos3);
buf_nm3 = arg_name.substr(dot_pos3 + 1);
if (!(buffer->dtype == DataType::Int(1) ||
buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4) || buffer->dtype.is_float4())) {
// Build FFI packed call to __tvm_error_dtype_mismatch when mismatch
// occurs. Only issue the call when handle is non-NULL and cond is false.
ffi::Array<PrimExpr> packed_args;
packed_args.push_back(StringImm(tvm_error_dtype_mismatch));
// Split arg_name of the form "<kernel>.<buffer>" into parts for clearer
// diagnostics
std::string kernel_name = arg_name;
std::string buffer_name = arg_name;
size_t dot_pos = arg_name.find('.');
if (dot_pos != std::string::npos) {
kernel_name = arg_name.substr(0, dot_pos);
buffer_name = arg_name.substr(dot_pos + 1);
}
ffi::Array<PrimExpr> pargs4;
pargs4.push_back(StringImm(tvm_error_constraint_violation));
pargs4.push_back(StringImm(kernel_nm3));
pargs4.push_back(StringImm(buf_nm3));
pargs4.push_back(StringImm("strides"));
Stmt call_err4 =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs4));
// Only check when strides array is present and condition fails
Stmt check = IfThenElse(Not(v_strides_is_null),
IfThenElse(Not(all_ok), call_err4), Evaluate(0));
asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
packed_args.push_back(StringImm(kernel_name));
packed_args.push_back(StringImm(buffer_name));
auto i64 = DataType::Int(64);
// Cast to int64 for FFI function signature
packed_args.push_back(cast(i64, v_type_code)); // actual_code
packed_args.push_back(cast(i64, v_type_bits)); // actual_bits
packed_args.push_back(cast(i64, v_type_lanes)); // actual_lanes
packed_args.push_back(cast(i64, expect_code)); // expect_code
packed_args.push_back(cast(i64, expect_bits)); // expect_bits
packed_args.push_back(cast(i64, expect_lanes)); // expect_lanes
Stmt call_err = Evaluate(
Call(DataType::Int(32), builtin::tvm_call_packed(), packed_args));
// Guard the call: only when handle is not null and cond fails
Stmt guarded = IfThenElse(Not(is_null) && Not(cond), call_err);
asserts_.emplace_back(SeqStmt({guarded, nop}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
PrimExpr stride_from_shape = 1;
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
DataType stride_dtype = buffer->strides[k].dtype();
PrimExpr explicit_stride =
cast(stride_dtype,
BufferLoad(buf_strides,
{IntImm(DataType::Int(32), static_cast<int>(k))}));
PrimExpr stride_val = tvm::if_then_else(
v_strides_is_null, stride_from_shape, explicit_stride);
// Get the pre-created shape buffer
Buffer buf_shape = shape_buffer_map[arg_name];
// Bind symbolic variables from buffer shape
for (size_t k = 0; k < buffer->shape.size(); ++k) {
// These packed-bit dtype shapes were not bound in the original
// implementation, so we just use them as is.
if (buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4) ||
buffer->dtype == DataType::Int(1)) {
break;
}
BindNullable(buffer->strides[k], stride_val, stride_element_name(k), true,
is_null);
// The "real" runtime shape value read from DLTensor
PrimExpr shape_val =
cast(buffer->shape[k].dtype(),
BufferLoad(buf_shape,
{IntImm(DataType::Int(32), static_cast<int>(k))}));
// Check if this dimension is a symbolic variable
if (const VarNode *v = buffer->shape[k].as<VarNode>()) {
auto it = def_map_->find(v);
if (it == def_map_->end()) {
// First time binding this symbolic variable
auto sources_it = shape_var_sources.find(v);
if (sources_it != shape_var_sources.end() &&
sources_it->second.size() > 1) {
// This variable appears in multiple buffers
// Assert that at least one buffer is non-null
PrimExpr any_nonnull = const_false();
for (const auto &src : sources_it->second) {
bool buf_is_used = used_param_buffers.count(src.handle_ptr);
if (buf_is_used) {
any_nonnull = const_true();
break;
}
Var src_is_null = is_null_map[src.buf_name];
any_nonnull = Or(any_nonnull, Not(src_is_null));
}
std::ostringstream err_msg;
err_msg << "Symbolic shape variable "
<< ffi::GetRef<Var>(v)->name_hint
<< " requires at least one non-null buffer among: ";
bool first = true;
for (const auto &src : sources_it->second) {
if (!first)
err_msg << ", ";
err_msg << src.buf_name;
first = false;
}
init_nest_.emplace_back(AssertStmt(
any_nonnull, tvm::tir::StringImm(err_msg.str()), nop));
// Build cascaded if_then_else: if !is_null_a then a.shape[k] else
// if !is_null_b then b.shape[k] ... We need to construct this in
// reverse order
PrimExpr cascaded_value;
bool is_first_source = true;
for (auto rit = sources_it->second.rbegin();
rit != sources_it->second.rend(); ++rit) {
const auto &src = *rit;
// Get the shape buffer for this source
auto it_buf = shape_buffer_map.find(src.buf_name);
if (it_buf == shape_buffer_map.end()) {
LOG(FATAL) << "Shape buffer not found for " << src.buf_name;
}
Buffer src_shape_buf = it_buf->second;
// Construct the shape load
PrimExpr src_shape_val =
cast(buffer->shape[k].dtype(),
BufferLoad(src_shape_buf,
{IntImm(DataType::Int(32),
static_cast<int>(src.dim_idx))}));
// Check if this buffer is used (non-nullable)
bool src_is_used = used_param_buffers.count(src.handle_ptr);
if (is_first_source) {
// Base case: use this shape value directly (we know at least
// one is non-null from assert)
cascaded_value = src_shape_val;
is_first_source = false;
} else {
// if !is_null then use this shape, else use previous cascaded
// value But if buffer is used (non-nullable), always use its
// shape
if (src_is_used) {
cascaded_value = src_shape_val;
} else {
Var src_is_null = is_null_map[src.buf_name];
cascaded_value = tvm::if_then_else(
Not(src_is_null), src_shape_val, cascaded_value);
}
}
}
// Bind the variable to the cascaded expression
Var v_arg = ffi::GetRef<Var>(v);
defs_.emplace_back(v_arg);
(*def_map_)[v] = cascaded_value;
init_nest_.emplace_back(
LetStmt(v_arg, cascaded_value, Evaluate(0)));
} else {
// Single source or no special handling needed, use the original
// nullable binding
BindNullable(buffer->shape[k], shape_val, shape_element_name(k),
true, is_null);
}
} else {
// Variable already bound, add assertion with nullable guard
PrimExpr cond = (it->second == shape_val);
BinderAddAssert(&analyzer_, cond, shape_element_name(k), &asserts_,
is_null);
}
} else {
// Constant dimension, just add assertion
BindNullable(buffer->shape[k], shape_val, shape_element_name(k), true,
is_null);
}
}
} else {
PrimExpr stride_from_shape = 1;
for (int k = static_cast<int>(buffer->strides.size()) - 1; k >= 0; --k) {
DataType stride_dtype = buffer->strides[k].dtype();
PrimExpr explicit_stride =
cast(stride_dtype,
BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
PrimExpr shape_stride = cast(
stride_dtype, BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)}));
// strides field
Buffer buf_strides =
decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())},
tvm_shape_type, arg_name + ".strides");
def_handle_dtype_.Set(buf_strides->data,
tir::TypeAnnotation(tvm_shape_type));
init_nest_.emplace_back(
LetStmt(buf_strides->data,
tvm::if_then_else(Not(is_null),
TVMArrayGet(DataType::Handle(), handle,
builtin::kArrStrides),
make_zero(DataType::Handle())),
nop));
init_nest_.emplace_back(DeclBuffer(buf_strides, nop));
PrimExpr v_strides_is_null =
Call(DataType::Bool(1), builtin::isnullptr(), {buf_strides->data});
if (buffer->strides.empty()) {
// Assert the buffer is compact
DataType stype = buffer->DefaultIndexType();
PrimExpr expect_stride = make_const(stype, 1);
ffi::Array<PrimExpr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
PrimExpr svalue =
cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32),
static_cast<int>(k))}));
conds.push_back(buffer->shape[k] == 1 || expect_stride == svalue);
expect_stride = expect_stride * buffer->shape[k];
}
std::ostringstream stride_err_msg;
stride_err_msg
<< stride_handle_name()
<< ": expected to be compact array, but got non-compact strides";
if (!conds.empty()) {
PrimExpr all_ok =
foldl([](PrimExpr a, PrimExpr b,
Span span) { return logical_and(a, b, span); },
const_true(1), conds);
// Packed generic violation for non-compact strides
std::string kernel_nm3 = arg_name;
std::string buf_nm3 = arg_name;
size_t dot_pos3 = arg_name.find('.');
if (dot_pos3 != std::string::npos) {
kernel_nm3 = arg_name.substr(0, dot_pos3);
buf_nm3 = arg_name.substr(dot_pos3 + 1);
}
ffi::Array<PrimExpr> pargs4;
pargs4.push_back(StringImm(tvm_error_constraint_violation));
pargs4.push_back(StringImm(kernel_nm3));
pargs4.push_back(StringImm(buf_nm3));
pargs4.push_back(StringImm("strides"));
Stmt call_err4 = Evaluate(
Call(DataType::Int(32), builtin::tvm_call_packed(), pargs4));
// Only check when strides array is present and condition fails
Stmt check =
IfThenElse(Not(v_strides_is_null),
IfThenElse(Not(all_ok), call_err4), Evaluate(0));
asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
}
} else if (buffer->buffer_type == kAutoBroadcast) {
PrimExpr stride_from_shape = 1;
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
DataType stride_dtype = buffer->strides[k].dtype();
PrimExpr explicit_stride =
cast(stride_dtype,
BufferLoad(buf_strides,
{IntImm(DataType::Int(32), static_cast<int>(k))}));
PrimExpr stride_val = tvm::if_then_else(
v_strides_is_null, stride_from_shape, explicit_stride);
PrimExpr stride_val = tvm::if_then_else(
v_strides_is_null, stride_from_shape, explicit_stride);
BindNullable(buffer->strides[k], stride_val, stride_element_name(k), true,
is_null);
BindNullable(buffer->strides[k], stride_val, stride_element_name(k),
true, is_null);
}
} else {
PrimExpr stride_from_shape = 1;
for (int k = static_cast<int>(buffer->strides.size()) - 1; k >= 0; --k) {
DataType stride_dtype = buffer->strides[k].dtype();
PrimExpr explicit_stride =
cast(stride_dtype,
BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
PrimExpr shape_stride =
cast(stride_dtype,
BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)}));
PrimExpr stride_val = tvm::if_then_else(
v_strides_is_null, stride_from_shape, explicit_stride);
BindNullable(buffer->strides[k], stride_val, stride_element_name(k),
true, is_null);
}
}
}
// Byte_offset field.
int data_bytes = GetVectorBytes(buffer->dtype);
// Byte_offset field.
int data_bytes = GetVectorBytes(buffer->dtype);
if (const auto *const_offset = buffer->elem_offset.as<IntImmNode>()) {
// Constant elem_offset: only need consistency check, no need for
// additional Var binding.
PrimExpr actual_byte_offset = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset),
make_const(DataType::UInt(64), 0));
PrimExpr expect_byte_offset =
make_const(DataType::UInt(64), const_offset->value * data_bytes);
PrimExpr ok = (expect_byte_offset == actual_byte_offset);
ffi::Array<PrimExpr> pargs;
pargs.push_back(StringImm(tvm_error_byte_offset_mismatch));
pargs.push_back(StringImm(kernel_nm));
pargs.push_back(StringImm(buf_nm));
pargs.push_back(cast(DataType::Int(64), expect_byte_offset));
pargs.push_back(cast(DataType::Int(64), actual_byte_offset));
Stmt call_err =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs));
asserts_.emplace_back(SeqStmt(
{IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err), Evaluate(0)),
nop}));
} else {
PrimExpr actual_byte_offset = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset),
make_const(DataType::UInt(64), 0));
PrimExpr expect_elem_off = cast(
buffer->elem_offset.dtype(),
(actual_byte_offset / make_const(DataType::UInt(64), data_bytes)));
BindNullable(buffer->elem_offset, expect_elem_off,
arg_name + ".elem_offset", true, is_null);
if (buffer->offset_factor > 1) {
PrimExpr offset = buffer->elem_offset;
PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
BindNullable(offset, truncmod(offset, factor),
arg_name + ".elem_offset", true, is_null);
}
}
if (const auto *const_offset = buffer->elem_offset.as<IntImmNode>()) {
// Constant elem_offset: only need consistency check, no need for additional
// Var binding.
PrimExpr actual_byte_offset = tvm::if_then_else(
// device info.
// Define device_id from handle when available (so later passes can use it)
PrimExpr actual_dev_type = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset),
make_const(DataType::UInt(64), 0));
PrimExpr expect_byte_offset =
make_const(DataType::UInt(64), const_offset->value * data_bytes);
PrimExpr ok = (expect_byte_offset == actual_byte_offset);
ffi::Array<PrimExpr> pargs;
pargs.push_back(StringImm(tvm_error_byte_offset_mismatch));
pargs.push_back(StringImm(kernel_nm));
pargs.push_back(StringImm(buf_nm));
pargs.push_back(cast(DataType::Int(64), expect_byte_offset));
pargs.push_back(cast(DataType::Int(64), actual_byte_offset));
Stmt call_err =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs));
asserts_.emplace_back(SeqStmt(
{IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err), Evaluate(0)),
nop}));
} else {
PrimExpr actual_byte_offset = tvm::if_then_else(
TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType),
make_zero(DataType::Int(32)));
PrimExpr actual_dev_id = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset),
make_const(DataType::UInt(64), 0));
PrimExpr expect_elem_off =
cast(buffer->elem_offset.dtype(),
(actual_byte_offset / make_const(DataType::UInt(64), data_bytes)));
BindNullable(buffer->elem_offset, expect_elem_off,
arg_name + ".elem_offset", true, is_null);
if (buffer->offset_factor > 1) {
PrimExpr offset = buffer->elem_offset;
PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
BindNullable(offset, truncmod(offset, factor), arg_name + ".elem_offset",
true, is_null);
}
}
// device info.
// Define device_id from handle when available (so later passes can use it)
PrimExpr actual_dev_type = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceType),
make_zero(DataType::Int(32)));
PrimExpr actual_dev_id = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId),
make_zero(DataType::Int(32)));
// Bind device_id to a safe expression (0 when NULL handle)
BindNullable(device_id, actual_dev_id, arg_name + ".device_id", true,
is_null);
// Check device_type consistency (device_id equality is implicitly ensured by
// binding above)
{
PrimExpr ok = (device_type == actual_dev_type);
ffi::Array<PrimExpr> pargs2;
pargs2.push_back(StringImm(tvm_error_device_type_mismatch));
pargs2.push_back(StringImm(kernel_nm));
pargs2.push_back(StringImm(buf_nm));
pargs2.push_back(cast(DataType::Int(64), device_type));
pargs2.push_back(cast(DataType::Int(64), actual_dev_type));
Stmt call_err2 =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs2));
asserts_.emplace_back(SeqStmt(
{IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err2), Evaluate(0)),
Evaluate(0)}));
}
TVMArrayGet(DataType::Int(32), handle, builtin::kArrDeviceId),
make_zero(DataType::Int(32)));
// Data field. Because the validation of the data field may depend
// on a dynamic size defined by the other DLTensor* parameters, this
// field must be generated last.
// Bind data pointer using expression-level guard to avoid deref on NULL.
{
Var vptr(buffer->data);
PrimExpr data_ptr = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Handle(), handle, builtin::kArrData),
make_zero(DataType::Handle()));
BindNullable(buffer->data, data_ptr, arg_name + ".data", true, is_null);
// Check if the data pointer is NULL. This check is skipped for
// size-0 arrays and also skipped when handle itself is NULL.
auto alloc_size = [&]() -> PrimExpr {
PrimExpr product = IntImm(buffer->DefaultIndexType(), 1);
for (const auto &dim : buffer->shape)
product *= dim;
return product;
}();
// Improve message: kernel/buffer naming for data pointer null check
std::string kernel_nm2 = arg_name;
std::string buf_nm2 = arg_name;
size_t dot_pos2 = arg_name.find('.');
if (dot_pos2 != std::string::npos) {
kernel_nm2 = arg_name.substr(0, dot_pos2);
buf_nm2 = arg_name.substr(dot_pos2 + 1);
// Bind device_id to a safe expression (0 when NULL handle)
BindNullable(device_id, actual_dev_id, arg_name + ".device_id", true,
is_null);
// Check device_type consistency (device_id equality is implicitly ensured
// by binding above)
{
PrimExpr ok = (device_type == actual_dev_type);
ffi::Array<PrimExpr> pargs2;
pargs2.push_back(StringImm(tvm_error_device_type_mismatch));
pargs2.push_back(StringImm(kernel_nm));
pargs2.push_back(StringImm(buf_nm));
pargs2.push_back(cast(DataType::Int(64), device_type));
pargs2.push_back(cast(DataType::Int(64), actual_dev_type));
Stmt call_err2 =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs2));
asserts_.emplace_back(
SeqStmt({IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err2),
Evaluate(0)),
Evaluate(0)}));
}
// expand combined condition via nested IfThenElse for portability
ffi::Array<PrimExpr> pargs3;
pargs3.push_back(StringImm(tvm_error_null_ptr));
pargs3.push_back(StringImm(kernel_nm2));
pargs3.push_back(StringImm(buf_nm2));
pargs3.push_back(StringImm("data pointer"));
Stmt call_err3 =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs3));
asserts_.emplace_back(SeqStmt(
{IfThenElse(Not(is_null),
IfThenElse(Not(alloc_size == 0),
IfThenElse(Call(DataType::Bool(),
builtin::isnullptr(), {vptr}),
call_err3),
Evaluate(0)),
Evaluate(0)),
nop}));
// mark alignment of external bufs
init_nest_.emplace_back(
AttrStmt(vptr, tir::attr::storage_alignment,
IntImm(DataType::Int(32), buffer->data_alignment), nop));
def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
// Data field. Because the validation of the data field may depend
// on a dynamic size defined by the other DLTensor* parameters, this
// field must be generated last.
// Bind data pointer using expression-level guard to avoid deref on NULL.
{
Var vptr(buffer->data);
PrimExpr data_ptr = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Handle(), handle, builtin::kArrData),
make_zero(DataType::Handle()));
BindNullable(buffer->data, data_ptr, arg_name + ".data", true, is_null);
// Check if the data pointer is NULL. This check is skipped for
// size-0 arrays and also skipped when handle itself is NULL.
PrimExpr alloc_size = IntImm(buffer->DefaultIndexType(), 1);
for (const auto &dim : buffer->shape) {
alloc_size = alloc_size * dim;
}
// Improve message: kernel/buffer naming for data pointer null check
std::string kernel_nm2 = arg_name;
std::string buf_nm2 = arg_name;
size_t dot_pos2 = arg_name.find('.');
if (dot_pos2 != std::string::npos) {
kernel_nm2 = arg_name.substr(0, dot_pos2);
buf_nm2 = arg_name.substr(dot_pos2 + 1);
}
// expand combined condition via nested IfThenElse for portability
ffi::Array<PrimExpr> pargs3;
pargs3.push_back(StringImm(tvm_error_null_ptr));
pargs3.push_back(StringImm(kernel_nm2));
pargs3.push_back(StringImm(buf_nm2));
pargs3.push_back(StringImm("data pointer"));
Stmt call_err3 =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs3));
asserts_.emplace_back(SeqStmt(
{IfThenElse(Not(is_null),
IfThenElse(Not(alloc_size == 0),
IfThenElse(Call(DataType::Bool(),
builtin::isnullptr(), {vptr}),
call_err3),
Evaluate(0)),
Evaluate(0)),
nop}));
// mark alignment of external bufs
init_nest_.emplace_back(
AttrStmt(vptr, tir::attr::storage_alignment,
IntImm(DataType::Int(32), buffer->data_alignment), nop));
def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
}
}
}
} // namespace tl
} // namespace tvm
} // namespace tvm
\ No newline at end of file
......@@ -95,17 +95,21 @@ public:
*/
void BindBuffer(const Buffer &arg, const Buffer &value,
const std::string &arg_name, bool fuzzy_match);
/*!
* \brief Bind symbolic buffer to a DLTensor handle.
* \param buffer The argument buffer to be binded.
* \param device_type The device id to be binded.
* \param device_type The device type to be binded.
* \param device_id The device id to be binded.
* \param handle The DLTensor handle.
* \param arg_name argument name.
* \param buffer_def The buffer definition.
* \param func_name The function name.
* \param used_param_buffers The used param buffers.
*/
void BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
const PrimExpr &device_id, const Var &handle,
const std::string &arg_name, bool is_used);
void
BindDLTensors(const std::vector<std::pair<Var, Buffer>> &buffer_def,
const PrimExpr &device_type, const PrimExpr &device_id,
const std::string &func_name,
const std::unordered_set<const VarNode *> &used_param_buffers);
/*! \return The defs generated in binding. */
const std::vector<Var> &defs() const { return defs_; }
......@@ -178,4 +182,4 @@ private:
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_TRANSFORM_ARG_BINDER_H_
#endif // TVM_TL_TRANSFORM_ARG_BINDER_H_
\ No newline at end of file
......@@ -393,10 +393,15 @@ PrimFunc MakePackedAPI(PrimFunc func) {
break;
}
}
if (!has_used_carrier && !carriers.empty()) {
// Choose the first carrier to anchor this symbol.
used_param_buffers.insert(carriers.front());
}
// NOTE: With the new nullable shape binding logic in
// ArgBinder::BindDLTensors, we no longer need to force one carrier to be
// non-NULL. The binder will:
// 1. Assert that at least one carrier is non-NULL at runtime
// 2. Use cascaded if_then_else to read from the first non-NULL carrier
// So we can allow all carriers to be nullable.
// if (!has_used_carrier && !carriers.empty()) {
// used_param_buffers.insert(carriers.front());
// }
}
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
......@@ -508,14 +513,14 @@ PrimFunc MakePackedAPI(PrimFunc func) {
binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
}
binder.BindDLTensors(buffer_def, device_type, device_id, name_hint,
used_param_buffers);
for (const auto &[var, buffer] : buffer_def) {
// Prefer buffer data var name in diagnostics to avoid exposing low-level
// handle vars
std::string display = name_hint + "." + buffer->data->name_hint;
binder.BindDLTensor(buffer, device_type, device_id, var, display,
used_param_buffers.count(var.get()));
arg_buffer_declarations.push_back(DeclBuffer(buffer, nop));
}
// reset global symbol to attach prefix
func = WithAttrs(
std::move(func),
......@@ -614,4 +619,4 @@ TVM_FFI_STATIC_INIT_BLOCK() {
}
} // namespace tl
} // namespace tvm
} // namespace tvm
\ No newline at end of file
import torch
import tilelang
import tilelang.testing
from tilelang import language as T
def test_nullable_shared_shape():
"""Test that buffers sharing a shape variable can be nullable."""
@tilelang.jit
def get_kernel():
m = T.dynamic("m")
@T.prim_func
def test_kernel(
a: T.Tensor[(m,), T.int32],
b: T.Tensor[(m,), T.int32],
c: T.Tensor[(m,), T.int32],
):
with T.Kernel(1, threads=64):
tx = T.get_thread_binding()
if tx == 0:
T.print(m)
return test_kernel
m = 200
kernel = get_kernel()
# Create test tensors
tensor_a = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32)
tensor_b = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32)
tensor_c = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32)
print("Test 1: All tensors provided")
kernel(tensor_a, tensor_b, tensor_c)
print("✓ PASS: All tensors provided")
print("\nTest 2: Only first tensor provided")
kernel(tensor_a, None, None)
print("✓ PASS: Only first tensor provided")
print("\nTest 3: Only middle tensor provided")
kernel(None, tensor_b, None)
print("✓ PASS: Only middle tensor provided")
print("\nTest 4: Only last tensor provided")
kernel(None, None, tensor_c)
print("✓ PASS: Only last tensor provided")
print("\nTest 5: First and last tensors provided")
kernel(tensor_a, None, tensor_c)
print("✓ PASS: First and last tensors provided")
print("\nTest 6: All tensors are None (should fail)")
try:
kernel(None, None, None)
print("✗ FAIL: Should have raised an error")
return False
except RuntimeError as e:
if "at least one non-null buffer" in str(e):
print(f"✓ PASS: Correctly rejected with error: {e}")
else:
print(f"✗ FAIL: Wrong error message: {e}")
return False
print("\n" + "=" * 60)
print("All tests passed!")
return True
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment