Commit 63773ec0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code cleanup for softmax.

parent 6ae2f087
...@@ -14,7 +14,7 @@ namespace device { ...@@ -14,7 +14,7 @@ namespace device {
template <class T> template <class T>
__device__ void __device__ void
reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num) reduce_max(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
{ {
auto stride = (item_num + 1) / 2; auto stride = (item_num + 1) / 2;
while(true) while(true)
...@@ -43,7 +43,7 @@ reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx ...@@ -43,7 +43,7 @@ reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx
template <class T> template <class T>
__device__ void __device__ void
reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num) reduce_sum(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
{ {
auto stride = (item_num + 1) / 2; auto stride = (item_num + 1) / 2;
while(true) while(true)
...@@ -62,7 +62,7 @@ reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx ...@@ -62,7 +62,7 @@ reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size, size_t thr_idx
if(thr_idx == 0) if(thr_idx == 0)
{ {
data_ptr[block_size + 1] += data_ptr[0]; data_ptr[block_size] += data_ptr[0];
} }
__syncthreads(); __syncthreads();
...@@ -72,7 +72,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -72,7 +72,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
{ {
auto lens = result.get_shape().lens(); auto lens = result.get_shape().lens();
auto batch_lens = lens; auto batch_lens = lens;
size_t n_dims = lens[axis]; size_t batch_item_num = lens[axis];
batch_lens[axis] = 1; batch_lens[axis] = 1;
migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
...@@ -86,7 +86,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -86,7 +86,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
// use one block for items in one batch. // use one block for items in one batch.
const size_t max_block_size = 1024; const size_t max_block_size = 1024;
size_t block_size = 1; size_t block_size = 1;
while(block_size < max_block_size and block_size < n_dims) while(block_size < max_block_size and block_size < batch_item_num)
{ {
block_size *= 2; block_size *= 2;
} }
...@@ -97,19 +97,16 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -97,19 +97,16 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
size_t blk_idx = idx.group; size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
// all data can be loaded to the lds once, so all operations are MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
// done in lds
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 2];
auto batch_idx = desc_batch.multi(blk_idx); auto batch_idx = desc_batch.multi(blk_idx);
auto data_idx = batch_idx; auto data_idx = batch_idx;
// load data to lds and compute the batch max // load data to lds and compute the batch max
size_t item_num = n_dims; size_t remaining_item_num = batch_item_num;
size_t thread_num = (n_dims + block_size - 1) / block_size * block_size; size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input_ptr[0]; lds_data[block_size] = input_ptr[0];
lds_data[block_size + 1] = 0; for(size_t i = thr_idx; i < round_item_num; i += block_size)
for(size_t i = thr_idx; i < thread_num; i += block_size)
{ {
if(i < n_dims) if(i < batch_item_num)
{ {
data_idx[axis] = i; data_idx[axis] = i;
lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)]; lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)];
...@@ -117,40 +114,42 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -117,40 +114,42 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads(); __syncthreads();
auto size = (item_num > block_size) ? block_size : item_num; auto size = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max<type>(lds_data, block_size, thr_idx, size); reduce_max<type>(lds_data, block_size, thr_idx, size);
__syncthreads(); remaining_item_num -= block_size;
item_num -= block_size;
} }
item_num = n_dims; auto batch_max = lds_data[block_size];
for(size_t i = thr_idx; i < thread_num; i += block_size) __syncthreads();
lds_data[block_size] = 0;
remaining_item_num = batch_item_num;
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
if(i < n_dims) if(i < batch_item_num)
{ {
data_idx[axis] = i; data_idx[axis] = i;
lds_data[thr_idx] = lds_data[thr_idx] =
input_ptr[desc_data.linear(data_idx)] - lds_data[block_size]; input_ptr[desc_data.linear(data_idx)] - batch_max;
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx])); lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
} }
__syncthreads(); __syncthreads();
auto size = (item_num > block_size) ? block_size : item_num; auto size = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum<type>(lds_data, block_size, thr_idx, size); reduce_sum<type>(lds_data, block_size, thr_idx, size);
__syncthreads();
item_num -= block_size; remaining_item_num -= block_size;
} }
auto batch_sum = lds_data[block_size];
for(size_t i = thr_idx; i < n_dims; i += block_size) for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{ {
data_idx[axis] = i; data_idx[axis] = i;
size_t index = desc_data.linear(data_idx); size_t index = desc_data.linear(data_idx);
auto val = input_ptr[index] - lds_data[block_size]; auto val = input_ptr[index] - batch_max;
output_ptr[index] = ::exp(to_hip_type(val)) / lds_data[block_size + 1]; output_ptr[index] = ::exp(to_hip_type(val)) / batch_sum;
} }
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment