"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "6dfdc546587ef7fa500fa98f82475dc6b45c5a94"
Unverified Commit f6db2014 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[ArgBinder] Enhance shape variable handling and assertions (#1467)

* feat(arg_binder): enhance shape variable handling and assertions

- Implemented special handling for comparing if_then_else expressions to simplify conditions involving NULL checks.
- Added methods to set shared shape variables and finalize deferred bindings, generating cascading if_then_else expressions and runtime assertions for non-NULL buffers.
- Updated the binding logic to defer shape variable bindings for shared variables, ensuring proper handling across multiple nullable buffers.

* refactor(arg_binder): clean up shape variable handling and remove unused code

- Removed deprecated methods for setting shared shape variables and finalizing deferred bindings, streamlining the argument binding process.
- Simplified the logic for handling shape values in the `BindDLTensor` function, ensuring immediate binding for normal shape variables.
- Enhanced clarity by eliminating unnecessary comments and code related to cascading if_then_else expressions for shared variables.

* refactor(arg_binder): enhance DLTensor binding with improved shape handling

- Replaced the single `BindDLTensor` method with `BindDLTensors` to support multiple buffers, improving flexibility in handling DLTensor bindings.
- Introduced a two-pass approach for shape variable handling, allowing for better management of symbolic dimensions and null checks.
- Updated the logic to assert non-null conditions at runtime and utilize cascaded if_then_else expressions for shape retrieval, enhancing robustness.
- Removed deprecated code and streamlined the binding process for clarity and maintainability.

* fix(test_nullable_buffer_params): improve formatting and consistency in test output

- Updated string formatting for better readability in the `test_nullable_shared_shape` function.
- Ensured consistent use of double quotes for string literals.
- Added a missing newline at the end of the file for proper formatting.

* refactor(arg_binder): simplify allocation size calculation in BindDLTensors

- Streamlined the calculation of allocation size by replacing a lambda function with a direct loop, enhancing readability and maintainability.
- Improved clarity in the null check message for data pointers, ensuring better understanding of the binding process.

* Remove debug prints from phase.py

Removed debug print statements after MakePackedAPI transformation.
parent f0672603
This diff is collapsed.
...@@ -95,17 +95,21 @@ public: ...@@ -95,17 +95,21 @@ public:
*/ */
void BindBuffer(const Buffer &arg, const Buffer &value, void BindBuffer(const Buffer &arg, const Buffer &value,
const std::string &arg_name, bool fuzzy_match); const std::string &arg_name, bool fuzzy_match);
/*! /*!
* \brief Bind symbolic buffer to a DLTensor handle. * \brief Bind symbolic buffer to a DLTensor handle.
* \param buffer The argument buffer to be binded. * \param buffer The argument buffer to be binded.
* \param device_type The device id to be binded. * \param device_type The device type to be binded.
* \param device_id The device id to be binded. * \param device_id The device id to be binded.
* \param handle The DLTensor handle. * \param buffer_def The buffer definition.
* \param arg_name argument name. * \param func_name The function name.
* \param used_param_buffers The used param buffers.
*/ */
void BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, void
const PrimExpr &device_id, const Var &handle, BindDLTensors(const std::vector<std::pair<Var, Buffer>> &buffer_def,
const std::string &arg_name, bool is_used); const PrimExpr &device_type, const PrimExpr &device_id,
const std::string &func_name,
const std::unordered_set<const VarNode *> &used_param_buffers);
/*! \return The defs generated in binding. */ /*! \return The defs generated in binding. */
const std::vector<Var> &defs() const { return defs_; } const std::vector<Var> &defs() const { return defs_; }
......
...@@ -393,10 +393,15 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -393,10 +393,15 @@ PrimFunc MakePackedAPI(PrimFunc func) {
break; break;
} }
} }
if (!has_used_carrier && !carriers.empty()) { // NOTE: With the new nullable shape binding logic in
// Choose the first carrier to anchor this symbol. // ArgBinder::BindDLTensors, we no longer need to force one carrier to be
used_param_buffers.insert(carriers.front()); // non-NULL. The binder will:
} // 1. Assert that at least one carrier is non-NULL at runtime
// 2. Use cascaded if_then_else to read from the first non-NULL carrier
// So we can allow all carriers to be nullable.
// if (!has_used_carrier && !carriers.empty()) {
// used_param_buffers.insert(carriers.front());
// }
} }
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) { for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
...@@ -508,14 +513,14 @@ PrimFunc MakePackedAPI(PrimFunc func) { ...@@ -508,14 +513,14 @@ PrimFunc MakePackedAPI(PrimFunc func) {
binder.Bind(param, expr, name_hint + "." + param->name_hint, true); binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
} }
binder.BindDLTensors(buffer_def, device_type, device_id, name_hint,
used_param_buffers);
for (const auto &[var, buffer] : buffer_def) { for (const auto &[var, buffer] : buffer_def) {
// Prefer buffer data var name in diagnostics to avoid exposing low-level // Prefer buffer data var name in diagnostics to avoid exposing low-level
// handle vars // handle vars
std::string display = name_hint + "." + buffer->data->name_hint;
binder.BindDLTensor(buffer, device_type, device_id, var, display,
used_param_buffers.count(var.get()));
arg_buffer_declarations.push_back(DeclBuffer(buffer, nop)); arg_buffer_declarations.push_back(DeclBuffer(buffer, nop));
} }
// reset global symbol to attach prefix // reset global symbol to attach prefix
func = WithAttrs( func = WithAttrs(
std::move(func), std::move(func),
......
import torch
import tilelang
import tilelang.testing
from tilelang import language as T
def test_nullable_shared_shape():
"""Test that buffers sharing a shape variable can be nullable."""
@tilelang.jit
def get_kernel():
m = T.dynamic("m")
@T.prim_func
def test_kernel(
a: T.Tensor[(m,), T.int32],
b: T.Tensor[(m,), T.int32],
c: T.Tensor[(m,), T.int32],
):
with T.Kernel(1, threads=64):
tx = T.get_thread_binding()
if tx == 0:
T.print(m)
return test_kernel
m = 200
kernel = get_kernel()
# Create test tensors
tensor_a = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32)
tensor_b = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32)
tensor_c = torch.randn((m,), device="cuda", dtype=torch.float32).to(torch.int32)
print("Test 1: All tensors provided")
kernel(tensor_a, tensor_b, tensor_c)
print("✓ PASS: All tensors provided")
print("\nTest 2: Only first tensor provided")
kernel(tensor_a, None, None)
print("✓ PASS: Only first tensor provided")
print("\nTest 3: Only middle tensor provided")
kernel(None, tensor_b, None)
print("✓ PASS: Only middle tensor provided")
print("\nTest 4: Only last tensor provided")
kernel(None, None, tensor_c)
print("✓ PASS: Only last tensor provided")
print("\nTest 5: First and last tensors provided")
kernel(tensor_a, None, tensor_c)
print("✓ PASS: First and last tensors provided")
print("\nTest 6: All tensors are None (should fail)")
try:
kernel(None, None, None)
print("✗ FAIL: Should have raised an error")
return False
except RuntimeError as e:
if "at least one non-null buffer" in str(e):
print(f"✓ PASS: Correctly rejected with error: {e}")
else:
print(f"✗ FAIL: Wrong error message: {e}")
return False
print("\n" + "=" * 60)
print("All tests passed!")
return True
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment