Commit d13dcab5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent ce649bb5
...@@ -15,7 +15,7 @@ namespace device { ...@@ -15,7 +15,7 @@ namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis) void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
arg.visit([&](auto input) { arg.visit([&](auto input) {
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
arg_op<pair_max<type, int64_t>>(pair_max<type, int64_t>{}, stream, result, arg, axis); arg_op<pair_max<type, int64_t>>(pair_max<type, int64_t>{}, stream, result, arg, axis);
}); });
} }
......
...@@ -15,7 +15,7 @@ namespace device { ...@@ -15,7 +15,7 @@ namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis) void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
arg.visit([&](auto input) { arg.visit([&](auto input) {
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
arg_op<pair_min<type, int64_t>>(pair_min<type, int64_t>{}, stream, result, arg, axis); arg_op<pair_min<type, int64_t>>(pair_min<type, int64_t>{}, stream, result, arg, axis);
}); });
} }
......
...@@ -41,12 +41,12 @@ struct pair_min ...@@ -41,12 +41,12 @@ struct pair_min
template <class T, class Op> template <class T, class Op>
inline __device__ void block_reduce_arg(T* data_ptr, inline __device__ void block_reduce_arg(T* data_ptr,
int64_t* index_ptr, int64_t* index_ptr,
Op op, Op op,
std::size_t block_size, std::size_t block_size,
std::size_t thr_idx, std::size_t thr_idx,
std::size_t item_num, std::size_t item_num,
std::size_t output_index) std::size_t output_index)
{ {
while(true) while(true)
{ {
...@@ -77,8 +77,7 @@ inline __device__ void block_reduce_arg(T* data_ptr, ...@@ -77,8 +77,7 @@ inline __device__ void block_reduce_arg(T* data_ptr,
__syncthreads(); __syncthreads();
} }
template <class Op>
template<class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int axis) void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
auto arg_shape = arg.get_shape(); auto arg_shape = arg.get_shape();
...@@ -124,13 +123,8 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a ...@@ -124,13 +123,8 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
block_reduce_arg<type, Op>(lds_data, block_reduce_arg<type, Op>(
lds_index, lds_data, lds_index, op, block_size, thr_idx, item_num, max_block_size);
op,
block_size,
thr_idx,
item_num,
max_block_size);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
...@@ -149,4 +143,3 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a ...@@ -149,4 +143,3 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
} // namespace migraphx } // namespace migraphx
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment