Commit 4924bb45 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'argmax_min' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into argmax_min

parents add40fd1 8ffe3180
...@@ -45,9 +45,10 @@ struct argmax ...@@ -45,9 +45,10 @@ struct argmax
for(std::size_t i = 1; i < item_num; ++i) for(std::size_t i = 1; i < item_num; ++i)
{ {
indices[axis] = i; indices[axis] = i;
if(max_val < input(indices.begin(), indices.end())) auto cur_val = input(indices.begin(), indices.end());
if(max_val < cur_val)
{ {
max_val = input(indices.begin(), indices.end()); max_val = cur_val;
max_index = i; max_index = i;
} }
} }
......
...@@ -50,9 +50,10 @@ struct argmin ...@@ -50,9 +50,10 @@ struct argmin
for(std::size_t i = 1; i < item_num; ++i) for(std::size_t i = 1; i < item_num; ++i)
{ {
indices[axis] = i; indices[axis] = i;
if(min_val > input(indices.begin(), indices.end())) auto cur_val = input(indices.begin(), indices.end());
if(min_val > cur_val)
{ {
min_val = input(indices.begin(), indices.end()); min_val = cur_val;
min_index = i; min_index = i;
} }
} }
......
...@@ -16,7 +16,7 @@ void argmax(hipStream_t stream, const argument& result, const argument& arg, int ...@@ -16,7 +16,7 @@ void argmax(hipStream_t stream, const argument& result, const argument& arg, int
{ {
arg.visit([&](auto input) { arg.visit([&](auto input) {
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
arg_op<pair_max<type, int64_t>>(pair_max<type, int64_t>{}, stream, result, arg, axis); arg_op<type, argmax_op<type>>(argmax_op<type>{}, stream, result, arg, axis);
}); });
} }
......
...@@ -16,7 +16,7 @@ void argmin(hipStream_t stream, const argument& result, const argument& arg, int ...@@ -16,7 +16,7 @@ void argmin(hipStream_t stream, const argument& result, const argument& arg, int
{ {
arg.visit([&](auto input) { arg.visit([&](auto input) {
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
arg_op<pair_min<type, int64_t>>(pair_min<type, int64_t>{}, stream, result, arg, axis); arg_op<type, argmin_op<type>>(argmin_op<type>{}, stream, result, arg, axis);
}); });
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/gpu/device/tensor.hpp> #include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
namespace migraphx { namespace migraphx {
...@@ -13,71 +14,50 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -13,71 +14,50 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class T, class F> template <class T>
struct pair_max struct val_index
{ {
using type = std::pair<T, F>; T val;
// This implementation is to ensure when multiple values int64_t index;
// are of max, the min index is returned };
type operator()(type x, type y) const
template <class T>
struct argmax_op
{
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{ {
if(x.first > y.first) if(x.val > y.val)
return x; return x;
else if(x.first < y.first) else if(x.val < y.val)
return y; return y;
else else
{ {
return (x.second < y.second) ? x : y; return (x.index < y.index) ? x : y;
} }
} }
};
template <class T, class F> MIGRAPHX_DEVICE_CONSTEXPR T init() const { return lowest(); }
struct pair_min
{
using type = std::pair<T, F>;
type operator()(type x, type y) const { return (x < y) ? x : y; }
}; };
template <class T, class Op> template <class T>
inline __device__ void block_reduce_arg(T* data_ptr, struct argmin_op
int64_t* index_ptr,
Op op,
std::size_t block_size,
std::size_t thr_idx,
std::size_t item_num,
std::size_t output_index)
{ {
while(true) MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{ {
auto stride = (item_num + 1) / 2; if(x.val < y.val)
auto size = item_num / 2; return x;
for(std::size_t i = thr_idx; i < size; i += block_size) else if(x.val > y.val)
return y;
else
{ {
auto output = return (x.index < y.index) ? x : y;
op({data_ptr[i], index_ptr[i]}, {data_ptr[i + stride], index_ptr[i + stride]});
data_ptr[i] = output.first;
index_ptr[i] = output.second;
} }
__syncthreads();
item_num = stride;
if(item_num == 1)
break;
} }
if(thr_idx == 0) MIGRAPHX_DEVICE_CONSTEXPR T init() const { return highest(); }
{ };
auto output =
op({data_ptr[output_index], index_ptr[output_index]}, {data_ptr[0], index_ptr[0]});
data_ptr[output_index] = output.first;
index_ptr[output_index] = output.second;
}
__syncthreads();
}
template <class Op> template <class T, class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int axis) void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int axis)
{ {
auto arg_shape = arg.get_shape(); auto arg_shape = arg.get_shape();
...@@ -90,48 +70,25 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a ...@@ -90,48 +70,25 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) { hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto output = device_cast(result.get<int64_t>().data()); auto output = device_cast(result.get<int64_t>().data());
// use one block for items in one batch. // use one block for items in one batch.
const size_t max_block_size = 1024; const size_t max_block_size = 256;
size_t block_size = 1; const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
while(block_size < max_block_size and block_size < batch_item_num) gs_launch(stream, batch_shape.elements() * block_size, block_size)(
{ [=](auto i, auto idx) __device__ {
block_size *= 2; auto batch_idx = batch_s.multi(i / block_size);
}
launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local;
size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
auto batch_idx = batch_s.multi(blk_idx);
auto data_idx = batch_idx; auto data_idx = batch_idx;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1]; T init_val = op.init();
MIGRAPHX_DEVICE_SHARED int64_t lds_index[max_block_size + 1]; val_index<T> init = {init_val, -1};
// load data to lds_data
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size; auto op_output = block_reduce<max_block_size, Op, val_index<T>>(
size_t remaining_item_num = batch_item_num; idx, op, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = 0; data_idx[axis] = j;
lds_data[max_block_size] = input[arg_s.index(data_idx)]; T val = input[arg_s.index(data_idx)];
lds_index[max_block_size] = 0; return val_index<T>{val, static_cast<int64_t>(j)};
for(size_t i = thr_idx; i < round_item_num; i += block_size) });
{
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_index[thr_idx] = i;
lds_data[thr_idx] = input[arg_s.index(data_idx)];
}
__syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
block_reduce_arg<type, Op>(
lds_data, lds_index, op, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size;
}
if(thr_idx == 0) if(idx.local == 0)
{ {
output[batch_s.index(batch_idx)] = lds_index[max_block_size]; output[batch_s.index(batch_idx)] = op_output.index;
} }
}); });
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment