"src/vscode:/vscode.git/clone" did not exist on "c075d3f7d91079d28340cda89d51e15117493968"
Commit f2726514 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

further optimization of the logsoftmax operator.

parent 7959306c
...@@ -38,7 +38,6 @@ argument logsoftmax(hipStream_t stream, ...@@ -38,7 +38,6 @@ argument logsoftmax(hipStream_t stream,
stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ { stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local; size_t thr_idx = idx.local;
size_t blk_idx = idx.group; size_t blk_idx = idx.group;
// using type = typename decltype(input)::value_type;
using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
// all data can be loaded to the lds once, so all operations are // all data can be loaded to the lds once, so all operations are
...@@ -56,18 +55,29 @@ argument logsoftmax(hipStream_t stream, ...@@ -56,18 +55,29 @@ argument logsoftmax(hipStream_t stream,
__syncthreads(); __syncthreads();
// use thread 0 for batch_max
if(thr_idx == 0)
{
auto size = (item_num > block_size) ? block_size : item_num; auto size = (item_num > block_size) ? block_size : item_num;
for(size_t j = 0; j < size; j++) auto stride = (size + 1) / 2;
while (true)
{
if (thr_idx + stride < size)
{ {
lds_data[block_size] = lds_data[thr_idx] =
::max(to_hip_type(lds_data[block_size]), to_hip_type(lds_data[j])); ::max(to_hip_type(lds_data[thr_idx]), to_hip_type(lds_data[thr_idx + stride]));
} }
item_num -= block_size; __syncthreads();
size = stride;
stride = (stride + 1) / 2;
if (size == 1) break;
}
if (thr_idx == 0)
{
lds_data[block_size] = (lds_data[0] < lds_data[block_size]) ? lds_data[block_size] : lds_data[0];
} }
__syncthreads(); __syncthreads();
item_num -= block_size;
} }
const size_t block_size1 = block_size + 1; const size_t block_size1 = block_size + 1;
...@@ -76,22 +86,32 @@ argument logsoftmax(hipStream_t stream, ...@@ -76,22 +86,32 @@ argument logsoftmax(hipStream_t stream,
for(size_t i = thr_idx; i < num_in_batch; i += block_size) for(size_t i = thr_idx; i < num_in_batch; i += block_size)
{ {
data_idx[axis] = i; data_idx[axis] = i;
lds_data[i] = input_ptr[desc_data.linear(data_idx)]; lds_data[i] = input_ptr[desc_data.linear(data_idx)] - lds_data[block_size];
lds_data[i] = ::exp(to_hip_type(lds_data[i]));
__syncthreads(); __syncthreads();
// use thread 0 for batch_max
if(thr_idx == 0)
{
auto size = (item_num > block_size) ? block_size : item_num; auto size = (item_num > block_size) ? block_size : item_num;
for(size_t j = 0; j < size; j++) auto stride = (size + 1) / 2;
while (true)
{
if (thr_idx + stride < size)
{ {
lds_data[block_size1] += lds_data[thr_idx] += lds_data[thr_idx + stride];
::exp(to_hip_type(lds_data[j] - lds_data[block_size]));
} }
item_num -= block_size; __syncthreads();
size = stride;
stride = (stride + 1) / 2;
if (size == 1) break;
}
if (thr_idx == 0)
{
lds_data[block_size1] += lds_data[0];
} }
__syncthreads(); __syncthreads();
item_num -= block_size;
} }
auto log_batch_sum = auto log_batch_sum =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment