Commit f2726514 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

further optimization of the logsoftmax operator.

parent 7959306c
......@@ -38,7 +38,6 @@ argument logsoftmax(hipStream_t stream,
stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local;
size_t blk_idx = idx.group;
// using type = typename decltype(input)::value_type;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
// all data can be loaded to the lds once, so all operations are
......@@ -56,18 +55,29 @@ argument logsoftmax(hipStream_t stream,
__syncthreads();
// use thread 0 for batch_max
if(thr_idx == 0)
{
auto size = (item_num > block_size) ? block_size : item_num;
for(size_t j = 0; j < size; j++)
auto stride = (size + 1) / 2;
while (true)
{
if (thr_idx + stride < size)
{
lds_data[block_size] =
::max(to_hip_type(lds_data[block_size]), to_hip_type(lds_data[j]));
lds_data[thr_idx] =
::max(to_hip_type(lds_data[thr_idx]), to_hip_type(lds_data[thr_idx + stride]));
}
item_num -= block_size;
__syncthreads();
size = stride;
stride = (stride + 1) / 2;
if (size == 1) break;
}
if (thr_idx == 0)
{
lds_data[block_size] = (lds_data[0] < lds_data[block_size]) ? lds_data[block_size] : lds_data[0];
}
__syncthreads();
item_num -= block_size;
}
const size_t block_size1 = block_size + 1;
......@@ -76,22 +86,32 @@ argument logsoftmax(hipStream_t stream,
for(size_t i = thr_idx; i < num_in_batch; i += block_size)
{
data_idx[axis] = i;
lds_data[i] = input_ptr[desc_data.linear(data_idx)];
lds_data[i] = input_ptr[desc_data.linear(data_idx)] - lds_data[block_size];
lds_data[i] = ::exp(to_hip_type(lds_data[i]));
__syncthreads();
// use thread 0 for batch_max
if(thr_idx == 0)
{
auto size = (item_num > block_size) ? block_size : item_num;
for(size_t j = 0; j < size; j++)
auto stride = (size + 1) / 2;
while (true)
{
if (thr_idx + stride < size)
{
lds_data[block_size1] +=
::exp(to_hip_type(lds_data[j] - lds_data[block_size]));
lds_data[thr_idx] += lds_data[thr_idx + stride];
}
item_num -= block_size;
__syncthreads();
size = stride;
stride = (stride + 1) / 2;
if (size == 1) break;
}
if (thr_idx == 0)
{
lds_data[block_size1] += lds_data[0];
}
__syncthreads();
item_num -= block_size;
}
auto log_batch_sum =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment