Commit f06f6aa3 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from develop branch

parents 80a35596 ebfe9735
......@@ -7,6 +7,14 @@
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
namespace migraphx {
namespace driver {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -17,6 +25,7 @@ struct loader
std::string file_type;
bool is_nhwc = true;
unsigned trim = 0;
bool optimize = false;
void parse(argument_parser& ap)
{
......@@ -26,6 +35,7 @@ struct loader
ap(is_nhwc, {"--nhwc"}, ap.help("Treat tensorflow format as nhwc"), ap.set_value(true));
ap(is_nhwc, {"--nchw"}, ap.help("Treat tensorflow format as nchw"), ap.set_value(false));
ap(trim, {"--trim", "-t"}, ap.help("Trim instructions from the end"));
ap(optimize, {"--optimize"}, ap.help("Optimize when reading"), ap.set_value(true));
}
program load()
......@@ -48,6 +58,20 @@ struct loader
auto last = std::prev(p.end(), trim);
p.remove_instructions(last, p.end());
}
if(optimize)
migraphx::run_passes(p,
{
migraphx::eliminate_identity{},
migraphx::dead_code_elimination{},
migraphx::simplify_algebra{},
migraphx::dead_code_elimination{},
migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{},
migraphx::propagate_constant{},
migraphx::dead_code_elimination{},
migraphx::eliminate_pad{},
migraphx::dead_code_elimination{},
});
return p;
}
};
......
#ifndef MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#define MIGRAPHX_GUARD_RTGLIB_ARRAY_HPP
#include <migraphx/config.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/requires.hpp>
#include <type_traits>
#include <array>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace detail {
template <class R, class...>
struct array_type
{
using type = R;
};
template <class... Ts>
struct array_type<void, Ts...> : std::common_type<Ts...>
{
};
template <class R, class... Ts>
using array_type_t = typename array_type<R, Ts...>::type;
template <class T, std::size_t N, std::size_t... I>
constexpr std::array<std::remove_cv_t<T>, N> to_array_impl(T (&a)[N], seq<I...>)
{
return {{a[I]...}};
}
} // namespace detail
template <class Result = void, class... Ts, MIGRAPHX_REQUIRES((sizeof...(Ts) > 0))>
constexpr std::array<detail::array_type_t<Result, Ts...>, sizeof...(Ts)> make_array(Ts&&... xs)
{
return {static_cast<detail::array_type_t<Result, Ts...>>(std::forward<Ts>(xs))...};
}
constexpr std::array<int, 0> make_array() { return {}; }
template <class T, std::size_t N>
constexpr auto to_array(T (&a)[N])
{
return detail::to_array_impl(a, detail::gens<N>{});
}
namespace detail {
template <std::size_t Offset = 0, class Array, std::size_t... I>
constexpr auto rearray_impl(Array a, seq<I...>)
{
return make_array(a[I + Offset]...);
}
} // namespace detail
template <class T, std::size_t N>
constexpr auto pop_front(std::array<T, N> a)
{
return detail::rearray_impl(a, detail::gens<N - 1>{});
}
template <class T, std::size_t N>
constexpr auto pop_back(std::array<T, N> a)
{
return detail::rearray_impl<1>(a, detail::gens<N - 1>{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -15,6 +15,12 @@ struct swallow
}
};
template <class T>
auto tuple_size(const T&)
{
return typename std::tuple_size<T>::type{};
}
namespace detail {
template <class R, class F>
......@@ -83,6 +89,12 @@ constexpr auto sequence_c(F&& f)
return detail::sequence_c_impl(f, detail::gens<N>{});
}
template <class IntegerConstant, class F>
constexpr auto sequence(IntegerConstant ic, F&& f)
{
return sequence_c<ic>(f);
}
template <class F, class... Ts>
constexpr void each_args(F f, Ts&&... xs)
{
......@@ -95,9 +107,9 @@ constexpr void each_args(F)
}
template <class F, class T>
auto unpack(F f, T& x)
auto unpack(F f, T&& x)
{
return sequence_c<std::tuple_size<T>{}>([&](auto... is) { f(std::get<is>(x)...); });
return sequence(tuple_size(x), [&](auto... is) { f(std::get<is>(static_cast<T&&>(x))...); });
}
/// Implements a fix-point combinator
......@@ -149,6 +161,52 @@ auto index_of(T& x)
return [&](auto&& y) { return x[y]; };
}
template <class T, class... Ts>
decltype(auto) front_args(T&& x, Ts&&...)
{
return static_cast<T&&>(x);
}
template <class... Ts>
decltype(auto) back_args(Ts&&... xs)
{
return std::get<sizeof...(Ts) - 1>(std::tuple<Ts&&...>(static_cast<Ts&&>(xs)...));
}
template <class T, class... Ts>
auto pop_front_args(T&&, Ts&&... xs)
{
return [&](auto f) { f(static_cast<Ts&&>(xs)...); };
}
template <class... Ts>
auto pop_back_args(Ts&&... xs)
{
return [&](auto f) {
using tuple_type = std::tuple<Ts&&...>;
auto t = tuple_type(static_cast<Ts&&>(xs)...);
return sequence_c<sizeof...(Ts) - 1>(
[&](auto... is) { return f(std::get<is>(static_cast<tuple_type&&>(t))...); });
};
}
template <class T>
struct always_f
{
T x;
template <class... Ts>
constexpr T operator()(Ts&&...) const
{
return x;
}
};
template <class T>
auto always(T x)
{
return always_f<T>{x};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -25,7 +25,7 @@ constexpr T normalize(unsigned long z)
template <class T, MIGRAPHX_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})>
constexpr T normalize(unsigned long z)
{
const auto max = std::numeric_limits<T>::max();
const auto max = std::numeric_limits<T>::max() / 64;
const auto half_max = max / 2;
return half_max - (z % max);
}
......@@ -33,7 +33,7 @@ constexpr T normalize(unsigned long z)
template <class T, MIGRAPHX_REQUIRES(not is_signed<T>{} and std::is_integral<T>{})>
constexpr T normalize(unsigned long z)
{
const auto max = std::numeric_limits<T>::max();
const auto max = std::numeric_limits<T>::max() / 64;
return z % max;
}
......
......@@ -79,6 +79,7 @@ struct literal : raw_data<literal>
template <class Iterator>
void fill(Iterator start, Iterator end)
{
assert(std::distance(start, end) == m_shape.elements());
if(m_shape.standard())
{
m_shape.visit_type([&](auto as) { std::copy(start, end, as.from(buffer.get())); });
......
......@@ -8,6 +8,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/config.hpp>
#include <unordered_map>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -20,6 +21,12 @@ struct matcher_context
std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
template <class M>
bool matched(M m, instruction_ref ins)
{
return m.match(*this, ins) != this->not_found();
}
private:
instruction_ref last;
};
......@@ -205,74 +212,147 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return result;
}
/// Find matches for an instruction in the program
template <class... Ms>
void find_matches(program& p, instruction_ref ins, Ms&&... ms)
{
bool match = false;
each_args(
[&](auto&& m) {
if(match)
return;
auto r = match_instruction(p, ins, m.matcher());
if(r.result == p.end())
return;
m.apply(p, r);
match = true;
},
ms...);
}
/// Find matches in a program
template <class... Ms>
void find_matches(program& p, Ms&&... ms)
{
for(auto ins : iterator_for(p))
{
bool match = false;
each_args(
[&](auto&& m) {
if(match)
return;
auto r = match_instruction(p, ins, m.matcher());
if(r.result == p.end())
return;
m.apply(p, r);
match = true;
},
ms...);
find_matches(p, ins, ms...);
}
}
template <class... Ts>
auto all_of(Ts... ms)
struct lazy_and
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, ins) != ctx.not_found();
})(true, ms...);
if(matches)
return ins;
return ctx.not_found();
});
}
template <class F, class G>
bool operator()(F f, G g) const
{
return f() and g();
}
};
template <class... Ts>
auto none_of(Ts... ms)
struct lazy_or
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, ins) == ctx.not_found();
})(true, ms...);
if(matches)
return ins;
return ctx.not_found();
});
template <class F, class G>
bool operator()(F f, G g) const
{
return f() or g();
}
};
template <class Op, bool Start, bool Matches>
struct match_fold_f
{
template <class... Ms>
static bool fold_matchers(matcher_context& ctx, instruction_ref ins, Ms... ms)
{
Op op;
auto matched = [&](auto m) { return [=, &ctx] { return ctx.matched(m, ins); }; };
return fold([&](auto x, auto y) { return op(always(x), matched(y)); })(Start, ms...);
}
template <class Pack>
static bool fold_matchers_pack(matcher_context& ctx, instruction_ref ins, Pack p)
{
return p([&](auto... ms) { return match_fold_f::fold_matchers(ctx, ins, ms...); });
}
template <class... Ts>
auto operator()(Ts... ms) const
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
if(matches == Matches)
return ins;
return ctx.not_found();
});
}
template <class Selector>
auto operator[](Selector select) const
{
return [=](auto... ms) {
// Workaround ICE on gcc by packing matchers into an object
auto mpack = pack(ms...);
return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
Op op;
bool matches = Start;
select(start, [&](auto ins) {
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm);
});
if(matches == Matches)
return start;
return ctx.not_found();
});
};
}
};
const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
const constexpr auto any_of = match_fold_f<lazy_or, false, true>{};
const constexpr auto none_of = match_fold_f<lazy_or, false, false>{};
inline auto inputs()
{
return [](auto ins, auto f) {
for(auto&& x : ins->inputs())
f(x);
};
}
template <class... Ts>
auto any_of(Ts... ms)
inline auto outputs()
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x or y.match(ctx, ins) != ctx.not_found();
})(false, ms...);
if(matches)
return ins;
return ctx.not_found();
});
return [](auto ins, auto f) {
for(auto&& x : ins->outputs())
f(x);
};
}
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
MIGRAPHX_PRED_MATCHER(none, instruction_ref) { return false; }
MIGRAPHX_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
MIGRAPHX_PRED_MATCHER(not_standard_shape, instruction_ref ins)
{
return not ins->get_shape().standard();
}
MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
{
return ins->get_shape().broadcasted();
}
MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
{
return ins->get_shape().transposed();
}
MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
{
if(ins->inputs().empty())
return false;
auto s = ins->inputs().front()->get_shape();
return std::all_of(
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().size() == 1)
......@@ -289,10 +369,39 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return ctx.not_found();
}
inline auto name(std::string name)
template <class... Ms>
auto skip_output(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) {
if(ins->outputs().size() == 1)
{
auto next = ins->outputs().front();
if(ctx.matched(m, next))
{
auto skipped_next = self(next);
if(skipped_next != ctx.not_found())
return skipped_next;
}
return next;
}
return ctx.not_found();
})(start);
});
}
inline auto name(std::string s)
{
return make_basic_pred_matcher(
[ =, name = std::move(name) ](instruction_ref ins) { return ins->name() == name; });
[ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; });
}
inline auto name(std::unordered_set<std::string> names)
{
return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
return names.count(ins->name()) > 0;
});
}
inline auto nargs(std::size_t n)
......@@ -338,6 +447,23 @@ inline auto either_arg(std::size_t i, std::size_t j)
};
}
template <class M>
auto same_shape(M m)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
auto i = m.match(ctx, ins);
if(i != ctx.not_found() and i->get_shape() == ins->get_shape())
return ins;
return ctx.not_found();
});
}
template <class... Ms>
auto same_shape(Ms... ms)
{
return all_of(same_shape(ms)...);
}
} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct argmax
{
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "argmax"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMAX: axis is out of range.");
}
lens[axis] = 1;
return {shape::int64_type, lens};
}
template <class T>
int64_t calc_argmax(T& input, std::vector<std::size_t>& indices, size_t item_num) const
{
auto max_val = input(indices.begin(), indices.end());
int64_t max_index = 0;
for(std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
auto cur_val = input(indices.begin(), indices.end());
if(max_val < cur_val)
{
max_val = cur_val;
max_index = i;
}
}
return max_index;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmax(input, data_idx, batch_item_num);
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct argmin
{
int64_t axis = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
std::string name() const { return "argmin"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMIN: axis is out of range.");
}
lens[axis] = 1;
return {shape::int64_type, lens};
}
template <class T>
int64_t calc_argmin(T& input, std::vector<std::size_t>& indices, size_t item_num) const
{
auto min_val = input(indices.begin(), indices.end());
int64_t min_index = 0;
for(std::size_t i = 1; i < item_num; ++i)
{
indices[axis] = i;
auto cur_val = input(indices.begin(), indices.end());
if(min_val > cur_val)
{
min_val = cur_val;
min_index = i;
}
}
return min_index;
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
std::size_t batch_item_num = args.front().get_shape().lens()[axis];
result.visit([&](auto output) {
args[0].visit([&](auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto data_idx = output_shape.multi(i);
output[i] = this->calc_argmin(input, data_idx, batch_item_num);
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ERF_HPP
#define MIGRAPHX_GUARD_OPERATORS_ERF_HPP
#include <migraphx/op/unary.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct erf : unary<erf>
{
auto apply() const
{
return [](auto x) { return std::erf(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_LOGSOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -35,14 +35,28 @@ struct multibroadcast
auto input = inputs.at(0);
if(input.lens().empty())
MIGRAPHX_THROW("inputs dimensions should be > 0");
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
}
if(input.lens().size() > output_lens.size())
MIGRAPHX_THROW("inputs dimensions should <= output size");
{
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should <= output size");
}
std::vector<size_t> bcast_strides(output_lens.size(), 0);
auto offset = output_lens.size() - input.lens().size();
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] != input.lens()[i] and input.lens()[i] != 1)
{
MIGRAPHX_THROW("MULTIBROADCAST: input shape {" + to_string_range(input.lens()) +
"} cannot be broadcasted to {" + to_string_range(output_lens) +
"}!");
}
}
std::vector<size_t> bcast_strides(output_lens.size(), 0);
for(std::ptrdiff_t i = input.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] == input.lens()[i])
{
......
#ifndef MIGRAPHX_GUARD_OPERATORS_SUM_HPP
#define MIGRAPHX_GUARD_OPERATORS_SUM_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct reduce_sum
{
std::vector<std::size_t> axes;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
std::string name() const { return "reduce_sum"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
auto s = inputs.at(0);
auto lens = s.lens();
for(auto axis : axes)
lens[axis] = 1;
return {s.type(), lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
shape_for_each(input.get_shape(), [&](auto&& in_idx) {
auto out_idx = in_idx;
for(auto axis : axes)
out_idx[axis] = 0;
output(out_idx.begin(), out_idx.end()) += input(in_idx.begin(), in_idx.end());
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_SOFTMAX_HPP
#include <array>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -5,6 +5,8 @@
#include <migraphx/op/abs.hpp>
#include <migraphx/op/acos.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
#include <migraphx/op/asin.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/atan.hpp>
......@@ -22,6 +24,7 @@
#include <migraphx/op/div.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/erf.hpp>
#include <migraphx/op/exp.hpp>
#include <migraphx/op/flatten.hpp>
#include <migraphx/op/gather.hpp>
......@@ -44,6 +47,7 @@
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp>
......
......@@ -33,6 +33,10 @@ auto generic_find_impl(rank<0>, C&& c, const T& x)
return std::find(c.begin(), c.end(), x);
}
struct empty
{
};
} // namespace detail
template <class C, class T>
......@@ -71,6 +75,12 @@ bool all_of(const std::initializer_list<T>& c, const Predicate& p)
return std::all_of(c.begin(), c.end(), p);
}
template <class Predicate>
bool all_of(detail::empty, const Predicate&)
{
return true;
}
template <class C, class Predicate>
bool any_of(const C& c, const Predicate& p)
{
......@@ -83,6 +93,12 @@ bool any_of(const std::initializer_list<T>& c, const Predicate& p)
return std::any_of(c.begin(), c.end(), p);
}
template <class Predicate>
bool any_of(detail::empty, const Predicate&)
{
return false;
}
template <class C, class Predicate>
bool none_of(const C& c, const Predicate& p)
{
......@@ -95,6 +111,12 @@ bool none_of(const std::initializer_list<T>& c, const Predicate& p)
return std::none_of(c.begin(), c.end(), p);
}
template <class Predicate>
bool none_of(detail::empty, const Predicate&)
{
return true;
}
template <class Range, class Iterator>
void copy(Range&& r, Iterator it)
{
......
......@@ -212,6 +212,25 @@ auto visit_all(T&& x, Ts&&... xs)
};
}
template <class T>
auto visit_all(const std::vector<T>& x)
{
auto&& s = x.front().get_shape();
if(!std::all_of(
x.begin(), x.end(), [&](const T& y) { return y.get_shape().type() == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
return [&](auto v) {
s.visit_type([&](auto as) {
using type = typename decltype(as)::type;
std::vector<tensor_view<type>> result;
std::transform(x.begin(), x.end(), std::back_inserter(result), [&](const auto& y) {
return make_view(y.get_shape(), as.from(y.data()));
});
v(result);
});
};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -99,6 +99,8 @@ struct shape
/// Map element index to space index
std::size_t index(std::size_t i) const;
std::vector<std::size_t> multi(std::size_t i) const;
/// Returns true if the shape is packed with no padding
bool packed() const;
/// Returns true is the shape has been transposed. That is the strides are not in descending
......
......@@ -40,6 +40,7 @@ struct onnx_parser
add_generic_op("Sigmoid", op::sigmoid{});
add_generic_op("Abs", op::abs{});
add_generic_op("Exp", op::exp{});
add_generic_op("Erf", op::erf{});
add_generic_op("Log", op::log{});
// disable dropout for inference
add_generic_op("Dropout", op::identity{});
......@@ -63,6 +64,8 @@ struct onnx_parser
add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{});
add_mem_op("ArgMax", &onnx_parser::parse_argmax);
add_mem_op("ArgMin", &onnx_parser::parse_argmin);
add_mem_op("Clip", &onnx_parser::parse_clip);
add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
......@@ -93,6 +96,7 @@ struct onnx_parser
add_mem_op("GRU", &onnx_parser::parse_gru);
add_mem_op("LSTM", &onnx_parser::parse_lstm);
add_mem_op("Pad", &onnx_parser::parse_pad);
add_mem_op("ReduceSum", &onnx_parser::parse_reduce_sum);
// init the activation function map
init_actv_func();
......@@ -182,7 +186,15 @@ struct onnx_parser
s0.end(),
s1.begin() + offset,
out_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
[&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" +
to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
}
......@@ -266,6 +278,60 @@ struct onnx_parser
return prog.add_instruction(op::logsoftmax{axis}, std::move(args));
}
instruction_ref parse_argmax(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmax{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(op::argmax{axis}, std::move(args));
}
}
instruction_ref parse_argmin(const std::string&,
const attribute_map& attributes,
std::vector<instruction_ref> args)
{
int64_t axis = 0;
if(contains(attributes, "axis"))
{
axis = static_cast<int64_t>(parse_value(attributes.at("axis")).at<int>());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 0)
{
auto ins = prog.add_instruction(op::argmin{axis}, std::move(args));
return prog.add_instruction(op::squeeze{{axis}}, ins);
}
else
{
return prog.add_instruction(op::argmin{axis}, std::move(args));
}
}
instruction_ref
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
......@@ -1222,6 +1288,40 @@ struct onnx_parser
return {hidden_states, last_output, last_cell_output};
}
instruction_ref parse_reduce_sum(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
std::size_t n_dim = args.front()->get_shape().lens().size();
// default to reduce over all dimensions
std::vector<std::size_t> axes(n_dim);
std::iota(axes.begin(), axes.end(), 0);
if(contains(attributes, "axes"))
{
axes.clear();
auto&& attr_axes = attributes["axes"].ints();
axes = std::vector<std::size_t>(attr_axes.begin(), attr_axes.end());
}
int keep_dims = 1;
if(contains(attributes, "keepdims"))
{
keep_dims = parse_value(attributes.at("keepdims")).at<int>();
}
if(keep_dims == 1)
{
return prog.add_instruction(op::reduce_sum{axes}, std::move(args));
}
else
{
auto ins = prog.add_instruction(op::reduce_sum{axes}, std::move(args));
std::vector<int64_t> squeeze_axes{axes.begin(), axes.end()};
return prog.add_instruction(op::squeeze{squeeze_axes}, ins);
}
}
void parse_from(std::istream& is)
{
onnx::ModelProto model;
......
......@@ -2,7 +2,6 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
......
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment