"...composable_kernel.git" did not exist on "9f453d424e4263fef04e60b525cd3acc719b7121"
Unverified Commit 7d389a43 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Correctly construct the argument list for atomic add based on the vector size (#1137)

* atomic_fix

* atomic_fix
parent 853f9c3d
...@@ -231,21 +231,25 @@ private: ...@@ -231,21 +231,25 @@ private:
// Ref: src/tl_templates/cuda/atomic.h::AtomicAdd // Ref: src/tl_templates/cuda/atomic.h::AtomicAdd
const IntImm memory_order = const IntImm memory_order =
node->args.size() >= 3 ? Downcast<IntImm>(node->args[2]) : IntImm(0); node->args.size() >= 3 ? Downcast<IntImm>(node->args[2]) : IntImm(0);
Array<PrimExpr> new_args;
Call address_of_dst = Call address_of_dst =
Call(DataType::Handle(), builtin::address_of(), {dst_node}); Call(DataType::Handle(), builtin::address_of(), {dst_node});
Call address_of_value = Call address_of_value =
Call(DataType::Handle(), builtin::address_of(), {value_node}); Call(DataType::Handle(), builtin::address_of(), {value_node});
Array<PrimExpr> new_args;
if (vector_size_ == 4) { if (vector_size_ == 4) {
new_args.push_back(StringImm("AtomicAddx4")); new_args.push_back(StringImm("AtomicAddx4"));
new_args.push_back(address_of_dst);
new_args.push_back(address_of_value);
} else if (vector_size_ == 2) { } else if (vector_size_ == 2) {
new_args.push_back(StringImm("AtomicAddx2")); new_args.push_back(StringImm("AtomicAddx2"));
new_args.push_back(address_of_dst);
new_args.push_back(address_of_value);
} else { } else {
new_args.push_back(StringImm("AtomicAdd")); new_args.push_back(StringImm("AtomicAdd"));
new_args.push_back(dst_node);
new_args.push_back(value_node);
} }
new_args.push_back(address_of_dst);
new_args.push_back(address_of_value);
new_args.push_back(memory_order); new_args.push_back(memory_order);
Call new_call = Call new_call =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment