Commit 9ee5b15f authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code refactoring for argmax and argmin operators

parent cffb1b1b
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -42,6 +36,42 @@ struct argmax
return {shape::int64_type, lens};
}
template <class T>
int64_t
calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const
{
auto max_val = input(indices.begin(), indices.end());
int64_t max_index = 0;
for(std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
if(max_val < input(indices.begin(), indices.end()))
{
max_val = input(indices.begin(), indices.end());
max_index = i;
}
}
return max_index;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num);
});
});
});
return result;
}
};
} // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#include <array>
//#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
//#include <migraphx/stringutils.hpp>
//#include <migraphx/literal.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
//#include <cmath>
//#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -42,6 +41,42 @@ struct argmin
return {shape::int64_type, lens};
}
template <class T>
int64_t
calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num) const
{
auto min_val = input(indices.begin(), indices.end());
int64_t min_index = 0;
for(std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
if(min_val > input(indices.begin(), indices.end()))
{
min_val = input(indices.begin(), indices.end());
min_index = i;
}
}
return min_index;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
std::size_t batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmin(input, data_idx, batch_item_num);
});
});
});
return result;
}
};
} // namespace op
......
#ifndef MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -663,43 +663,9 @@ struct cpu_argmax
std::string name() const { return "cpu::argmax"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <class T>
int64_t
calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num, int axis) const
{
auto max_val = input(indices.begin(), indices.end());
int64_t max_index = 0;
for(std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
if(max_val < input(indices.begin(), indices.end()))
{
max_val = input(indices.begin(), indices.end());
max_index = i;
}
}
return max_index;
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto batch_lens = args.front().get_shape().lens();
std::size_t batch_item_num = batch_lens[op.axis];
batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(batch_shape.elements(), [&](auto i) {
auto data_idx = batch_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num, op.axis);
});
});
});
return result;
return op.compute(output_shape, args);
}
};
......@@ -716,43 +682,9 @@ struct cpu_argmin
std::string name() const { return "cpu::argmin"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
template <class T>
int64_t
calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num, int axis) const
{
auto min_val = input(indices.begin(), indices.end());
int64_t min_index = 0;
for(std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
if(min_val > input(indices.begin(), indices.end()))
{
min_val = input(indices.begin(), indices.end());
min_index = i;
}
}
return min_index;
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto batch_lens = args.front().get_shape().lens();
std::size_t batch_item_num = batch_lens[op.axis];
batch_lens[op.axis] = 1;
shape batch_shape{shape::int32_type, batch_lens};
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(batch_shape.elements(), [&](auto i) {
auto data_idx = batch_shape.multi(i);
output[i] = this->calc_argmin(input, data_idx, batch_item_num, op.axis);
});
});
});
return result;
return op.compute(output_shape, args);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment