Unverified Commit abb3efc1 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Simplify nested reshapes (#1932)

The find_reshaper is supposed to do this, but it doesnt work and there were no tests. So I updated for it to work and I added unit tests for it.
parent 75e6618c
......@@ -89,38 +89,23 @@ struct find_reshaper
{
auto matcher() const
{
return match::name(reshaper_names())(
match::any_of[match::outputs()](match::name(reshaper_names())));
auto reshaper = match::name(reshaper_names());
auto contiguous = match::name("contiguous");
auto no_output_reshape = match::none_of[match::outputs()](reshaper);
auto input_reshape = match::arg(0)(match::skip(contiguous)(reshaper));
auto input = match::skip(reshaper, contiguous)(match::any().bind("x"));
return reshaper(no_output_reshape, input_reshape, input);
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{
assert(not reshapes.back()->inputs().empty());
assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
auto ins = mr.result;
auto input = mr.instructions["x"];
auto dims = ins->get_shape().lens();
std::pair<instruction_ref, instruction_ref> r{m.end(), m.end()};
for(auto start : iterator_for(reshapes))
{
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
{
m.replace_instruction(r.first, r.second);
}
if(not input->get_shape().standard())
input = m.insert_instruction(ins, make_op("contiguous"), input);
m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input);
}
};
......@@ -804,9 +789,9 @@ void simplify_reshapes::apply(module& m) const
match::find_matches(m,
find_where_op{},
find_resize{},
find_reshape_cont{},
find_nop_reshapes{},
find_reshaper{},
find_reshape_cont{},
find_transpose{},
find_concat_transpose{},
find_concat_multibroadcasts{},
......
......@@ -357,6 +357,106 @@ TEST_CASE(nop_convert)
EXPECT(std::distance(m.begin(), m.end()) == n - 1);
}
TEST_CASE(nested_reshape)
{
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4, 5, 6, 7}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto rshp1 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 3, 4, 5, 42}}}), x);
auto rshp2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 12, 5, 42}}}), rshp1);
auto rshp3 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 12, 5, 42}}}), rshp2);
auto rshp4 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 60, 42}}}), rshp3);
auto rshp5 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {120, 42}}}), rshp4);
auto rshp6 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {5040}}}), rshp5);
m1.add_return({rshp6});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto rshp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {5040}}}), x);
m2.add_return({rshp});
}
EXPECT(m1 == m2);
}
TEST_CASE(nested_reshape_contiguous)
{
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4, 5, 6, 7}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto rshp1 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 3, 4, 5, 42}}}), x);
auto c1 = m1.add_instruction(migraphx::make_op("contiguous"), rshp1);
auto rshp2 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 12, 5, 42}}}), c1);
auto c2 = m1.add_instruction(migraphx::make_op("contiguous"), rshp2);
auto rshp3 =
m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 12, 5, 42}}}), c2);
auto c3 = m1.add_instruction(migraphx::make_op("contiguous"), rshp3);
auto rshp4 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 60, 42}}}), c3);
auto c4 = m1.add_instruction(migraphx::make_op("contiguous"), rshp4);
auto rshp5 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {120, 42}}}), c4);
auto c5 = m1.add_instruction(migraphx::make_op("contiguous"), rshp5);
auto rshp6 = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {5040}}}), c5);
m1.add_return({rshp6});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto rshp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {5040}}}), x);
m2.add_return({rshp});
}
EXPECT(m1 == m2);
}
TEST_CASE(nested_reshape_squeeze)
{
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto rshp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 12}}}), x);
auto squeeze = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), rshp);
m1.add_return({squeeze});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto rshp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 12}}}), x);
m2.add_return({rshp});
}
EXPECT(m1 == m2);
}
TEST_CASE(nested_squeeze_reshape)
{
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s);
auto squeeze = m1.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), x);
auto rshp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 12}}}), squeeze);
m1.add_return({rshp});
}
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s);
auto rshp = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 12}}}), x);
m2.add_return({rshp});
}
EXPECT(m1 == m2);
}
TEST_CASE(concat_multibroadcasts1)
{
// Broadcasted batch dim, new axis < old axis
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment