Commit 1da2689b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent f2726514
...@@ -55,25 +55,28 @@ argument logsoftmax(hipStream_t stream, ...@@ -55,25 +55,28 @@ argument logsoftmax(hipStream_t stream,
__syncthreads(); __syncthreads();
auto size = (item_num > block_size) ? block_size : item_num; auto size = (item_num > block_size) ? block_size : item_num;
auto stride = (size + 1) / 2; auto stride = (size + 1) / 2;
while (true) while(true)
{ {
if (thr_idx + stride < size) if(thr_idx + stride < size)
{ {
lds_data[thr_idx] = lds_data[thr_idx] = ::max(to_hip_type(lds_data[thr_idx]),
::max(to_hip_type(lds_data[thr_idx]), to_hip_type(lds_data[thr_idx + stride])); to_hip_type(lds_data[thr_idx + stride]));
} }
__syncthreads(); __syncthreads();
size = stride; size = stride;
stride = (stride + 1) / 2; stride = (stride + 1) / 2;
if (size == 1) break; if(size == 1)
break;
} }
if (thr_idx == 0) if(thr_idx == 0)
{ {
lds_data[block_size] = (lds_data[0] < lds_data[block_size]) ? lds_data[block_size] : lds_data[0]; lds_data[block_size] = (lds_data[0] < lds_data[block_size])
? lds_data[block_size]
: lds_data[0];
} }
__syncthreads(); __syncthreads();
...@@ -87,25 +90,26 @@ argument logsoftmax(hipStream_t stream, ...@@ -87,25 +90,26 @@ argument logsoftmax(hipStream_t stream,
{ {
data_idx[axis] = i; data_idx[axis] = i;
lds_data[i] = input_ptr[desc_data.linear(data_idx)] - lds_data[block_size]; lds_data[i] = input_ptr[desc_data.linear(data_idx)] - lds_data[block_size];
lds_data[i] = ::exp(to_hip_type(lds_data[i])); lds_data[i] = ::exp(to_hip_type(lds_data[i]));
__syncthreads(); __syncthreads();
auto size = (item_num > block_size) ? block_size : item_num; auto size = (item_num > block_size) ? block_size : item_num;
auto stride = (size + 1) / 2; auto stride = (size + 1) / 2;
while (true) while(true)
{ {
if (thr_idx + stride < size) if(thr_idx + stride < size)
{ {
lds_data[thr_idx] += lds_data[thr_idx + stride]; lds_data[thr_idx] += lds_data[thr_idx + stride];
} }
__syncthreads(); __syncthreads();
size = stride; size = stride;
stride = (stride + 1) / 2; stride = (stride + 1) / 2;
if (size == 1) break; if(size == 1)
break;
} }
if (thr_idx == 0) if(thr_idx == 0)
{ {
lds_data[block_size1] += lds_data[0]; lds_data[block_size1] += lds_data[0];
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment