Commit ee46bc9f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 604d5fcd
...@@ -21,13 +21,13 @@ struct val_index ...@@ -21,13 +21,13 @@ struct val_index
int64_t index; int64_t index;
}; };
template<class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v) MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v)
{ {
return {v, -1}; return {v, -1};
} }
template<class T> template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i) MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
{ {
return {v, i}; return {v, i};
...@@ -81,27 +81,28 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a ...@@ -81,27 +81,28 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) { hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto output = device_cast(result.get<int64_t>().data()); auto output = device_cast(result.get<int64_t>().data());
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
// use one block for items in one batch. // use one block for items in one batch.
const size_t max_block_size = 256; const size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size); const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, batch_shape.elements() * block_size, block_size)( gs_launch(stream,
[=](auto i, auto idx) __device__ { batch_shape.elements() * block_size,
auto batch_idx = batch_s.multi(i / block_size); block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch_idx; auto batch_idx = batch_s.multi(i / block_size);
auto init = make_val_index<type>(op.init()); auto data_idx = batch_idx;
auto init = make_val_index<type>(op.init());
auto op_output = block_reduce<max_block_size>( auto op_output =
idx, op, init, batch_item_num, [&](auto j) __device__ { block_reduce<max_block_size>(idx, op, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j; data_idx[axis] = j;
return make_val_index(input[arg_s.index(data_idx)], j); return make_val_index(input[arg_s.index(data_idx)], j);
}); });
if(idx.local == 0) if(idx.local == 0)
{ {
output[batch_s.index(batch_idx)] = op_output.index; output[batch_s.index(batch_idx)] = op_output.index;
} }
}); });
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment