Commit e279b402 authored by rocking's avatar rocking
Browse files

Add more operations

parent b3bc666b
......@@ -25,6 +25,7 @@ __global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc,
in_grid_1d_desc, p_in_global, p_indices_global, p_out_global, elementwise_op);
}
// output[indices] = input
template <typename InGrid1dDesc,
typename InDataType,
typename IndexDataType,
......@@ -116,10 +117,6 @@ struct GridwisePutElement_1D
static_for<0, InVectorSize, 1>{}([&](auto iM) {
if(indices_thread_buf[iM] >= 0)
{
// TODO - Support other operations
static_assert(MemOp == InMemoryDataOperationEnum::Set ||
MemOp == InMemoryDataOperationEnum::AtomicAdd);
if constexpr(MemOp == InMemoryDataOperationEnum::Set)
{
// User should guarantee each index in p_indices_global is different
......@@ -131,9 +128,23 @@ struct GridwisePutElement_1D
atomic_add<OutDataType>(p_out_global + indices_thread_buf[iM],
ck::type_convert<OutDataType>(in_thread_buf[iM]));
}
else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicMax)
{
atomic_max<OutDataType>(p_out_global + indices_thread_buf[iM],
ck::type_convert<OutDataType>(in_thread_buf[iM]));
}
else if constexpr(MemOp == InMemoryDataOperationEnum::Add)
{
// User should guarantee each index in p_indices_global is different
*(p_out_global + indices_thread_buf[iM]) +=
ck::type_convert<OutDataType>(in_thread_buf[iM]);
}
else
{
// TODO
static_assert(MemOp == InMemoryDataOperationEnum::Set ||
MemOp == InMemoryDataOperationEnum::AtomicAdd ||
MemOp == InMemoryDataOperationEnum::AtomicMax ||
MemOp == InMemoryDataOperationEnum::Add);
}
}
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment