Commit c2b3881f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent ad583f24
...@@ -14,17 +14,19 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -14,17 +14,19 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template<class T> template <class T>
struct val_index { struct val_index
{
T val; T val;
int64_t index; int64_t index;
// MIGRAPHX_DEVICE_CONSTEXPR val_index(T v, int64_t idx) : val(v), index(idx) { } // MIGRAPHX_DEVICE_CONSTEXPR val_index(T v, int64_t idx) : val(v), index(idx) { }
}; };
template<class T> template <class T>
struct argmax_op { struct argmax_op
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const {
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{ {
if(x.val > y.val) if(x.val > y.val)
return x; return x;
...@@ -36,14 +38,13 @@ struct argmax_op { ...@@ -36,14 +38,13 @@ struct argmax_op {
} }
} }
MIGRAPHX_DEVICE_CONSTEXPR T init() const { MIGRAPHX_DEVICE_CONSTEXPR T init() const { return lowest(); }
return lowest();
}
}; };
template<class T> template <class T>
struct argmin_op { struct argmin_op
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const {
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{ {
if(x.val < y.val) if(x.val < y.val)
return x; return x;
...@@ -55,9 +56,7 @@ struct argmin_op { ...@@ -55,9 +56,7 @@ struct argmin_op {
} }
} }
MIGRAPHX_DEVICE_CONSTEXPR T init() const { MIGRAPHX_DEVICE_CONSTEXPR T init() const { return highest(); }
return highest();
}
}; };
template <class T, class Op> template <class T, class Op>
...@@ -73,28 +72,27 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a ...@@ -73,28 +72,27 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) { hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto output = device_cast(result.get<int64_t>().data()); auto output = device_cast(result.get<int64_t>().data());
// use one block for items in one batch. // use one block for items in one batch.
const size_t max_block_size = 256; const size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size); const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream, gs_launch(stream, batch_shape.elements() * block_size, block_size)(
batch_shape.elements() * block_size, [=](auto i, auto idx) __device__ {
block_size)([=](auto i, auto idx) __device__ { auto batch_idx = batch_s.multi(i / block_size);
auto batch_idx = batch_s.multi(i / block_size); auto data_idx = batch_idx;
auto data_idx = batch_idx; T init_val = op.init();
T init_val = op.init(); val_index<T> init = {init_val, -1};
val_index<T> init = {init_val, -1};
auto op_output = block_reduce<max_block_size, Op, val_index<T>>( auto op_output = block_reduce<max_block_size, Op, val_index<T>>(
idx, op, init, batch_item_num, [&](auto j) __device__ { idx, op, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j; data_idx[axis] = j;
T val = input[arg_s.index(data_idx)]; T val = input[arg_s.index(data_idx)];
return val_index<T>{val, static_cast<int64_t>(j)}; return val_index<T>{val, static_cast<int64_t>(j)};
}); });
if(idx.local == 0) if(idx.local == 0)
{ {
output[batch_s.index(batch_idx)] = op_output.index; output[batch_s.index(batch_idx)] = op_output.index;
} }
}); });
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment