"vscode:/vscode.git/clone" did not exist on "09c946b24d23954f50807fbe8d7777cbcce33550"
Commit 0dc81c3b authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

formatting

parent 00dae07f
......@@ -562,15 +562,19 @@ struct find_inner_broadcast
}));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
std::vector<shape> broadcast_shapes;
std::transform(broadcasts.begin(), broadcasts.end(), std::back_inserter(broadcast_shapes), [](auto broadcast){
return broadcast->get_shape();
});
std::transform(broadcasts.begin(),
broadcasts.end(),
std::back_inserter(broadcast_shapes),
[](auto broadcast) { return broadcast->get_shape(); });
std::vector<shape> common_shapes;
std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(common_shapes), [](auto common){
return common->get_shape();
});
if(broadcast_shapes == common_shapes and std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i){
return i->name() == "broadcast" or i->name() == "multibroadcast";}))
std::transform(op->inputs().begin(),
op->inputs().end(),
std::back_inserter(common_shapes),
[](auto common) { return common->get_shape(); });
if(broadcast_shapes == common_shapes and
std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i) {
return i->name() == "broadcast" or i->name() == "multibroadcast";
}))
return;
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
......
......@@ -651,7 +651,8 @@ struct find_broadcast_transpose
{
std::vector<size_t> unsqueeze_axes(lens_diff);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 0);
input = m.insert_instruction(bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
input = m.insert_instruction(
bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
}
input = m.insert_instruction(bcast_ins, ins->get_operator(), input);
}
......
......@@ -85,12 +85,15 @@ TEST_CASE(broadcast_transpose_inner_broadcast_generic)
auto l1 = m2.add_parameter("x", {migraphx::shape::float_type, {5, 10}});
auto l2 = m2.add_parameter("y", {migraphx::shape::float_type, {5}});
auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l2);
auto transpose = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze);
auto mb1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), l1);
auto mb2 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), transpose);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze);
auto mb1 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), l1);
auto mb2 = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), transpose);
auto mul = m2.add_instruction(migraphx::make_op("mul"), mb1, mb2);
auto mb3 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), mul);
auto mb3 = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), mul);
m2.add_return({mb3});
}
EXPECT(m1 == m2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment