Commit 604d5fcd authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix review comments.

parent 6fa72229
...@@ -14,10 +14,7 @@ namespace device { ...@@ -14,10 +14,7 @@ namespace device {
void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis) void argmax(hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
arg.visit([&](auto input) { arg_op(argmax_op{}, stream, result, arg, axis);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
arg_op<type, argmax_op<type>>(argmax_op<type>{}, stream, result, arg, axis);
});
} }
} // namespace device } // namespace device
......
...@@ -14,10 +14,7 @@ namespace device { ...@@ -14,10 +14,7 @@ namespace device {
void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis) void argmin(hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
arg.visit([&](auto input) { arg_op(argmin_op{}, stream, result, arg, axis);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
arg_op<type, argmin_op<type>>(argmin_op<type>{}, stream, result, arg, axis);
});
} }
} // namespace device } // namespace device
......
...@@ -21,9 +21,21 @@ struct val_index ...@@ -21,9 +21,21 @@ struct val_index
int64_t index; int64_t index;
}; };
template <class T> template<class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v)
{
return {v, -1};
}
template<class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> make_val_index(T v, int64_t i)
{
return {v, i};
}
struct argmax_op struct argmax_op
{ {
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{ {
if(x.val > y.val) if(x.val > y.val)
...@@ -36,12 +48,12 @@ struct argmax_op ...@@ -36,12 +48,12 @@ struct argmax_op
} }
} }
MIGRAPHX_DEVICE_CONSTEXPR T init() const { return lowest(); } MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return lowest(); }
}; };
template <class T>
struct argmin_op struct argmin_op
{ {
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{ {
if(x.val < y.val) if(x.val < y.val)
...@@ -54,10 +66,10 @@ struct argmin_op ...@@ -54,10 +66,10 @@ struct argmin_op
} }
} }
MIGRAPHX_DEVICE_CONSTEXPR T init() const { return highest(); } MIGRAPHX_DEVICE_CONSTEXPR auto init() const { return highest(); }
}; };
template <class T, class Op> template <class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int axis) void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
auto arg_shape = arg.get_shape(); auto arg_shape = arg.get_shape();
...@@ -69,6 +81,7 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a ...@@ -69,6 +81,7 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) { hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto output = device_cast(result.get<int64_t>().data()); auto output = device_cast(result.get<int64_t>().data());
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
// use one block for items in one batch. // use one block for items in one batch.
const size_t max_block_size = 256; const size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size); const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
...@@ -76,14 +89,12 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a ...@@ -76,14 +89,12 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
[=](auto i, auto idx) __device__ { [=](auto i, auto idx) __device__ {
auto batch_idx = batch_s.multi(i / block_size); auto batch_idx = batch_s.multi(i / block_size);
auto data_idx = batch_idx; auto data_idx = batch_idx;
T init_val = op.init(); auto init = make_val_index<type>(op.init());
val_index<T> init = {init_val, -1};
auto op_output = block_reduce<max_block_size, Op, val_index<T>>( auto op_output = block_reduce<max_block_size>(
idx, op, init, batch_item_num, [&](auto j) __device__ { idx, op, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j; data_idx[axis] = j;
T val = input[arg_s.index(data_idx)]; return make_val_index(input[arg_s.index(data_idx)], j);
return val_index<T>{val, static_cast<int64_t>(j)};
}); });
if(idx.local == 0) if(idx.local == 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment