Commit c10e3bb9 authored by Paul's avatar Paul
Browse files

Reorder two adds as well

parent eeb92b57
......@@ -240,6 +240,25 @@ void find_matches(program& p, Ms&&... ms)
}
}
template<class M>
struct find_skip
{
M m;
M matcher() const
{
return m;
}
void apply(program&, matcher_result) const
{}
};
template<class M>
find_skip<M> make_find_skip(M m)
{
return {m};
}
struct lazy_and
{
template <class F, class G>
......@@ -311,6 +330,12 @@ const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
const constexpr auto any_of = match_fold_f<lazy_or, false, true>{};
const constexpr auto none_of = match_fold_f<lazy_or, false, false>{};
template<class... Ms>
auto skip_matches(Ms... ms)
{
return make_find_skip(any_of(ms...));
}
inline auto inputs()
{
return [](auto ins, auto f) {
......@@ -371,6 +396,13 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
MIGRAPHX_PRED_MATCHER(is_constant, instruction_ref ins) { return ins->can_eval(); }
MIGRAPHX_BASIC_MATCHER(is_unused, matcher_context& ctx, instruction_ref ins)
{
if (ins->outputs().empty() and ins != std::prev(ctx.not_found()))
return ins;
return ctx.not_found();
}
template <class... Ms>
auto skip_output(Ms... ms)
{
......
......@@ -58,7 +58,9 @@ struct find_mul_add
return match::name("mul")(match::either_arg(0, 1)(
match::name("add")(match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("y"))),
match::any_of(conv_const_weights(), match::is_constant()).bind("y")),
match::none_of(match::args(match::is_constant(), match::is_constant()))
),
match::is_constant().bind("a")));
}
......@@ -77,6 +79,26 @@ struct find_mul_add
};
struct find_add_lit_broadcast
{
auto matcher() const
{
return match::name("add")(
match::either_arg(0, 1)(op_lit_broadcast("add", "a", "x"), lit_broadcast().bind("b")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto sumab = p.insert_instruction(ins, op::add{}, a_ins, b_ins);
p.replace_instruction(ins, op::add{}, x_ins, sumab);
}
};
struct find_double_add_lit_broadcast
{
auto matcher() const
{
......@@ -92,11 +114,9 @@ struct find_add_lit_broadcast
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
if(a_ins->name() != b_ins->name())
return;
instruction_ref sumab;
if(a_ins->name() == "broadcast")
if(a_ins->name() == "broadcast" and b_ins->name() == "broadcast")
{
if(a_ins->inputs().at(0)->get_shape() != b_ins->inputs().at(0)->get_shape())
return;
......@@ -119,7 +139,7 @@ void simplify_algebra::apply(program& p) const
{
// Run simplifications multiple times
for(int i = 0; i < 4; i++)
match::find_matches(p, find_add_lit_broadcast{}, find_mul_conv{}, find_mul_add{});
match::find_matches(p, match::skip_matches(match::is_unused(), match::is_constant()), find_double_add_lit_broadcast{}, find_add_lit_broadcast{}, find_mul_conv{}, find_mul_add{});
}
} // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment