Commit a7c83234 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

formatting

parent 59f6009b
......@@ -631,7 +631,8 @@ struct find_broadcast_transpose
{
auto matcher() const
{
return match::name("multibroadcast")(match::all_of[match::outputs()](match::name("transpose").bind("trans_ins")));
return match::name("multibroadcast")(
match::all_of[match::outputs()](match::name("transpose").bind("trans_ins")));
}
void apply(module& m, const match::matcher_result& r) const
......@@ -669,8 +670,12 @@ struct find_broadcast_transpose
}
}
auto unsqueeze_ins = m.insert_instruction(trans_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
auto mbcast_ins = m.insert_instruction(trans_ins, make_op("multibroadcast", {{"out_lens", trans_shape.lens()}}), unsqueeze_ins);
auto unsqueeze_ins = m.insert_instruction(
trans_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
auto mbcast_ins =
m.insert_instruction(trans_ins,
make_op("multibroadcast", {{"out_lens", trans_shape.lens()}}),
unsqueeze_ins);
m.replace_instruction(trans_ins, mbcast_ins);
}
};
......
......@@ -207,7 +207,8 @@ TEST_CASE(broadcast_transpose)
auto* mm = p.get_main_module();
auto l = mm->add_parameter("x", {migraphx::shape::float_type, {1024}});
auto mb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 3072, 1024}}}), l);
auto mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 3072, 1024}}}), l);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), mb);
mm->add_return({t1});
......@@ -216,7 +217,8 @@ TEST_CASE(broadcast_transpose)
mm = p2.get_main_module();
l = mm->add_parameter("x", {migraphx::shape::float_type, {1024}});
auto unsqueeze = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 2}}}), l);
mb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1024, 3072}}}), unsqueeze);
mb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1024, 3072}}}),
unsqueeze);
mm->add_return({mb});
EXPECT(p == p2);
// EXPECT(not mm->get_output_shapes().back().standard());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment