Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_ADD_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/scatter.hpp>
// Scatter op. with "add" function as reduction.
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter_add : scatter<scatter_add>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter methods, there are three different reduction functions.
auto reduction() const
{
return [](auto& x, const auto& y) { x += y; };
}
// name of this struct is automatically assigned by the op_name<>
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_MUL_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
#include <migraphx/op/scatter.hpp>
// Scatter op. with "multiply" as the reduction function.
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter_mul : scatter<scatter_mul>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter operators, there are three different reduction functions.
auto reduction() const
{
return [](auto& x, const auto& y) { x *= y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_NONE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/op/scatter.hpp>
#include <cmath>
#include <utility>
// Scatter op. with "none" as the reduction function (just copies the value). This is identical to
// the previously existing Scatter op.
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatter_none : scatter<scatter_none>
{
// reduction (pointwise operation) is called by the parent struct's compute() method.
// It works much like a virtual function overload.
// For the scatter operators, there are three different reduction functions.
auto reduction() const
{
return [](auto& x, const auto& y) { x = y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_ADD_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_ADD_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_add : scatternd_op<scatternd_add>
{
scatternd_add() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x += y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_MUL_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_MUL_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_mul : scatternd_op<scatternd_mul>
{
scatternd_mul() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x *= y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_NONE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_NONE_HPP
#include <migraphx/op/scatternd_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct scatternd_none : scatternd_op<scatternd_none>
{
scatternd_none() {}
auto reduction() const
{
return [](auto& x, const auto& y) { x = y; };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTERND_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTERND_OP_HPP
#include <migraphx/op/name.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
template <class Derived>
struct scatternd_op : op_name<Derived>
{
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3);
auto r = inputs.front().lens().size();
auto q = inputs.at(1).lens().size();
auto k = inputs.at(1).lens().back();
auto ind_lens = inputs.at(1).lens();
auto upd_lens = inputs.back().lens();
auto data_lens = inputs.front().lens();
if(k > r)
MIGRAPHX_THROW("ScatterND: index of size " + std::to_string(k) +
" is too large for tensor of rank " + std::to_string(r));
if(not(std::equal(ind_lens.begin(), ind_lens.begin() + q - 1, upd_lens.begin()) and
std::equal(data_lens.begin() + k, data_lens.end(), upd_lens.begin() + q - 1)))
MIGRAPHX_THROW("ScatterND: incorrect update shape. update.lens != indices.lens[0:q-1] "
"++ data.lens[k:r-1]");
auto s = inputs.front();
if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto& self = static_cast<const Derived&>(*this);
visit_all(result, args[0], args[2])([&](auto output, auto data, auto updates) {
std::copy(data.begin(), data.end(), output.begin());
args[1].visit([&](auto indices) {
auto updates_shape = updates.get_shape();
auto updates_std = shape{updates_shape.type(), updates_shape.lens()};
auto indices_shape = indices.get_shape();
auto k = indices_shape.lens().back();
auto q = indices_shape.lens().size();
auto r = output_shape.lens().size();
par_for(updates_shape.elements(), [&](const auto i) {
auto updates_idx = updates_std.multi(i);
std::vector<std::size_t> indices_idx(q, 0);
std::copy(
updates_idx.begin(), updates_idx.begin() + q - 1, indices_idx.begin());
auto index_start = indices.begin() +
indices_shape.index(indices_idx.begin(), indices_idx.end());
auto index_end = index_start + k;
std::vector<std::size_t> out_idx(r, 0);
std::copy(index_start, index_end, out_idx.begin());
std::copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
self.reduction()(output[output_shape.index(out_idx)], updates[i]);
});
});
});
return result;
}
auto init() const {}
scatternd_op() {}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
...@@ -19,6 +18,7 @@ namespace op { ...@@ -19,6 +18,7 @@ namespace op {
struct sigmoid : unary<sigmoid> struct sigmoid : unary<sigmoid>
{ {
std::string point_op() const { return "1.f / (1.f + ${function:exp}(-${0}))"; }
auto apply() const auto apply() const
{ {
return [](auto x) { return 1.f / (1.f + std::exp(-x)); }; return [](auto x) { return 1.f / (1.f + std::exp(-x)); };
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
...@@ -19,6 +18,7 @@ namespace op { ...@@ -19,6 +18,7 @@ namespace op {
struct sign : unary<sign> struct sign : unary<sign>
{ {
std::string point_op() const { return "(${0} > 0 ? 1 : ((${0} < 0) ? -1 : 0))"; }
auto apply() const auto apply() const
{ {
return [](auto x) { return (x > 0 ? 1 : ((x < 0) ? -1 : 0)); }; return [](auto x) { return (x > 0 ? 1 : ((x < 0) ? -1 : 0)); };
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SLICE_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_SLICE_HPP
#define MIGRAPHX_GUARD_OPERATORS_SLICE_HPP #define MIGRAPHX_GUARD_OPERATORS_SLICE_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
#include <vector>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -28,6 +27,23 @@ struct slice ...@@ -28,6 +27,23 @@ struct slice
return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends")); return pack(f(self.axes, "axes"), f(self.starts, "starts"), f(self.ends, "ends"));
} }
value attributes() const
{
value normalize = value::object{};
normalize["axes"] = value::array{normalize_attribute::include_min};
normalize["starts"] = value::array{normalize_attribute::clip_max,
normalize_attribute::clip_min,
normalize_attribute::include_max,
normalize_attribute::use_len,
normalize_attribute::include_min};
normalize["ends"] = value::array{normalize_attribute::clip_max,
normalize_attribute::clip_min,
normalize_attribute::include_max,
normalize_attribute::use_len,
normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "slice"; } std::string name() const { return "slice"; }
auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const auto fix_index(const std::vector<std::size_t>& lens, std::size_t axis, int64_t index) const
...@@ -61,16 +77,24 @@ struct slice ...@@ -61,16 +77,24 @@ struct slice
return offset; return offset;
} }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto t = input_shape.type(); auto t = input_shape.type();
const auto& old_lens = input_shape.lens(); const auto& old_lens = input_shape.lens();
const auto& old_strides = input_shape.strides(); const auto& old_strides = input_shape.strides();
if(std::any_of(
axes.begin(), axes.end(), [&](auto i) { return (i >= old_lens.size() and i < 0); }))
{
MIGRAPHX_THROW("SLICE: input axis " + to_string_range(axes) + " out of range");
}
if(starts.size() != axes.size() || axes.size() != ends.size()) if(starts.size() != axes.size() || axes.size() != ends.size())
{ {
MIGRAPHX_THROW("inconsistent sizes"); MIGRAPHX_THROW("SLICE: inconsistent sizes");
} }
std::vector<std::size_t> new_lens = old_lens; std::vector<std::size_t> new_lens = old_lens;
for(std::size_t i = 0; i < axes.size(); i++) for(std::size_t i = 0; i < axes.size(); i++)
{ {
...@@ -80,6 +104,7 @@ struct slice ...@@ -80,6 +104,7 @@ struct slice
} }
return shape{t, new_lens, old_strides}; return shape{t, new_lens, old_strides};
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
auto input = args[0]; auto input = args[0];
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP #ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP #define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
...@@ -11,7 +12,7 @@ namespace op { ...@@ -11,7 +12,7 @@ namespace op {
struct softmax struct softmax
{ {
int axis = 1; int64_t axis = 1;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -19,16 +20,26 @@ struct softmax ...@@ -19,16 +20,26 @@ struct softmax
return pack(f(self.axis, "axis")); return pack(f(self.axis, "axis"));
} }
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "softmax"; } std::string name() const { return "softmax"; }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1).standard(); check_shapes{inputs, *this}.has(1);
if(axis < 0 || axis >= inputs[0].lens().size()) if(inputs.at(0).packed())
{
return inputs.at(0);
}
else
{ {
MIGRAPHX_THROW("SoftMax: input axis value " + std::to_string(axis) + auto lens = inputs.at(0).lens();
" is out of range"); return {inputs.at(0).type(), lens};
} }
return inputs.at(0);
} }
auto output() const auto output() const
......
...@@ -9,6 +9,7 @@ namespace op { ...@@ -9,6 +9,7 @@ namespace op {
struct sqdiff : binary<sqdiff> struct sqdiff : binary<sqdiff>
{ {
std::string point_op() const { return "(${0} - ${1}) * (${0} - ${1})"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return (x - y) * (x - y); }; return [](auto x, auto y) { return (x - y) * (x - y); };
......
...@@ -2,13 +2,14 @@ ...@@ -2,13 +2,14 @@
#define MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP #define MIGRAPHX_GUARD_OPERATORS_SQUEEZE_HPP
#include <array> #include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/lifetime.hpp>
#include <cmath> #include <cmath>
#include <utility> #include <utility>
...@@ -26,49 +27,62 @@ struct squeeze ...@@ -26,49 +27,62 @@ struct squeeze
return pack(f(self.axes, "axes")); return pack(f(self.axes, "axes"));
} }
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1).standard(); check_shapes{inputs, *this}.has(1);
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
if(std::any_of( auto old_strides = input_shape.strides();
axes.begin(), axes.end(), [&](auto axis) { return input_shape.lens()[axis] != 1; })) if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return old_lens[axis] != 1; }))
{ {
MIGRAPHX_THROW("squeeze axis dimension should be equal to 1"); MIGRAPHX_THROW("squeeze axis dimension should be equal to 1");
} }
std::vector<std::size_t> new_lens; std::vector<std::size_t> new_lens;
std::vector<std::size_t> new_strides;
if(axes.empty()) if(axes.empty())
{ {
std::copy_if(old_lens.begin(), for(auto i : range(old_lens.size()))
old_lens.end(), {
std::back_inserter(new_lens), if(old_lens[i] != 1)
[](auto len) { return len != 1; }); {
new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
}
}
} }
else else
{ {
for(std::size_t i = 0; i < old_lens.size(); i++) for(auto i : range(old_lens.size()))
{ {
if(std::find(axes.begin(), axes.end(), i) == axes.end()) if(std::find(axes.begin(), axes.end(), i) == axes.end())
{ {
new_lens.push_back(old_lens[i]); new_lens.push_back(old_lens[i]);
new_strides.push_back(old_strides[i]);
} }
} }
} }
if(new_lens.empty()) if(new_lens.empty())
{ {
return shape{type}; return shape{type};
} }
else else
{ {
return shape{type, new_lens}; return shape{type, new_lens, new_strides};
} }
} }
argument compute(shape output_shape, std::vector<argument> args) const argument compute(shape output_shape, std::vector<argument> args) const
{ {
return {std::move(output_shape), std::move(args.front().data)}; return args[0].reshape(output_shape);
} }
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; } std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
}; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_STEP_HPP
#define MIGRAPHX_GUARD_OPERATORS_STEP_HPP
#include "migraphx/stringutils.hpp"
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/lifetime.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct step
{
std::vector<int64_t> axes;
std::vector<int64_t> steps;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"), f(self.steps, "steps"));
}
value attributes() const
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "step"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto input = inputs.at(0);
auto in_lens = input.lens();
auto t = input.type();
if(axes.size() != steps.size())
{
MIGRAPHX_THROW("STEP: attribute axes {" + to_string_range(axes) +
"} has different dimensions from step {" + to_string_range(steps) +
"}.");
}
if(std::any_of(axes.begin(), axes.end(), [&](auto axis) { return axis >= in_lens.size(); }))
{
MIGRAPHX_THROW("STEP: axis value is out of range!");
}
auto lens = in_lens;
auto strides = input.strides();
for(auto i : range(axes.size()))
{
auto axis = axes[i];
auto step = steps[i];
lens[axis] = (in_lens[axis] + step - 1) / step;
strides[axis] *= step;
}
return {t, lens, strides};
}
argument compute(shape output_shape, std::vector<argument> args) const
{
return args[0].reshape(output_shape);
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 0; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/binary.hpp> #include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
...@@ -19,6 +18,7 @@ namespace op { ...@@ -19,6 +18,7 @@ namespace op {
struct sub : binary<sub> struct sub : binary<sub>
{ {
std::string point_function() const { return "-"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return x - y; }; return [](auto x, auto y) { return x - y; };
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include <array> #include <array>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp> #include <migraphx/streamutils.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHER_HPP
#define MIGRAPHX_GUARD_OPERATORS_GATHER_HPP
#include <algorithm>
#include <migraphx/check_shapes.hpp>
#include <migraphx/config.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct topk
{
int64_t k = 1;
int64_t axis = 0;
bool largest = true;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.k, "k"), f(self.axis, "axis"), f(self.largest, "largest"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "topk"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs.at(0).lens();
auto type = inputs.at(0).type();
lens[axis] = k;
shape s_val{type, lens};
shape s_ind{shape::int64_type, lens};
return {{s_val, s_ind}};
}
template <class T, class Compare>
struct heap_vector
{
std::vector<T> data;
Compare compare;
heap_vector(const std::vector<T>& val, Compare comp) : data(val), compare(std::move(comp))
{
std::make_heap(data.begin(), data.end(), compare);
}
void try_push(T val)
{
if(not compare(val, data.front()))
return;
std::pop_heap(data.begin(), data.end(), compare);
data.back() = val;
std::push_heap(data.begin(), data.end(), compare);
}
std::vector<T> sort()
{
auto sorted_data = data;
std::sort_heap(sorted_data.begin(), sorted_data.end(), compare);
return sorted_data;
}
};
template <class T, class Compare>
heap_vector<T, Compare> make_heap(std::vector<T> val, Compare compare) const
{
return {std::move(val), std::move(compare)};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto vec_ss = output_shape.sub_shapes();
argument res_val{vec_ss.front()};
argument res_ind{vec_ss.back()};
auto in_s = args.front().get_shape();
auto out_s = vec_ss.front();
auto comp_lens = in_s.lens();
auto axis_dim = comp_lens[axis];
// compute shape
comp_lens[axis] = 1;
shape comp_s{in_s.type(), comp_lens};
visit_all(res_val, args.front())([&](auto out_val, auto input) {
auto* out_ind = res_ind.cast<int64_t>();
par_for(comp_s.elements(), [&](auto i) {
auto idx = comp_s.multi(i);
std::vector<std::size_t> indices(k);
std::iota(indices.begin(), indices.end(), 0);
auto comp = [&](auto i1, auto i2) {
auto idx1 = idx;
auto idx2 = idx;
idx1[axis] = i1;
idx2[axis] = i2;
return this->largest
? std::greater<>{}(input[in_s.index(idx1)], input[in_s.index(idx2)])
: std::less<>{}(input[in_s.index(idx1)], input[in_s.index(idx2)]);
};
auto hp = this->make_heap(indices, comp);
for(std::size_t ii = indices.size(); ii < axis_dim; ++ii)
{
hp.try_push(ii);
}
auto sorted_indices = hp.sort();
auto out_idx = idx;
auto in_idx = idx;
for(auto j : range(sorted_indices.size()))
{
out_idx[axis] = j;
in_idx[axis] = sorted_indices[j];
out_val[out_s.index(out_idx)] = input[in_s.index(in_idx)];
out_ind[out_s.index(out_idx)] = sorted_indices[j];
}
});
});
return {{res_val, res_ind}};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment