Commit 99d1fed4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add cpu implmentations of the argmax and argmin operators.

parent 66bae091
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct argmax
{
int axis = 0;
int keep_dims = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.keep_dims, "keep_dims"));
}
std::string name() const { return "argmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMAX: axis is out of range.");
}
lens[axis] = 1;
if (!keep_dims)
{
lens.erase(lens.begin() + axis);
}
return {shape::int64_type, lens};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct argmin
{
int axis = 0;
int keep_dims = 1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"), f(self.keep_dims, "keep_dims"));
}
std::string name() const { return "argmin"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int n_dim = static_cast<int>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMIN: axis is out of range.");
}
lens[axis] = 1;
if (!keep_dims)
{
lens.erase(lens.begin() + axis);
}
return {shape::int64_type, lens};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <migraphx/op/abs.hpp> #include <migraphx/op/abs.hpp>
#include <migraphx/op/acos.hpp> #include <migraphx/op/acos.hpp>
#include <migraphx/op/add.hpp> #include <migraphx/op/add.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/asin.hpp> #include <migraphx/op/asin.hpp>
#include <migraphx/op/as_shape.hpp> #include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp> #include <migraphx/op/atan.hpp>
......
...@@ -637,6 +637,82 @@ struct cpu_logsoftmax ...@@ -637,6 +637,82 @@ struct cpu_logsoftmax
} }
}; };
struct cpu_argmax
{
op::argmax op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "cpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
using value_type = batch_max(output_shape.elements(),
std::numeric_limits<value_type>::lowest());
auto data_shape = args[0].get_shape();
shape_for_each(data_shape, [&](auto idx) {
auto data_index = data_shape.index(idx);
idx[axis] = 0;
auto out_index = data_shape.index(idx);
if (batch_max[index] < input[data_index])
{
batch_max[index] = input[data_index];
output[index] = static_cast<int64_t>(data_index);
}
});
});
});
return result;
}
};
struct cpu_argmin
{
op::argmin op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "cpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
using value_type = batch_min(output_shape.elements(),
std::numeric_limits<value_type>::max());
auto data_shape = args[0].get_shape();
shape_for_each(data_shape, [&](auto idx) {
auto data_index = data_shape.index(idx);
idx[axis] = 0;
auto out_index = data_shape.index(idx);
if (batch_min[index] > input[data_index])
{
batch_min[index] = input[data_index];
output[index] = static_cast<int64_t>(data_index);
}
});
});
});
return result;
}
};
struct cpu_apply struct cpu_apply
{ {
program* prog; program* prog;
...@@ -656,6 +732,8 @@ struct cpu_apply ...@@ -656,6 +732,8 @@ struct cpu_apply
void init() void init()
{ {
apply_map["argmax"] = extend_op<cpu_argmax, op::argmax>();
apply_map["argmin"] = extend_op<cpu_argmin, op::argmin>();
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>(); apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment