"tests/git@developer.sourcefind.cn:xdb4_94051/vllm.git" did not exist on "ba0bfd40e21cacfd5da6a1e43028a37258a29cb4"
Commit b6786993 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent ccdacf44
...@@ -58,7 +58,8 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, ...@@ -58,7 +58,8 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
} }
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num); reduce_max(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size; remaining_item_num -= block_size;
...@@ -74,21 +75,20 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, ...@@ -74,21 +75,20 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
if(i < batch_item_num) if(i < batch_item_num)
{ {
data_idx[axis] = i; data_idx[axis] = i;
lds_data[thr_idx] = lds_data[thr_idx] = input_ptr[desc_data.linear(data_idx)] - batch_max;
input_ptr[desc_data.linear(data_idx)] - batch_max;
lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx])); lds_data[thr_idx] = ::exp(to_hip_type(lds_data[thr_idx]));
} }
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num); reduce_sum(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
auto log_batch_sum = auto log_batch_sum = ::log(to_hip_type(lds_data[block_size])) + batch_max;
::log(to_hip_type(lds_data[block_size])) + batch_max;
for(size_t i = thr_idx; i < batch_item_num; i += block_size) for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{ {
......
...@@ -59,7 +59,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -59,7 +59,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max<type>(lds_data, block_size, thr_idx, item_num); reduce_max<type>(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size; remaining_item_num -= block_size;
...@@ -81,7 +82,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -81,7 +82,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num =
(remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum<type>(lds_data, block_size, thr_idx, item_num); reduce_sum<type>(lds_data, block_size, thr_idx, item_num);
remaining_item_num -= block_size; remaining_item_num -= block_size;
......
...@@ -70,4 +70,3 @@ __device__ void reduce_sum(T* data_ptr, size_t block_size, size_t thr_idx, size_ ...@@ -70,4 +70,3 @@ __device__ void reduce_sum(T* data_ptr, size_t block_size, size_t thr_idx, size_
} // namespace migraphx } // namespace migraphx
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment