Unverified Commit 1d4b7180 authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[BugFix] Add memory order argument for non-vectorized atomic add (#1081)

* [BugFix] Add memory order argument for non-vectorized atomic add

* [Lint]

* [BugFix] Memory order

* [Lint]

* [BugFix] Argument in cuda template

* [Lint]
parent 792e5d5b
...@@ -58,8 +58,12 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) { ...@@ -58,8 +58,12 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() >= 3) { if (args.size() >= 3) {
node->use_tma = Downcast<IntImm>(args[2]); node->use_tma = Downcast<IntImm>(args[2]);
} }
node->memory_order = IntImm(0);
if (args.size() >= 4) { if (args.size() >= 4) {
node->coalesced_width = Downcast<IntImm>(args[3]); node->memory_order = Downcast<IntImm>(args[3]);
}
if (args.size() >= 5) {
node->coalesced_width = Downcast<IntImm>(args[4]);
} }
data_ = std::move(node); data_ = std::move(node);
} }
...@@ -285,6 +289,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -285,6 +289,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
new_args.push_back(dst_value); new_args.push_back(dst_value);
new_args.push_back(src_value); new_args.push_back(src_value);
new_args.push_back(memory_order);
Call atomicadd_call = Call atomicadd_call =
tvm::tir::Call(dst->dtype, atomicadd_elem_op(), new_args); tvm::tir::Call(dst->dtype, atomicadd_elem_op(), new_args);
......
...@@ -22,6 +22,7 @@ public: ...@@ -22,6 +22,7 @@ public:
dst_range; ///< Access ranges for source and destination dst_range; ///< Access ranges for source and destination
IntImm use_tma; ///< Whether to use TMA for memory operations IntImm use_tma; ///< Whether to use TMA for memory operations
IntImm coalesced_width; ///< Width for memory coalescing optimization IntImm coalesced_width; ///< Width for memory coalescing optimization
IntImm memory_order; ///< Memory order for atomic operations
mutable ParallelOp par_op_; ///< Associated parallel operation mutable ParallelOp par_op_; ///< Associated parallel operation
static constexpr const char *_type_key = "tl.AtomicAdd"; static constexpr const char *_type_key = "tl.AtomicAdd";
...@@ -41,7 +42,8 @@ public: ...@@ -41,7 +42,8 @@ public:
.def_ro("src_range", &AtomicAddNode::src_range) .def_ro("src_range", &AtomicAddNode::src_range)
.def_ro("dst_range", &AtomicAddNode::dst_range) .def_ro("dst_range", &AtomicAddNode::dst_range)
.def_ro("use_tma", &AtomicAddNode::use_tma) .def_ro("use_tma", &AtomicAddNode::use_tma)
.def_ro("coalesced_width", &AtomicAddNode::coalesced_width); .def_ro("coalesced_width", &AtomicAddNode::coalesced_width)
.def_ro("memory_order", &AtomicAddNode::memory_order);
} }
bool SEqualReduce(const AtomicAddNode *other, SEqualReducer equal) const { bool SEqualReduce(const AtomicAddNode *other, SEqualReducer equal) const {
...@@ -49,7 +51,8 @@ public: ...@@ -49,7 +51,8 @@ public:
equal(src_range, other->src_range) && equal(src_range, other->src_range) &&
equal(dst_range, other->dst_range) && equal(dst_range, other->dst_range) &&
equal(use_tma, other->use_tma) && equal(use_tma, other->use_tma) &&
equal(coalesced_width, other->coalesced_width); equal(coalesced_width, other->coalesced_width) &&
equal(memory_order, other->memory_order);
} }
void SHashReduce(SHashReducer hash_reduce) const { void SHashReduce(SHashReducer hash_reduce) const {
...@@ -59,6 +62,7 @@ public: ...@@ -59,6 +62,7 @@ public:
hash_reduce(dst_range); hash_reduce(dst_range);
hash_reduce(use_tma); hash_reduce(use_tma);
hash_reduce(coalesced_width); hash_reduce(coalesced_width);
hash_reduce(memory_order);
} }
static constexpr bool _type_has_method_sequal_reduce = true; static constexpr bool _type_has_method_sequal_reduce = true;
......
...@@ -296,7 +296,7 @@ TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset) ...@@ -296,7 +296,7 @@ TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset)
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op) TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op)
.set_num_inputs(2) .set_num_inputs(3)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
......
...@@ -105,8 +105,9 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val, ...@@ -105,8 +105,9 @@ TL_DEVICE void AtomicAdd(T1 &ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) { int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type; using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref; T1 *address = &ref;
if constexpr (std::is_same_v<NT1, half> || if constexpr ((std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) { std::is_same_v<NT1, __nv_bfloat16>)&&memory_order ==
int(cuda::memory_order_relaxed)) {
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)); atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val));
} else { } else {
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address); cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*address);
...@@ -119,8 +120,9 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val, ...@@ -119,8 +120,9 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) { int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type; using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref; T1 *address = &ref;
if constexpr (std::is_same_v<NT1, half> || if constexpr ((std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) { std::is_same_v<NT1, __nv_bfloat16>)&&memory_order ==
int(cuda::memory_order_relaxed)) {
return static_cast<T1>( return static_cast<T1>(
atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val))); atomicAdd(reinterpret_cast<NT1 *>(address), static_cast<NT1>(val)));
} else { } else {
...@@ -130,24 +132,31 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val, ...@@ -130,24 +132,31 @@ TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val,
} }
} }
TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val) { // TODO add memory_order for vectorized atomic add
TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val,
int memory_order = int(cuda::memory_order_relaxed)) {
atomicAdd(reinterpret_cast<half2 *>(ref), atomicAdd(reinterpret_cast<half2 *>(ref),
static_cast<half2>(*reinterpret_cast<half2 *>(val))); static_cast<half2>(*reinterpret_cast<half2 *>(val)));
} }
TL_DEVICE half2 AtomicAddx2Ret(half_t *ref, half_t *val) { TL_DEVICE half2
AtomicAddx2Ret(half_t *ref, half_t *val,
int memory_order = int(cuda::memory_order_relaxed)) {
return atomicAdd(reinterpret_cast<half2 *>(ref), return atomicAdd(reinterpret_cast<half2 *>(ref),
static_cast<half2>(*reinterpret_cast<half2 *>(val))); static_cast<half2>(*reinterpret_cast<half2 *>(val)));
} }
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750)) #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750))
TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val) { TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val,
int memory_order = int(cuda::memory_order_relaxed)) {
atomicAdd( atomicAdd(
reinterpret_cast<__nv_bfloat162 *>(ref), reinterpret_cast<__nv_bfloat162 *>(ref),
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
} }
TL_DEVICE __nv_bfloat162 AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val) { TL_DEVICE __nv_bfloat162
AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val,
int memory_order = int(cuda::memory_order_relaxed)) {
return atomicAdd( return atomicAdd(
reinterpret_cast<__nv_bfloat162 *>(ref), reinterpret_cast<__nv_bfloat162 *>(ref),
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val))); static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
...@@ -155,22 +164,28 @@ TL_DEVICE __nv_bfloat162 AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val) { ...@@ -155,22 +164,28 @@ TL_DEVICE __nv_bfloat162 AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val) {
#endif #endif
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
TL_DEVICE void AtomicAddx2(float *ref, float *val) { TL_DEVICE void AtomicAddx2(float *ref, float *val,
int memory_order = int(cuda::memory_order_relaxed)) {
atomicAdd(reinterpret_cast<float2 *>(ref), atomicAdd(reinterpret_cast<float2 *>(ref),
static_cast<float2>(*reinterpret_cast<float2 *>(val))); static_cast<float2>(*reinterpret_cast<float2 *>(val)));
} }
TL_DEVICE float2 AtomicAddx2Ret(float *ref, float *val) { TL_DEVICE float2
AtomicAddx2Ret(float *ref, float *val,
int memory_order = int(cuda::memory_order_relaxed)) {
return atomicAdd(reinterpret_cast<float2 *>(ref), return atomicAdd(reinterpret_cast<float2 *>(ref),
static_cast<float2>(*reinterpret_cast<float2 *>(val))); static_cast<float2>(*reinterpret_cast<float2 *>(val)));
} }
TL_DEVICE void AtomicAddx4(float *ref, float *val) { TL_DEVICE void AtomicAddx4(float *ref, float *val,
int memory_order = int(cuda::memory_order_relaxed)) {
atomicAdd(reinterpret_cast<float4 *>(ref), atomicAdd(reinterpret_cast<float4 *>(ref),
static_cast<float4>(*reinterpret_cast<float4 *>(val))); static_cast<float4>(*reinterpret_cast<float4 *>(val)));
} }
TL_DEVICE float4 AtomicAddx4Ret(float *ref, float *val) { TL_DEVICE float4
AtomicAddx4Ret(float *ref, float *val,
int memory_order = int(cuda::memory_order_relaxed)) {
return atomicAdd(reinterpret_cast<float4 *>(ref), return atomicAdd(reinterpret_cast<float4 *>(ref),
static_cast<float4>(*reinterpret_cast<float4 *>(val))); static_cast<float4>(*reinterpret_cast<float4 *>(val)));
} }
......
...@@ -227,6 +227,10 @@ private: ...@@ -227,6 +227,10 @@ private:
if (legal_vectorize) { if (legal_vectorize) {
const BufferLoad dst_node = Downcast<BufferLoad>(node->args[0]); const BufferLoad dst_node = Downcast<BufferLoad>(node->args[0]);
const BufferLoad value_node = Downcast<BufferLoad>(node->args[1]); const BufferLoad value_node = Downcast<BufferLoad>(node->args[1]);
// The default memory order is relaxed
// Ref: src/tl_templates/cuda/atomic.h::AtomicAdd
const IntImm memory_order =
node->args.size() >= 3 ? Downcast<IntImm>(node->args[2]) : IntImm(0);
Call address_of_dst = Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node}); Call(DataType::Handle(), builtin::address_of(), {dst_node});
...@@ -242,6 +246,7 @@ private: ...@@ -242,6 +246,7 @@ private:
} }
new_args.push_back(address_of_dst); new_args.push_back(address_of_dst);
new_args.push_back(address_of_value); new_args.push_back(address_of_value);
new_args.push_back(memory_order);
Call new_call = Call new_call =
tvm::tir::Call(node->dtype, builtin::call_extern(), new_args); tvm::tir::Call(node->dtype, builtin::call_extern(), new_args);
......
...@@ -227,7 +227,11 @@ def atomic_add(dst: Buffer, ...@@ -227,7 +227,11 @@ def atomic_add(dst: Buffer,
raise NotImplementedError( raise NotImplementedError(
"return_prev is not supported for tile-region-based atomic operations") "return_prev is not supported for tile-region-based atomic operations")
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma) if memory_order is None:
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma, 0)
else:
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma,
_MEMORY_ORDER_ID_MAP[memory_order])
def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr: def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment