Commit aeb02070 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

simplify the code for softmax gpu implementation.

parent ee877777
......@@ -12,6 +12,62 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template<class T>
__device__ void reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
{
auto stride = (item_num + 1) / 2;
while (true)
{
if(thr_idx + stride < item_num)
{
data_ptr[thr_idx] = ::max(to_hip_type(data_ptr[thr_idx]),
to_hip_type(data_ptr[thr_idx + stride]));
}
__syncthreads();
item_num = stride;
stride = (stride + 1) / 2;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
data_ptr[block_size] = (data_ptr[0] < data_ptr[block_size])
? data_ptr[block_size]
: data_ptr[0];
}
__syncthreads();
}
template<class T>
__device__ void reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
{
auto stride = (item_num + 1) / 2;
while (true)
{
if(thr_idx + stride < item_num)
{
data_ptr[thr_idx] += data_ptr[thr_idx + stride];
}
__syncthreads();
item_num = stride;
stride = (stride + 1) / 2;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
data_ptr[block_size + 1] += data_ptr[0];
}
__syncthreads();
}
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{
auto lens = result.get_shape().lens();
......@@ -62,28 +118,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads();
auto size = (item_num > block_size) ? block_size : item_num;
auto stride = (size + 1) / 2;
while(true)
{
if(thr_idx + stride < size)
{
lds_data[thr_idx] = ::max(to_hip_type(lds_data[thr_idx]),
to_hip_type(lds_data[thr_idx + stride]));
}
__syncthreads();
size = stride;
stride = (stride + 1) / 2;
if(size == 1)
break;
}
reduce_max<type>(lds_data, block_size, thr_idx, size);
if(thr_idx == 0)
{
lds_data[block_size] = (lds_data[0] < lds_data[block_size])
? lds_data[block_size]
: lds_data[0];
}
__syncthreads();
item_num -= block_size;
......@@ -103,24 +139,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads();
auto size = (item_num > block_size) ? block_size : item_num;
auto stride = (size + 1) / 2;
while(true)
{
if(thr_idx + stride < size)
{
lds_data[thr_idx] += lds_data[thr_idx + stride];
}
__syncthreads();
size = stride;
stride = (stride + 1) / 2;
if(size == 1)
break;
}
if(thr_idx == 0)
{
lds_data[block_size + 1] += lds_data[0];
}
reduce_sum<type>(lds_data, block_size, thr_idx, size);
__syncthreads();
item_num -= block_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment