Commit f385f43d authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

formatting

parent b4a9a9c5
...@@ -642,8 +642,8 @@ struct find_broadcast_transpose ...@@ -642,8 +642,8 @@ struct find_broadcast_transpose
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto transpose = r.result; auto transpose = r.result;
auto transpose_lens = transpose->get_shape().lens(); auto transpose_lens = transpose->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"]; auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front(); auto input = bcast_ins->inputs().front();
// scalar transformation does not need extra transpose // scalar transformation does not need extra transpose
......
...@@ -74,15 +74,17 @@ TEST_CASE(broadcast_transpose) ...@@ -74,15 +74,17 @@ TEST_CASE(broadcast_transpose)
auto l = m1.add_parameter("x", {migraphx::shape::float_type, {5}}); auto l = m1.add_parameter("x", {migraphx::shape::float_type, {5}});
auto mb = auto mb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5}}}), l); m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5}}}), l);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), mb); auto t1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), mb);
m1.add_return({t1}); m1.add_return({t1});
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}}); auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}});
auto u1 = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l); auto u1 = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l);
auto t1 = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), u1); auto t1 =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 1}}}), u1);
auto mb = auto mb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3}}}), t1); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5, 2, 3}}}), t1);
m2.add_return({mb}); m2.add_return({mb});
...@@ -99,13 +101,14 @@ TEST_CASE(broadcast_transpose_opt) ...@@ -99,13 +101,14 @@ TEST_CASE(broadcast_transpose_opt)
auto l = m1.add_parameter("x", {migraphx::shape::float_type, {5}}); auto l = m1.add_parameter("x", {migraphx::shape::float_type, {5}});
auto mb = auto mb =
m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5}}}), l); m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 5}}}), l);
auto t1 = m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), mb); auto t1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0, 2}}}), mb);
m1.add_return({t1}); m1.add_return({t1});
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}}); auto l = m2.add_parameter("x", {migraphx::shape::float_type, {5}});
auto u1 = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l); auto u1 = m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0, 1}}}), l);
auto mb = auto mb =
m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 5}}}), u1); m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 5}}}), u1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment