Commit 5967d68d authored by Paul's avatar Paul
Browse files

Format

parent da5c6162
......@@ -871,7 +871,8 @@ struct find_broadcast_reshaper
{
auto broadcast =
match::broadcast_shape(match::skip(match::broadcast_shape())(match::any().bind("x")));
return match::name(reshaper_names())(match::args(match::skip(match::name("contiguous"))(broadcast.bind("broadcast"))));
return match::name(reshaper_names())(
match::args(match::skip(match::name("contiguous"))(broadcast.bind("broadcast"))));
}
void apply(module& m, const match::matcher_result& r) const
......
......@@ -1467,20 +1467,23 @@ TEST_CASE(broadcast_transpose_reshape)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {320, 1, 1}});
auto broadcast = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 320, 64, 64}}}), x);
auto transpose =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), broadcast);
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {320, 1, 1}});
auto broadcast = m1.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 320, 64, 64}}}), x);
auto transpose = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), broadcast);
auto contiguous = m1.add_instruction(migraphx::make_op("contiguous"), transpose);
auto reshape = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4096, 320}}}), contiguous);
auto reshape = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4096, 320}}}),
contiguous);
m1.add_return({reshape});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {320, 1, 1}});
auto squeeze = m2.add_instruction(migraphx::make_op("squeeze"), x);
auto broadcast = m2.add_instruction(migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", {2, 4096, 320}}}), squeeze);
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {320, 1, 1}});
auto squeeze = m2.add_instruction(migraphx::make_op("squeeze"), x);
auto broadcast = m2.add_instruction(
migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", {2, 4096, 320}}}), squeeze);
m2.add_return({broadcast});
}
EXPECT(m1.sort() == m2.sort());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment