Commit 254abcfa authored by Paul's avatar Paul
Browse files

Formatting

parent e2e3606f
...@@ -100,24 +100,22 @@ auto pack(Ts... xs) ...@@ -100,24 +100,22 @@ auto pack(Ts... xs)
return [=](auto f) { return f(xs...); }; return [=](auto f) { return f(xs...); };
} }
template<class F, class T> template <class F, class T>
auto fold_impl(F&&, T&& x) auto fold_impl(F&&, T&& x)
{ {
return x; return x;
} }
template<class F, class T, class U, class... Ts> template <class F, class T, class U, class... Ts>
auto fold_impl(F&& f, T&& x, U&& y, Ts&&... xs) auto fold_impl(F&& f, T&& x, U&& y, Ts&&... xs)
{ {
return fold_impl(f, f(std::forward<T>(x), std::forward<U>(y)), std::forward<Ts>(xs)...); return fold_impl(f, f(std::forward<T>(x), std::forward<U>(y)), std::forward<Ts>(xs)...);
} }
template<class F> template <class F>
auto fold(F f) auto fold(F f)
{ {
return [=](auto&&... xs) { return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); };
return fold_impl(f, std::forward<decltype(xs)>(xs)...);
};
} }
} // namespace migraph } // namespace migraph
......
...@@ -12,19 +12,15 @@ namespace migraph { ...@@ -12,19 +12,15 @@ namespace migraph {
struct matcher_context struct matcher_context
{ {
matcher_context(instruction_ref i) matcher_context(instruction_ref i) : last(i) {}
: last(i)
{}
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const instruction_ref not_found() const { return last; }
{
return last;
}
private: private:
instruction_ref last; instruction_ref last;
}; };
template<class P> template <class P>
struct predicate_matcher struct predicate_matcher
{ {
P p; P p;
...@@ -38,7 +34,7 @@ struct predicate_matcher ...@@ -38,7 +34,7 @@ struct predicate_matcher
} }
}; };
template<class F> template <class F>
struct function_matcher struct function_matcher
{ {
F f; F f;
...@@ -50,13 +46,13 @@ struct function_matcher ...@@ -50,13 +46,13 @@ struct function_matcher
} }
}; };
template<class F> template <class F>
function_matcher<F> make_function_matcher(F f) function_matcher<F> make_function_matcher(F f)
{ {
return {f}; return {f};
} }
template<class M> template <class M>
auto bind_match(M m, std::string name) auto bind_match(M m, std::string name)
{ {
return make_function_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_function_matcher([=](matcher_context& ctx, instruction_ref ins) {
...@@ -67,15 +63,12 @@ auto bind_match(M m, std::string name) ...@@ -67,15 +63,12 @@ auto bind_match(M m, std::string name)
}); });
} }
template<class M> template <class M>
struct bindable_matcher struct bindable_matcher
{ {
M m; M m;
auto bind(std::string name) auto bind(std::string name) { return bind_match(m, name); }
{
return bind_match(m, name);
}
instruction_ref match(matcher_context& ctx, instruction_ref ins) const instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{ {
...@@ -83,19 +76,19 @@ struct bindable_matcher ...@@ -83,19 +76,19 @@ struct bindable_matcher
} }
}; };
template<class M> template <class M>
bindable_matcher<M> make_bindable_matcher(M m) bindable_matcher<M> make_bindable_matcher(M m)
{ {
return {m}; return {m};
} }
template<class F> template <class F>
bindable_matcher<function_matcher<F>> make_bf_matcher(F f) bindable_matcher<function_matcher<F>> make_bf_matcher(F f)
{ {
return {{f}}; return {{f}};
} }
template<class F> template <class F>
bindable_matcher<predicate_matcher<F>> make_bp_matcher(F f) bindable_matcher<predicate_matcher<F>> make_bp_matcher(F f)
{ {
return {{f}}; return {{f}};
...@@ -105,25 +98,22 @@ using bool_list = std::initializer_list<bool>; ...@@ -105,25 +98,22 @@ using bool_list = std::initializer_list<bool>;
struct id_matcher struct id_matcher
{ {
instruction_ref match(matcher_context&, instruction_ref ins) const instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; }
{
return ins;
}
}; };
template<class M> template <class M>
struct basic_matcher struct basic_matcher
{ {
M m; M m;
template<class... Ts> template <class... Ts>
auto operator()(Ts... ms) const auto operator()(Ts... ms) const
{ {
// Copy m because we cant capture `this` by value // Copy m because we cant capture `this` by value
auto mm = m; auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
auto result = mm.match(ctx, ins); auto result = mm.match(ctx, ins);
if(result != ctx.not_found()) if(result != ctx.not_found())
{ {
bool matches = fold([&](auto x, auto y) { bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, result) != ctx.not_found(); return x and y.match(ctx, result) != ctx.not_found();
...@@ -135,10 +125,7 @@ struct basic_matcher ...@@ -135,10 +125,7 @@ struct basic_matcher
}); });
} }
auto bind(std::string name) auto bind(std::string name) { return bind_match(m, name); }
{
return bind_match(m, name);
}
instruction_ref match(matcher_context& ctx, instruction_ref ins) const instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{ {
...@@ -146,41 +133,39 @@ struct basic_matcher ...@@ -146,41 +133,39 @@ struct basic_matcher
} }
}; };
template<class M> template <class M>
basic_matcher<M> make_basic_matcher(M m) basic_matcher<M> make_basic_matcher(M m)
{ {
return {m}; return {m};
} }
template<class F> template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f) basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f)
{ {
return {{f}}; return {{f}};
} }
template<class P> template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p) basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
{ {
return {{p}}; return {{p}};
} }
#define MIGRAPH_BASIC_MATCHER(name, ...) \
#define MIGRAPH_BASIC_MATCHER(name, ...) \ struct name##_m \
struct name ## _m \ { \
{ \ instruction_ref match(__VA_ARGS__) const; \
instruction_ref match(__VA_ARGS__) const; \ }; \
}; \ const constexpr auto name = migraph::basic_matcher<name##_m>{{}}; \
const constexpr auto name = migraph::basic_matcher<name ## _m>{{}}; \ inline instruction_ref name##_m::match(__VA_ARGS__) const
inline instruction_ref name ## _m::match(__VA_ARGS__) const
#define MIGRAPH_PRED_MATCHER(name, ...) \
#define MIGRAPH_PRED_MATCHER(name, ...) \ struct name##_m \
struct name ## _m \ { \
{ \ bool operator()(__VA_ARGS__) const; \
bool operator()(__VA_ARGS__) const; \ }; \
}; \ const constexpr auto name = migraph::basic_matcher<predicate_matcher<name##_m>>{{}}; \
const constexpr auto name = migraph::basic_matcher<predicate_matcher<name ## _m>>{{}}; \ inline bool name##_m::operator()(__VA_ARGS__) const
inline bool name ## _m::operator()(__VA_ARGS__) const
struct matcher_result struct matcher_result
{ {
...@@ -188,23 +173,23 @@ struct matcher_result ...@@ -188,23 +173,23 @@ struct matcher_result
instruction_ref result; instruction_ref result;
}; };
template<class M> template <class M>
matcher_result match_instruction(program& p, instruction_ref ins, M&& m) matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
{ {
assert(ins != p.end()); assert(ins != p.end());
matcher_result result; matcher_result result;
matcher_context ctx{p.end()}; matcher_context ctx{p.end()};
result.result = m.match(ctx, ins); result.result = m.match(ctx, ins);
return result; return result;
} }
template<class T, class... Ts> template <class T, class... Ts>
std::array<T, sizeof...(Ts)+1> make_array(T x, Ts... xs) std::array<T, sizeof...(Ts) + 1> make_array(T x, Ts... xs)
{ {
return {x, xs...}; return {x, xs...};
} }
template<class... Ts> template <class... Ts>
bool all_of_eager(Ts... xs) bool all_of_eager(Ts... xs)
{ {
return make_array((xs, true)...) == make_array(static_cast<bool>(xs)...); return make_array((xs, true)...) == make_array(static_cast<bool>(xs)...);
...@@ -212,7 +197,7 @@ bool all_of_eager(Ts... xs) ...@@ -212,7 +197,7 @@ bool all_of_eager(Ts... xs)
namespace matchers { namespace matchers {
template<class... Ts> template <class... Ts>
auto all_of(Ts... ms) auto all_of(Ts... ms)
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
...@@ -225,7 +210,7 @@ auto all_of(Ts... ms) ...@@ -225,7 +210,7 @@ auto all_of(Ts... ms)
}); });
} }
template<class... Ts> template <class... Ts>
auto none_of(Ts... ms) auto none_of(Ts... ms)
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
...@@ -238,29 +223,23 @@ auto none_of(Ts... ms) ...@@ -238,29 +223,23 @@ auto none_of(Ts... ms)
}); });
} }
template<class... Ts> template <class... Ts>
auto any_of(Ts... ms) auto any_of(Ts... ms)
{ {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) { bool matches = fold(
return x or y.match(ctx, ins) != ctx.not_found(); [&](auto x, auto y) { return x or y.match(ctx, ins) != ctx.not_found(); })(true, ms...);
})(true, ms...);
if(matches) if(matches)
return ins; return ins;
return ctx.not_found(); return ctx.not_found();
}); });
} }
MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
{
return ins->get_shape().standard();
}
inline auto name(std::string name) inline auto name(std::string name)
{ {
return make_basic_pred_matcher([=](instruction_ref ins) { return make_basic_pred_matcher([=](instruction_ref ins) { return ins->name() == name; });
return ins->name() == name;
});
} }
inline auto arg(std::size_t i) inline auto arg(std::size_t i)
...@@ -273,22 +252,24 @@ inline auto arg(std::size_t i) ...@@ -273,22 +252,24 @@ inline auto arg(std::size_t i)
} }
// Workaround for bugs in clang // Workaround for bugs in clang
template<std::size_t...> template <std::size_t...>
struct args_impl_ints {}; struct args_impl_ints
{
};
template<std::size_t... Ns, class... Ms> template <std::size_t... Ns, class... Ms>
auto args_impl(args_impl_ints<Ns...>, Ms... ms) auto args_impl(args_impl_ints<Ns...>, Ms... ms)
{ {
return matchers::all_of(arg(Ns)(ms)...); return matchers::all_of(arg(Ns)(ms)...);
} }
template<class... Ms> template <class... Ms>
auto args(Ms... ms) auto args(Ms... ms)
{ {
return sequence_c<sizeof...(Ms)>([=](auto... is) { return sequence_c<sizeof...(Ms)>([=](auto... is) {
// It needs to be written as `decltype(is)::value` for gcc 5 // It needs to be written as `decltype(is)::value` for gcc 5
return args_impl(args_impl_ints<decltype(is)::value...>{}, ms...); return args_impl(args_impl_ints<decltype(is)::value...>{}, ms...);
}); });
} }
} // namespace matchers } // namespace matchers
......
...@@ -5,11 +5,11 @@ ...@@ -5,11 +5,11 @@
namespace matchers = migraph::matchers; namespace matchers = migraph::matchers;
template<class M> template <class M>
migraph::matcher_result find_match(migraph::program& p, M&& m) migraph::matcher_result find_match(migraph::program& p, M&& m)
{ {
migraph::matcher_result result; migraph::matcher_result result;
for(auto ins:migraph::iterator_for(p)) for(auto ins : migraph::iterator_for(p))
{ {
result = migraph::match_instruction(p, ins, m); result = migraph::match_instruction(p, ins, m);
if(result.result != p.end()) if(result.result != p.end())
...@@ -30,8 +30,8 @@ void match1() ...@@ -30,8 +30,8 @@ void match1()
void match_name1() void match_name1()
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum"); auto m = matchers::name("sum");
...@@ -42,8 +42,8 @@ void match_name1() ...@@ -42,8 +42,8 @@ void match_name1()
void match_name2() void match_name2()
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("min"); auto m = matchers::name("min");
...@@ -54,8 +54,8 @@ void match_name2() ...@@ -54,8 +54,8 @@ void match_name2()
void match_name3() void match_name3()
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::standard_shape()); auto m = matchers::name("sum")(matchers::standard_shape());
...@@ -66,11 +66,12 @@ void match_name3() ...@@ -66,11 +66,12 @@ void match_name3()
void match_arg1() void match_arg1()
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")), matchers::standard_shape()); auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
...@@ -78,11 +79,12 @@ void match_arg1() ...@@ -78,11 +79,12 @@ void match_arg1()
void match_arg2() void match_arg2()
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("sum")), matchers::standard_shape()); auto m =
matchers::name("sum")(matchers::arg(0)(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
...@@ -90,11 +92,12 @@ void match_arg2() ...@@ -90,11 +92,12 @@ void match_arg2()
void match_arg3() void match_arg3()
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(1)(matchers::name("@literal")), matchers::standard_shape()); auto m = matchers::name("sum")(matchers::arg(1)(matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
...@@ -104,9 +107,10 @@ void match_arg4() ...@@ -104,9 +107,10 @@ void match_arg4()
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum); auto pass = p.add_instruction(pass_op{}, sum);
auto m = matchers::name("pass")(matchers::arg(0)(matchers::name("sum")), matchers::standard_shape()); auto m =
matchers::name("pass")(matchers::arg(0)(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == pass}); EXPECT(bool{r.result == pass});
} }
...@@ -114,11 +118,12 @@ void match_arg4() ...@@ -114,11 +118,12 @@ void match_arg4()
void match_arg5() void match_arg5()
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("pass")(matchers::arg(1)(matchers::name("sum")), matchers::standard_shape()); auto m =
matchers::name("pass")(matchers::arg(1)(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
...@@ -126,8 +131,8 @@ void match_arg5() ...@@ -126,8 +131,8 @@ void match_arg5()
void match_arg6() void match_arg6()
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal"))); auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")));
...@@ -138,11 +143,12 @@ void match_arg6() ...@@ -138,11 +143,12 @@ void match_arg6()
void match_arg7() void match_arg7()
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")), matchers::arg(1)(matchers::name("@literal"))); auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")),
matchers::arg(1)(matchers::name("@literal")));
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
...@@ -150,16 +156,19 @@ void match_arg7() ...@@ -150,16 +156,19 @@ void match_arg7()
void match_args1() void match_args1()
{ {
migraph::program p; migraph::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two); auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum); p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::args(matchers::name("@literal"), matchers::name("@literal")), matchers::standard_shape()); auto m = matchers::name("sum")(
matchers::args(matchers::name("@literal"), matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m); auto r = find_match(p, m);
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
int main() { int main()
{
match1(); match1();
match_name1(); match_name1();
match_name2(); match_name2();
...@@ -174,5 +183,4 @@ int main() { ...@@ -174,5 +183,4 @@ int main() {
match_arg7(); match_arg7();
match_args1(); match_args1();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment