Commit beaccf94 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add operators argmax and argmin.

parents bf6fc5f8 bfa455a1
......@@ -684,23 +684,46 @@ struct cpu_argmax
std::string name() const { return "cpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
std::vector<size_t> compute_batch_indices(size_t idx, const shape& s) const
{
std::vector<std::size_t> indices(s.lens().size());
std::transform(s.strides().begin(),
s.strides().end(),
s.lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (idx / stride) % len;
});
return indices;
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto batch_lens = args.front().get_shape().lens();
size_t batch_item_num = batch_lens[op.axis];
batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
using value_type =
batch_max(output_shape.elements(), std::numeric_limits<value_type>::lowest());
auto data_shape = args[0].get_shape();
shape_for_each(data_shape, [&](auto idx) {
auto data_index = data_shape.index(idx);
idx[axis] = 0;
auto out_index = data_shape.index(idx);
if(batch_max[index] < input[data_index])
par_for(batch_shape.elements(), [&](auto i) {
auto data_idx = this->compute_batch_indices(i, batch_shape);
auto max_val = input[i];
int64_t max_index = 0;
for (size_t j = 1; j < batch_item_num; ++j)
{
batch_max[index] = input[data_index];
output[index] = static_cast<int64_t>(data_index);
data_idx[op.axis] = j;
if (max_val < input(data_idx.begin(), data_idx.end()))
{
max_val = input(data_idx.begin(), data_idx.end());
max_index = j;
}
}
output[i] = max_index;
});
});
});
......@@ -722,23 +745,46 @@ struct cpu_argmin
std::string name() const { return "cpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
std::vector<size_t> compute_batch_indices(size_t idx, const shape& s) const
{
std::vector<std::size_t> indices(s.lens().size());
std::transform(s.strides().begin(),
s.strides().end(),
s.lens().begin(),
indices.begin(),
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (idx / stride) % len;
});
return indices;
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto batch_lens = args.front().get_shape().lens();
size_t batch_item_num = batch_lens[op.axis];
batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
using value_type =
batch_min(output_shape.elements(), std::numeric_limits<value_type>::max());
auto data_shape = args[0].get_shape();
shape_for_each(data_shape, [&](auto idx) {
auto data_index = data_shape.index(idx);
idx[axis] = 0;
auto out_index = data_shape.index(idx);
if(batch_min[index] > input[data_index])
par_for(batch_shape.elements(), [&](auto i) {
auto data_idx = this->compute_batch_indices(i, batch_shape);
auto min_val = input[i];
int64_t min_index = 0;
for (size_t j = 1; j < batch_item_num; ++j)
{
batch_min[index] = input[data_index];
output[index] = static_cast<int64_t>(data_index);
data_idx[op.axis] = j;
if (min_val > input(data_idx.begin(), data_idx.end()))
{
min_val = input(data_idx.begin(), data_idx.end());
min_index = j;
}
}
output[i] = min_index;
});
});
});
......
......@@ -12,6 +12,8 @@ endif()
add_library(migraphx_device
device/add.cpp
device/argmax.cpp
device/argmin.cpp
device/max.cpp
device/min.cpp
device/exp.cpp
......@@ -43,6 +45,8 @@ target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURR
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
add_library(migraphx_gpu
argmax.cpp
argmin.cpp
eliminate_workspace.cpp
fuse_ops.cpp
hip.cpp
......
#include <migraphx/gpu/argmax.hpp>
#include <migraphx/gpu/device/argmax.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(0)});
}
argument hip_argmax::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::argmax(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/argmin.hpp>
#include <migraphx/gpu/device/argmin.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
return op.compute_shape({inputs.at(0)});
}
argument hip_argmin::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::argmin(ctx.get_stream().get(), args.back(), args.front(), op.axis);
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -55,7 +55,7 @@ void argmax(hipStream_t stream, const argument& result, const argument& arg, int
__syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_argmax(lds_data, lds_index, block_size, thr_idx, size, max_block_size);
reduce_argmax(lds_data, lds_index, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size;
}
......@@ -66,8 +66,6 @@ void argmax(hipStream_t stream, const argument& result, const argument& arg, int
}
});
});
return args.back();
}
} // namespace device
......
......@@ -55,7 +55,7 @@ void argmin(hipStream_t stream, const argument& result, const argument& arg, int
__syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_argmin(lds_data, lds_index, block_size, thr_idx, size, max_block_size);
reduce_argmin(lds_data, lds_index, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size;
}
......@@ -66,8 +66,6 @@ void argmin(hipStream_t stream, const argument& result, const argument& arg, int
}
});
});
return args.back();
}
} // namespace device
......
......@@ -41,7 +41,7 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
// load data to lds and compute the batch max
size_t remaining_item_num = batch_item_num;
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input[0];
lds_data[max_block_size] = input[0];
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{
if(i < batch_item_num)
......@@ -52,16 +52,16 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
__syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num);
reduce_max(lds_data, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size;
}
auto batch_max = lds_data[block_size];
auto batch_max = lds_data[max_block_size];
__syncthreads();
lds_data[block_size] = 0;
remaining_item_num = batch_item_num;
lds_data[max_block_size] = 0;
remaining_item_num = batch_item_num;
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{
if(i < batch_item_num)
......@@ -74,12 +74,12 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
__syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num);
reduce_sum(lds_data, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size;
}
auto log_batch_sum = ::log(to_hip_type(lds_data[block_size])) + batch_max;
auto log_batch_sum = ::log(to_hip_type(lds_data[max_block_size])) + batch_max;
for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{
......
......@@ -41,7 +41,7 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
// load data to lds and compute the batch max
size_t remaining_item_num = batch_item_num;
size_t round_item_num = (batch_item_num + block_size - 1) / block_size * block_size;
lds_data[block_size] = input[0];
lds_data[max_block_size] = input[0];
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{
if(i < batch_item_num)
......@@ -53,16 +53,16 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
__syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num);
reduce_max(lds_data, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size;
}
auto batch_max = lds_data[block_size];
auto batch_max = lds_data[max_block_size];
__syncthreads();
lds_data[block_size] = 0;
remaining_item_num = batch_item_num;
lds_data[max_block_size] = 0;
remaining_item_num = batch_item_num;
for(size_t i = thr_idx; i < round_item_num; i += block_size)
{
if(i < batch_item_num)
......@@ -75,11 +75,11 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
__syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num);
reduce_sum(lds_data, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size;
}
auto batch_sum = lds_data[block_size];
auto batch_sum = lds_data[max_block_size];
for(size_t i = thr_idx; i < batch_item_num; i += block_size)
{
......
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGMAX_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/gpu/device/argmax.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_argmax
{
op::argmax op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_ARGMIN_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARGMIN_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/gpu/device/argmin.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_argmin
{
op::argmin op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -11,7 +11,8 @@ namespace gpu {
namespace device {
template <class T>
inline __device__ void reduce_max(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
inline __device__ void
reduce_max(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num, size_t max_index)
{
while(true)
{
......@@ -30,8 +31,36 @@ inline __device__ void reduce_max(T* data_ptr, size_t block_size, size_t thr_idx
if(thr_idx == 0)
{
data_ptr[block_size] =
(data_ptr[0] < data_ptr[block_size]) ? data_ptr[block_size] : data_ptr[0];
data_ptr[max_index] =
(data_ptr[0] < data_ptr[max_index]) ? data_ptr[max_index] : data_ptr[0];
}
__syncthreads();
}
template <class T>
inline __device__ void
reduce_min(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num, size_t min_index)
{
while(true)
{
auto stride = (item_num + 1) / 2;
auto size = item_num / 2;
for(size_t i = thr_idx; i < size; i += block_size)
{
data_ptr[i] = ::min(to_hip_type(data_ptr[i]), to_hip_type(data_ptr[i + stride]));
}
__syncthreads();
item_num = stride;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
data_ptr[min_index] =
(data_ptr[0] > data_ptr[min_index]) ? data_ptr[min_index] : data_ptr[0];
}
__syncthreads();
......@@ -78,9 +107,8 @@ inline __device__ void reduce_argmax(T* data_ptr,
template <class T>
inline __device__ void
reduce_argmin(T* data_ptr, int64_t* index_ptr, size_t block_size, size_t thr_idx, size_t item_num)
reduce_argmin(T* data_ptr, int64_t* index_ptr, size_t block_size, size_t thr_idx, size_t item_num, size_t min_index)
{
size_t min_index = item_num;
while(true)
{
auto stride = (item_num + 1) / 2;
......@@ -113,7 +141,8 @@ reduce_argmin(T* data_ptr, int64_t* index_ptr, size_t block_size, size_t thr_idx
}
template <class T>
inline __device__ void reduce_sum(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num)
inline __device__ void
reduce_sum(T* data_ptr, size_t block_size, size_t thr_idx, size_t item_num, size_t sum_index)
{
while(true)
{
......@@ -132,7 +161,7 @@ inline __device__ void reduce_sum(T* data_ptr, size_t block_size, size_t thr_idx
if(thr_idx == 0)
{
data_ptr[block_size] += data_ptr[0];
data_ptr[sum_index] += data_ptr[0];
}
__syncthreads();
......
......@@ -11,6 +11,8 @@
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/argmax.hpp>
#include <migraphx/gpu/argmin.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment