Commit 6ce3aefb authored by Paul's avatar Paul
Browse files

Rewrite conv add

parent 863bdfbf
......@@ -52,7 +52,7 @@ auto op_lit_broadcast(std::string op, std::string x, std::string y)
auto conv_const_weights()
{
return match::name("convolution")(match::used_once(),
match::args(match::any(), match::is_constant().bind("w")));
match::args(match::none_of(match::is_constant()), match::is_constant().bind("w")));
}
auto reduction() { return match::name_contains("reduce"); }
......@@ -267,6 +267,31 @@ struct find_dot_add
}
};
struct find_conv_add
{
auto matcher() const
{
auto add = match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("a")), match::used_once());
return match::name("convolution")(match::used_once(),
match::args(add, match::is_constant().bind("w")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto x_ins = r.instructions["x"];
auto w_ins = r.instructions["w"];
auto conv1 = m.insert_instruction(ins, ins->get_operator(), a_ins, w_ins);
auto conv2 = m.insert_instruction(ins, ins->get_operator(), x_ins, w_ins);
m.replace_instruction(ins, make_op("add"), conv1, conv2);
}
};
struct find_add_lit_broadcast
{
auto matcher() const
......@@ -1220,6 +1245,7 @@ void simplify_algebra::apply(module& m) const
find_neg_unit_ops{},
find_zero_ops{},
find_dot_add{},
find_conv_add{},
find_div_const{},
find_sub_const{},
find_rsqrt{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment