Commit 40e06e7c authored by Paul's avatar Paul
Browse files

Format

parent 1e34e1ad
......@@ -869,34 +869,47 @@ struct find_broadcast_reshaper
{
auto matcher() const
{
auto broadcast = match::broadcast_shape(match::skip(match::broadcast_shape())(match::any().bind("x"))).bind("broadcast");
auto broadcast =
match::broadcast_shape(match::skip(match::broadcast_shape())(match::any().bind("x")))
.bind("broadcast");
return match::name(reshaper_names())(match::args(broadcast));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto broadcast_ins = r.instructions["broadcast"];
auto x_ins = r.instructions["x"];
auto broadcast_ins = r.instructions["broadcast"];
auto x_ins = r.instructions["x"];
auto broadcast_shape = broadcast_ins->get_shape();
auto result_shape = ins->get_shape();
auto result_shape = ins->get_shape();
if (std::accumulate(broadcast_shape.strides().begin(), broadcast_shape.strides().end(), 0) != 1)
if(std::accumulate(broadcast_shape.strides().begin(), broadcast_shape.strides().end(), 0) !=
1)
return;
auto baxis = std::find(broadcast_shape.strides().begin(), broadcast_shape.strides().end(), 1) - broadcast_shape.strides().begin();
auto baxis =
std::find(broadcast_shape.strides().begin(), broadcast_shape.strides().end(), 1) -
broadcast_shape.strides().begin();
auto relements = result_shape.lens();
std::partial_sum(relements.begin(), relements.end(), relements.begin(), std::multiplies<>{});
auto prefix_elements = std::accumulate(broadcast_shape.lens().begin(), broadcast_shape.lens().begin()+baxis+1, 1, std::multiplies<>{});
auto axis = std::find(relements.begin(), relements.end(), prefix_elements) - relements.begin();
if (axis >= relements.size())
std::partial_sum(
relements.begin(), relements.end(), relements.begin(), std::multiplies<>{});
auto prefix_elements = std::accumulate(broadcast_shape.lens().begin(),
broadcast_shape.lens().begin() + baxis + 1,
1,
std::multiplies<>{});
auto axis =
std::find(relements.begin(), relements.end(), prefix_elements) - relements.begin();
if(axis >= relements.size())
return;
if (x_ins->get_shape().lens().size() > 1)
if(x_ins->get_shape().lens().size() > 1)
x_ins = m.insert_instruction(ins, make_op("squeeze"), x_ins);
m.replace_instruction(ins, make_op("broadcast", {{"axis", axis}, {"out_lens", ins->get_shape().lens()}}), x_ins);
m.replace_instruction(
ins,
make_op("broadcast", {{"axis", axis}, {"out_lens", ins->get_shape().lens()}}),
x_ins);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment