"...composable_kernel.git" did not exist on "d807d05e3a96acbc3dd7134cdab213e9f8168338"
Unverified Commit 3546e2ee authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Atomic] Use ptr for atomicAdd dst instead of reference (#1425)

* [Enhancement] Update AtomicAdd function signature to accept pointer to destination

* Modified AtomicAdd in CUDA to take a pointer instead of a reference for the destination argument.
* Updated related code in atomicadd_vectorize.cc to ensure compatibility with the new signature.
* Adjusted Python interface in atomic.py to pass the destination by pointer, aligning with device function requirements.

* [Enhancement] Refactor AtomicAddRet function signature to accept pointer

* Updated AtomicAddRet in both CUDA and HIP to take a pointer instead of a reference for the address argument, improving consistency with the AtomicAdd function.
* Adjusted the implementation to ensure proper reinterpretation of the address type for atomic operations.

* lint fix

* [Enhancement] Refactor AtomicAddNode::MakeSIMTLoop to use destination pointer

* Updated the MakeSIMTLoop function to build a pointer to the destination element using tvm_access_ptr instead of loading the destination value directly.
* Simplified the handling of source and destination predicates, improving clarity and maintainability of the code.
* Ensured compatibility with the new pointer-based approach for atomic operations.

* lint fix

* test fix

* lint fix
parent 29051439
......@@ -267,22 +267,22 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);
Array<PrimExpr> new_args;
// Optional bounds predicates for src and dst
PrimExpr src_predicate = MakePredicate(analyzer, loop_vars, src->shape, 0);
PrimExpr dst_predicate = MakePredicate(analyzer, loop_vars, dst->shape, 1);
Array<PrimExpr> new_args;
// Load source value and cast to dst dtype if needed
PrimExpr src_value = BufferLoad(src, src_indices);
if (src->dtype != dst->dtype)
src_value = Cast(dst->dtype, src_value);
if (src_predicate.defined())
src_value = if_then_else(src_predicate, src_value, make_zero(dst->dtype));
PrimExpr dst_value = BufferLoad(dst, dst_indices);
if (dst_predicate.defined())
dst_value = if_then_else(dst_predicate, dst_value, make_zero(dst->dtype));
// Build a pointer to destination element using tvm_access_ptr
PrimExpr dst_ptr = Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(dst, dst_indices)});
new_args.push_back(dst_value);
new_args.push_back(dst_ptr);
new_args.push_back(src_value);
new_args.push_back(memory_order);
......
......@@ -171,10 +171,9 @@ TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val,
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 890))
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
TL_DEVICE void AtomicAdd(T1 *address, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
if (memory_order == int(cuda::memory_order_relaxed)) {
......@@ -242,19 +241,18 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
}
#else
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
TL_DEVICE void AtomicAdd(T1 *address, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
(void)memory_order;
atomicAdd(reinterpret_cast<NT1 *>(&ref), cuda_cast<NT1>(val));
atomicAdd(reinterpret_cast<NT1 *>(address), cuda_cast<NT1>(val));
}
#endif
template <typename T1, typename T2>
TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
TL_DEVICE T1 AtomicAddRet(T1 *address, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
if (memory_order == int(cuda::memory_order_relaxed)) {
......
......@@ -116,6 +116,7 @@ TL_DEVICE void AtomicAdd(T1 address, T2 val) {
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
}
template <typename T1, typename T2> TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val) {
return atomicAdd(&ref, static_cast<T1>(val));
template <typename T1, typename T2>
TL_DEVICE T1 AtomicAddRet(T1 *address, T2 val) {
return atomicAdd(reinterpret_cast<T1 *>(address), static_cast<T1>(val));
}
......@@ -246,8 +246,9 @@ private:
new_args.push_back(address_of_dst);
new_args.push_back(address_of_value);
} else {
// Scalar case: AtomicAdd now expects a pointer to destination.
new_args.push_back(StringImm("AtomicAdd"));
new_args.push_back(dst_node);
new_args.push_back(address_of_dst);
new_args.push_back(value_node);
}
new_args.push_back(memory_order);
......@@ -259,8 +260,28 @@ private:
} else {
Array<PrimExpr> new_args;
new_args.push_back(StringImm("AtomicAdd"));
for (auto x : node->args)
new_args.push_back(x);
// Ensure first argument is an address; keep value as-is.
if (!node->args.empty()) {
if (const auto *bl = node->args[0].as<BufferLoadNode>()) {
Call address_of_dst = Call(DataType::Handle(), builtin::address_of(),
{Downcast<BufferLoad>(node->args[0])});
new_args.push_back(address_of_dst);
} else if (const auto *call = node->args[0].as<CallNode>()) {
// If it's already an address_of, forward it; otherwise, keep
// original.
if (call->op.same_as(builtin::address_of())) {
new_args.push_back(node->args[0]);
} else {
new_args.push_back(node->args[0]);
}
} else {
new_args.push_back(node->args[0]);
}
// Push remaining args unchanged (value, optional memory_order, ...)
for (size_t i = 1; i < node->args.size(); ++i) {
new_args.push_back(node->args[i]);
}
}
Call new_call =
tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);
......
......@@ -101,7 +101,7 @@ def vectorize_access_with_atmoic_add_legalize(M: int = 64, N: int = 64, M_offset
# Nest if-then-else is expected, do not flatten it to pass structural equal check
if j + N_offset < N: # noqa: SIM102
if tid + M_offset < M:
T.call_extern("handle", "AtomicAdd", A[tid + M_offset, j + N_offset], 1)
T.call_extern("handle", "AtomicAdd", T.address_of(A[tid + M_offset, j + N_offset]), 1)
return main, expected
......
......@@ -179,10 +179,17 @@ def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None, re
func_name = "AtomicAddRet" if return_prev else "AtomicAdd"
return_type = dst.dtype if return_prev else "handle"
# Pass destination by pointer to match device signature
if memory_order is None:
return T.call_extern(return_type, func_name, dst, value)
return T.call_extern(return_type, func_name, T.address_of(dst), value)
else:
return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order])
return T.call_extern(
return_type,
func_name,
T.address_of(dst),
value,
_MEMORY_ORDER_ID_MAP[memory_order],
)
if isinstance(dst, Buffer) and isinstance(value, Buffer):
ir.assert_structural_equal(dst.shape, value.shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment