Commit 8817e238 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

further code cleanup

parent 7da35f54
...@@ -54,9 +54,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -54,9 +54,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num, max_block_size); block_reduce<type, max_op<type>>(
reduce_max(lds_data, block_size, thr_idx, item_num, max_block_size); lds_data, max_op<type>{}, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
...@@ -77,7 +76,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -77,7 +76,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num, max_block_size); block_reduce<type, sum_op<type>>(
lds_data, sum_op<type>{}, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
......
...@@ -62,68 +62,6 @@ inline __device__ void block_reduce(T* data_ptr, ...@@ -62,68 +62,6 @@ inline __device__ void block_reduce(T* data_ptr,
__syncthreads(); __syncthreads();
} }
template <class T>
inline __device__ void reduce_max(T* data_ptr,
std::size_t block_size,
std::size_t thr_idx,
std::size_t item_num,
std::size_t max_index)
{
while(true)
{
auto stride = (item_num + 1) / 2;
auto size = item_num / 2;
for(std::size_t i = thr_idx; i < size; i += block_size)
{
data_ptr[i] = ::max(to_hip_type(data_ptr[i]), to_hip_type(data_ptr[i + stride]));
}
__syncthreads();
item_num = stride;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
data_ptr[max_index] =
(data_ptr[0] < data_ptr[max_index]) ? data_ptr[max_index] : data_ptr[0];
}
__syncthreads();
}
template <class T>
inline __device__ void reduce_min(T* data_ptr,
std::size_t block_size,
std::size_t thr_idx,
std::size_t item_num,
std::size_t min_index)
{
while(true)
{
auto stride = (item_num + 1) / 2;
auto size = item_num / 2;
for(std::size_t i = thr_idx; i < size; i += block_size)
{
data_ptr[i] = ::min(to_hip_type(data_ptr[i]), to_hip_type(data_ptr[i + stride]));
}
__syncthreads();
item_num = stride;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
data_ptr[min_index] =
(data_ptr[0] > data_ptr[min_index]) ? data_ptr[min_index] : data_ptr[0];
}
__syncthreads();
}
template <class T> template <class T>
inline __device__ void reduce_argmax(T* data_ptr, inline __device__ void reduce_argmax(T* data_ptr,
int64_t* index_ptr, int64_t* index_ptr,
...@@ -202,36 +140,6 @@ inline __device__ void reduce_argmin(T* data_ptr, ...@@ -202,36 +140,6 @@ inline __device__ void reduce_argmin(T* data_ptr,
__syncthreads(); __syncthreads();
} }
template <class T>
inline __device__ void reduce_sum(T* data_ptr,
std::size_t block_size,
std::size_t thr_idx,
std::size_t item_num,
std::size_t sum_index)
{
while(true)
{
auto stride = (item_num + 1) / 2;
auto size = item_num / 2;
for(std::size_t i = thr_idx; i < size; i += block_size)
{
data_ptr[i] += data_ptr[i + stride];
}
__syncthreads();
item_num = stride;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
data_ptr[sum_index] += data_ptr[0];
}
__syncthreads();
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment