Commit ad583f24 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

change to call dpp implementation for argmax/argmin

parent a64fb36d
......@@ -16,7 +16,7 @@ void argmax(hipStream_t stream, const argument& result, const argument& arg, int
{
arg.visit([&](auto input) {
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
arg_op<pair_max<type, int64_t>>(pair_max<type, int64_t>{}, stream, result, arg, axis);
arg_op<type, argmax_op<type>>(argmax_op<type>{}, stream, result, arg, axis);
});
}
......
......@@ -16,7 +16,7 @@ void argmin(hipStream_t stream, const argument& result, const argument& arg, int
{
arg.visit([&](auto input) {
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
arg_op<pair_min<type, int64_t>>(pair_min<type, int64_t>{}, stream, result, arg, axis);
arg_op<type, argmin_op<type>>(argmin_op<type>{}, stream, result, arg, axis);
});
}
......
......@@ -182,7 +182,7 @@ __device__ auto block_reduce(index idx, Op op, T init, std::size_t n, F f)
}
__syncthreads();
type y = 0;
type y = init;
for(std::size_t i = 0; i < idx.nlocal() / 64; i++)
{
y = op(y, buffer[i]);
......
......@@ -6,6 +6,7 @@
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/hip.hpp>
namespace migraphx {
......@@ -13,71 +14,53 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T, class F>
struct pair_max
{
using type = std::pair<T, F>;
// This implementation is to ensure when multiple values
// are of max, the min index is returned
type operator()(type x, type y) const
template<class T>
struct val_index {
T val;
int64_t index;
// MIGRAPHX_DEVICE_CONSTEXPR val_index(T v, int64_t idx) : val(v), index(idx) { }
};
template<class T>
struct argmax_op {
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
if(x.first > y.first)
if(x.val > y.val)
return x;
else if(x.first < y.first)
else if(x.val < y.val)
return y;
else
{
return (x.second < y.second) ? x : y;
return (x.index < y.index) ? x : y;
}
}
};
template <class T, class F>
struct pair_min
{
using type = std::pair<T, F>;
type operator()(type x, type y) const { return (x < y) ? x : y; }
MIGRAPHX_DEVICE_CONSTEXPR T init() const {
return lowest();
}
};
template <class T, class Op>
inline __device__ void block_reduce_arg(T* data_ptr,
int64_t* index_ptr,
Op op,
std::size_t block_size,
std::size_t thr_idx,
std::size_t item_num,
std::size_t output_index)
{
while(true)
template<class T>
struct argmin_op {
MIGRAPHX_DEVICE_CONSTEXPR val_index<T> operator()(val_index<T> x, val_index<T> y) const
{
auto stride = (item_num + 1) / 2;
auto size = item_num / 2;
for(std::size_t i = thr_idx; i < size; i += block_size)
if(x.val < y.val)
return x;
else if(x.val > y.val)
return y;
else
{
auto output =
op({data_ptr[i], index_ptr[i]}, {data_ptr[i + stride], index_ptr[i + stride]});
data_ptr[i] = output.first;
index_ptr[i] = output.second;
return (x.index < y.index) ? x : y;
}
__syncthreads();
item_num = stride;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
auto output =
op({data_ptr[output_index], index_ptr[output_index]}, {data_ptr[0], index_ptr[0]});
data_ptr[output_index] = output.first;
index_ptr[output_index] = output.second;
MIGRAPHX_DEVICE_CONSTEXPR T init() const {
return highest();
}
};
__syncthreads();
}
template <class Op>
template <class T, class Op>
void arg_op(Op op, hipStream_t stream, const argument& result, const argument& arg, int axis)
{
auto arg_shape = arg.get_shape();
......@@ -90,48 +73,26 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto output = device_cast(result.get<int64_t>().data());
// use one block for items in one batch.
const size_t max_block_size = 1024;
size_t block_size = 1;
while(block_size < max_block_size and block_size < batch_item_num)
{
block_size *= 2;
}
launch(stream, batch_shape.elements() * block_size, block_size)([=](auto idx) __device__ {
size_t thr_idx = idx.local;
size_t blk_idx = idx.group;
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
auto batch_idx = batch_s.multi(blk_idx);
auto data_idx = batch_idx;
MIGRAPHX_DEVICE_SHARED type lds_data[max_block_size + 1];
MIGRAPHX_DEVICE_SHARED int64_t lds_index[max_block_size + 1];
// load data to lds_data
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
size_t remaining_item_num = batch_item_num;
data_idx[axis] = 0;
lds_data[max_block_size] = input[arg_s.index(data_idx)];
lds_index[max_block_size] = 0;
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{
if(i < batch_item_num)
{
data_idx[axis] = i;
lds_index[thr_idx] = i;
lds_data[thr_idx] = input[arg_s.index(data_idx)];
}
__syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
block_reduce_arg<type, Op>(
lds_data, lds_index, op, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size;
}
if(thr_idx == 0)
const size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream,
batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ {
auto batch_idx = batch_s.multi(i / block_size);
auto data_idx = batch_idx;
T init_val = op.init();
val_index<T> init = {init_val, -1};
auto op_output = block_reduce<max_block_size, Op, val_index<T>>(
idx, op, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
T val = input[arg_s.index(data_idx)];
return val_index<T>{val, static_cast<int64_t>(j)};
});
if(idx.local == 0)
{
output[batch_s.index(batch_idx)] = lds_index[max_block_size];
output[batch_s.index(batch_idx)] = op_output.index;
}
});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment