Commit 605cce41 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

more optimization of reduce operation.

parent 93eae2df
...@@ -41,7 +41,7 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis) ...@@ -41,7 +41,7 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
// load data to lds and compute the batch max // load data to lds and compute the batch max
size_t remaining_item_num = batch_item_num; size_t remaining_item_num = batch_item_num;
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size; size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input[0]; lds_data[max_block_size] = input[0];
for(size_t i = thr_idx; i < round_item_num; i += block_size) for(size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
if(i < batch_item_num) if(i < batch_item_num)
...@@ -52,15 +52,15 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis) ...@@ -52,15 +52,15 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num); reduce_max(lds_data, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
auto batch_max = lds_data[block_size]; auto batch_max = lds_data[max_block_size];
__syncthreads(); __syncthreads();
lds_data[block_size] = 0; lds_data[max_block_size] = 0;
remaining_item_num = batch_item_num; remaining_item_num = batch_item_num;
for(size_t i = thr_idx; i < round_item_num; i += block_size) for(size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
...@@ -74,12 +74,12 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis) ...@@ -74,12 +74,12 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num); reduce_sum(lds_data, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
auto log_batch_sum = ::log(to_hip_type(lds_data[block_size])) + batch_max; auto log_batch_sum = ::log(to_hip_type(lds_data[max_block_size])) + batch_max;
for(size_t i = thr_idx; i < batch_item_num; i += block_size) for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{ {
......
...@@ -41,7 +41,7 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis) ...@@ -41,7 +41,7 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
// load data to lds and compute the batch max // load data to lds and compute the batch max
size_t remaining_item_num = batch_item_num; size_t remaining_item_num = batch_item_num;
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size; size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input[0]; lds_data[max_block_size] = input[0];
for(size_t i = thr_idx; i < round_item_num; i += block_size) for(size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
if(i < batch_item_num) if(i < batch_item_num)
...@@ -53,15 +53,15 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis) ...@@ -53,15 +53,15 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num); reduce_max(lds_data, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
auto batch_max = lds_data[block_size]; auto batch_max = lds_data[max_block_size];
__syncthreads(); __syncthreads();
lds_data[block_size] = 0; lds_data[max_block_size] = 0;
remaining_item_num = batch_item_num; remaining_item_num = batch_item_num;
for(size_t i = thr_idx; i < round_item_num; i += block_size) for(size_t i = thr_idx; i < round_item_num; i += block_size)
{ {
...@@ -75,11 +75,11 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis) ...@@ -75,11 +75,11 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num); reduce_sum(lds_data, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
auto batch_sum = lds_data[block_size]; auto batch_sum = lds_data[max_block_size];
for(size_t i = thr_idx; i < batch_item_num; i += block_size) for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{ {
......
...@@ -11,7 +11,7 @@ namespace gpu { ...@@ -11,7 +11,7 @@ namespace gpu {
namespace device { namespace device {
template <class T> template <class T>
inline __device__ void reduce_max(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num) inline __device__ void reduce_max(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num, size_t max_index)
{ {
while(true) while(true)
{ {
...@@ -30,15 +30,42 @@ inline __device__ void reduce_max(T* data_ptr, size_t block_size, size_t thr_idx ...@@ -30,15 +30,42 @@ inline __device__ void reduce_max(T* data_ptr, size_t block_size, size_t thr_idx
if(thr_idx == 0) if(thr_idx == 0)
{ {
data_ptr[block_size] = data_ptr[max_index] =
(data_ptr[0] < data_ptr[block_size]) ? data_ptr[block_size] : data_ptr[0]; (data_ptr[0] < data_ptr[max_index]) ? data_ptr[max_index] : data_ptr[0];
} }
__syncthreads(); __syncthreads();
} }
template <class T> template <class T>
inline __device__ void reduce_sum(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num) inline __device__ void reduce_min(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num, size_t min_index)
{
while(true)
{
auto stride = (item_num + 1) / 2;
auto size = item_num / 2;
for(size_t i = thr_idx; i < size; i += block_size)
{
data_ptr[i] = ::min(to_hip_type(data_ptr[i]), to_hip_type(data_ptr[i + stride]));
}
__syncthreads();
item_num = stride;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
data_ptr[min_index] =
(data_ptr[0] > data_ptr[min_index]) ? data_ptr[min_index] : data_ptr[0];
}
__syncthreads();
}
template <class T>
inline __device__ void reduce_sum(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num, size_t sum_index)
{ {
while(true) while(true)
{ {
...@@ -57,7 +84,7 @@ inline __device__ void reduce_sum(T* data_ptr, size_t block_size, size_t thr_idx ...@@ -57,7 +84,7 @@ inline __device__ void reduce_sum(T* data_ptr, size_t block_size, size_t thr_idx
if(thr_idx == 0) if(thr_idx == 0)
{ {
data_ptr[block_size] += data_ptr[0]; data_ptr[sum_index] += data_ptr[0];
} }
__syncthreads(); __syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment