Commit c10e3bb9 authored by Paul's avatar Paul
Browse files

Reorder two adds as well

parent eeb92b57
...@@ -240,6 +240,25 @@ void find_matches(program& p, Ms&&... ms) ...@@ -240,6 +240,25 @@ void find_matches(program& p, Ms&&... ms)
} }
} }
template<class M>
struct find_skip
{
M m;
M matcher() const
{
return m;
}
void apply(program&, matcher_result) const
{}
};
template<class M>
find_skip<M> make_find_skip(M m)
{
return {m};
}
struct lazy_and struct lazy_and
{ {
template <class F, class G> template <class F, class G>
...@@ -311,6 +330,12 @@ const constexpr auto all_of = match_fold_f<lazy_and, true, true>{}; ...@@ -311,6 +330,12 @@ const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
const constexpr auto any_of = match_fold_f<lazy_or, false, true>{}; const constexpr auto any_of = match_fold_f<lazy_or, false, true>{};
const constexpr auto none_of = match_fold_f<lazy_or, false, false>{}; const constexpr auto none_of = match_fold_f<lazy_or, false, false>{};
template<class... Ms>
auto skip_matches(Ms... ms)
{
return make_find_skip(any_of(ms...));
}
inline auto inputs() inline auto inputs()
{ {
return [](auto ins, auto f) { return [](auto ins, auto f) {
...@@ -371,6 +396,13 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins) ...@@ -371,6 +396,13 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); } MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, matcher_context& ctx, instruction_ref ins)
{
if (ins->outputs().empty() and ins != std::prev(ctx.not_found()))
return ins;
return ctx.not_found();
}
template <class... Ms> template <class... Ms>
auto skip_output(Ms... ms) auto skip_output(Ms... ms)
{ {
......
...@@ -58,7 +58,9 @@ struct find_mul_add ...@@ -58,7 +58,9 @@ struct find_mul_add
return match::name("mul")(match::either_arg(0, 1)( return match::name("mul")(match::either_arg(0, 1)(
match::name("add")(match::either_arg(0, 1)( match::name("add")(match::either_arg(0, 1)(
match::any().bind("x"), match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("y"))), match::any_of(conv_const_weights(), match::is_constant()).bind("y")),
match::none_of(match::args(match::is_constant(), match::is_constant()))
),
match::is_constant().bind("a"))); match::is_constant().bind("a")));
} }
...@@ -77,6 +79,26 @@ struct find_mul_add ...@@ -77,6 +79,26 @@ struct find_mul_add
}; };
struct find_add_lit_broadcast struct find_add_lit_broadcast
{
auto matcher() const
{
return match::name("add")(
match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, x_ins, sumab);
}
};
struct find_double_add_lit_broadcast
{ {
auto matcher() const auto matcher() const
{ {
...@@ -92,11 +114,9 @@ struct find_add_lit_broadcast ...@@ -92,11 +114,9 @@ struct find_add_lit_broadcast
auto a_ins = r.instructions["a"]; auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"]; auto b_ins = r.instructions["b"];
if(a_ins->name() != b_ins->name())
return;
instruction_ref sumab; instruction_ref sumab;
if(a_ins->name() == "broadcast") if(a_ins->name() == "broadcast" and b_ins->name() == "broadcast")
{ {
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape()) if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return; return;
...@@ -119,7 +139,7 @@ void simplify_algebra::apply(program& p) const ...@@ -119,7 +139,7 @@ void simplify_algebra::apply(program& p) const
{ {
// Run simplifications multiple times // Run simplifications multiple times
for(int i = 0; i < 4; i++) for(int i = 0; i < 4; i++)
match::find_matches(p, find_add_lit_broadcast{}, find_mul_conv{}, find_mul_add{}); match::find_matches(p, match::skip_matches(match::is_unused(), match::is_constant()), find_double_add_lit_broadcast{}, find_add_lit_broadcast{}, find_mul_conv{}, find_mul_add{});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment