Commit 1e34e1ad authored by Paul's avatar Paul
Browse files

Rewrite reshapes and broadcast

parent 5722eb1b
......@@ -865,6 +865,41 @@ struct find_reshape_gemm
}
};
struct find_broadcast_reshaper
{
auto matcher() const
{
auto broadcast = match::broadcast_shape(match::skip(match::broadcast_shape())(match::any().bind("x"))).bind("broadcast");
return match::name(reshaper_names())(match::args(broadcast));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto broadcast_ins = r.instructions["broadcast"];
auto x_ins = r.instructions["x"];
auto broadcast_shape = broadcast_ins->get_shape();
auto result_shape = ins->get_shape();
if (std::accumulate(broadcast_shape.strides().begin(), broadcast_shape.strides().end(), 0) != 1)
return;
auto baxis = std::find(broadcast_shape.strides().begin(), broadcast_shape.strides().end(), 1) - broadcast_shape.strides().begin();
auto relements = result_shape.lens();
std::partial_sum(relements.begin(), relements.end(), relements.begin(), std::multiplies<>{});
auto prefix_elements = std::accumulate(broadcast_shape.lens().begin(), broadcast_shape.lens().begin()+baxis+1, 1, std::multiplies<>{});
auto axis = std::find(relements.begin(), relements.end(), prefix_elements) - relements.begin();
if (axis >= relements.size())
return;
if (x_ins->get_shape().lens().size() > 1)
x_ins = m.insert_instruction(ins, make_op("squeeze"), x_ins);
m.replace_instruction(ins, make_op("broadcast", {{"axis", axis}, {"out_lens", ins->get_shape().lens()}}), x_ins);
}
};
void simplify_reshapes::apply(module& m) const
{
for(int i = 0; i < 4; i++)
......@@ -874,6 +909,7 @@ void simplify_reshapes::apply(module& m) const
find_resize{},
find_nop_reshapes{},
find_reshaper{},
find_broadcast_reshaper{},
// find_reshape_cont{},
find_transpose{},
find_concat_transpose{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment