Commit c46b5480 authored by Paul's avatar Paul
Browse files

Make sure recursions are use only once

parent efb704c5
......@@ -469,6 +469,12 @@ inline auto name(std::unordered_set<std::string> names)
});
}
template<class... Ts>
inline auto name(std::string s, Ts... xs)
{
return name(std::unordered_set<std::string>{s, xs...});
}
inline auto nargs(std::size_t n)
{
return make_basic_pred_matcher([=](instruction_ref ins) { return ins->inputs().size() == n; });
......
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
......@@ -19,7 +20,7 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y)
auto conv_const_weights()
{
return match::name("convolution")(match::used_once_recursive(4),
return match::name("convolution")(match::used_once(),
match::args(match::any(), match::is_constant().bind("w")));
}
......@@ -61,7 +62,7 @@ struct find_mul_add
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("y")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once_recursive(4)),
match::used_once()),
match::is_constant().bind("a")));
}
......@@ -135,16 +136,43 @@ struct find_double_add_lit_broadcast
}
};
struct find_inner_broadcast
{
auto matcher() const
{
return match::name("mul", "add")(match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto y_ins = r.instructions["y"];
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator());
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator());
if (xbroadcast.axis != ybroadcast.axis)
return;
auto op = p.insert_instruction(ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front());
p.replace_instruction(ins, xbroadcast, op);
}
};
void simplify_algebra::apply(program& p) const
{
// Run simplifications multiple times
for(int i = 0; i < 4; i++)
{
match::find_matches(p,
match::skip_matches(match::is_unused(), match::is_constant()),
find_inner_broadcast{},
find_double_add_lit_broadcast{},
find_add_lit_broadcast{},
find_mul_conv{},
find_mul_add{});
dead_code_elimination{}.apply(p);
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -265,6 +265,7 @@ struct find_add_relu
auto matcher() const
{
return match::name("gpu::relu")(match::arg(0)(
match::used_once(),
match::any_of(match::name("gpu::add"),
match::name("hip::triadd"),
match::any_of(match::name("@literal"),
......@@ -294,7 +295,7 @@ struct find_triadd
auto matcher() const
{
return match::name("gpu::add")(match::either_arg(0, 1)(
match::name("gpu::add").bind("add"),
match::name("gpu::add")(match::used_once()).bind("add"),
match::any(match::any_of(match::name("@literal"),
match::any_of[match::inputs()](match::standard_shape())))
.bind("input")));
......@@ -325,7 +326,7 @@ struct find_mul_add
auto matcher() const
{
return match::name("gpu::add")(
match::either_arg(0, 1)(match::name("gpu::mul").bind("mul"), match::any().bind("b")));
match::either_arg(0, 1)(match::name("gpu::mul")(match::used_once()).bind("mul"), match::any().bind("b")));
}
void apply(program& p, match::matcher_result r) const
......@@ -349,7 +350,7 @@ struct find_mul_add_relu
{
auto matcher() const
{
return match::name("gpu::relu")(match::arg(0)(match::name("hip::mul_add").bind("mul_add")));
return match::name("gpu::relu")(match::arg(0)(match::name("hip::mul_add")(match::used_once()).bind("mul_add")));
}
void apply(program& p, match::matcher_result r) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment