Commit 9d1f9399 authored by Paul's avatar Paul
Browse files

Update reshape matchers

parent 51b79c51
......@@ -49,7 +49,6 @@ const auto& reshaper_names()
static const std::unordered_set<std::string> names = {
"flatten",
"reshape",
"contiguous",
"squeeze",
"unsqueeze"
};
......@@ -89,38 +88,21 @@ struct find_reshaper
{
auto matcher() const
{
return match::name(reshaper_names())(
match::any_of[match::outputs()](match::name(reshaper_names())));
auto no_output_reshape = match::none_of[match::outputs()](match::name(reshaper_names()));
auto input_reshape = match::arg(0)(match::skip(match::name("contiguous"))(match::name(reshaper_names())));
auto input = match::skip(match::name(reshaper_names()), match::name("contiguous"))(match::arg(0).bind("x"));
return match::name(reshaper_names())(no_output_reshape, input_reshape, input);
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins};
while(is_reshaper(reshapes.back()))
{
assert(not reshapes.back()->inputs().empty());
assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
auto input = mr.instructions["x"];
auto dims = ins->get_shape().lens();
std::pair<instruction_ref, instruction_ref> r{m.end(), m.end()};
for(auto start : iterator_for(reshapes))
{
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
{
m.replace_instruction(r.first, r.second);
}
if (not input->get_shape().standard())
input = m.insert_instruction(input, make_op("contiguous"), input);
m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input);
}
};
......@@ -848,9 +830,9 @@ void simplify_reshapes::apply(module& m) const
match::find_matches(m,
find_where_op{},
find_resize{},
find_reshape_cont{},
find_nop_reshapes{},
find_reshaper{},
find_reshape_cont{},
find_transpose{},
find_concat_transpose{},
find_concat_multibroadcasts{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment