#ifndef MIGRAPH_GUARD_RTGLIB_MATCHER_HPP #define MIGRAPH_GUARD_RTGLIB_MATCHER_HPP #include #include #include #include #include #include namespace migraph { struct matcher_context { matcher_context(instruction_ref i) : last(i) {} std::unordered_map instructions; instruction_ref not_found() const { return last; } private: instruction_ref last; }; template struct predicate_matcher { P p; instruction_ref match(matcher_context& ctx, instruction_ref ins) const { assert(ins != ctx.not_found()); if(p(ins)) return ins; return ctx.not_found(); } }; template struct function_matcher { F f; instruction_ref match(matcher_context& ctx, instruction_ref ins) const { assert(ins != ctx.not_found()); return f(ctx, ins); } }; template function_matcher make_function_matcher(F f) { return {f}; } template auto bind_match(M m, std::string name) { return make_function_matcher([=](matcher_context& ctx, instruction_ref ins) { auto result = m.match(ctx, ins); if(result != ctx.not_found()) ctx.instructions.emplace(name, ins); return result; }); } template struct bindable_matcher { M m; auto bind(std::string name) { return bind_match(m, name); } instruction_ref match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); } }; template bindable_matcher make_bindable_matcher(M m) { return {m}; } template bindable_matcher> make_bf_matcher(F f) { return {{f}}; } template bindable_matcher> make_bp_matcher(F f) { return {{f}}; } using bool_list = std::initializer_list; struct id_matcher { instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; } }; template struct basic_matcher { M m; template auto operator()(Ts... ms) const { // Copy m because we cant capture `this` by value auto mm = m; return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { auto result = mm.match(ctx, ins); if(result != ctx.not_found()) { bool matches = fold([&](auto x, auto y) { return x and y.match(ctx, result) != ctx.not_found(); })(true, ms...); if(matches) return result; } return ctx.not_found(); }); } auto bind(std::string name) { return bind_match(m, name); } instruction_ref match(matcher_context& ctx, instruction_ref ins) const { return m.match(ctx, ins); } }; template basic_matcher make_basic_matcher(M m) { return {m}; } template basic_matcher> make_basic_fun_matcher(F f) { return {{f}}; } template basic_matcher> make_basic_pred_matcher(P p) { return {{p}}; } #define MIGRAPH_BASIC_MATCHER(name, ...) \ struct name##_m \ { \ instruction_ref match(__VA_ARGS__) const; \ }; \ const constexpr auto name = migraph::basic_matcher{{}}; \ inline instruction_ref name##_m::match(__VA_ARGS__) const #define MIGRAPH_PRED_MATCHER(name, ...) \ struct name##_m \ { \ bool operator()(__VA_ARGS__) const; \ }; \ const constexpr auto name = migraph::basic_matcher>{{}}; \ inline bool name##_m::operator()(__VA_ARGS__) const struct matcher_result { std::unordered_map instructions; instruction_ref result; }; template matcher_result match_instruction(program& p, instruction_ref ins, M&& m) { assert(ins != p.end()); matcher_result result; matcher_context ctx{p.end()}; result.result = m.match(ctx, ins); result.instructions = ctx.instructions; return result; } template std::array make_array(T x, Ts... xs) { return {x, xs...}; } template bool all_of_eager(Ts... xs) { return make_array((xs, true)...) == make_array(static_cast(xs)...); } namespace matchers { template auto all_of(Ts... ms) { return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { bool matches = fold([&](auto x, auto y) { return x and y.match(ctx, ins) != ctx.not_found(); })(true, ms...); if(matches) return ins; return ctx.not_found(); }); } template auto none_of(Ts... ms) { return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { bool matches = fold([&](auto x, auto y) { return x and y.match(ctx, ins) == ctx.not_found(); })(true, ms...); if(matches) return ins; return ctx.not_found(); }); } template auto any_of(Ts... ms) { return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) { bool matches = fold( [&](auto x, auto y) { return x or y.match(ctx, ins) != ctx.not_found(); })(true, ms...); if(matches) return ins; return ctx.not_found(); }); } MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); } inline auto name(std::string name) { return make_basic_pred_matcher([=](instruction_ref ins) { return ins->name() == name; }); } inline auto arg(std::size_t i) { return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) { if(i < ins->inputs().size()) return ins->inputs()[i]; return ctx.not_found(); }); } // Workaround for bugs in clang template struct args_impl_ints { }; template auto args_impl(args_impl_ints, Ms... ms) { return matchers::all_of(arg(Ns)(ms)...); } template auto args(Ms... ms) { return sequence_c([=](auto... is) { // It needs to be written as `decltype(is)::value` for gcc 5 return args_impl(args_impl_ints{}, ms...); }); } } // namespace matchers } // namespace migraph #endif