Commit 2e51006e authored by Paul's avatar Paul
Browse files

Use matchers

parent 0fcf61e0
......@@ -8,6 +8,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/config.hpp>
#include <unordered_map>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -20,6 +21,12 @@ struct matcher_context
std::unordered_map<std::string, instruction_ref> instructions;
instruction_ref not_found() const { return last; }
template<class M>
bool matched(M m, instruction_ref ins)
{
return m.match(*this, ins) != this->not_found();
}
private:
instruction_ref last;
};
......@@ -205,64 +212,89 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return result;
}
/// Find matches for an instruction in the program
template <class... Ms>
void find_matches(program& p, instruction_ref ins, Ms&&... ms)
{
bool match = false;
each_args(
[&](auto&& m) {
if(match)
return;
auto r = match_instruction(p, ins, m.matcher());
if(r.result == p.end())
return;
m.apply(p, r);
match = true;
},
ms...);
}
/// Find matches in a program
template <class... Ms>
void find_matches(program& p, Ms&&... ms)
{
for(auto ins : iterator_for(p))
{
bool match = false;
each_args(
[&](auto&& m) {
if(match)
return;
auto r = match_instruction(p, ins, m.matcher());
if(r.result == p.end())
return;
m.apply(p, r);
match = true;
},
ms...);
find_matches(p, ins, ms...);
}
}
template <class... Ts>
auto all_of(Ts... ms)
template<class Op, bool Start, bool Matches>
struct folder
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, ins) != ctx.not_found();
})(true, ms...);
if(matches)
return ins;
return ctx.not_found();
});
}
template <class... Ts>
auto operator()(Ts... ms) const
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
Op op;
bool matches = fold([&](auto x, auto y) {
return op(x, y.match(ctx, ins) != ctx.not_found());
})(Start, ms...);
if(matches == Matches)
return ins;
return ctx.not_found();
});
}
template<class Selector>
auto operator[](Selector select) const
{
return [=](auto... ms) {
return make_bf_matcher([=](matcher_context& ctx, instruction_ref start) {
Op op;
bool matches = Start;
select(start, [&](auto ins) {
matches = op(matches, fold([&](auto x, auto y) {
return op(x, y.match(ctx, ins) == ctx.not_found());
})(Start, ms...));
});
if(matches == Matches)
return start;
return ctx.not_found();
});
};
}
};
template <class... Ts>
auto none_of(Ts... ms)
const constexpr auto all_of = folder<std::logical_and<bool>, true, true>{};
const constexpr auto any_of = folder<std::logical_or<bool>, false, true>{};
const constexpr auto none_of = folder<std::logical_or<bool>, false, false>{};
inline auto inputs()
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x and y.match(ctx, ins) == ctx.not_found();
})(true, ms...);
if(matches)
return ins;
return ctx.not_found();
});
return [](auto ins, auto f) {
for(auto&& x:ins->inputs())
f(x);
};
}
template <class... Ts>
auto any_of(Ts... ms)
inline auto outputs()
{
return make_bf_matcher([=](matcher_context& ctx, instruction_ref ins) {
bool matches = fold([&](auto x, auto y) {
return x or y.match(ctx, ins) != ctx.not_found();
})(false, ms...);
if(matches)
return ins;
return ctx.not_found();
});
return [](auto ins, auto f) {
for(auto&& x:ins->outputs())
f(x);
};
}
MIGRAPHX_PRED_MATCHER(any, instruction_ref) { return true; }
......@@ -273,6 +305,21 @@ MIGRAPHX_PRED_MATCHER(broadcast_shape, instruction_ref ins)
return ins->get_shape().broadcasted();
}
MIGRAPHX_PRED_MATCHER(transpose_shape, instruction_ref ins)
{
return ins->get_shape().transposed();
}
MIGRAPHX_PRED_MATCHER(same_shapes, instruction_ref ins)
{
if (ins->inputs().empty())
return false;
auto s = ins->inputs().front()->get_shape();
return std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto x) {
return x->get_shape() == s;
});
}
MIGRAPHX_BASIC_MATCHER(output, matcher_context& ctx, instruction_ref ins)
{
if(ins->outputs().size() == 1)
......@@ -289,10 +336,38 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return ctx.not_found();
}
inline auto name(std::string name)
template<class... Ms>
auto skip_output(Ms... ms)
{
auto m = any_of(ms...);
return make_basic_fun_matcher([=](matcher_context& ctx, instruction_ref start) {
return fix<instruction_ref>([&](auto self, auto ins) {
if(ins->outputs().size() == 1)
{
auto next = ins->outputs().front();
if (ctx.matched(m, next))
{
auto skipped_next = self(next);
if (skipped_next != ctx.not_found())
return skipped_next;
}
return next;
}
return ctx.not_found();
})(start);
});
}
inline auto name(std::string s)
{
return make_basic_pred_matcher(
[ =, s = std::move(s) ](instruction_ref ins) { return ins->name() == s; });
}
inline auto name(std::unordered_set<std::string> names)
{
return make_basic_pred_matcher(
[ =, name = std::move(name) ](instruction_ref ins) { return ins->name() == name; });
[ =, names = std::move(names) ](instruction_ref ins) { return names.count(ins->name()) > 0; });
}
inline auto nargs(std::size_t n)
......
......@@ -6,12 +6,13 @@
#include <migraphx/op/concat.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool is_reshaper(instruction_ref ins)
const auto& reshaper_names()
{
// clang-format off
static const std::unordered_set<std::string> names = {
......@@ -21,16 +22,12 @@ bool is_reshaper(instruction_ref ins)
"unsqueeze"
};
// clang-format on
return contains(names, ins->name());
return names;
}
bool is_transpose_output(instruction_ref ins)
bool is_reshaper(instruction_ref ins)
{
if(ins->outputs().size() != 1)
return false;
if(ins->outputs().front()->name() == "contiguous")
return is_transpose_output(ins->outputs().front());
return ins->outputs().front()->name() == "transpose";
return contains(reshaper_names(), ins->name());
}
instruction_ref find_transpose_input(instruction_ref ins)
......@@ -89,96 +86,133 @@ std::vector<int64_t> find_permutation(const shape& s)
return sort_permutation(s.strides(), std::greater<>{});
}
void simplify_reshapes::apply(program& p) const
struct find_reshaper
{
auto end = std::prev(p.end());
for(auto ins : iterator_for(p))
auto matcher() const
{
if(ins == end and ins->name() == "contiguous")
continue;
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end)
continue;
if(is_reshaper(ins))
return match::name(reshaper_names())(match::any_of[match::outputs()](match::name(reshaper_names())));
}
void apply(program& p, match::matcher_result mr) const
{
auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{
if(std::any_of(ins->outputs().begin(), ins->outputs().end(), &is_reshaper))
continue;
// Gather reshapes
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{
assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
assert(!reshapes.back()->inputs().empty());
assert(p.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
for(auto start : iterator_for(reshapes))
{
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
std::pair<instruction_ref, instruction_ref> r{p.end(), p.end()};
for(auto start : iterator_for(reshapes))
{
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
p.replace_instruction(r.first, r.second);
r = std::make_pair(*start, *last);
break;
}
}
else if(ins->name() == "transpose")
if(r.first != r.second)
{
if(is_transpose_output(ins))
continue;
auto x = ins;
auto t = ins;
std::vector<std::int64_t> dims(ins->get_shape().lens().size());
std::iota(dims.begin(), dims.end(), 0);
do
{
dims = reorder_dims(get_transpose_dims(t), dims);
x = t;
t = find_transpose_input(x);
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
continue;
if(is_no_transpose(dims))
{
p.replace_instruction(ins, t->inputs().front());
}
else
{
p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
}
p.replace_instruction(r.first, r.second);
}
else if(ins->name() == "concat")
}
};
MIGRAPHX_PRED_MATCHER(is_transpose_output, instruction_ref start)
{
return fix<bool>([&](auto self, auto ins) {
if(ins->outputs().size() != 1)
return false;
if(ins->outputs().front()->name() == "contiguous")
return self(ins->outputs().front());
return ins->outputs().front()->name() == "transpose";
})(start);
}
struct find_transpose
{
auto matcher() const
{
return match::name("transpose")(match::none_of(match::skip_output(match::name("contiguous"))(match::name("transpose"))));
}
void apply(program& p, match::matcher_result mr) const
{
auto ins = mr.result;
auto x = ins;
auto t = ins;
std::vector<std::int64_t> dims(ins->get_shape().lens().size());
std::iota(dims.begin(), dims.end(), 0);
do
{
dims = reorder_dims(get_transpose_dims(t), dims);
x = t;
t = find_transpose_input(x);
} while(x != t and t->name() == "transpose");
if(t == ins or t->name() != "transpose")
return;
if(is_no_transpose(dims))
{
p.replace_instruction(ins, t->inputs().front());
}
else
{
if(ins->inputs().empty())
continue;
auto s = ins->inputs().front()->get_shape();
if(none_of(ins->inputs(), [&](auto i) { return i->get_shape().transposed(); }) or
none_of(ins->inputs(), [&](auto i) { return i->get_shape() == s; }))
continue;
auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s);
auto ipermutaion = invert_permutation(permutation);
op.axis = ipermutaion[op.axis];
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto i) { return p.insert_instruction(ins, op::transpose{permutation}, i); });
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutaion}, concat);
p.replace_instruction(ins, t);
p.replace_instruction(ins, op::transpose{{dims}}, t->inputs().front());
}
}
};
struct find_concat_transpose
{
auto matcher() const
{
return match::name("concat")(match::same_shapes(), match::all_of[match::inputs()](match::transpose_shape()));
}
void apply(program& p, match::matcher_result mr) const
{
auto ins = mr.result;
auto s = ins->inputs().front()->get_shape();
auto op = any_cast<op::concat>(ins->get_operator());
auto permutation = find_permutation(s);
auto ipermutaion = invert_permutation(permutation);
op.axis = ipermutaion[op.axis];
std::vector<instruction_ref> inputs;
std::transform(
ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto i) { return p.insert_instruction(ins, op::transpose{permutation}, i); });
auto concat = p.insert_instruction(ins, op, inputs);
auto t = p.insert_instruction(ins, op::transpose{ipermutaion}, concat);
p.replace_instruction(ins, t);
}
};
void simplify_reshapes::apply(program& p) const
{
auto end = std::prev(p.end());
for(auto ins : iterator_for(p))
{
if(ins == end and ins->name() == "contiguous")
continue;
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end)
continue;
match::find_matches(p, ins,
find_reshaper{},
find_transpose{},
find_concat_transpose{}
);
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -359,6 +359,132 @@ TEST_CASE(match_none_of2)
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_output1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::output(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_output2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::output(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_skip_output1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto minus_pass = p.add_instruction(pass_op{}, minus);
auto sum = p.add_instruction(sum_op{}, minus_pass, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto minus_pass1 = p.add_instruction(pass_op{}, minus);
auto minus_pass2 = p.add_instruction(pass_op{}, minus_pass1);
auto minus_pass3 = p.add_instruction(pass_op{}, minus_pass2);
auto sum = p.add_instruction(sum_op{}, minus_pass3, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto pass = p.add_instruction(pass_op{}, one);
auto sum = p.add_instruction(sum_op{}, pass, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == two});
}
TEST_CASE(match_skip_output5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto pass = p.add_instruction(pass_op{}, one);
auto sum1 = p.add_instruction(sum_op{}, pass, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one);
auto sum3 = p.add_instruction(sum_op{}, sum2, two);
p.add_instruction(pass_op{}, sum3);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_skip_output6)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum1 = p.add_instruction(sum_op{}, minus, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one);
auto sum3 = p.add_instruction(sum_op{}, sum2, two);
p.add_instruction(pass_op{}, sum3);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output7)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus1 = p.add_instruction(minus_op{}, two, one);
auto minus2 = p.add_instruction(minus_op{}, two, minus1);
auto sum = p.add_instruction(sum_op{}, one, minus2);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus1});
}
TEST_CASE(match_bind1)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment