Unverified Commit cf0b6d6d authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

DepthToSpace and pointwise unary operations fusion (#986)

In migraphx, DepthToSpace (d2s) is implemented as reshape --> transpose --> contiguous --> reshape.

This PR adds matcher to find d2s + unary pointwise ops.

Application of the matcher moves the pointwise unary operation before the contiguous and reshape of the d2s.
So it becomes
reshape --> transpose --> unary --> contiguous --> reshape.

Motivation is that, later pointwise module would be created out of unary --> contiguous --> reshape. Codegen for this pointwise module can write out buffer such that explicit contiguous and reshape wouldn't be required.

This transformation is not always guaranteed to improve performance, since unary op will operate on non-standard shape. So, we would need some tuning mechanism to make decision.

#905 pending PR for binary operations.
parent 912c8d22
File mode changed from 100755 to 100644
......@@ -539,6 +539,44 @@ struct find_reshape_cont
}
};
// match sequence of transpose --> contiguous --> reshaper_op
auto match_transpose_contiguous_reshaper()
{
return match::name({"reshape", "squeeze", "unsqueeze"})(
match::used_once(),
match::args(
match::name("contiguous")(
match::used_once(), match::args(match::transpose_shape().bind("trans_ins")))
.bind("cont_ins")))
.bind("reshaper_ins");
};
// finds the pattern of transpose --> contiguous --> reshaper_op --> unary
// application of this matcher moves the unary operation before the contiguous so it becomes
// transpose --> unary --> contiguous --> reshaper_op. later pointwise sub-module can be created out
// of unary --> contiguous --> reshaper_op. Such pattern appears in depthToSpace or spaceToDepth
// operator.
struct find_transpose_contiguous_reshaper_unary
{
auto matcher() const
{
return pointwise(match::used_once(), match::args(match_transpose_contiguous_reshaper()));
}
void apply(module& p, match::matcher_result r) const
{
auto ins = r.result;
auto reshaper_ins = r.instructions["reshaper_ins"];
auto trans_ins = r.instructions["trans_ins"];
auto cont_ins = r.instructions["cont_ins"];
auto unary_op_name = ins->get_operator().name();
auto unary_ins = p.insert_instruction(cont_ins, make_op(unary_op_name), trans_ins);
auto new_cont_ins = p.insert_instruction(cont_ins, make_op("contiguous"), unary_ins);
// older cont and reshape are removed by deadcode elimination
p.replace_instruction(ins, reshaper_ins->get_operator(), new_cont_ins);
}
};
void simplify_reshapes::apply(module& p) const
{
for(int i = 0; i < 2; i++)
......@@ -553,7 +591,8 @@ void simplify_reshapes::apply(module& p) const
find_concat_transpose{},
find_nested_convert{},
find_nested_slice{},
find_nested_concat{});
find_nested_concat{},
find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(p);
}
}
......
......@@ -945,4 +945,90 @@ TEST_CASE(reshape_cont_nonpw)
EXPECT(m1 == create_module());
}
TEST_CASE(transpose_contiguous_reshape_unary)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto reshape_ins1 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto reshape_ins2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins);
auto relu = m1.add_instruction(migraphx::make_op("relu"), reshape_ins2);
m1.add_instruction(pass_op{}, relu);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto reshape_ins1 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), x);
auto transpose_ins = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), reshape_ins1);
auto relu = m2.add_instruction(migraphx::make_op("relu"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), relu);
auto reshape_ins2 =
m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), cont_ins);
m2.add_instruction(pass_op{}, reshape_ins2);
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_contiguous_squeeze_unary)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}});
auto transpose_ins =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto sq_ins = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins);
auto rsqrt = m1.add_instruction(migraphx::make_op("rsqrt"), sq_ins);
m1.add_instruction(pass_op{}, rsqrt);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 1, 5}});
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto rsqrt = m2.add_instruction(migraphx::make_op("rsqrt"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), rsqrt);
auto sq_ins = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {1}}}), cont_ins);
m2.add_instruction(pass_op{}, sq_ins);
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_contiguous_unsqueeze_unary)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto transpose_ins =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto cont_ins = m1.add_instruction(migraphx::make_op("contiguous"), transpose_ins);
auto unsq_ins =
m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins);
auto round = m1.add_instruction(migraphx::make_op("round"), unsq_ins);
m1.add_instruction(pass_op{}, round);
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 8, 5, 5}});
auto transpose_ins =
m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x);
auto round = m2.add_instruction(migraphx::make_op("round"), transpose_ins);
auto cont_ins = m2.add_instruction(migraphx::make_op("contiguous"), round);
auto unsq_ins =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), cont_ins);
m2.add_instruction(pass_op{}, unsq_ins);
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment