Commit 5967d68d authored by Paul's avatar Paul
Browse files

Format

parent da5c6162
...@@ -871,7 +871,8 @@ struct find_broadcast_reshaper ...@@ -871,7 +871,8 @@ struct find_broadcast_reshaper
{ {
auto broadcast = auto broadcast =
match::broadcast_shape(match::skip(match::broadcast_shape())(match::any().bind("x"))); match::broadcast_shape(match::skip(match::broadcast_shape())(match::any().bind("x")));
return match::name(reshaper_names())(match::args(match::skip(match::name("contiguous"))(broadcast.bind("broadcast")))); return match::name(reshaper_names())(
match::args(match::skip(match::name("contiguous"))(broadcast.bind("broadcast"))));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
......
...@@ -1467,20 +1467,23 @@ TEST_CASE(broadcast_transpose_reshape) ...@@ -1467,20 +1467,23 @@ TEST_CASE(broadcast_transpose_reshape)
{ {
migraphx::module m1; migraphx::module m1;
{ {
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {320, 1, 1}}); auto x = m1.add_parameter("x", {migraphx::shape::float_type, {320, 1, 1}});
auto broadcast = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 320, 64, 64}}}), x); auto broadcast = m1.add_instruction(
auto transpose = migraphx::make_op("multibroadcast", {{"out_lens", {2, 320, 64, 64}}}), x);
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), broadcast); auto transpose = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), broadcast);
auto contiguous = m1.add_instruction(migraphx::make_op("contiguous"), transpose); auto contiguous = m1.add_instruction(migraphx::make_op("contiguous"), transpose);
auto reshape = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4096, 320}}}), contiguous); auto reshape = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4096, 320}}}),
contiguous);
m1.add_return({reshape}); m1.add_return({reshape});
} }
run_pass(m1); run_pass(m1);
migraphx::module m2; migraphx::module m2;
{ {
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {320, 1, 1}}); auto x = m2.add_parameter("x", {migraphx::shape::float_type, {320, 1, 1}});
auto squeeze = m2.add_instruction(migraphx::make_op("squeeze"), x); auto squeeze = m2.add_instruction(migraphx::make_op("squeeze"), x);
auto broadcast = m2.add_instruction(migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", {2, 4096, 320}}}), squeeze); auto broadcast = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", {2, 4096, 320}}}), squeeze);
m2.add_return({broadcast}); m2.add_return({broadcast});
} }
EXPECT(m1.sort() == m2.sort()); EXPECT(m1.sort() == m2.sort());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment