"library/vscode:/vscode.git/clone" did not exist on "627054b941166aedd08724575b036ffe0ad74ca8"
Commit 0dc81c3b authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

formatting

parent 00dae07f
......@@ -524,7 +524,7 @@ struct find_inner_broadcast
auto bcast_strides = broadcasts.front()->get_shape().strides().size();
std::vector<size_t> common_axis(bcast_strides, 0);
// go through the strides of each broadcast,
// keep track of values that are equal to 0 in a dimension
// keep track of values that are equal to 0 in a dimension
for(auto i = 0; i < bcast_strides; i++)
{
for(auto j = 0; j < broadcasts.size(); j++)
......@@ -562,15 +562,19 @@ struct find_inner_broadcast
}));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
std::vector<shape> broadcast_shapes;
std::transform(broadcasts.begin(), broadcasts.end(), std::back_inserter(broadcast_shapes), [](auto broadcast){
return broadcast->get_shape();
});
std::transform(broadcasts.begin(),
broadcasts.end(),
std::back_inserter(broadcast_shapes),
[](auto broadcast) { return broadcast->get_shape(); });
std::vector<shape> common_shapes;
std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(common_shapes), [](auto common){
return common->get_shape();
});
if(broadcast_shapes == common_shapes and std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i){
return i->name() == "broadcast" or i->name() == "multibroadcast";}))
std::transform(op->inputs().begin(),
op->inputs().end(),
std::back_inserter(common_shapes),
[](auto common) { return common->get_shape(); });
if(broadcast_shapes == common_shapes and
std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i) {
return i->name() == "broadcast" or i->name() == "multibroadcast";
}))
return;
m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
}
......
......@@ -645,13 +645,14 @@ struct find_broadcast_transpose
if(not input->get_shape().scalar())
{
// find common shape
auto in_lens = input->get_shape().lens();
auto in_lens = input->get_shape().lens();
int lens_diff = ins_lens.size() - in_lens.size();
if(lens_diff > 0)
{
std::vector<size_t> unsqueeze_axes(lens_diff);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 0);
input = m.insert_instruction(bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
input = m.insert_instruction(
bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
}
input = m.insert_instruction(bcast_ins, ins->get_operator(), input);
}
......
......@@ -82,15 +82,18 @@ TEST_CASE(broadcast_transpose_inner_broadcast_generic)
run_pass(m1);
migraphx::module m2;
{
auto l1 = m2.add_parameter("x", {migraphx::shape::float_type, {5, 10}});
auto l2 = m2.add_parameter("y", {migraphx::shape::float_type, {5}});
auto l1 = m2.add_parameter("x", {migraphx::shape::float_type, {5, 10}});
auto l2 = m2.add_parameter("y", {migraphx::shape::float_type, {5}});
auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l2);
auto transpose = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze);
auto mb1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), l1);
auto mb2 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), transpose);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze);
auto mb1 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), l1);
auto mb2 = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), transpose);
auto mul = m2.add_instruction(migraphx::make_op("mul"), mb1, mb2);
auto mb3 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), mul);
auto mb3 = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), mul);
m2.add_return({mb3});
}
EXPECT(m1 == m2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment