Unverified Commit 1b098fd7 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into type-string-driver

parents 05f2ee1c c0398ded
......@@ -16,7 +16,7 @@ struct manage_deleter
{
if(x != nullptr)
{
f(x);
(void)f(x);
}
}
};
......
#ifndef MIGRAPHX_GUARD_MARKER_HPP
#define MIGRAPHX_GUARD_MARKER_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN
/// Marker is an interface to general marking functions, such as rocTX markers.
#else
#ifdef TYPE_ERASED_DECLARATION
// Type-erased interface for:
struct marker
{
//
void mark_start(instruction_ref ins_ref);
//
void mark_start(const program& prog);
//
void mark_stop(instruction_ref ins);
//
void mark_stop(const program& prog);
};
#else
struct marker
{
// Constructors
marker() = default;
template <typename PrivateDetailTypeErasedT>
marker(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
marker& operator=(PrivateDetailTypeErasedT value)
{
using std::swap;
auto* derived = this->any_cast<PrivateDetailTypeErasedT>();
if(derived and private_detail_te_handle_mem_var.unique())
{
*derived = std::forward<PrivateDetailTypeErasedT>(value);
}
else
{
marker rhs(value);
swap(private_detail_te_handle_mem_var, rhs.private_detail_te_handle_mem_var);
}
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return this->type_id() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
void mark_start(instruction_ref ins_ref)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_start(ins_ref);
}
void mark_start(const program& prog)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_start(prog);
}
void mark_stop(instruction_ref ins)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_stop(ins);
}
void mark_stop(const program& prog)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().mark_stop(prog);
}
friend bool is_shared(const marker& private_detail_x, const marker& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual void mark_start(instruction_ref ins_ref) = 0;
virtual void mark_start(const program& prog) = 0;
virtual void mark_stop(instruction_ref ins) = 0;
virtual void mark_stop(const program& prog) = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
void mark_start(instruction_ref ins_ref) override
{
private_detail_te_value.mark_start(ins_ref);
}
void mark_start(const program& prog) override { private_detail_te_value.mark_start(prog); }
void mark_stop(instruction_ref ins) override { private_detail_te_value.mark_stop(ins); }
void mark_stop(const program& prog) override { private_detail_te_value.mark_stop(prog); }
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const marker* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(marker* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(marker& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const marker& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_MATCH_GELU_ERF_HPP
#define MIGRAPHX_GUARD_MATCH_GELU_ERF_HPP
#include <migraphx/config.hpp>
#include <migraphx/matcher.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace match {
namespace detail {
template <class F>
struct gelu_erf_matcher
{
F f;
auto erf_fn() const
{
return f("erf")(
used_once(),
arg(0)(used_once(),
f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"),
has_value(M_SQRT1_2, 1e-3)))));
}
auto add_erf() const
{
return f("add")(used_once(), either_arg(0, 1)(erf_fn(), has_value(1.0f)));
}
auto one_half() const { return has_value(0.5f); }
auto matcher() const { return unordered_tree(f("mul"), one_half(), add_erf(), any()); }
};
} // namespace detail
template <class F>
auto gelu_erf(F f)
{
return detail::gelu_erf_matcher<F>{f}.matcher();
}
inline auto gelu_erf()
{
return gelu_erf([](auto x) { return name(x); });
}
} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MATCH_GELU_ERF_HPP
#ifndef MIGRAPHX_GUARD_MATCH_GELU_TANH_HPP
#define MIGRAPHX_GUARD_MATCH_GELU_TANH_HPP
#include <migraphx/config.hpp>
#include <migraphx/matcher.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace match {
namespace detail {
template <class F>
struct gelu_tanh_matcher
{
F f;
auto pow_fn() const { return f("pow")(used_once(), arg(1)(has_value(3.0f))); }
auto tanh_fn() const
{
return f("tanh")(
used_once(),
arg(0)(f("mul")(either_arg(0, 1)(has_value(sqrt(M_2_PI), 1e-3),
f("add")(any_arg(0, 1)(f("mul")(either_arg(0, 1)(
has_value(0.044715f), pow_fn()))))))));
}
auto matcher() const
{
return f("mul")(used_once(),
either_arg(0, 1)(any().bind("x"),
f("add")(any_arg(0, 1)(f("mul")(
either_arg(0, 1)(has_value(0.5f), tanh_fn()))))));
}
};
} // namespace detail
template <class F>
auto gelu_tanh(F f)
{
return detail::gelu_tanh_matcher<F>{f}.matcher();
}
inline auto gelu_tanh()
{
return gelu_tanh([](auto x) { return name(x); });
}
} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MATCH_GELU_TANH_HPP
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_MATCH_LAYERNORM_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_MATCH_LAYERNORM_HPP
#include <migraphx/config.hpp>
#include <migraphx/matcher.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace match {
namespace detail {
template <class F>
struct layernorm_matcher
{
F f;
auto x_minus_mean() const
{
return f("sub")(arg(0)(any().bind("x")), arg(1)(skip_broadcasts(f("reduce_mean"))));
}
auto variance() const
{
return f("reduce_mean")(arg(0)(f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f)))));
}
auto layernorm_onnx() const
{
return f("div")(arg(0)(x_minus_mean()),
arg(1)(skip_broadcasts(f("sqrt")(
arg(0)(f("add")(either_arg(0, 1)(variance(), has_value(1e-12f))))))));
}
auto matcher() const { return layernorm_onnx(); }
};
} // namespace detail
template <class F>
auto layernorm(F f)
{
return detail::layernorm_matcher<F>{f}.matcher();
}
inline auto layernorm()
{
return layernorm([](auto x) { return name(x); });
}
} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -4,8 +4,10 @@
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/module.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/config.hpp>
#include <unordered_map>
#include <unordered_set>
......@@ -17,18 +19,51 @@ namespace match {
struct matcher_context
{
matcher_context(instruction_ref i) : last(i) {}
matcher_context(module& m) : mod(&m) {}
std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
template <class M>
bool matched(M m, instruction_ref ins)
{
return m.match(*this, ins) != this->not_found();
return has_value(m.match(*this, ins));
}
template <class M>
bool matched(M m, optional<instruction_ref> ins)
{
if(ins)
return has_value(m.match(*this, *ins));
return false;
}
template <class M, class I>
auto lazy_match(M m, I ins)
{
return [=] { return this->matched(m, ins); };
}
bool has_instruction(instruction_ref ins) const
{
if(mod == nullptr)
return true;
return mod->has_instruction(ins);
}
bool has_instruction(optional<instruction_ref> ins) const
{
if(ins)
return this->has_instruction(*ins);
return false;
}
bool is_last(instruction_ref ins) const
{
assert(mod->begin() != mod->end());
assert(this->has_instruction(ins));
return ins == std::prev(mod->end());
}
private:
instruction_ref last;
module* mod = nullptr;
};
/// Convert a predicate function into a matcher
......@@ -37,12 +72,11 @@ struct predicate_matcher
{
P p;
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
optional<instruction_ref> match(const matcher_context&, instruction_ref ins) const
{
assert(ins != ctx.not_found());
if(p(ins))
return ins;
return ctx.not_found();
return optional<instruction_ref>{ins};
return nullopt;
}
};
......@@ -52,11 +86,7 @@ struct function_matcher
{
F f;
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
assert(ins != ctx.not_found());
return f(ctx, ins);
}
auto match(matcher_context& ctx, instruction_ref ins) const { return f(ctx, ins); }
};
/// Convert a function into a matcher
......@@ -71,10 +101,15 @@ template <class M>
auto bind_match(M m, std::string name)
{
return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
[=, name = std::move(name)](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = m.match(ctx, ins);
if(result != ctx.not_found())
if(result)
{
if(not ctx.has_instruction(ins))
return nullopt;
ctx.instructions[name] = ins;
}
return result;
});
}
......@@ -87,10 +122,7 @@ struct bindable_matcher
auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
return m.match(ctx, ins);
}
auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
};
/// Create a bindable matcher
......@@ -118,9 +150,25 @@ using bool_list = std::initializer_list<bool>;
struct id_matcher
{
instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; }
auto match(matcher_context&, instruction_ref ins) const
{
return optional<instruction_ref>{ins};
}
};
// Forward declare class and constructors
template <class M>
struct basic_matcher;
template <class M>
basic_matcher<M> make_basic_matcher(M m);
template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f);
template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p);
/// The basic matcher provides the all_of composability of the matcher
template <class M>
struct basic_matcher
......@@ -132,26 +180,23 @@ struct basic_matcher
{
// Copy m because we cant capture `this` by value
auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
return make_basic_fun_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins);
if(result != ctx.not_found())
if(result)
{
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, result) != ctx.not_found();
})(true, ms...);
bool matches =
fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
if(matches)
return result;
}
return ctx.not_found();
return nullopt;
});
}
auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
return m.match(ctx, ins);
}
auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
};
/// Create a basic matcher from a matcher
......@@ -175,14 +220,25 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
return {{p}};
}
/// Create a typed-erased matcher
using any_matcher_base = basic_matcher<
function_matcher<std::function<optional<instruction_ref>(matcher_context&, instruction_ref)>>>;
struct any_matcher : any_matcher_base
{
template <class M>
any_matcher(M mm) : any_matcher_base({[=](auto& ctx, auto ins) { return mm.match(ctx, ins); }})
{
}
};
/// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \
{ \
instruction_ref match(__VA_ARGS__) const; \
optional<instruction_ref> match(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \
inline instruction_ref name##_m::match(__VA_ARGS__) const
inline optional<instruction_ref> name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPHX_PRED_MATCHER(name, ...) \
......@@ -196,57 +252,139 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
struct matcher_result
{
std::unordered_map<std::string, instruction_ref> instructions;
struct instruction_container
{
instruction_container() = default;
instruction_container(std::unordered_map<std::string, instruction_ref> x)
: ins_map(std::move(x))
{
}
instruction_ref operator[](const std::string& name) const
{
auto it = ins_map.find(name);
if(it == ins_map.end())
MIGRAPHX_THROW("Accessing name that wasn't bound in matcher: " + name);
return it->second;
}
auto find(const std::string& name) const { return ins_map.find(name); }
auto begin() const { return ins_map.cbegin(); }
auto end() const { return ins_map.cend(); }
bool has_instructions_in(const module& mod) const
{
return std::all_of(ins_map.begin(), ins_map.end(), [&](auto&& p) {
return mod.has_instruction(p.second);
});
}
private:
std::unordered_map<std::string, instruction_ref> ins_map;
};
instruction_container instructions;
instruction_ref result;
};
/// Match a single instruction
template <class M>
matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
assert(ins != p.end());
assert(ins != mod.end());
assert(mod.has_instruction(ins));
matcher_context ctx{mod};
matcher_result result;
matcher_context ctx{p.end()};
result.result = m.match(ctx, ins);
result.instructions = ctx.instructions;
if(m.match(ctx, ins))
{
result.result = ins;
result.instructions = ctx.instructions;
assert(result.instructions.has_instructions_in(mod));
}
else
{
result.result = mod.end();
}
return result;
}
/// Find first instance of a matching instruction in a module
template <class M>
match::matcher_result find_match(module& modl, M&& m)
{
match::matcher_result result;
for(auto ins : iterator_for(modl))
{
result = match::match_instruction(modl, ins, m);
if(result.result != modl.end())
return result;
}
return result;
}
/// Find matches for an instruction in the program
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the module
template <class... Ms>
void find_matches(program& p, instruction_ref ins, Ms&&... ms)
void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
{
bool match = false;
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
#endif
bool trace = enabled(MIGRAPHX_TRACE_MATCHES{});
bool match = false;
each_args(
[&](auto&& m) {
if(match)
return;
auto r = match_instruction(p, ins, m.matcher());
if(r.result == p.end())
auto r = match_instruction(mod, ins, m.matcher());
if(r.result == mod.end())
return;
m.apply(p, r);
if(trace)
{
std::cout << "Matched by " << get_type_name(m) << std::endl;
mod.debug_print(ins);
}
m.apply(mod, r);
match = true;
},
ms...);
}
/// Find matches in a program
/// Find matches in a module
template <class... Ms>
void find_matches(program& p, Ms&&... ms)
void find_matches(module& mod, Ms&&... ms)
{
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(mod))
{
find_matches(p, ins, ms...);
find_matches(mod, ins, ms...);
}
}
template <class M, class F>
struct find_generic_match
{
M m;
F f;
M matcher() const { return m; }
void apply(module& mod, const matcher_result& mr) const { f(mod, mr); }
};
template <class M, class F>
find_generic_match<M, F> make_match_finder(M m, F f)
{
return {m, f};
}
template <class M>
struct find_skip
{
M m;
M matcher() const { return m; }
void apply(program&, const matcher_result&) const {}
void apply(module&, const matcher_result&) const {}
};
template <class M>
......@@ -258,18 +396,18 @@ find_skip<M> make_find_skip(M m)
struct lazy_and
{
template <class F, class G>
bool operator()(F f, G g) const
auto operator()(F f, G g) const
{
return f() and g();
return [=] { return f() and g(); };
}
};
struct lazy_or
{
template <class F, class G>
bool operator()(F f, G g) const
auto operator()(F f, G g) const
{
return f() or g();
return [=] { return f() or g(); };
}
};
......@@ -281,7 +419,7 @@ struct match_fold_f
{
Op op;
auto matched = [&](auto m) { return [=, &ctx] { return ctx.matched(m, ins); }; };
return fold([&](auto x, auto y) { return op(always(x), matched(y)); })(Start, ms...);
return fold(op)(always(Start), matched(ms)...)();
}
template <class Pack>
......@@ -293,12 +431,13 @@ struct match_fold_f
template <class... Ts>
auto operator()(Ts... ms) const
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
if(matches == Matches)
return ins;
return ctx.not_found();
});
return make_bf_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
if(matches == Matches)
return {ins};
return nullopt;
});
}
template <class Selector>
......@@ -307,17 +446,18 @@ struct match_fold_f
return [=](auto... ms) {
// Workaround ICE on gcc by packing matchers into an object
auto mpack = pack(ms...);
return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
Op op;
bool matches = Start;
select(start, [&](auto ins) {
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm);
return make_bf_matcher(
[=](matcher_context& ctx, instruction_ref start) -> optional<instruction_ref> {
Op op;
bool matches = Start;
select(start, [&](auto ins) {
auto fm = [&] { return match_fold_f::fold_matchers_pack(ctx, ins, mpack); };
matches = op(always(matches), fm)();
});
if(matches == Matches)
return {start};
return nullopt;
});
if(matches == Matches)
return start;
return ctx.not_found();
});
};
}
};
......@@ -374,64 +514,46 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_BASIC_MATCHER(output, const matcher_context& ctx, instruction_ref ins)
MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins->outputs().front();
return ctx.not_found();
return {ins->outputs().front()};
return nullopt;
}
MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins;
if(ins->outputs().empty() and std::next(ins) == ctx.not_found())
return ins;
return ctx.not_found();
}
inline auto used_once_recursive(std::size_t depth)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
// Used once
if(start->outputs().size() == 1)
return start;
// Unused
if(start->outputs().empty())
{
if(std::next(start) == ctx.not_found())
return start;
else
return ctx.not_found();
}
// Check for dead instructions
auto is_dead = fix<bool>([&](auto self, auto ins, auto n) {
if(n == 0)
return false;
if(ins->get_shape().elements() == 0)
return false;
if(ins->outputs().empty() and std::next(ins) != ctx.not_found())
return true;
return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
return self(i, n - 1);
});
});
auto dead = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) {
return is_dead(i, depth);
});
if(dead + 1 == start->outputs().size())
return start;
return ctx.not_found();
});
return {ins};
if(ins->outputs().empty() and ctx.is_last(ins))
return {ins};
return nullopt;
}
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().empty() and ins != std::prev(ctx.not_found()))
return ins;
return ctx.not_found();
if(ins->outputs().empty() and not ctx.is_last(ins))
return {ins};
return nullopt;
}
template <class... Ms>
auto skip(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<optional<instruction_ref>>(
[&](auto self, auto ins) -> optional<instruction_ref> {
if(ins->inputs().size() == 1 and ctx.matched(m, ins))
{
auto next = ins->inputs().front();
return self(next);
}
return ins;
})(start);
});
}
template <class... Ms>
......@@ -439,32 +561,51 @@ auto skip_output(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) {
if(ins->outputs().size() == 1)
{
auto next = ins->outputs().front();
if(ctx.matched(m, next))
return fix<optional<instruction_ref>>(
[&](auto self, auto ins) -> optional<instruction_ref> {
if(ins->outputs().size() == 1)
{
auto skipped_next = self(next);
if(skipped_next != ctx.not_found())
return skipped_next;
auto next = ins->outputs().front();
if(ctx.matched(m, next))
{
auto skipped_next = self(next);
if(skipped_next)
return skipped_next;
}
return next;
}
return next;
}
return ctx.not_found();
})(start);
return nullopt;
})(start);
});
}
inline auto var(std::string s)
{
return make_basic_fun_matcher(
[=, s = std::move(s)](const matcher_context& ctx,
instruction_ref) -> optional<instruction_ref> {
auto it = ctx.instructions.find(s);
if(it == ctx.instructions.end())
return nullopt;
return it->second;
});
}
inline auto name(std::string s)
{
return make_basic_pred_matcher(
[ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; });
[=, s = std::move(s)](instruction_ref ins) { return ins->name() == s; });
}
inline auto name_contains(const std::string& name)
{
return make_basic_pred_matcher(
[=](instruction_ref ins) { return contains(ins->get_operator().name(), name); });
}
inline auto name(std::unordered_set<std::string> names)
{
return make_basic_pred_matcher([ =, names = std::move(names) ](instruction_ref ins) {
return make_basic_pred_matcher([=, names = std::move(names)](instruction_ref ins) {
return names.count(ins->name()) > 0;
});
}
......@@ -482,11 +623,12 @@ inline auto nargs(std::size_t n)
inline auto arg(std::size_t i)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref ins) {
if(i < ins->inputs().size())
return ins->inputs()[i];
return ctx.not_found();
});
return make_basic_fun_matcher(
[=](const matcher_context&, instruction_ref ins) -> optional<instruction_ref> {
if(i < ins->inputs().size())
return ins->inputs()[i];
return nullopt;
});
}
// Workaround for bugs in clang
......@@ -518,15 +660,86 @@ inline auto either_arg(std::size_t i, std::size_t j)
};
}
inline auto any_arg(std::size_t i, std::size_t j)
{
return [=](auto m) { return match::any_of(arg(i)(m), arg(j)(m)); };
}
template <std::size_t N, class M>
std::size_t tree_leafs_impl(matcher_context& ctx,
std::array<instruction_ref, N>& leafs,
M m,
instruction_ref ins)
{
std::size_t idx = 0;
fix([&](auto self, auto i) {
if(idx == leafs.size())
return;
if(ctx.matched(m, i) and i->inputs().size() >= 2)
{
self(i->inputs()[0]);
self(i->inputs()[1]);
return;
}
leafs[idx] = i;
idx++;
})(ins);
return idx;
}
template <class M, class... Ms>
auto tree(M main_op, Ms... ms)
{
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
// Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size())
return nullopt;
// Use explicit captures to workaround ICE on gcc
// Capture by value to workaround compile error on gcc 9
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
});
if(not found)
return nullopt;
return ins;
});
}
template <class M, class... Ms>
auto unordered_tree(M main_op, Ms... ms)
{
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
// Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size())
return nullopt;
// Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) {
return fold(lazy_or{})(ctx.lazy_match(m, leafs[is])...);
})(ms...)();
});
if(not found)
return nullopt;
return ins;
});
}
template <class M>
auto same_shape(M m)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
auto i = m.match(ctx, ins);
if(i != ctx.not_found() and i->get_shape() == ins->get_shape())
return ins;
return ctx.not_found();
});
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
auto i = m.match(ctx, ins);
if(i and (*i)->get_shape() == ins->get_shape())
return ins;
return nullopt;
});
}
template <class... Ms>
......@@ -535,6 +748,50 @@ auto same_shape(Ms... ms)
return all_of(same_shape(ms)...);
}
template <class... Ms>
auto skip_broadcasts(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous"))(ms...);
}
template <class... Ms>
auto skip_broadcasts_converts(Ms... ms)
{
return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...);
}
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
{
return skip_broadcasts_converts(make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "@literal")
return false;
auto l = ins->get_literal();
if(l.empty())
return false;
bool b = false;
l.visit([&](auto v) {
if(std::all_of(
v.begin(), v.end(), [&](auto val) { return std::fabs(val - x) < tolerance; }))
b = true;
});
return b;
}));
}
inline auto has_attribute(const std::string& name)
{
return make_basic_pred_matcher(
[=](instruction_ref ins) { return ins->get_operator().attributes().contains(name); });
}
template <class... Ms>
auto pointwise(Ms... ms)
{
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)),
ms...);
}
} // namespace match
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -7,7 +7,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
/**
* Remove memory allocations. It uses graph coloring to find memory allocations that can be reused.
......@@ -17,7 +17,7 @@ struct memory_coloring
std::string allocation_op{};
bool verify = false;
std::string name() const { return "memory coloring"; }
void apply(program& p) const;
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
......
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_MODULE_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_MODULE_HPP
#include <list>
#include <unordered_set>
#include <unordered_map>
#include <migraphx/operation.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
const operation& get_operation(instruction_ref ins);
struct module_impl;
using parameter_map = std::unordered_map<std::string, argument>;
using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
/**
* @brief Stores the instruction stream
*/
struct module
{
module(const std::string& name = "");
// move constructor
module(module&&) noexcept;
// copy constructor
module(const module&);
// copy assignment operator
module& operator=(module);
~module() noexcept;
std::string name() const;
bool bypass() const;
void set_bypass(bool b = true);
template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref add_instruction(operation op, Ts... args)
{
return add_instruction(op, {args...});
}
instruction_ref add_instruction(const operation& op, std::vector<instruction_ref> args);
instruction_ref add_instruction(const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args);
template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref insert_instruction(instruction_ref ins, operation op, Ts... args)
{
return insert_instruction(ins, op, {args...});
}
instruction_ref
insert_instruction(instruction_ref ins, const operation& op, std::vector<instruction_ref> args);
instruction_ref insert_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args);
template <class... Ts, MIGRAPHX_REQUIRES(std::is_same<Ts, instruction_ref>{}...)>
instruction_ref replace_instruction(instruction_ref ins, operation op, Ts... args)
{
return replace_instruction(ins, op, {args...});
}
instruction_ref replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args) MIGRAPHX_TIDY_CONST;
instruction_ref replace_instruction(instruction_ref ins,
const operation& op,
std::vector<instruction_ref> args,
std::vector<module_ref> module_args) MIGRAPHX_TIDY_CONST;
instruction_ref replace_instruction(instruction_ref ins, instruction_ref rep);
instruction_ref remove_instruction(instruction_ref ins);
instruction_ref remove_instructions(instruction_ref first, instruction_ref last);
instruction_ref move_instruction(instruction_ref src, instruction_ref dst);
instruction_ref move_instructions(instruction_ref src, instruction_ref dst);
std::vector<instruction_ref>
insert_module_instructions(instruction_ref ins,
module_ref m,
std::unordered_map<instruction_ref, instruction_ref> map_ins = {});
template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
{
return add_literal(literal{std::forward<Ts>(xs)...});
}
instruction_ref add_literal(literal l);
instruction_ref add_outline(const shape& s);
instruction_ref add_parameter(std::string name, shape s);
instruction_ref add_return(std::vector<instruction_ref> args);
instruction_ref replace_return(std::vector<instruction_ref> args);
std::vector<std::string> get_parameter_names() const;
shape get_parameter_shape(std::string name) const;
instruction_ref get_parameter(std::string name) const;
std::unordered_map<std::string, shape> get_parameter_shapes() const;
bool has_instruction(instruction_ref ins) const;
std::size_t size() const;
instruction_ref begin() const;
instruction_ref end() const;
std::vector<shape> get_output_shapes() const;
instruction_ref validate() const;
instruction_ref find_dangling_reference() const;
void finalize(context& ctx);
void debug_print() const;
void debug_print(instruction_ref ins) const;
void debug_print(instruction_ref ins,
std::unordered_map<instruction_ref, std::string>& names) const;
void debug_print(const std::vector<instruction_ref>& inss) const;
std::unordered_map<instruction_ref, std::string> print(
const std::function<void(
instruction_ref, const std::unordered_map<instruction_ref, std::string>&)>& print_func,
std::unordered_map<instruction_ref, std::string> names) const;
void print(const std::function<void(instruction_ref,
const std::unordered_map<instruction_ref, std::string>&)>&
print_func) const;
void print_graph(std::ostream& os, bool brief = false) const;
void print_cpp(std::ostream& os) const;
std::unordered_map<instruction_ref, std::string>
print_cpp(std::ostream& os, std::unordered_map<instruction_ref, std::string> names) const;
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
std::vector<module_ref> get_sub_modules() const;
module& sort();
ins_dep_map calc_implicit_deps() const;
friend std::ostream& operator<<(std::ostream& os, const module& m);
friend bool operator==(const module& x, const module& y);
friend bool operator!=(const module& x, const module& y) { return !(x == y); }
private:
void assign(const module& m);
void calc_implicit_deps(const module& smod,
const module& pmod,
instruction_ref ins,
ins_dep_map& deps) const;
std::unique_ptr<module_impl> impl;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_MODULE_REF_HPP
#define MIGRAPHX_GUARD_MODULE_REF_HPP
#include <list>
#include <functional>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
using module_ref = module*;
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_MSGPACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_MSGPACK_HPP
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
std::vector<char> to_msgpack(const value& v);
value from_msgpack(const std::vector<char>& buffer);
value from_msgpack(const char* buffer, std::size_t size);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_NORMALIZE_ATTRIBUTES_HPP
#define MIGRAPHX_GUARD_RTGLIB_NORMALIZE_ATTRIBUTES_HPP
#include <migraphx/config.hpp>
#include <migraphx/shape.hpp>
#include <cstring>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct operation;
template <class T, class...>
struct select_dependent_type
{
using type = T;
};
template <class T, class... Ts>
using dependent_type = typename select_dependent_type<T, Ts...>::type;
bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_NORMALIZE_OPS_HPP
#define MIGRAPHX_GUARD_RTGLIB_NORMALIZE_OPS_HPP
#include <string>
#include <vector>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/program.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Process negative axis attributes of ops
*/
struct normalize_ops
{
std::string name() const { return "normalize_ops"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -7,8 +7,31 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
/// struct to pass in onnx options to parser
struct onnx_options
{
/// default batch size to use (if not specified in onnx file)
std::size_t default_dim_value = 1;
/// Explicitly specify the dims of an input
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims = {};
/// Continue parsing onnx file if an unknown operator is found
bool skip_unknown_operators = false;
/// Print program if an error occurs
bool print_program_on_error = false;
/// Max iter num for the loop operator
int64_t max_loop_iterations = 10;
};
/// Create a program from an onnx file
program parse_onnx(const std::string& name);
program parse_onnx(const std::string& name, const onnx_options& = onnx_options{});
/// Create a program from an onnx buffer
program parse_onnx_buffer(const std::string& buffer, const onnx_options& options);
/// Create a program from an onnx buffer
program parse_onnx_buffer(const void* data, std::size_t size, const onnx_options& options);
std::vector<std::string> get_onnx_operators();
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ACOSH_HPP
#define MIGRAPHX_GUARD_OPERATORS_ACOSH_HPP
#include <migraphx/op/unary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct acosh : unary<acosh>
{
auto apply() const
{
return [](auto x) { return std::acosh(x); };
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -3,7 +3,6 @@
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
......@@ -19,6 +18,13 @@ namespace op {
struct add : binary<add>
{
value attributes() const
{
auto a = base_attributes();
a["commutative"] = true;
return a;
}
std::string point_function() const { return "+"; }
auto apply() const
{
return [](auto x, auto y) { return x + y; };
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ALLOCATE_HPP
#define MIGRAPHX_GUARD_OPERATORS_ALLOCATE_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
#include <migraphx/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct allocate
{
shape s{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.s, "shape"));
}
std::string name() const { return "allocate"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
migraphx::check_shapes{inputs, *this}.has(0);
return s;
}
argument compute(const shape& output_shape, const std::vector<argument>&) const
{
return {output_shape};
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMAX_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -20,17 +24,19 @@ struct argmax
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "argmax"; }
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMAX: axis is out of range.");
}
check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens();
lens[axis] = 1;
......
#ifndef MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#define MIGRAPHX_GUARD_OPERATORS_ARGMIN_HPP
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -20,17 +24,19 @@ struct argmin
return pack(f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "argmin"; }
shape compute_shape(std::vector<shape> inputs) const
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs[0].lens();
int64_t n_dim = static_cast<int64_t>(lens.size());
if(axis >= n_dim || axis < 0)
{
MIGRAPHX_THROW("ARGMIN: axis is out of range.");
}
check_shapes{inputs, *this}.has(1);
auto lens = inputs[0].lens();
lens[axis] = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment