"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "aa7b76b50895fe52675877619fea508d07d06993"
Commit 0dc81c3b authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

formatting

parent 00dae07f
...@@ -524,7 +524,7 @@ struct find_inner_broadcast ...@@ -524,7 +524,7 @@ struct find_inner_broadcast
auto bcast_strides = broadcasts.front()->get_shape().strides().size(); auto bcast_strides = broadcasts.front()->get_shape().strides().size();
std::vector<size_t> common_axis(bcast_strides, 0); std::vector<size_t> common_axis(bcast_strides, 0);
// go through the strides of each broadcast, // go through the strides of each broadcast,
// keep track of values that are equal to 0 in a dimension // keep track of values that are equal to 0 in a dimension
for(auto i = 0; i < bcast_strides; i++) for(auto i = 0; i < bcast_strides; i++)
{ {
for(auto j = 0; j < broadcasts.size(); j++) for(auto j = 0; j < broadcasts.size(); j++)
...@@ -562,15 +562,19 @@ struct find_inner_broadcast ...@@ -562,15 +562,19 @@ struct find_inner_broadcast
})); }));
auto op = insert_common_op(m, ins, ins->get_operator(), inputs); auto op = insert_common_op(m, ins, ins->get_operator(), inputs);
std::vector<shape> broadcast_shapes; std::vector<shape> broadcast_shapes;
std::transform(broadcasts.begin(), broadcasts.end(), std::back_inserter(broadcast_shapes), [](auto broadcast){ std::transform(broadcasts.begin(),
return broadcast->get_shape(); broadcasts.end(),
}); std::back_inserter(broadcast_shapes),
[](auto broadcast) { return broadcast->get_shape(); });
std::vector<shape> common_shapes; std::vector<shape> common_shapes;
std::transform(op->inputs().begin(), op->inputs().end(), std::back_inserter(common_shapes), [](auto common){ std::transform(op->inputs().begin(),
return common->get_shape(); op->inputs().end(),
}); std::back_inserter(common_shapes),
if(broadcast_shapes == common_shapes and std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i){ [](auto common) { return common->get_shape(); });
return i->name() == "broadcast" or i->name() == "multibroadcast";})) if(broadcast_shapes == common_shapes and
std::all_of(op->inputs().begin(), op->inputs().end(), [](auto i) {
return i->name() == "broadcast" or i->name() == "multibroadcast";
}))
return; return;
m.replace_instruction(ins, broadcasts.front()->get_operator(), op); m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
} }
......
...@@ -645,13 +645,14 @@ struct find_broadcast_transpose ...@@ -645,13 +645,14 @@ struct find_broadcast_transpose
if(not input->get_shape().scalar()) if(not input->get_shape().scalar())
{ {
// find common shape // find common shape
auto in_lens = input->get_shape().lens(); auto in_lens = input->get_shape().lens();
int lens_diff = ins_lens.size() - in_lens.size(); int lens_diff = ins_lens.size() - in_lens.size();
if(lens_diff > 0) if(lens_diff > 0)
{ {
std::vector<size_t> unsqueeze_axes(lens_diff); std::vector<size_t> unsqueeze_axes(lens_diff);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 0); std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 0);
input = m.insert_instruction(bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input); input = m.insert_instruction(
bcast_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
} }
input = m.insert_instruction(bcast_ins, ins->get_operator(), input); input = m.insert_instruction(bcast_ins, ins->get_operator(), input);
} }
......
...@@ -82,15 +82,18 @@ TEST_CASE(broadcast_transpose_inner_broadcast_generic) ...@@ -82,15 +82,18 @@ TEST_CASE(broadcast_transpose_inner_broadcast_generic)
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto l1 = m2.add_parameter("x", {migraphx::shape::float_type, {5, 10}}); auto l1 = m2.add_parameter("x", {migraphx::shape::float_type, {5, 10}});
auto l2 = m2.add_parameter("y", {migraphx::shape::float_type, {5}}); auto l2 = m2.add_parameter("y", {migraphx::shape::float_type, {5}});
auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l2); auto unsqueeze = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l2);
auto transpose = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze); auto transpose = m2.add_instruction(
auto mb1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), l1); migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze);
auto mb2 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), transpose); auto mb1 =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), l1);
auto mb2 = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 5, 10}}}), transpose);
auto mul = m2.add_instruction(migraphx::make_op("mul"), mb1, mb2); auto mul = m2.add_instruction(migraphx::make_op("mul"), mb1, mb2);
auto mb3 = auto mb3 = m2.add_instruction(
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), mul); migraphx::make_op("multibroadcast", {{"out_lens", {3, 5, 10}}}), mul);
m2.add_return({mb3}); m2.add_return({mb3});
} }
EXPECT(m1 == m2); EXPECT(m1 == m2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment