Unverified Commit f6db2014 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[ArgBinder] Enhance shape variable handling and assertions (#1467)

* feat(arg_binder): enhance shape variable handling and assertions

- Implemented special handling for comparing if_then_else expressions to simplify conditions involving NULL checks.
- Added methods to set shared shape variables and finalize deferred bindings, generating cascading if_then_else expressions and runtime assertions for non-NULL buffers.
- Updated the binding logic to defer shape variable bindings for shared variables, ensuring proper handling across multiple nullable buffers.

* refactor(arg_binder): clean up shape variable handling and remove unused code

- Removed deprecated methods for setting shared shape variables and finalizing deferred bindings, streamlining the argument binding process.
- Simplified the logic for handling shape values in the `BindDLTensor` function, ensuring immediate binding for normal shape variables.
- Enhanced clarity by eliminating unnecessary comments and code related to cascading if_then_else expressions for shared variables.

* refactor(arg_binder): enhance DLTensor binding with improved shape handling

- Replaced the single `BindDLTensor` method with `BindDLTensors` to support multiple buffers, improving flexibility in handling DLTensor bindings.
- Introduced a two-pass approach for shape variable handling, allowing for better management of symbolic dimensions and null checks.
- Updated the logic to assert non-null conditions at runtime and utilize cascaded if_then_else expressions for shape retrieval, enhancing robustness.
- Removed deprecated code and streamlined the binding process for clarity and maintainability.

* fix(test_nullable_buffer_params): improve formatting and consistency in test output

- Updated string formatting for better readability in the `test_nullable_shared_shape` function.
- Ensured consistent use of double quotes for string literals.
- Added a missing newline at the end of the file for proper formatting.

* refactor(arg_binder): simplify allocation size calculation in BindDLTensors

- Streamlined the calculation of allocation size by replacing a lambda function with a direct loop, enhancing readability and maintainability.
- Improved clarity in the null check message for data pointers, ensuring better understanding of the binding process.

* Remove debug prints from phase.py

Removed debug print statements after MakePackedAPI transformation.
parent f0672603
......@@ -311,29 +311,108 @@ inline PrimExpr TVMArrayGet(DataType t, Var arr,
return TVMStructGet(t, arr, 0, kind);
}
void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
const PrimExpr &device_id, const Var &handle,
const std::string &arg_name, bool is_used) {
void ArgBinder::BindDLTensors(
const std::vector<std::pair<Var, Buffer>> &buffer_def,
const PrimExpr &device_type, const PrimExpr &device_id,
const std::string &func_name,
const std::unordered_set<const VarNode *> &used_param_buffers) {
ffi::Array<Buffer> buffers;
ffi::Array<Var> handles;
// First pass: collect shape var -> list of (buffer_name, dim_idx, handle_ptr)
struct ShapeVarSource {
std::string buf_name;
size_t dim_idx;
const VarNode *handle_ptr; // Raw pointer to check used_param_buffers
};
std::unordered_map<const VarNode *, std::vector<ShapeVarSource>>
shape_var_sources;
for (const auto &[handle, buffer] : buffer_def) {
std::string arg_name = func_name + "." + buffer->data->name_hint;
// Scan buffer shape for symbolic variables
for (size_t k = 0; k < buffer->shape.size(); ++k) {
if (buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4) ||
buffer->dtype == DataType::Int(1)) {
break;
}
if (const VarNode *v = buffer->shape[k].as<VarNode>()) {
// This dimension is a symbolic variable
shape_var_sources[v].push_back({arg_name, k, handle.get()});
}
}
}
// Second pass: Create is_null vars and shape buffers for all buffers first
std::unordered_map<std::string, Var> is_null_map;
std::unordered_map<std::string, Buffer> shape_buffer_map;
std::unordered_map<std::string, PrimExpr>
is_null_expr_map; // arg_name -> is_null expression (const_false for used
// buffers)
const DataType tvm_shape_type = DataType::ShapeIndex();
const DataType tvm_ndim_type = DataType::Int(32);
const Stmt nop = Evaluate(0);
// Allow NULL DLTensor* for optional inputs. When the handle is NULL,
// avoid dereferencing it by using expression-level conditionals and
// short-circuiting guards in asserts. Cache the null check in a Let-bound
// boolean so codegen does not repeat `(handle == NULL)` everywhere.
// Create all is_null vars and shape buffers first
for (const auto &[handle, buffer] : buffer_def) {
bool is_used = used_param_buffers.count(handle.get());
std::string arg_name = func_name + "." + buffer->data->name_hint;
Var is_null_var(arg_name + "_is_null", DataType::Bool());
init_nest_.emplace_back(
LetStmt(is_null_var,
Call(DataType::Bool(), builtin::isnullptr(), {handle}), nop));
const PrimExpr &is_null = is_used ? const_false() : is_null_var;
is_null_map[arg_name] = is_null_var;
is_null_expr_map[arg_name] = is_null;
if (is_used) {
init_nest_.emplace_back(AssertStmt(
!is_null_var,
tvm::tir::StringImm(arg_name + " is expected to have non-NULL pointer"),
init_nest_.emplace_back(
AssertStmt(!is_null_var,
tvm::tir::StringImm(
arg_name + " is expected to have non-NULL pointer"),
nop));
}
}
// Create all shape buffers before binding any shapes
for (const auto &[handle, buffer] : buffer_def) {
std::string arg_name = func_name + "." + buffer->data->name_hint;
const PrimExpr &is_null = is_null_expr_map[arg_name];
// Helper functions for shape/stride name formatting
auto shape_handle_name = [&]() { return arg_name + ".shape"; };
// shape field
Buffer buf_shape =
decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())},
tvm_shape_type, shape_handle_name());
def_handle_dtype_.Set(buf_shape->data, make_const(tvm_shape_type, 0));
// Use if_then_else for NULL guard on the shape pointer itself, avoiding
// dereferencing TVMStructGet(handle, kArrShape) when handle is NULL.
init_nest_.emplace_back(
LetStmt(buf_shape->data,
tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape),
make_zero(DataType::Handle())),
nop));
init_nest_.emplace_back(DeclBuffer(buf_shape, nop));
// Save for later use in shape binding
shape_buffer_map[arg_name] = buf_shape;
}
// Now process each buffer fully
for (const auto &[handle, buffer] : buffer_def) {
bool is_used = used_param_buffers.count(handle.get());
std::string arg_name = func_name + "." + buffer->data->name_hint;
const PrimExpr &is_null = is_null_expr_map[arg_name];
// dimension checks
PrimExpr v_ndim = TVMArrayGet(tvm_ndim_type, handle, builtin::kArrNDim);
......@@ -371,8 +450,8 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
ndim_args.push_back(StringImm(buf_nm));
ndim_args.push_back(cast(DataType::Int(64), a_ndim));
ndim_args.push_back(cast(DataType::Int(64), v_ndim));
Stmt ndim_call =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), ndim_args));
Stmt ndim_call = Evaluate(
Call(DataType::Int(32), builtin::tvm_call_packed(), ndim_args));
init_nest_.emplace_back(
SeqStmt({IfThenElse(Not(is_null), IfThenElse(Not(ndim_ok), ndim_call),
Evaluate(0)),
......@@ -401,7 +480,8 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
// Allow float8_e4m3 to match float8_e4m3fn/float8_e4m3fnuz at runtime.
if (buffer->dtype.is_float8_e4m3()) {
PrimExpr code_e4m3 = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3);
PrimExpr code_e4m3fn = IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fn);
PrimExpr code_e4m3fn =
IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fn);
PrimExpr code_e4m3fnuz =
IntImm(DataType::UInt(8), DataType::kFloat8_e4m3fnuz);
PrimExpr code_match =
......@@ -441,7 +521,8 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
// Also accept any dtype whose bitwidth=1, regardless of code, to be
// defensive.
PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok);
cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok;
cond =
cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok;
}
// Allow float4 to match int8 at runtime (PyTorch uses int8 as storage for
// FP4).
......@@ -458,8 +539,8 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
if (!(buffer->dtype == DataType::Int(1) ||
buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4) || buffer->dtype.is_float4())) {
// Build FFI packed call to __tvm_error_dtype_mismatch when mismatch occurs.
// Only issue the call when handle is non-NULL and cond is false.
// Build FFI packed call to __tvm_error_dtype_mismatch when mismatch
// occurs. Only issue the call when handle is non-NULL and cond is false.
ffi::Array<PrimExpr> packed_args;
packed_args.push_back(StringImm(tvm_error_dtype_mismatch));
// Split arg_name of the form "<kernel>.<buffer>" into parts for clearer
......@@ -490,23 +571,10 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
asserts_.emplace_back(SeqStmt({guarded, nop}));
}
// shape field
Buffer buf_shape =
decl_buffer({IntImm(DataType::Int(32), buffer->shape.size())},
tvm_shape_type, shape_handle_name());
Var v_shape(shape_handle_name(), DataType::Handle());
def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
// Use if_then_else for NULL guard on the shape pointer itself, avoiding
// dereferencing TVMStructGet(handle, kArrShape) when handle is NULL.
init_nest_.emplace_back(
LetStmt(buf_shape->data,
tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Handle(), handle, builtin::kArrShape),
make_zero(DataType::Handle())),
nop));
init_nest_.emplace_back(DeclBuffer(buf_shape, nop));
// Get the pre-created shape buffer
Buffer buf_shape = shape_buffer_map[arg_name];
// Bind symbolic variables from buffer shape
for (size_t k = 0; k < buffer->shape.size(); ++k) {
// These packed-bit dtype shapes were not bound in the original
// implementation, so we just use them as is.
......@@ -522,23 +590,124 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
BufferLoad(buf_shape,
{IntImm(DataType::Int(32), static_cast<int>(k))}));
// When first encountering a Var (e.g., m), this will generate:
// Let(m, bound_shape_val, ...)
// Constant dimensions will only generate consistency assertions.
// Check if this dimension is a symbolic variable
if (const VarNode *v = buffer->shape[k].as<VarNode>()) {
auto it = def_map_->find(v);
if (it == def_map_->end()) {
// First time binding this symbolic variable
auto sources_it = shape_var_sources.find(v);
if (sources_it != shape_var_sources.end() &&
sources_it->second.size() > 1) {
// This variable appears in multiple buffers
// Assert that at least one buffer is non-null
PrimExpr any_nonnull = const_false();
for (const auto &src : sources_it->second) {
bool buf_is_used = used_param_buffers.count(src.handle_ptr);
if (buf_is_used) {
any_nonnull = const_true();
break;
}
Var src_is_null = is_null_map[src.buf_name];
any_nonnull = Or(any_nonnull, Not(src_is_null));
}
std::ostringstream err_msg;
err_msg << "Symbolic shape variable "
<< ffi::GetRef<Var>(v)->name_hint
<< " requires at least one non-null buffer among: ";
bool first = true;
for (const auto &src : sources_it->second) {
if (!first)
err_msg << ", ";
err_msg << src.buf_name;
first = false;
}
init_nest_.emplace_back(AssertStmt(
any_nonnull, tvm::tir::StringImm(err_msg.str()), nop));
// Build cascaded if_then_else: if !is_null_a then a.shape[k] else
// if !is_null_b then b.shape[k] ... We need to construct this in
// reverse order
PrimExpr cascaded_value;
bool is_first_source = true;
for (auto rit = sources_it->second.rbegin();
rit != sources_it->second.rend(); ++rit) {
const auto &src = *rit;
// Get the shape buffer for this source
auto it_buf = shape_buffer_map.find(src.buf_name);
if (it_buf == shape_buffer_map.end()) {
LOG(FATAL) << "Shape buffer not found for " << src.buf_name;
}
Buffer src_shape_buf = it_buf->second;
// Construct the shape load
PrimExpr src_shape_val =
cast(buffer->shape[k].dtype(),
BufferLoad(src_shape_buf,
{IntImm(DataType::Int(32),
static_cast<int>(src.dim_idx))}));
// Check if this buffer is used (non-nullable)
bool src_is_used = used_param_buffers.count(src.handle_ptr);
if (is_first_source) {
// Base case: use this shape value directly (we know at least
// one is non-null from assert)
cascaded_value = src_shape_val;
is_first_source = false;
} else {
// if !is_null then use this shape, else use previous cascaded
// value But if buffer is used (non-nullable), always use its
// shape
if (src_is_used) {
cascaded_value = src_shape_val;
} else {
Var src_is_null = is_null_map[src.buf_name];
cascaded_value = tvm::if_then_else(
Not(src_is_null), src_shape_val, cascaded_value);
}
}
}
// Bind the variable to the cascaded expression
Var v_arg = ffi::GetRef<Var>(v);
defs_.emplace_back(v_arg);
(*def_map_)[v] = cascaded_value;
init_nest_.emplace_back(
LetStmt(v_arg, cascaded_value, Evaluate(0)));
} else {
// Single source or no special handling needed, use the original
// nullable binding
BindNullable(buffer->shape[k], shape_val, shape_element_name(k),
true, is_null);
}
} else {
// Variable already bound, add assertion with nullable guard
PrimExpr cond = (it->second == shape_val);
BinderAddAssert(&analyzer_, cond, shape_element_name(k), &asserts_,
is_null);
}
} else {
// Constant dimension, just add assertion
BindNullable(buffer->shape[k], shape_val, shape_element_name(k), true,
is_null);
}
}
// strides field
Buffer buf_strides =
decl_buffer({IntImm(DataType::Int(32), buffer->strides.size())},
tvm_shape_type, arg_name + ".strides");
def_handle_dtype_.Set(buf_strides->data, tir::TypeAnnotation(tvm_shape_type));
def_handle_dtype_.Set(buf_strides->data,
tir::TypeAnnotation(tvm_shape_type));
init_nest_.emplace_back(
LetStmt(buf_strides->data,
tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides),
tvm::if_then_else(Not(is_null),
TVMArrayGet(DataType::Handle(), handle,
builtin::kArrStrides),
make_zero(DataType::Handle())),
nop));
init_nest_.emplace_back(DeclBuffer(buf_strides, nop));
......@@ -552,9 +721,9 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
ffi::Array<PrimExpr> conds;
for (size_t i = buffer->shape.size(); i != 0; --i) {
size_t k = i - 1;
PrimExpr svalue = cast(
stype, BufferLoad(buf_strides,
{IntImm(DataType::Int(32), static_cast<int>(k))}));
PrimExpr svalue =
cast(stype, BufferLoad(buf_strides, {IntImm(DataType::Int(32),
static_cast<int>(k))}));
conds.push_back(buffer->shape[k] == 1 || expect_stride == svalue);
expect_stride = expect_stride * buffer->shape[k];
}
......@@ -563,7 +732,8 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
<< stride_handle_name()
<< ": expected to be compact array, but got non-compact strides";
if (!conds.empty()) {
PrimExpr all_ok = foldl([](PrimExpr a, PrimExpr b,
PrimExpr all_ok =
foldl([](PrimExpr a, PrimExpr b,
Span span) { return logical_and(a, b, span); },
const_true(1), conds);
// Packed generic violation for non-compact strides
......@@ -579,10 +749,11 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
pargs4.push_back(StringImm(kernel_nm3));
pargs4.push_back(StringImm(buf_nm3));
pargs4.push_back(StringImm("strides"));
Stmt call_err4 =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs4));
Stmt call_err4 = Evaluate(
Call(DataType::Int(32), builtin::tvm_call_packed(), pargs4));
// Only check when strides array is present and condition fails
Stmt check = IfThenElse(Not(v_strides_is_null),
Stmt check =
IfThenElse(Not(v_strides_is_null),
IfThenElse(Not(all_ok), call_err4), Evaluate(0));
asserts_.emplace_back(SeqStmt({check, Evaluate(0)}));
}
......@@ -599,8 +770,8 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
PrimExpr stride_val = tvm::if_then_else(
v_strides_is_null, stride_from_shape, explicit_stride);
BindNullable(buffer->strides[k], stride_val, stride_element_name(k), true,
is_null);
BindNullable(buffer->strides[k], stride_val, stride_element_name(k),
true, is_null);
}
} else {
PrimExpr stride_from_shape = 1;
......@@ -610,14 +781,15 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
PrimExpr explicit_stride =
cast(stride_dtype,
BufferLoad(buf_strides, {IntImm(DataType::Int(32), k)}));
PrimExpr shape_stride = cast(
stride_dtype, BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)}));
PrimExpr shape_stride =
cast(stride_dtype,
BufferLoad(buf_shape, {IntImm(DataType::Int(32), k)}));
PrimExpr stride_val = tvm::if_then_else(
v_strides_is_null, stride_from_shape, explicit_stride);
BindNullable(buffer->strides[k], stride_val, stride_element_name(k), true,
is_null);
BindNullable(buffer->strides[k], stride_val, stride_element_name(k),
true, is_null);
}
}
......@@ -625,8 +797,8 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
int data_bytes = GetVectorBytes(buffer->dtype);
if (const auto *const_offset = buffer->elem_offset.as<IntImmNode>()) {
// Constant elem_offset: only need consistency check, no need for additional
// Var binding.
// Constant elem_offset: only need consistency check, no need for
// additional Var binding.
PrimExpr actual_byte_offset = tvm::if_then_else(
Not(is_null),
TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset),
......@@ -650,8 +822,8 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
Not(is_null),
TVMArrayGet(DataType::UInt(64), handle, builtin::kArrByteOffset),
make_const(DataType::UInt(64), 0));
PrimExpr expect_elem_off =
cast(buffer->elem_offset.dtype(),
PrimExpr expect_elem_off = cast(
buffer->elem_offset.dtype(),
(actual_byte_offset / make_const(DataType::UInt(64), data_bytes)));
BindNullable(buffer->elem_offset, expect_elem_off,
......@@ -661,8 +833,8 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
PrimExpr offset = buffer->elem_offset;
PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor);
PrimExpr zero = make_zero(offset.dtype());
BindNullable(offset, truncmod(offset, factor), arg_name + ".elem_offset",
true, is_null);
BindNullable(offset, truncmod(offset, factor),
arg_name + ".elem_offset", true, is_null);
}
}
......@@ -680,8 +852,8 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
// Bind device_id to a safe expression (0 when NULL handle)
BindNullable(device_id, actual_dev_id, arg_name + ".device_id", true,
is_null);
// Check device_type consistency (device_id equality is implicitly ensured by
// binding above)
// Check device_type consistency (device_id equality is implicitly ensured
// by binding above)
{
PrimExpr ok = (device_type == actual_dev_type);
ffi::Array<PrimExpr> pargs2;
......@@ -692,8 +864,9 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
pargs2.push_back(cast(DataType::Int(64), actual_dev_type));
Stmt call_err2 =
Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), pargs2));
asserts_.emplace_back(SeqStmt(
{IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err2), Evaluate(0)),
asserts_.emplace_back(
SeqStmt({IfThenElse(Not(is_null), IfThenElse(Not(ok), call_err2),
Evaluate(0)),
Evaluate(0)}));
}
......@@ -711,12 +884,10 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
// Check if the data pointer is NULL. This check is skipped for
// size-0 arrays and also skipped when handle itself is NULL.
auto alloc_size = [&]() -> PrimExpr {
PrimExpr product = IntImm(buffer->DefaultIndexType(), 1);
for (const auto &dim : buffer->shape)
product *= dim;
return product;
}();
PrimExpr alloc_size = IntImm(buffer->DefaultIndexType(), 1);
for (const auto &dim : buffer->shape) {
alloc_size = alloc_size * dim;
}
// Improve message: kernel/buffer naming for data pointer null check
std::string kernel_nm2 = arg_name;
std::string buf_nm2 = arg_name;
......@@ -750,6 +921,7 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
def_handle_dtype_.Set(vptr, tir::TypeAnnotation(buffer->dtype));
}
}
}
} // namespace tl
......
......@@ -95,17 +95,21 @@ public:
*/
void BindBuffer(const Buffer &arg, const Buffer &value,
const std::string &arg_name, bool fuzzy_match);
/*!
* \brief Bind symbolic buffer to a DLTensor handle.
* \param buffer The argument buffer to be binded.
* \param device_type The device id to be binded.
* \param device_type The device type to be binded.
* \param device_id The device id to be binded.
* \param handle The DLTensor handle.
* \param arg_name argument name.
* \param buffer_def The buffer definition.
* \param func_name The function name.
* \param used_param_buffers The used param buffers.
*/
void BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
const PrimExpr &device_id, const Var &handle,
const std::string &arg_name, bool is_used);
void
BindDLTensors(const std::vector<std::pair<Var, Buffer>> &buffer_def,
const PrimExpr &device_type, const PrimExpr &device_id,
const std::string &func_name,
const std::unordered_set<const VarNode *> &used_param_buffers);
/*! \return The defs generated in binding. */
const std::vector<Var> &defs() const { return defs_; }
......
......@@ -393,10 +393,15 @@ PrimFunc MakePackedAPI(PrimFunc func) {
break;
}
}
if (!has_used_carrier && !carriers.empty()) {
// Choose the first carrier to anchor this symbol.
used_param_buffers.insert(carriers.front());
}
// NOTE: With the new nullable shape binding logic in
// ArgBinder::BindDLTensors, we no longer need to force one carrier to be
// non-NULL. The binder will:
// 1. Assert that at least one carrier is non-NULL at runtime
// 2. Use cascaded if_then_else to read from the first non-NULL carrier
// So we can allow all carriers to be nullable.
// if (!has_used_carrier && !carriers.empty()) {
// used_param_buffers.insert(carriers.front());
// }
}
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
......@@ -508,14 +513,14 @@ PrimFunc MakePackedAPI(PrimFunc func) {
binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
}
binder.BindDLTensors(buffer_def, device_type, device_id, name_hint,
used_param_buffers);
for (const auto &[var, buffer] : buffer_def) {
// Prefer buffer data var name in diagnostics to avoid exposing low-level
// handle vars
std::string display = name_hint + "." + buffer->data->name_hint;
binder.BindDLTensor(buffer, device_type, device_id, var, display,
used_param_buffers.count(var.get()));
arg_buffer_declarations.push_back(DeclBuffer(buffer, nop));
}
// reset global symbol to attach prefix
func = WithAttrs(
std::move(func),
......
import torch
import tilelang
import tilelang.testing
from tilelang import language as T
def test_nullable_shared_shape():
"""Test that buffers sharing a shape variable can be nullable."""
@tilelang.jit
def get_kernel():
m = T.dynamic("m")
@T.prim_func
def test_kernel(
a: T.Tensor[(m,), T.int32],
b: T.Tensor[(m,), T.int32],
c: T.Tensor[(m,), T.int32],
):
with T.Kernel(1, threads=64):
tx = T.get_thread_binding()
if tx == 0:
T.print(m)
return test_kernel
m = 200
kernel = get_kernel()
# Create test tensors
tensor_a = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32)
tensor_b = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32)
tensor_c = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32)
print("Test 1: All tensors provided")
kernel(tensor_a, tensor_b, tensor_c)
print("✓ PASS: All tensors provided")
print("\nTest 2: Only first tensor provided")
kernel(tensor_a, None, None)
print("✓ PASS: Only first tensor provided")
print("\nTest 3: Only middle tensor provided")
kernel(None, tensor_b, None)
print("✓ PASS: Only middle tensor provided")
print("\nTest 4: Only last tensor provided")
kernel(None, None, tensor_c)
print("✓ PASS: Only last tensor provided")
print("\nTest 5: First and last tensors provided")
kernel(tensor_a, None, tensor_c)
print("✓ PASS: First and last tensors provided")
print("\nTest 6: All tensors are None (should fail)")
try:
kernel(None, None, None)
print("✗ FAIL: Should have raised an error")
return False
except RuntimeError as e:
if "at least one non-null buffer" in str(e):
print(f"✓ PASS: Correctly rejected with error: {e}")
else:
print(f"✗ FAIL: Wrong error message: {e}")
return False
print("\n" + "=" * 60)
print("All tests passed!")
return True
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment