Commit e279b402 authored by rocking's avatar rocking
Browse files

Add more operations

parent b3bc666b
...@@ -25,6 +25,7 @@ __global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc, ...@@ -25,6 +25,7 @@ __global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc,
in_grid_1d_desc, p_in_global, p_indices_global, p_out_global, elementwise_op); in_grid_1d_desc, p_in_global, p_indices_global, p_out_global, elementwise_op);
} }
// output[indices] = input
template <typename InGrid1dDesc, template <typename InGrid1dDesc,
typename InDataType, typename InDataType,
typename IndexDataType, typename IndexDataType,
...@@ -116,10 +117,6 @@ struct GridwisePutElement_1D ...@@ -116,10 +117,6 @@ struct GridwisePutElement_1D
static_for<0, InVectorSize, 1>{}([&](auto iM) { static_for<0, InVectorSize, 1>{}([&](auto iM) {
if(indices_thread_buf[iM] >= 0) if(indices_thread_buf[iM] >= 0)
{ {
// TODO - Support other operations
static_assert(MemOp == InMemoryDataOperationEnum::Set ||
MemOp == InMemoryDataOperationEnum::AtomicAdd);
if constexpr(MemOp == InMemoryDataOperationEnum::Set) if constexpr(MemOp == InMemoryDataOperationEnum::Set)
{ {
// User should guarantee each index in p_indices_global is different // User should guarantee each index in p_indices_global is different
...@@ -131,9 +128,23 @@ struct GridwisePutElement_1D ...@@ -131,9 +128,23 @@ struct GridwisePutElement_1D
atomic_add<OutDataType>(p_out_global + indices_thread_buf[iM], atomic_add<OutDataType>(p_out_global + indices_thread_buf[iM],
ck::type_convert<OutDataType>(in_thread_buf[iM])); ck::type_convert<OutDataType>(in_thread_buf[iM]));
} }
else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicMax)
{
atomic_max<OutDataType>(p_out_global + indices_thread_buf[iM],
ck::type_convert<OutDataType>(in_thread_buf[iM]));
}
else if constexpr(MemOp == InMemoryDataOperationEnum::Add)
{
// User should guarantee each index in p_indices_global is different
*(p_out_global + indices_thread_buf[iM]) +=
ck::type_convert<OutDataType>(in_thread_buf[iM]);
}
else else
{ {
// TODO static_assert(MemOp == InMemoryDataOperationEnum::Set ||
MemOp == InMemoryDataOperationEnum::AtomicAdd ||
MemOp == InMemoryDataOperationEnum::AtomicMax ||
MemOp == InMemoryDataOperationEnum::Add);
} }
} }
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment