Commit 00df057a authored by Paul's avatar Paul
Browse files

Rewrite dot add for better effeciency

parent 7271ddbc
......@@ -185,6 +185,43 @@ struct find_mul_add
}
};
struct find_dot_add
{
auto matcher() const
{
return match::name("dot")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
const bool flipped = a_ins == ins->inputs().back();
auto insert_dot = [&](auto x, auto y) {
if (flipped)
return m.insert_instruction(ins, make_op("dot"), y, x);
else
return m.insert_instruction(ins, make_op("dot"), x, y);
};
auto ax_ins = insert_dot(a_ins, x_ins);
auto ab_ins = insert_dot(a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
struct find_add_lit_broadcast
{
auto matcher() const
......@@ -247,25 +284,34 @@ struct find_inner_broadcast
auto matcher() const
{
return pointwise(
match::nargs(2),
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
match::all_of[match::inputs()](match::broadcast_shape(), match::name("broadcast", "multibroadcast")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto y_ins = r.instructions["y"];
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator());
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator());
auto inputs = ins->inputs();
if (inputs.empty())
return;
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
if (contains({"broadcast", "multibroadcast"}, i->name()))
return i->inputs().front();
else
return i;
});
if(xbroadcast.axis != ybroadcast.axis)
if (not std::all_of(inputs.begin(), inputs.end(), [&](auto& x) {
return x->get_shape() == inputs.front()->get_shape();
}))
return;
auto op = m.insert_instruction(
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front());
m.replace_instruction(ins, xbroadcast, op);
auto op = m.insert_instruction(ins, ins->get_operator(), inputs);
auto bop = std::find_if(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return contains({"broadcast", "multibroadcast"}, i->name());
});
assert(bop != ins->inputs().end());
m.replace_instruction(ins, (*bop)->get_operator(), op);
}
};
......@@ -1025,6 +1071,7 @@ void simplify_algebra::apply(module& m) const
find_mul_conv{},
find_mul_slice_conv{},
find_mul_add{},
find_dot_add{},
find_div_const{},
find_sub_const{},
find_rsqrt{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment