"include/vscode:/vscode.git/clone" did not exist on "5d73dd3e579db7e58fe27dfa5d1d52d19dff9c5e"
Commit 1adf6096 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add scatter op

parent e00479af
...@@ -144,6 +144,7 @@ register_migraphx_ops( ...@@ -144,6 +144,7 @@ register_migraphx_ops(
round round
rsqrt rsqrt
scalar scalar
scatter
sigmoid sigmoid
sign sign
sinh sinh
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter
{
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "scatter"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3).standard();
return inputs.front();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
// max dimension in axis
visit_all(result, args[0], args[2])([&](auto output, auto data, auto update) {
std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) {
auto ind_s = indices.get_shape();
shape_for_each(ind_s, [&](const auto& idx) {
auto out_idx = idx;
out_idx[axis] = indices[ind_s.index(idx)];
output[output_shape.index(out_idx)] = update[ind_s.index(idx)];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -80,6 +80,7 @@ ...@@ -80,6 +80,7 @@
#include <migraphx/op/round.hpp> #include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp> #include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp> #include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/sigmoid.hpp> #include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp> #include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp> #include <migraphx/op/sinh.hpp>
......
...@@ -35,6 +35,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -35,6 +35,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{"Reciprocal", "recip"}, {"Reciprocal", "recip"},
{"Relu", "relu"}, {"Relu", "relu"},
{"Round", "round"}, {"Round", "round"},
{"Scatter", "scatter"},
{"Sigmoid", "sigmoid"}, {"Sigmoid", "sigmoid"},
{"Sign", "sign"}, {"Sign", "sign"},
{"Sin", "sin"}, {"Sin", "sin"},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment