Commit 9fc184e5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

generalize the reduce operation implementation.

parent 42b24bd1
......@@ -13,17 +13,17 @@ namespace device {
template <class T>
inline __device__ void reduce_max(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
{
auto stride = (item_num + 1) / 2;
while(true)
{
if(thr_idx + stride < item_num)
auto stride = (item_num + 1) / 2;
auto size = item_num / 2;
for (size_t i = thr_idx; i < size; i += block_size)
{
data_ptr[thr_idx] =
::max(to_hip_type(data_ptr[thr_idx]), to_hip_type(data_ptr[thr_idx + stride]));
data_ptr[i] =
::max(to_hip_type(data_ptr[i]), to_hip_type(data_ptr[i + stride]));
}
__syncthreads();
item_num = stride;
stride = (stride + 1) / 2;
if(item_num == 1)
break;
......@@ -41,16 +41,16 @@ inline __device__ void reduce_max(T* data_ptr, size_t block_size, size_t thr_idx
template <class T>
inline __device__ void reduce_sum(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
{
auto stride = (item_num + 1) / 2;
while(true)
{
if(thr_idx + stride < item_num)
auto stride = (item_num + 1) / 2;
auto size = item_num / 2;
for (size_t i = thr_idx; i < size; i += block_size)
{
data_ptr[thr_idx] += data_ptr[thr_idx + stride];
data_ptr[i] += data_ptr[i + stride];
}
__syncthreads();
item_num = stride;
stride = (stride + 1) / 2;
if(item_num == 1)
break;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment