Commit 6ae2f087 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent aeb02070
......@@ -12,20 +12,21 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template<class T>
__device__ void reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
template <class T>
__device__ void
reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
{
auto stride = (item_num + 1) / 2;
while (true)
while(true)
{
if(thr_idx + stride < item_num)
{
data_ptr[thr_idx] = ::max(to_hip_type(data_ptr[thr_idx]),
to_hip_type(data_ptr[thr_idx + stride]));
data_ptr[thr_idx] =
::max(to_hip_type(data_ptr[thr_idx]), to_hip_type(data_ptr[thr_idx + stride]));
}
__syncthreads();
item_num = stride;
stride = (stride + 1) / 2;
item_num = stride;
stride = (stride + 1) / 2;
if(item_num == 1)
break;
......@@ -33,27 +34,27 @@ __device__ void reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size
if(thr_idx == 0)
{
data_ptr[block_size] = (data_ptr[0] < data_ptr[block_size])
? data_ptr[block_size]
: data_ptr[0];
data_ptr[block_size] =
(data_ptr[0] < data_ptr[block_size]) ? data_ptr[block_size] : data_ptr[0];
}
__syncthreads();
}
template<class T>
__device__ void reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
template <class T>
__device__ void
reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
{
auto stride = (item_num + 1) / 2;
while (true)
while(true)
{
if(thr_idx + stride < item_num)
{
data_ptr[thr_idx] += data_ptr[thr_idx + stride];
}
__syncthreads();
item_num = stride;
stride = (stride + 1) / 2;
item_num = stride;
stride = (stride + 1) / 2;
if(item_num == 1)
break;
......@@ -67,7 +68,6 @@ __device__ void reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size
__syncthreads();
}
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{
auto lens = result.get_shape().lens();
......@@ -117,7 +117,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads();
auto size = (item_num > block_size) ? block_size : item_num;
auto size = (item_num > block_size) ? block_size : item_num;
reduce_max<type>(lds_data, block_size, thr_idx, size);
__syncthreads();
......@@ -138,7 +138,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads();
auto size = (item_num > block_size) ? block_size : item_num;
auto size = (item_num > block_size) ? block_size : item_num;
reduce_sum<type>(lds_data, block_size, thr_idx, size);
__syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment