Commit 41c6d737 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 2e7398f3
...@@ -58,8 +58,13 @@ void argmax(hipStream_t stream, const argument& result, const argument& arg, int ...@@ -58,8 +58,13 @@ void argmax(hipStream_t stream, const argument& result, const argument& arg, int
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
block_reduce_pair<type, pair_max_op<type, int64_t>>(lds_data, lds_index, pair_max_op<type, int64_t>{}, block_reduce_pair<type, pair_max_op<type, int64_t>>(lds_data,
block_size, thr_idx, item_num, max_block_size); lds_index,
pair_max_op<type, int64_t>{},
block_size,
thr_idx,
item_num,
max_block_size);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
......
...@@ -58,8 +58,13 @@ void argmin(hipStream_t stream, const argument& result, const argument& arg, int ...@@ -58,8 +58,13 @@ void argmin(hipStream_t stream, const argument& result, const argument& arg, int
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
block_reduce_pair<type, pair_min_op<type, int64_t>>(lds_data, lds_index, pair_min_op<type, int64_t>{}, block_reduce_pair<type, pair_min_op<type, int64_t>>(lds_data,
block_size, thr_idx, item_num, max_block_size); lds_index,
pair_min_op<type, int64_t>{},
block_size,
thr_idx,
item_num,
max_block_size);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
......
...@@ -66,10 +66,11 @@ struct pair_max_op ...@@ -66,10 +66,11 @@ struct pair_max_op
using type = std::pair<T, F>; using type = std::pair<T, F>;
// This implementation is to ensure when multiple values // This implementation is to ensure when multiple values
// are of max, the min index is returned // are of max, the min index is returned
type operator()(type x, type y) const { type operator()(type x, type y) const
if (x.first > y.first) {
if(x.first > y.first)
return x; return x;
else if (x.first < y.first) else if(x.first < y.first)
return y; return y;
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment