Unverified Commit c72a047f authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Dont match or bind to global instructions (#826)



* Add optional header

* Formatting

* Use optional in the matcher

* Foramtting

* Remove program from tests

* Formatting

* Dont bind or match non-local variables

* Formatting

* Fix gcc 5 error

* Format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 9a5e0c06
......@@ -5,7 +5,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/module.hpp>
#include <migraphx/program.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/type_name.hpp>
#include <migraphx/config.hpp>
......@@ -19,24 +19,51 @@ namespace match {
struct matcher_context
{
matcher_context(instruction_ref i) : last(i) {}
matcher_context(module& m) : mod(&m) {}
std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
template <class M>
bool matched(M m, instruction_ref ins)
{
return m.match(*this, ins) != this->not_found();
return has_value(m.match(*this, ins));
}
template <class M>
auto lazy_match(M m, instruction_ref ins)
bool matched(M m, optional<instruction_ref> ins)
{
if(ins)
return has_value(m.match(*this, *ins));
return false;
}
template <class M, class I>
auto lazy_match(M m, I ins)
{
return [=] { return this->matched(m, ins); };
}
bool has_instruction(instruction_ref ins) const
{
if(mod == nullptr)
return true;
return mod->has_instruction(ins);
}
bool has_instruction(optional<instruction_ref> ins) const
{
if(ins)
return this->has_instruction(*ins);
return false;
}
bool is_last(instruction_ref ins) const
{
assert(mod->begin() != mod->end());
assert(this->has_instruction(ins));
return ins == std::prev(mod->end());
}
private:
instruction_ref last;
module* mod = nullptr;
};
/// Convert a predicate function into a matcher
......@@ -45,12 +72,11 @@ struct predicate_matcher
{
P p;
instruction_ref match(const matcher_context& ctx, instruction_ref ins) const
optional<instruction_ref> match(const matcher_context&, instruction_ref ins) const
{
assert(ins != ctx.not_found());
if(p(ins))
return ins;
return ctx.not_found();
return optional<instruction_ref>{ins};
return nullopt;
}
};
......@@ -60,11 +86,7 @@ struct function_matcher
{
F f;
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
assert(ins != ctx.not_found());
return f(ctx, ins);
}
auto match(matcher_context& ctx, instruction_ref ins) const { return f(ctx, ins); }
};
/// Convert a function into a matcher
......@@ -79,10 +101,15 @@ template <class M>
auto bind_match(M m, std::string name)
{
return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins)
->optional<instruction_ref> {
auto result = m.match(ctx, ins);
if(result != ctx.not_found())
if(result)
{
if(not ctx.has_instruction(ins))
return nullopt;
ctx.instructions[name] = ins;
}
return result;
});
}
......@@ -95,10 +122,7 @@ struct bindable_matcher
auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
return m.match(ctx, ins);
}
auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
};
/// Create a bindable matcher
......@@ -126,7 +150,10 @@ using bool_list = std::initializer_list<bool>;
struct id_matcher
{
instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; }
auto match(matcher_context&, instruction_ref ins) const
{
return optional<instruction_ref>{ins};
}
};
/// The basic matcher provides the all_of composability of the matcher
......@@ -140,26 +167,23 @@ struct basic_matcher
{
// Copy m because we cant capture `this` by value
auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
return make_bf_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins);
if(result != ctx.not_found())
if(result)
{
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, result) != ctx.not_found();
})(true, ms...);
bool matches =
fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
if(matches)
return result;
}
return ctx.not_found();
return nullopt;
});
}
auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
return m.match(ctx, ins);
}
auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
};
/// Create a basic matcher from a matcher
......@@ -185,7 +209,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
/// Create a typed-erased matcher
using any_matcher_base = basic_matcher<
function_matcher<std::function<instruction_ref(matcher_context&, instruction_ref)>>>;
function_matcher<std::function<optional<instruction_ref>(matcher_context&, instruction_ref)>>>;
struct any_matcher : any_matcher_base
{
template <class M>
......@@ -198,10 +222,10 @@ struct any_matcher : any_matcher_base
#define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \
{ \
instruction_ref match(__VA_ARGS__) const; \
optional<instruction_ref> match(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \
inline instruction_ref name##_m::match(__VA_ARGS__) const
inline optional<instruction_ref> name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPHX_PRED_MATCHER(name, ...) \
......@@ -221,21 +245,29 @@ struct matcher_result
/// Match a single instruction
template <class M>
matcher_result match_instruction(module& p, instruction_ref ins, M&& m)
matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
assert(ins != p.end());
assert(ins != mod.end());
assert(mod.has_instruction(ins));
matcher_context ctx{mod};
matcher_result result;
matcher_context ctx{p.end()};
result.result = m.match(ctx, ins);
if(m.match(ctx, ins))
{
result.result = ins;
result.instructions = ctx.instructions;
}
else
{
result.result = mod.end();
}
return result;
}
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the program
/// Find matches for an instruction in the module
template <class... Ms>
void find_matches(module& p, instruction_ref ins, Ms&&... ms)
void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
{
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const
......@@ -246,27 +278,27 @@ void find_matches(module& p, instruction_ref ins, Ms&&... ms)
[&](auto&& m) {
if(match)
return;
auto r = match_instruction(p, ins, m.matcher());
if(r.result == p.end())
auto r = match_instruction(mod, ins, m.matcher());
if(r.result == mod.end())
return;
if(trace)
{
std::cout << "Matched by " << get_type_name(m) << std::endl;
p.debug_print(ins);
mod.debug_print(ins);
}
m.apply(p, r);
m.apply(mod, r);
match = true;
},
ms...);
}
/// Find matches in a program
/// Find matches in a module
template <class... Ms>
void find_matches(module& p, Ms&&... ms)
void find_matches(module& mod, Ms&&... ms)
{
for(auto ins : iterator_for(p))
for(auto ins : iterator_for(mod))
{
find_matches(p, ins, ms...);
find_matches(mod, ins, ms...);
}
}
......@@ -339,11 +371,12 @@ struct match_fold_f
template <class... Ts>
auto operator()(Ts... ms) const
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
return make_bf_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
if(matches == Matches)
return ins;
return ctx.not_found();
return {ins};
return nullopt;
});
}
......@@ -353,7 +386,8 @@ struct match_fold_f
return [=](auto... ms) {
// Workaround ICE on gcc by packing matchers into an object
auto mpack = pack(ms...);
return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
return make_bf_matcher(
[=](matcher_context& ctx, instruction_ref start) -> optional<instruction_ref> {
Op op;
bool matches = Start;
select(start, [&](auto ins) {
......@@ -361,8 +395,8 @@ struct match_fold_f
matches = op(always(matches), fm)();
});
if(matches == Matches)
return start;
return ctx.not_found();
return {start};
return nullopt;
});
};
}
......@@ -420,64 +454,29 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
}
MIGRAPHX_BASIC_MATCHER(output, const matcher_context& ctx, instruction_ref ins)
MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins->outputs().front();
return ctx.not_found();
return {ins->outputs().front()};
return nullopt;
}
MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().size() == 1)
return ins;
if(ins->outputs().empty() and std::next(ins) == ctx.not_found())
return ins;
return ctx.not_found();
}
inline auto used_once_recursive(std::size_t depth)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
// Used once
if(start->outputs().size() == 1)
return start;
// Unused
if(start->outputs().empty())
{
if(std::next(start) == ctx.not_found())
return start;
else
return ctx.not_found();
}
// Check for dead instructions
auto is_dead = fix<bool>([&](auto self, auto ins, auto n) {
if(n == 0)
return false;
if(ins->get_shape().elements() == 0)
return false;
if(ins->outputs().empty() and std::next(ins) != ctx.not_found())
return true;
return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
return self(i, n - 1);
});
});
auto dead = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) {
return is_dead(i, depth);
});
if(dead + 1 == start->outputs().size())
return start;
return ctx.not_found();
});
return {ins};
if(ins->outputs().empty() and ctx.is_last(ins))
return {ins};
return nullopt;
}
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().empty() and ins != std::prev(ctx.not_found()))
return ins;
return ctx.not_found();
if(ins->outputs().empty() and not ctx.is_last(ins))
return {ins};
return nullopt;
}
template <class... Ms>
......@@ -485,7 +484,8 @@ auto skip(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) {
return fix<optional<instruction_ref>>(
[&](auto self, auto ins) -> optional<instruction_ref> {
if(ins->inputs().size() == 1 and ctx.matched(m, ins))
{
auto next = ins->inputs().front();
......@@ -501,19 +501,20 @@ auto skip_output(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) {
return fix<optional<instruction_ref>>(
[&](auto self, auto ins) -> optional<instruction_ref> {
if(ins->outputs().size() == 1)
{
auto next = ins->outputs().front();
if(ctx.matched(m, next))
{
auto skipped_next = self(next);
if(skipped_next != ctx.not_found())
if(skipped_next)
return skipped_next;
}
return next;
}
return ctx.not_found();
return nullopt;
})(start);
});
}
......@@ -550,10 +551,11 @@ inline auto nargs(std::size_t n)
inline auto arg(std::size_t i)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref ins) {
return make_basic_fun_matcher(
[=](const matcher_context&, instruction_ref ins) -> optional<instruction_ref> {
if(i < ins->inputs().size())
return ins->inputs()[i];
return ctx.not_found();
return nullopt;
});
}
......@@ -616,18 +618,19 @@ std::size_t tree_leafs_impl(matcher_context& ctx,
template <class M, class... Ms>
auto tree(M main_op, Ms... ms)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
// Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size())
return ctx.not_found();
return nullopt;
// Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) {
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
});
if(not found)
return ctx.not_found();
return nullopt;
return ins;
});
}
......@@ -635,12 +638,13 @@ auto tree(M main_op, Ms... ms)
template <class M, class... Ms>
auto unordered_tree(M main_op, Ms... ms)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
// Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size())
return ctx.not_found();
return nullopt;
// Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) {
......@@ -648,7 +652,7 @@ auto unordered_tree(M main_op, Ms... ms)
})(ms...)();
});
if(not found)
return ctx.not_found();
return nullopt;
return ins;
});
}
......@@ -656,11 +660,12 @@ auto unordered_tree(M main_op, Ms... ms)
template <class M>
auto same_shape(M m)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
auto i = m.match(ctx, ins);
if(i != ctx.not_found() and i->get_shape() == ins->get_shape())
if(i and (*i)->get_shape() == ins->get_shape())
return ins;
return ctx.not_found();
return nullopt;
});
}
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#endif
#if __has_include(<experimental/optional>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#else
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#if MIGRAPHX_HAS_OPTIONAL
#include <optional>
#elif MIGRAPHX_HAS_OPTIONAL_TS
#include <experimental/optional>
#else
#error "No optional include available"
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#if MIGRAPHX_HAS_OPTIONAL
template <class T>
using optional = std::optional<T>;
using nullopt_t = std::nullopt_t;
constexpr auto nullopt = std::nullopt;
#elif MIGRAPHX_HAS_OPTIONAL_TS
template <class T>
using optional = std::experimental::optional<T>;
using nullopt_t = std::experimental::nullopt_t;
constexpr auto nullopt = std::experimental::nullopt;
#endif
template <class T>
bool has_value(const optional<T>& x)
{
return x != nullopt;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment