Commit a7c83234 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

formatting

parent 59f6009b
...@@ -631,16 +631,17 @@ struct find_broadcast_transpose ...@@ -631,16 +631,17 @@ struct find_broadcast_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("multibroadcast")(match::all_of[match::outputs()](match::name("transpose").bind("trans_ins"))); return match::name("multibroadcast")(
match::all_of[match::outputs()](match::name("transpose").bind("trans_ins")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto ins_lens = ins->get_shape().lens(); auto ins_lens = ins->get_shape().lens();
auto input = ins->inputs().front(); auto input = ins->inputs().front();
auto input_lens = input->get_shape().lens(); auto input_lens = input->get_shape().lens();
auto trans_ins = r.instructions["trans_ins"]; auto trans_ins = r.instructions["trans_ins"];
auto trans_shape = trans_ins->get_shape(); auto trans_shape = trans_ins->get_shape();
auto permutation = trans_ins->get_operator().to_value()["permutation"].to_vector<int64_t>(); auto permutation = trans_ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
...@@ -669,8 +670,12 @@ struct find_broadcast_transpose ...@@ -669,8 +670,12 @@ struct find_broadcast_transpose
} }
} }
auto unsqueeze_ins = m.insert_instruction(trans_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input); auto unsqueeze_ins = m.insert_instruction(
auto mbcast_ins = m.insert_instruction(trans_ins, make_op("multibroadcast", {{"out_lens", trans_shape.lens()}}), unsqueeze_ins); trans_ins, make_op("unsqueeze", {{"axes", unsqueeze_axes}}), input);
auto mbcast_ins =
m.insert_instruction(trans_ins,
make_op("multibroadcast", {{"out_lens", trans_shape.lens()}}),
unsqueeze_ins);
m.replace_instruction(trans_ins, mbcast_ins); m.replace_instruction(trans_ins, mbcast_ins);
} }
}; };
......
...@@ -207,16 +207,18 @@ TEST_CASE(broadcast_transpose) ...@@ -207,16 +207,18 @@ TEST_CASE(broadcast_transpose)
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto l = mm->add_parameter("x", {migraphx::shape::float_type, {1024}}); auto l = mm->add_parameter("x", {migraphx::shape::float_type, {1024}});
auto mb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 3072, 1024}}}), l); auto mb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {1, 3072, 1024}}}), l);
auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), mb); auto t1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), mb);
mm->add_return({t1}); mm->add_return({t1});
run_pass(*mm); run_pass(*mm);
migraphx::program p2; migraphx::program p2;
mm = p2.get_main_module(); mm = p2.get_main_module();
l = mm->add_parameter("x", {migraphx::shape::float_type, {1024}}); l = mm->add_parameter("x", {migraphx::shape::float_type, {1024}});
auto unsqueeze = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 2}}}), l); auto unsqueeze = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 2}}}), l);
mb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1024, 3072}}}), unsqueeze); mb = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1024, 3072}}}),
unsqueeze);
mm->add_return({mb}); mm->add_return({mb});
EXPECT(p == p2); EXPECT(p == p2);
// EXPECT(not mm->get_output_shapes().back().standard()); // EXPECT(not mm->get_output_shapes().back().standard());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment