Commit a7c83234 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

formatting

parent 59f6009b
...@@ -631,7 +631,8 @@ struct find_broadcast_transpose ...@@ -631,7 +631,8 @@ struct find_broadcast_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("multibroadcast")(match::all_of[match::outputs()](match::name("transpose").bind("trans_ins"))); return match::name("multibroadcast")(
match::all_of[match::outputs()](match::name("transpose").bind("trans_ins")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -669,8 +670,12 @@ struct find_broadcast_transpose ...@@ -669,8 +670,12 @@ struct find_broadcast_transpose
} }
} }
auto unsqueeze_ins = m.insert_instruction(trans_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input); auto unsqueeze_ins = m.insert_instruction(
auto mbcast_ins = m.insert_instruction(trans_ins, make_op("multibroadcast", {{"out_lens", trans_shape.lens()}}), unsqueeze_ins); trans_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
auto mbcast_ins =
m.insert_instruction(trans_ins,
make_op("multibroadcast", {{"out_lens", trans_shape.lens()}}),
unsqueeze_ins);
m.replace_instruction(trans_ins, mbcast_ins); m.replace_instruction(trans_ins, mbcast_ins);
} }
}; };
......
...@@ -207,7 +207,8 @@ TEST_CASE(broadcast_transpose) ...@@ -207,7 +207,8 @@ TEST_CASE(broadcast_transpose)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_parameter("x", {migraphx::shape::float_type, {1024}}); auto l = mm->add_parameter("x", {migraphx::shape::float_type, {1024}});
auto mb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 3072, 1024}}}), l); auto mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 3072, 1024}}}), l);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), mb); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), mb);
mm->add_return({t1}); mm->add_return({t1});
...@@ -216,7 +217,8 @@ TEST_CASE(broadcast_transpose) ...@@ -216,7 +217,8 @@ TEST_CASE(broadcast_transpose)
mm = p2.get_main_module(); mm = p2.get_main_module();
l = mm->add_parameter("x", {migraphx::shape::float_type, {1024}}); l = mm->add_parameter("x", {migraphx::shape::float_type, {1024}});
auto unsqueeze = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 2}}}), l); auto unsqueeze = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 2}}}), l);
mb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1024, 3072}}}), unsqueeze); mb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1024, 3072}}}),
unsqueeze);
mm->add_return({mb}); mm->add_return({mb});
EXPECT(p == p2); EXPECT(p == p2);
// EXPECT(not mm->get_output_shapes().back().standard()); // EXPECT(not mm->get_output_shapes().back().standard());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment