Unverified Commit 3d264140 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge pull request #56 from ROCmSoftwarePlatform/matcher

Add matcher API
parents b56c8be0 f1102c26
Matchers
========
Introduction
------------
The matchers provide a way compose several predicates together. Many of the matchers can be composed so that ``m(m1, m2)`` will first check that ``m`` matches and then it will check that ``m1`` and ``m2`` will match.
The most commonly-used matcher is the ``name`` matcher. It will match the instruction that have the operator that is equal to the name specified::
auto match_sum = name("sum");
This will find ``sum`` operators. We can also find ``sum`` operators which the output is ``standard_shape``:
auto match_sum = name("sum")(standard_shape());
Arguments
---------
We also want to match arguments to the instructions as well. One way, is to match each argument using the ``arg`` matcher::
auto match_sum = name("sum")(arg(0)(name("@literal"), arg(1)(name("@literal"))));
This will match a ``sum`` operator with the two arguments that are literals. Of course, instead of writing ``arg(0)`` and ``arg(1)`` everytime, the ``args`` matcher can be used::
auto match_sum = name("sum")(args(name("@literal"), name("@literal")));
Binding
-------
As we traverse through the instructions we may want reference some of the instructions we find along the way. We can do this by calling ``.bind``::
auto match_sum = name("sum")(args(
name("@literal").bind("one"),
name("@literal").bind("two")
)).bind("sum");
This will associate the instruction to a name that can be read from the ``matcher_result`` when it matches.
Finding matches
---------------
Finally, when you want to use the matchers to find instructions a callback object can be written which has the matcher and an ``apply`` function which will take the ``matcher_result`` when the match is found::
struct match_find_sum
{
auto matcher() const { return name("sum"); }
void apply(program& p, matcher_result r) const
{
// Do something with the result
}
};
find_matches(prog, match_find_sum{});
Creating matchers
-----------------
There are several ways to create matchers. The macros ``MIGRAPH_BASIC_MATCHER`` and ``MIGRAPH_PRED_MATCHER`` help with creating matchers. For example, we can create a matcher for shapes that are broadcasted::
MIGRAPH_PRED_MATCHER(broadcasted_shape, instruction_ref ins)
{
return ins->get_shape().broadcasted();
}
If we want parameters to the predicate, then we will need to use the ``make_basic_pred_matcher`` to create the matcher. For example, here is how we would create a matcher to check the number of dimensions of the shape::
inline auto number_of_dims(std::size_t n)
{
return make_basic_pred_matcher([=](instruction_ref ins) {
return ins->get_shape().lens().size() == n;
});
}
Developer Guide
===============
.. toctree::
:maxdepth: 2
:caption: Contents:
dev/matchers
......@@ -7,15 +7,11 @@ Welcome to MIGraph's documentation!
===================================
.. toctree::
:maxdepth: 2
:maxdepth: 3
:caption: Contents:
overview
reference/data
reference/operators
reference/program
reference/targets
reference/pass
user_guide
developer_guide
Indices and tables
......
User Guide
==========
.. toctree::
:maxdepth: 2
:caption: Contents:
overview
reference/data
reference/operators
reference/program
reference/targets
reference/pass
......@@ -61,6 +61,12 @@ constexpr void repeat_c_impl(F f, seq<Ns...>)
swallow{(f(std::integral_constant<std::size_t, Ns>{}), 0)...};
}
template <class F, std::size_t... Ns>
constexpr auto sequence_c_impl(F&& f, seq<Ns...>)
{
return f(std::integral_constant<std::size_t, Ns>{}...);
}
} // namespace detail
template <std::size_t N, class F>
......@@ -69,6 +75,18 @@ constexpr void repeat_c(F f)
detail::repeat_c_impl(f, detail::gens<N>{});
}
template <std::size_t N, class F>
constexpr auto sequence_c(F&& f)
{
return detail::sequence_c_impl(f, detail::gens<N>{});
}
template <class F, class... Ts>
constexpr void each_args(F f, Ts&&... xs)
{
swallow{(f(std::forward<Ts>(xs)), 0)...};
}
/// Implements a fix-point combinator
template <class R, class F>
detail::fix_f<R, F> fix(F f)
......@@ -88,6 +106,24 @@ auto pack(Ts... xs)
return [=](auto f) { return f(xs...); };
}
template <class F, class T>
auto fold_impl(F&&, T&& x)
{
return x;
}
template <class F, class T, class U, class... Ts>
auto fold_impl(F&& f, T&& x, U&& y, Ts&&... xs)
{
return fold_impl(f, f(std::forward<T>(x), std::forward<U>(y)), std::forward<Ts>(xs)...);
}
template <class F>
auto fold(F f)
{
return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); };
}
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_MATCHER_HPP
#define MIGRAPH_GUARD_RTGLIB_MATCHER_HPP
#include <migraph/functional.hpp>
#include <migraph/ranges.hpp>
#include <migraph/instruction.hpp>
#include <migraph/program.hpp>
#include <migraph/iterator_for.hpp>
#include <unordered_map>
namespace migraph {
namespace matchers {
struct matcher_context
{
matcher_context(instruction_ref i) : last(i) {}
std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
private:
instruction_ref last;
};
/// Convert a predicate function into a matcher
template <class P>
struct predicate_matcher
{
P p;
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
assert(ins != ctx.not_found());
if(p(ins))
return ins;
return ctx.not_found();
}
};
/// Convert a function into a matcher
template <class F>
struct function_matcher
{
F f;
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
assert(ins != ctx.not_found());
return f(ctx, ins);
}
};
/// Convert a function into a matcher
template <class F>
function_matcher<F> make_function_matcher(F f)
{
return {f};
}
/// Converts a matcher to bind the instruction to name
template <class M>
auto bind_match(M m, std::string name)
{
return make_function_matcher(
[ =, name = std::move(name) ](matcher_context & ctx, instruction_ref ins) {
auto result = m.match(ctx, ins);
if(result != ctx.not_found())
ctx.instructions.emplace(name, ins);
return result;
});
}
/// Convert a matcher to a bindable matcher
template <class M>
struct bindable_matcher
{
M m;
auto bind(std::string name) { return bind_match(m, std::move(name)); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
return m.match(ctx, ins);
}
};
/// Create a bindable matcher
template <class M>
bindable_matcher<M> make_bindable_matcher(M m)
{
return {m};
}
/// Create a bindable matcher from a function
template <class F>
bindable_matcher<function_matcher<F>> make_bf_matcher(F f)
{
return {{f}};
}
/// Create a bindable matcher from a predicate function
template <class F>
bindable_matcher<predicate_matcher<F>> make_bp_matcher(F f)
{
return {{f}};
}
using bool_list = std::initializer_list<bool>;
struct id_matcher
{
instruction_ref match(matcher_context&, instruction_ref ins) const { return ins; }
};
/// The basic matcher provides the all_of composability of the matcher
template <class M>
struct basic_matcher
{
M m;
template <class... Ts>
auto operator()(Ts... ms) const
{
// Copy m because we cant capture `this` by value
auto mm = m;
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
auto result = mm.match(ctx, ins);
if(result != ctx.not_found())
{
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, result) != ctx.not_found();
})(true, ms...);
if(matches)
return result;
}
return ctx.not_found();
});
}
auto bind(std::string name) { return bind_match(m, name); }
instruction_ref match(matcher_context& ctx, instruction_ref ins) const
{
return m.match(ctx, ins);
}
};
/// Create a basic matcher from a matcher
template <class M>
basic_matcher<M> make_basic_matcher(M m)
{
return {m};
}
/// Create a basic matcher from a function
template <class F>
basic_matcher<function_matcher<F>> make_basic_fun_matcher(F f)
{
return {{f}};
}
/// Create a basic matcher from a predicate function
template <class P>
basic_matcher<predicate_matcher<P>> make_basic_pred_matcher(P p)
{
return {{p}};
}
/// This macro takes care of the boilerplate for defining a matcher
#define MIGRAPH_BASIC_MATCHER(name, ...) \
struct name##_m \
{ \
instruction_ref match(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraph::matchers::basic_matcher<name##_m>{{}}; \
inline instruction_ref name##_m::match(__VA_ARGS__) const
/// This macro takes care of the boilerplate for defining a predicate matcher
#define MIGRAPH_PRED_MATCHER(name, ...) \
struct name##_m \
{ \
bool operator()(__VA_ARGS__) const; \
}; \
const constexpr auto name = migraph::matchers::basic_matcher<predicate_matcher<name##_m>>{{}}; \
inline bool name##_m::operator()(__VA_ARGS__) const
struct matcher_result
{
std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref result;
};
/// Match a single instruction
template <class M>
matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
{
assert(ins != p.end());
matcher_result result;
matcher_context ctx{p.end()};
result.result = m.match(ctx, ins);
result.instructions = ctx.instructions;
return result;
}
/// Find matches in a program
template <class... Ms>
void find_matches(program& p, Ms&&... ms)
{
for(auto ins : iterator_for(p))
{
bool match = false;
each_args(
[&](auto&& m) {
// cppcheck-suppress knownConditionTrueFalse
if(match)
return;
auto r = match_instruction(p, ins, m.matcher());
if(r.result == p.end())
return;
m.apply(p, r);
match = true;
},
ms...);
}
}
template <class... Ts>
auto all_of(Ts... ms)
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, ins) != ctx.not_found();
})(true, ms...);
if(matches)
return ins;
return ctx.not_found();
});
}
template <class... Ts>
auto none_of(Ts... ms)
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, ins) == ctx.not_found();
})(true, ms...);
if(matches)
return ins;
return ctx.not_found();
});
}
template <class... Ts>
auto any_of(Ts... ms)
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x or y.match(ctx, ins) != ctx.not_found();
})(false, ms...);
if(matches)
return ins;
return ctx.not_found();
});
}
MIGRAPH_PRED_MATCHER(standard_shape, instruction_ref ins) { return ins->get_shape().standard(); }
inline auto name(std::string name)
{
return make_basic_pred_matcher(
[ =, name = std::move(name) ](instruction_ref ins) { return ins->name() == name; });
}
inline auto nargs(std::size_t n)
{
return make_basic_pred_matcher([=](instruction_ref ins) { return ins->inputs().size() == n; });
}
inline auto arg(std::size_t i)
{
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref ins) {
if(i < ins->inputs().size())
return ins->inputs()[i];
return ctx.not_found();
});
}
// Workaround for bugs in clang
template <std::size_t...>
struct args_impl_ints
{
};
template <std::size_t... Ns, class... Ms>
auto args_impl(args_impl_ints<Ns...>, Ms... ms)
{
return matchers::all_of(nargs(sizeof...(Ns)), arg(Ns)(ms)...);
}
template <class... Ms>
auto args(Ms... ms)
{
return sequence_c<sizeof...(Ms)>([=](auto... is) {
// It needs to be written as `decltype(is)::value` for gcc 5
return args_impl(args_impl_ints<decltype(is)::value...>{}, ms...);
});
}
} // namespace matchers
} // namespace migraph
#endif
#include <migraph/matcher.hpp>
#include <migraph/iterator_for.hpp>
#include <test.hpp>
#include <basic_ops.hpp>
namespace matchers = migraph::matchers;
template <class M>
migraph::matchers::matcher_result find_match(migraph::program& p, M&& m)
{
migraph::matchers::matcher_result result;
for(auto ins : migraph::iterator_for(p))
{
result = migraph::matchers::match_instruction(p, ins, m);
if(result.result != p.end())
return result;
}
return result;
}
void match1()
{
migraph::program p;
auto l = p.add_literal(1);
auto m = matchers::standard_shape();
auto r = find_match(p, m);
EXPECT(bool{r.result == l});
}
void match_name1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum");
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_name2()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("min");
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_name3()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_arg1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_arg2()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("sum")(matchers::arg(0)(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_arg3()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(1)(matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_arg4()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("pass")(matchers::arg(0)(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == pass});
}
void match_arg5()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("pass")(matchers::arg(1)(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_arg6()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_arg7()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::arg(0)(matchers::name("@literal")),
matchers::arg(1)(matchers::name("@literal")));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_args1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(
matchers::args(matchers::name("@literal"), matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_args2()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("sum")(matchers::args(matchers::name("@literal"), matchers::name("sum")),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_args3()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::args(matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_args4()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum1 = p.add_instruction(sum_op{}, one, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, two);
p.add_instruction(pass_op{}, sum2);
auto m =
matchers::name("sum")(matchers::args(matchers::name("sum"), matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum2});
}
void match_args5()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("sum")(matchers::args(matchers::name("sum"), matchers::name("@literal")),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_args6()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum);
auto m =
matchers::name("pass")(matchers::args(matchers::name("sum")), matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == pass});
}
void match_args7()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum);
auto m = matchers::name("pass")(matchers::args(matchers::name("sum")(matchers::args(
matchers::name("@literal"), matchers::name("@literal")))),
matchers::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == pass});
}
void match_all_of1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::all_of(matchers::arg(0)(matchers::name("@literal")),
matchers::arg(1)(matchers::name("@literal"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_all_of2()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::all_of(matchers::arg(0)(matchers::name("sum")),
matchers::arg(1)(matchers::name("@literal"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_any_of1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::any_of(matchers::arg(0)(matchers::name("sum")),
matchers::arg(1)(matchers::name("@literal"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_any_of2()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::any_of(matchers::arg(0)(matchers::name("sum")),
matchers::arg(1)(matchers::name("sum"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_none_of1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::none_of(matchers::arg(0)(matchers::name("sum")),
matchers::arg(1)(matchers::name("sum"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
void match_none_of2()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = matchers::name("sum")(matchers::none_of(matchers::arg(0)(matchers::name("@literal")),
matchers::arg(1)(matchers::name("@literal"))));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
void match_bind1()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
auto pass = p.add_instruction(pass_op{}, sum);
auto m = matchers::name("pass")(
matchers::args(
matchers::name("sum")(matchers::args(matchers::name("@literal").bind("one"),
matchers::name("@literal").bind("two")))
.bind("sum")),
matchers::standard_shape())
.bind("pass");
auto r = find_match(p, m);
EXPECT(bool{r.instructions.at("one") == one});
EXPECT(bool{r.instructions.at("two") == two});
EXPECT(bool{r.instructions.at("sum") == sum});
EXPECT(bool{r.instructions.at("pass") == pass});
EXPECT(bool{r.result == pass});
}
struct match_find_sum
{
migraph::instruction_ref ins;
auto matcher() const { return matchers::name("sum"); }
void apply(migraph::program&, matchers::matcher_result r) const
{
EXPECT(bool{r.result == ins});
}
};
struct match_find_literal
{
migraph::instruction_ref ins;
auto matcher() const { return matchers::name("@literal"); }
void apply(migraph::program&, matchers::matcher_result r) const
{
EXPECT(bool{r.result != ins});
EXPECT(r.result->name() == "@literal");
}
};
void match_finder()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
matchers::find_matches(p, match_find_sum{sum}, match_find_literal{sum});
}
int main()
{
match1();
match_name1();
match_name2();
match_name3();
match_arg1();
match_arg2();
match_arg3();
match_arg4();
match_arg5();
match_arg6();
match_arg7();
match_args1();
match_args2();
match_args3();
match_args4();
match_args5();
match_args6();
match_args7();
match_all_of1();
match_all_of2();
match_any_of1();
match_any_of2();
match_none_of1();
match_none_of2();
match_bind1();
match_finder();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment