Commit 40e06e7c authored by Paul's avatar Paul
Browse files

Format

parent 1e34e1ad
...@@ -869,34 +869,47 @@ struct find_broadcast_reshaper ...@@ -869,34 +869,47 @@ struct find_broadcast_reshaper
{ {
auto matcher() const auto matcher() const
{ {
auto broadcast = match::broadcast_shape(match::skip(match::broadcast_shape())(match::any().bind("x"))).bind("broadcast"); auto broadcast =
match::broadcast_shape(match::skip(match::broadcast_shape())(match::any().bind("x")))
.bind("broadcast");
return match::name(reshaper_names())(match::args(broadcast)); return match::name(reshaper_names())(match::args(broadcast));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto broadcast_ins = r.instructions["broadcast"]; auto broadcast_ins = r.instructions["broadcast"];
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
auto broadcast_shape = broadcast_ins->get_shape(); auto broadcast_shape = broadcast_ins->get_shape();
auto result_shape = ins->get_shape(); auto result_shape = ins->get_shape();
if (std::accumulate(broadcast_shape.strides().begin(), broadcast_shape.strides().end(), 0) != 1) if(std::accumulate(broadcast_shape.strides().begin(), broadcast_shape.strides().end(), 0) !=
1)
return; return;
auto baxis = std::find(broadcast_shape.strides().begin(), broadcast_shape.strides().end(), 1) - broadcast_shape.strides().begin(); auto baxis =
std::find(broadcast_shape.strides().begin(), broadcast_shape.strides().end(), 1) -
broadcast_shape.strides().begin();
auto relements = result_shape.lens(); auto relements = result_shape.lens();
std::partial_sum(relements.begin(), relements.end(), relements.begin(), std::multiplies<>{}); std::partial_sum(
auto prefix_elements = std::accumulate(broadcast_shape.lens().begin(), broadcast_shape.lens().begin()+baxis+1, 1, std::multiplies<>{}); relements.begin(), relements.end(), relements.begin(), std::multiplies<>{});
auto axis = std::find(relements.begin(), relements.end(), prefix_elements) - relements.begin(); auto prefix_elements = std::accumulate(broadcast_shape.lens().begin(),
if (axis >= relements.size()) broadcast_shape.lens().begin() + baxis + 1,
1,
std::multiplies<>{});
auto axis =
std::find(relements.begin(), relements.end(), prefix_elements) - relements.begin();
if(axis >= relements.size())
return; return;
if (x_ins->get_shape().lens().size() > 1) if(x_ins->get_shape().lens().size() > 1)
x_ins = m.insert_instruction(ins, make_op("squeeze"), x_ins); x_ins = m.insert_instruction(ins, make_op("squeeze"), x_ins);
m.replace_instruction(ins, make_op("broadcast", {{"axis", axis}, {"out_lens", ins->get_shape().lens()}}), x_ins); m.replace_instruction(
ins,
make_op("broadcast", {{"axis", axis}, {"out_lens", ins->get_shape().lens()}}),
x_ins);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment