Unverified Commit eb6abd27 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Add tests for the DepthToSpace+Binary pointwise operations fusion (#987)

In migraphx, DepthToSpace (d2s) is implemented as reshape --> transpose --> contiguous --> reshape.

If there is trailing binary pointwise operator after depthToSpace then, migraphx can move binary operator before contiguous and reshape of the depthtospce.

So, it becomes reshape-->transpose-->binary_op-->contiguous-->reshape.

Explicit contiguous wouldn't be required since binary_op outputs standard shape. So, it becomes reshape-->transpose-->binary-->reshape.

simplify_reshapes already has matcher that can do this transformation. This PR adds test for cases like depthtospace +binary op.

solves #905
parent c98b22d8
...@@ -560,7 +560,9 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -560,7 +560,9 @@ struct find_transpose_contiguous_reshaper_unary
{ {
auto matcher() const auto matcher() const
{ {
return pointwise(match::used_once(), match::args(match_transpose_contiguous_reshaper())); return pointwise(match::used_once(),
match::nargs(1),
match::args(match_transpose_contiguous_reshaper()));
} }
void apply(module& p, match::matcher_result r) const void apply(module& p, match::matcher_result r) const
......
...@@ -1031,4 +1031,91 @@ TEST_CASE(transpose_contiguous_unsqueeze_unary) ...@@ -1031,4 +1031,91 @@ TEST_CASE(transpose_contiguous_unsqueeze_unary)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(transpose_contiguous_reshape_binary_packed)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 128, 28, 28}});
auto w1 = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = m1.add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
x,
w1); // (2, 256, 28, 28)
auto w2 = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {512, 256, 1, 1}}));
auto conv2 = m1.add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
conv1,
w2); // (2, 512, 14, 14)
auto conv2_rsp1 = m1.add_instruction(
migraphx::make_op("reshape", {{"dims", {2, 2, 2, 128, 14, 14}}}), conv2);
auto conv2_trans = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), conv2_rsp1);
auto conv2_cont = m1.add_instruction(migraphx::make_op("contiguous"), conv2_trans);
auto conv2_rsp2 = m1.add_instruction(
migraphx::make_op("reshape", {{"dims", {2, 128, 28, 28}}}), conv2_cont);
auto add_ins = m1.add_instruction(migraphx::make_op("add"), conv2_rsp2, x);
m1.add_instruction(pass_op{}, add_ins);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 128, 28, 28}});
auto w1 = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = m2.add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
x,
w1); // (2, 256, 28, 28)
auto w2 = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {512, 256, 1, 1}}));
auto conv2 = m2.add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
conv1,
w2); // (2, 512, 14, 14)
auto conv2_rsp = m2.add_instruction(
migraphx::make_op("reshape", {{"dims", {2, 2, 2, 128, 14, 14}}}), conv2);
auto conv2_trans = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), conv2_rsp);
auto x_rsp =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 128, 14, 2, 14, 2}}}), x);
auto add_ins = m2.add_instruction(migraphx::make_op("add"), conv2_trans, x_rsp);
auto add_rsp =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 128, 28, 28}}}), add_ins);
m2.add_instruction(pass_op{}, add_rsp);
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
{
migraphx::module m1;
{
migraphx::shape sx{migraphx::shape::float_type, {4}};
migraphx::shape sy{migraphx::shape::float_type, {2, 6, 2, 2}};
auto x = m1.add_parameter("x", sx);
auto y = m1.add_parameter("y", sy);
auto x_brcst = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 4, 6}}}), x);
auto y_trans =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), y);
auto y_cont = m1.add_instruction(migraphx::make_op("contiguous"), y_trans);
auto y_rsp =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 4, 6}}}), y_cont);
auto r = m1.add_instruction(migraphx::make_op("add"), y_rsp, x_brcst);
m1.add_return({r});
}
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment