Commit ccdacf44 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code refactor.

parent 17a269a4
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/gpu/device/logsoftmax.hpp> #include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/gpu/device/reduce_opers.hpp>
#include <migraphx/gpu/device/tensor.hpp> #include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
...@@ -15,7 +16,7 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, ...@@ -15,7 +16,7 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
{ {
auto lens = result.get_shape().lens(); auto lens = result.get_shape().lens();
auto n_dims = lens[axis]; auto batch_item_num = lens[axis];
auto batch_lens = lens; auto batch_lens = lens;
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
...@@ -28,8 +29,6 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, ...@@ -28,8 +29,6 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
hip_tensor_descriptor<n_dim> desc_data(result.get_shape()); hip_tensor_descriptor<n_dim> desc_data(result.get_shape());
// use one block for items in one batch. // use one block for items in one batch.
// opt 1, load all data to lds then use the same approach as
// the current optimization
const size_t max_block_size = 1024; const size_t max_block_size = 1024;
size_t block_size = 1; size_t block_size = 1;
while(block_size < max_block_size and block_size < n_dim) while(block_size < max_block_size and block_size < n_dim)
...@@ -43,94 +42,55 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, ...@@ -43,94 +42,55 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
size_t blk_idx = idx.group; size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
// all data can be loaded to the lds once, so all operations are
// done in lds
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 2]; MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 2];
auto batch_idx = desc_batch.multi(blk_idx); auto batch_idx = desc_batch.multi(blk_idx);
auto data_idx = batch_idx; auto data_idx = batch_idx;
// load data to lds and compute the batch max // load data to lds and compute the batch max
size_t item_num = n_dims; size_t remaining_item_num = batch_item_num;
size_t thread_num = (n_dims + block_size - 1) / block_size * block_size; size_t thread_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input_ptr[0]; lds_data[block_size] = input_ptr[0];
for(size_t i = thr_idx; i < thread_num; i += block_size) for(size_t i = thr_idx; i < thread_num; i += block_size)
{ {
if(i < n_dims) if(i < batch_item_num)
{ {
data_idx[axis] = i; data_idx[axis] = i;
lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)]; lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)];
} }
__syncthreads(); __syncthreads();
auto size = (item_num > block_size) ? block_size : item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
auto stride = (size + 1) / 2; reduce_max(lds_data, block_size, thr_idx, item_num);
while(true)
{
if(thr_idx + stride < size)
{
lds_data[thr_idx] = ::max(to_hip_type(lds_data[thr_idx]),
to_hip_type(lds_data[thr_idx + stride]));
}
__syncthreads();
size = stride;
stride = (stride + 1) / 2;
if(size == 1) remaining_item_num -= block_size;
break;
} }
if(thr_idx == 0) auto batch_max = lds_data[block_size];
{
lds_data[block_size] = (lds_data[0] < lds_data[block_size])
? lds_data[block_size]
: lds_data[0];
}
__syncthreads(); __syncthreads();
item_num -= block_size; lds_data[block_size] = 0;
} remaining_item_num = batch_item_num;
const size_t block_size1 = block_size + 1;
lds_data[block_size1] = 0;
item_num = n_dims;
for(size_t i = thr_idx; i < thread_num; i += block_size) for(size_t i = thr_idx; i < thread_num; i += block_size)
{ {
if(i < n_dims) if(i < batch_item_num)
{ {
data_idx[axis] = i; data_idx[axis] = i;
lds_data[thr_idx] = lds_data[thr_idx] =
input_ptr[desc_data.linear(data_idx)] - lds_data[block_size]; input_ptr[desc_data.linear(data_idx)] - batch_max;
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx])); lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
} }
__syncthreads(); __syncthreads();
auto size = (item_num > block_size) ? block_size : item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
auto stride = (size + 1) / 2; reduce_sum(lds_data, block_size, thr_idx, item_num);
while(true)
{
if(thr_idx + stride < size)
{
lds_data[thr_idx] += lds_data[thr_idx + stride];
}
__syncthreads();
size = stride;
stride = (stride + 1) / 2;
if(size == 1)
break;
}
if(thr_idx == 0)
{
lds_data[block_size1] += lds_data[0];
}
__syncthreads();
item_num -= block_size; remaining_item_num -= block_size;
} }
auto log_batch_sum = auto log_batch_sum =
::log(to_hip_type(lds_data[block_size1])) + lds_data[block_size]; ::log(to_hip_type(lds_data[block_size])) + batch_max;
for(size_t i = thr_idx; i < n_dims; i += block_size)
for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{ {
data_idx[axis] = i; data_idx[axis] = i;
size_t index = desc_data.linear(data_idx); size_t index = desc_data.linear(data_idx);
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp> #include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/device/reduce_opers.hpp>
#include <migraphx/gpu/device/tensor.hpp> #include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
...@@ -12,60 +13,6 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -12,60 +13,6 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class T>
__device__ void reduce_max(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
{
auto stride = (item_num + 1) / 2;
while(true)
{
if(thr_idx + stride < item_num)
{
data_ptr[thr_idx] =
::max(to_hip_type(data_ptr[thr_idx]), to_hip_type(data_ptr[thr_idx + stride]));
}
__syncthreads();
item_num = stride;
stride = (stride + 1) / 2;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
data_ptr[block_size] =
(data_ptr[0] < data_ptr[block_size]) ? data_ptr[block_size] : data_ptr[0];
}
__syncthreads();
}
template <class T>
__device__ void reduce_sum(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
{
auto stride = (item_num + 1) / 2;
while(true)
{
if(thr_idx + stride < item_num)
{
data_ptr[thr_idx] += data_ptr[thr_idx + stride];
}
__syncthreads();
item_num = stride;
stride = (stride + 1) / 2;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
data_ptr[block_size] += data_ptr[0];
}
__syncthreads();
}
void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis) void softmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
auto lens = result.get_shape().lens(); auto lens = result.get_shape().lens();
...@@ -112,8 +59,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -112,8 +59,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads(); __syncthreads();
auto size = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max<type>(lds_data, block_size, thr_idx, size); reduce_max<type>(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
...@@ -134,8 +81,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -134,8 +81,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads(); __syncthreads();
auto size = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum<type>(lds_data, block_size, thr_idx, size); reduce_sum<type>(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_OPERS_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_OPERS_HPP
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T>
__device__ void reduce_max(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
{
auto stride = (item_num + 1) / 2;
while(true)
{
if(thr_idx + stride < item_num)
{
data_ptr[thr_idx] =
::max(to_hip_type(data_ptr[thr_idx]), to_hip_type(data_ptr[thr_idx + stride]));
}
__syncthreads();
item_num = stride;
stride = (stride + 1) / 2;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
data_ptr[block_size] =
(data_ptr[0] < data_ptr[block_size]) ? data_ptr[block_size] : data_ptr[0];
}
__syncthreads();
}
template <class T>
__device__ void reduce_sum(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
{
auto stride = (item_num + 1) / 2;
while(true)
{
if(thr_idx + stride < item_num)
{
data_ptr[thr_idx] += data_ptr[thr_idx + stride];
}
__syncthreads();
item_num = stride;
stride = (stride + 1) / 2;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
data_ptr[block_size] += data_ptr[0];
}
__syncthreads();
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment