Unverified Commit c72a047f authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Dont match or bind to global instructions (#826)



* Add optional header

* Formatting

* Use optional in the matcher

* Foramtting

* Remove program from tests

* Formatting

* Dont bind or match non-local variables

* Formatting

* Fix gcc 5 error

* Format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 9a5e0c06
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/program.hpp> #include <migraphx/optional.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/type_name.hpp> #include <migraphx/type_name.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -19,24 +19,51 @@ namespace match { ...@@ -19,24 +19,51 @@ namespace match {
struct matcher_context struct matcher_context
{ {
matcher_context(instruction_ref i) : last(i) {} matcher_context(module& m) : mod(&m) {}
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
template <class M> template <class M>
bool matched(M m, instruction_ref ins) bool matched(M m, instruction_ref ins)
{ {
return m.match(*this, ins) != this->not_found(); return has_value(m.match(*this, ins));
} }
template <class M> template <class M>
auto lazy_match(M m, instruction_ref ins) bool matched(M m, optional<instruction_ref> ins)
{
if(ins)
return has_value(m.match(*this, *ins));
return false;
}
template <class M, class I>
auto lazy_match(M m, I ins)
{ {
return [=] { return this->matched(m, ins); }; return [=] { return this->matched(m, ins); };
} }
bool has_instruction(instruction_ref ins) const
{
if(mod == nullptr)
return true;
return mod->has_instruction(ins);
}
bool has_instruction(optional<instruction_ref> ins) const
{
if(ins)
return this->has_instruction(*ins);
return false;
}
bool is_last(instruction_ref ins) const
{
assert(mod->begin() != mod->end());
assert(this->has_instruction(ins));
return ins == std::prev(mod->end());
}
private: private:
instruction_ref last; module* mod = nullptr;
}; };
/// Convert a predicate function into a matcher /// Convert a predicate function into a matcher
...@@ -45,12 +72,11 @@ struct predicate_matcher ...@@ -45,12 +72,11 @@ struct predicate_matcher
{ {
P p; P p;
instruction_ref match(const matcher_context& ctx, instruction_ref ins) const optional<instruction_ref> match(const matcher_context&, instruction_ref ins) const
{ {
assert(ins != ctx.not_found());
if(p(ins)) if(p(ins))
return ins; return optional<instruction_ref>{ins};
return ctx.not_found(); return nullopt;
} }
}; };
...@@ -60,11 +86,7 @@ struct function_matcher ...@@ -60,11 +86,7 @@ struct function_matcher
{ {
F f; F f;
instruction_ref match(matcher_context& ctx, instruction_ref ins) const auto match(matcher_context& ctx, instruction_ref ins) const { return f(ctx, ins); }
{
assert(ins != ctx.not_found());
return f(ctx, ins);
}
}; };
/// Convert a function into a matcher /// Convert a function into a matcher
...@@ -79,10 +101,15 @@ template <class M> ...@@ -79,10 +101,15 @@ template <class M>
auto bind_match(M m, std::string name) auto bind_match(M m, std::string name)
{ {
return make_function_matcher( return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) { [ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins)
->optional<instruction_ref> {
auto result = m.match(ctx, ins); auto result = m.match(ctx, ins);
if(result != ctx.not_found()) if(result)
{
if(not ctx.has_instruction(ins))
return nullopt;
ctx.instructions[name] = ins; ctx.instructions[name] = ins;
}
return result; return result;
}); });
} }
...@@ -95,10 +122,7 @@ struct bindable_matcher ...@@ -95,10 +122,7 @@ struct bindable_matcher
auto bind(std::string name) const { return bind_match(m, std::move(name)); } auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
{
return m.match(ctx, ins);
}
}; };
/// Create a bindable matcher /// Create a bindable matcher
...@@ -126,7 +150,10 @@ using bool_list = std::initializer_list<bool>; ...@@ -126,7 +150,10 @@ using bool_list = std::initializer_list<bool>;
struct id_matcher struct id_matcher
{ {
instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; } auto match(matcher_context&, instruction_ref ins) const
{
return optional<instruction_ref>{ins};
}
}; };
/// The basic matcher provides the all_of composability of the matcher /// The basic matcher provides the all_of composability of the matcher
...@@ -140,26 +167,23 @@ struct basic_matcher ...@@ -140,26 +167,23 @@ struct basic_matcher
{ {
// Copy m because we cant capture `this` by value // Copy m because we cant capture `this` by value
auto mm = m; auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher([=](matcher_context& ctx,
instruction_ref ins) -> optional<instruction_ref> {
auto result = mm.match(ctx, ins); auto result = mm.match(ctx, ins);
if(result != ctx.not_found()) if(result)
{ {
bool matches = fold([&](auto x, auto y) { bool matches =
return x and y.match(ctx, result) != ctx.not_found(); fold([&](auto x, auto y) { return x and ctx.matched(y, result); })(true, ms...);
})(true, ms...);
if(matches) if(matches)
return result; return result;
} }
return ctx.not_found(); return nullopt;
}); });
} }
auto bind(std::string name) const { return bind_match(m, std::move(name)); } auto bind(std::string name) const { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const auto match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); }
{
return m.match(ctx, ins);
}
}; };
/// Create a basic matcher from a matcher /// Create a basic matcher from a matcher
...@@ -185,7 +209,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p) ...@@ -185,7 +209,7 @@ basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
/// Create a typed-erased matcher /// Create a typed-erased matcher
using any_matcher_base = basic_matcher< using any_matcher_base = basic_matcher<
function_matcher<std::function<instruction_ref(matcher_context&, instruction_ref)>>>; function_matcher<std::function<optional<instruction_ref>(matcher_context&, instruction_ref)>>>;
struct any_matcher : any_matcher_base struct any_matcher : any_matcher_base
{ {
template <class M> template <class M>
...@@ -198,10 +222,10 @@ struct any_matcher : any_matcher_base ...@@ -198,10 +222,10 @@ struct any_matcher : any_matcher_base
#define MIGRAPHX_BASIC_MATCHER(name, ...) \ #define MIGRAPHX_BASIC_MATCHER(name, ...) \
struct name##_m \ struct name##_m \
{ \ { \
instruction_ref match(__VA_ARGS__) const; \ optional<instruction_ref> match(__VA_ARGS__) const; \
}; \ }; \
const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \ const constexpr auto name = migraphx::match::basic_matcher<name##_m>{{}}; \
inline instruction_ref name##_m::match(__VA_ARGS__) const inline optional<instruction_ref> name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher /// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPHX_PRED_MATCHER(name, ...) \ #define MIGRAPHX_PRED_MATCHER(name, ...) \
...@@ -221,21 +245,29 @@ struct matcher_result ...@@ -221,21 +245,29 @@ struct matcher_result
/// Match a single instruction /// Match a single instruction
template <class M> template <class M>
matcher_result match_instruction(module& p, instruction_ref ins, M&& m) matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{ {
assert(ins != p.end()); assert(ins != mod.end());
assert(mod.has_instruction(ins));
matcher_context ctx{mod};
matcher_result result; matcher_result result;
matcher_context ctx{p.end()}; if(m.match(ctx, ins))
result.result = m.match(ctx, ins); {
result.result = ins;
result.instructions = ctx.instructions; result.instructions = ctx.instructions;
}
else
{
result.result = mod.end();
}
return result; return result;
} }
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MATCHES)
/// Find matches for an instruction in the program /// Find matches for an instruction in the module
template <class... Ms> template <class... Ms>
void find_matches(module& p, instruction_ref ins, Ms&&... ms) void find_matches(module& mod, instruction_ref ins, Ms&&... ms)
{ {
#if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5 #if !defined(__GNUC__) || defined(__clang__) || __GNUC__ > 5
const const
...@@ -246,27 +278,27 @@ void find_matches(module& p, instruction_ref ins, Ms&&... ms) ...@@ -246,27 +278,27 @@ void find_matches(module& p, instruction_ref ins, Ms&&... ms)
[&](auto&& m) { [&](auto&& m) {
if(match) if(match)
return; return;
auto r = match_instruction(p, ins, m.matcher()); auto r = match_instruction(mod, ins, m.matcher());
if(r.result == p.end()) if(r.result == mod.end())
return; return;
if(trace) if(trace)
{ {
std::cout << "Matched by " << get_type_name(m) << std::endl; std::cout << "Matched by " << get_type_name(m) << std::endl;
p.debug_print(ins); mod.debug_print(ins);
} }
m.apply(p, r); m.apply(mod, r);
match = true; match = true;
}, },
ms...); ms...);
} }
/// Find matches in a program /// Find matches in a module
template <class... Ms> template <class... Ms>
void find_matches(module& p, Ms&&... ms) void find_matches(module& mod, Ms&&... ms)
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(mod))
{ {
find_matches(p, ins, ms...); find_matches(mod, ins, ms...);
} }
} }
...@@ -339,11 +371,12 @@ struct match_fold_f ...@@ -339,11 +371,12 @@ struct match_fold_f
template <class... Ts> template <class... Ts>
auto operator()(Ts... ms) const auto operator()(Ts... ms) const
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
bool matches = match_fold_f::fold_matchers(ctx, ins, ms...); bool matches = match_fold_f::fold_matchers(ctx, ins, ms...);
if(matches == Matches) if(matches == Matches)
return ins; return {ins};
return ctx.not_found(); return nullopt;
}); });
} }
...@@ -353,7 +386,8 @@ struct match_fold_f ...@@ -353,7 +386,8 @@ struct match_fold_f
return [=](auto... ms) { return [=](auto... ms) {
// Workaround ICE on gcc by packing matchers into an object // Workaround ICE on gcc by packing matchers into an object
auto mpack = pack(ms...); auto mpack = pack(ms...);
return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) { return make_bf_matcher(
[=](matcher_context& ctx, instruction_ref start) -> optional<instruction_ref> {
Op op; Op op;
bool matches = Start; bool matches = Start;
select(start, [&](auto ins) { select(start, [&](auto ins) {
...@@ -361,8 +395,8 @@ struct match_fold_f ...@@ -361,8 +395,8 @@ struct match_fold_f
matches = op(always(matches), fm)(); matches = op(always(matches), fm)();
}); });
if(matches == Matches) if(matches == Matches)
return start; return {start};
return ctx.not_found(); return nullopt;
}); });
}; };
} }
...@@ -420,64 +454,29 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins) ...@@ -420,64 +454,29 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; }); ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; });
} }
MIGRAPHX_BASIC_MATCHER(output, const matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
return ins->outputs().front(); return {ins->outputs().front()};
return ctx.not_found(); return nullopt;
} }
MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
return ins; return {ins};
if(ins->outputs().empty() and std::next(ins) == ctx.not_found()) if(ins->outputs().empty() and ctx.is_last(ins))
return ins; return {ins};
return ctx.not_found(); return nullopt;
}
inline auto used_once_recursive(std::size_t depth)
{
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref start) {
// Used once
if(start->outputs().size() == 1)
return start;
// Unused
if(start->outputs().empty())
{
if(std::next(start) == ctx.not_found())
return start;
else
return ctx.not_found();
}
// Check for dead instructions
auto is_dead = fix<bool>([&](auto self, auto ins, auto n) {
if(n == 0)
return false;
if(ins->get_shape().elements() == 0)
return false;
if(ins->outputs().empty() and std::next(ins) != ctx.not_found())
return true;
return std::all_of(ins->outputs().begin(), ins->outputs().end(), [&](auto i) {
return self(i, n - 1);
});
});
auto dead = std::count_if(start->outputs().begin(), start->outputs().end(), [&](auto i) {
return is_dead(i, depth);
});
if(dead + 1 == start->outputs().size())
return start;
return ctx.not_found();
});
} }
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); } MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins) MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref ins)
{ {
if(ins->outputs().empty() and ins != std::prev(ctx.not_found())) if(ins->outputs().empty() and not ctx.is_last(ins))
return ins; return {ins};
return ctx.not_found(); return nullopt;
} }
template <class... Ms> template <class... Ms>
...@@ -485,7 +484,8 @@ auto skip(Ms... ms) ...@@ -485,7 +484,8 @@ auto skip(Ms... ms)
{ {
auto m = any_of(ms...); auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) { return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) { return fix<optional<instruction_ref>>(
[&](auto self, auto ins) -> optional<instruction_ref> {
if(ins->inputs().size() == 1 and ctx.matched(m, ins)) if(ins->inputs().size() == 1 and ctx.matched(m, ins))
{ {
auto next = ins->inputs().front(); auto next = ins->inputs().front();
...@@ -501,19 +501,20 @@ auto skip_output(Ms... ms) ...@@ -501,19 +501,20 @@ auto skip_output(Ms... ms)
{ {
auto m = any_of(ms...); auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) { return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) { return fix<optional<instruction_ref>>(
[&](auto self, auto ins) -> optional<instruction_ref> {
if(ins->outputs().size() == 1) if(ins->outputs().size() == 1)
{ {
auto next = ins->outputs().front(); auto next = ins->outputs().front();
if(ctx.matched(m, next)) if(ctx.matched(m, next))
{ {
auto skipped_next = self(next); auto skipped_next = self(next);
if(skipped_next != ctx.not_found()) if(skipped_next)
return skipped_next; return skipped_next;
} }
return next; return next;
} }
return ctx.not_found(); return nullopt;
})(start); })(start);
}); });
} }
...@@ -550,10 +551,11 @@ inline auto nargs(std::size_t n) ...@@ -550,10 +551,11 @@ inline auto nargs(std::size_t n)
inline auto arg(std::size_t i) inline auto arg(std::size_t i)
{ {
return make_basic_fun_matcher([=](const matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
[=](const matcher_context&, instruction_ref ins) -> optional<instruction_ref> {
if(i < ins->inputs().size()) if(i < ins->inputs().size())
return ins->inputs()[i]; return ins->inputs()[i];
return ctx.not_found(); return nullopt;
}); });
} }
...@@ -616,18 +618,19 @@ std::size_t tree_leafs_impl(matcher_context& ctx, ...@@ -616,18 +618,19 @@ std::size_t tree_leafs_impl(matcher_context& ctx,
template <class M, class... Ms> template <class M, class... Ms>
auto tree(M main_op, Ms... ms) auto tree(M main_op, Ms... ms)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
// Flatten leaf nodes // Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs; std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins); std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size()) if(idx != leafs.size())
return ctx.not_found(); return nullopt;
// Use explicit captures to workaround ICE on gcc // Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) { bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) {
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)(); return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
}); });
if(not found) if(not found)
return ctx.not_found(); return nullopt;
return ins; return ins;
}); });
} }
...@@ -635,12 +638,13 @@ auto tree(M main_op, Ms... ms) ...@@ -635,12 +638,13 @@ auto tree(M main_op, Ms... ms)
template <class M, class... Ms> template <class M, class... Ms>
auto unordered_tree(M main_op, Ms... ms) auto unordered_tree(M main_op, Ms... ms)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
// Flatten leaf nodes // Flatten leaf nodes
std::array<instruction_ref, sizeof...(Ms)> leafs; std::array<instruction_ref, sizeof...(Ms)> leafs;
std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins); std::size_t idx = tree_leafs_impl(ctx, leafs, main_op, ins);
if(idx != leafs.size()) if(idx != leafs.size())
return ctx.not_found(); return nullopt;
// Use explicit captures to workaround ICE on gcc // Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) { bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) { return by(fold(lazy_and{}), [is..., &ctx, &leafs](auto m) {
...@@ -648,7 +652,7 @@ auto unordered_tree(M main_op, Ms... ms) ...@@ -648,7 +652,7 @@ auto unordered_tree(M main_op, Ms... ms)
})(ms...)(); })(ms...)();
}); });
if(not found) if(not found)
return ctx.not_found(); return nullopt;
return ins; return ins;
}); });
} }
...@@ -656,11 +660,12 @@ auto unordered_tree(M main_op, Ms... ms) ...@@ -656,11 +660,12 @@ auto unordered_tree(M main_op, Ms... ms)
template <class M> template <class M>
auto same_shape(M m) auto same_shape(M m)
{ {
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_basic_fun_matcher(
[=](matcher_context& ctx, instruction_ref ins) -> optional<instruction_ref> {
auto i = m.match(ctx, ins); auto i = m.match(ctx, ins);
if(i != ctx.not_found() and i->get_shape() == ins->get_shape()) if(i and (*i)->get_shape() == ins->get_shape())
return ins; return ins;
return ctx.not_found(); return nullopt;
}); });
} }
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
#include <migraphx/config.hpp>
#if defined(__has_include) && !defined(CPPCHECK)
#if __has_include(<optional>) && __cplusplus >= 201703L
#define MIGRAPHX_HAS_OPTIONAL 1
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#endif
#if __has_include(<experimental/optional>) && __cplusplus >= 201103L
#define MIGRAPHX_HAS_OPTIONAL_TS 1
#else
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#else
#define MIGRAPHX_HAS_OPTIONAL 0
#define MIGRAPHX_HAS_OPTIONAL_TS 0
#endif
#if MIGRAPHX_HAS_OPTIONAL
#include <optional>
#elif MIGRAPHX_HAS_OPTIONAL_TS
#include <experimental/optional>
#else
#error "No optional include available"
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#if MIGRAPHX_HAS_OPTIONAL
template <class T>
using optional = std::optional<T>;
using nullopt_t = std::nullopt_t;
constexpr auto nullopt = std::nullopt;
#elif MIGRAPHX_HAS_OPTIONAL_TS
template <class T>
using optional = std::experimental::optional<T>;
using nullopt_t = std::experimental::nullopt_t;
constexpr auto nullopt = std::experimental::nullopt;
#endif
template <class T>
bool has_value(const optional<T>& x)
{
return x != nullopt;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_OPTIONAL_HPP
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment