Commit 6ae2f087 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent aeb02070
...@@ -12,20 +12,21 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,20 +12,21 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template<class T> template <class T>
__device__ void reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num) __device__ void
reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
{ {
auto stride = (item_num + 1) / 2; auto stride = (item_num + 1) / 2;
while (true) while(true)
{ {
if(thr_idx + stride < item_num) if(thr_idx + stride < item_num)
{ {
data_ptr[thr_idx] = ::max(to_hip_type(data_ptr[thr_idx]), data_ptr[thr_idx] =
to_hip_type(data_ptr[thr_idx + stride])); ::max(to_hip_type(data_ptr[thr_idx]), to_hip_type(data_ptr[thr_idx + stride]));
} }
__syncthreads(); __syncthreads();
item_num = stride; item_num = stride;
stride = (stride + 1) / 2; stride = (stride + 1) / 2;
if(item_num == 1) if(item_num == 1)
break; break;
...@@ -33,27 +34,27 @@ __device__ void reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size ...@@ -33,27 +34,27 @@ __device__ void reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size
if(thr_idx == 0) if(thr_idx == 0)
{ {
data_ptr[block_size] = (data_ptr[0] < data_ptr[block_size]) data_ptr[block_size] =
? data_ptr[block_size] (data_ptr[0] < data_ptr[block_size]) ? data_ptr[block_size] : data_ptr[0];
: data_ptr[0];
} }
__syncthreads(); __syncthreads();
} }
template<class T> template <class T>
__device__ void reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num) __device__ void
reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
{ {
auto stride = (item_num + 1) / 2; auto stride = (item_num + 1) / 2;
while (true) while(true)
{ {
if(thr_idx + stride < item_num) if(thr_idx + stride < item_num)
{ {
data_ptr[thr_idx] += data_ptr[thr_idx + stride]; data_ptr[thr_idx] += data_ptr[thr_idx + stride];
} }
__syncthreads(); __syncthreads();
item_num = stride; item_num = stride;
stride = (stride + 1) / 2; stride = (stride + 1) / 2;
if(item_num == 1) if(item_num == 1)
break; break;
...@@ -67,7 +68,6 @@ __device__ void reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size ...@@ -67,7 +68,6 @@ __device__ void reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size
__syncthreads(); __syncthreads();
} }
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis) void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
auto lens = result.get_shape().lens(); auto lens = result.get_shape().lens();
...@@ -117,7 +117,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -117,7 +117,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads(); __syncthreads();
auto size = (item_num > block_size) ? block_size : item_num; auto size = (item_num > block_size) ? block_size : item_num;
reduce_max<type>(lds_data, block_size, thr_idx, size); reduce_max<type>(lds_data, block_size, thr_idx, size);
__syncthreads(); __syncthreads();
...@@ -138,7 +138,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -138,7 +138,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads(); __syncthreads();
auto size = (item_num > block_size) ? block_size : item_num; auto size = (item_num > block_size) ? block_size : item_num;
reduce_sum<type>(lds_data, block_size, thr_idx, size); reduce_sum<type>(lds_data, block_size, thr_idx, size);
__syncthreads(); __syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment