"testing/vscode:/vscode.git/clone" did not exist on "2fff0eeca445c318f11b18553638a825c66b38c4"
Commit 9d1f9399 authored by Paul's avatar Paul
Browse files

Update reshape matchers

parent 51b79c51
...@@ -49,7 +49,6 @@ const auto& reshaper_names() ...@@ -49,7 +49,6 @@ const auto& reshaper_names()
static const std::unordered_set<std::string> names = { static const std::unordered_set<std::string> names = {
"flatten", "flatten",
"reshape", "reshape",
"contiguous",
"squeeze", "squeeze",
"unsqueeze" "unsqueeze"
}; };
...@@ -89,38 +88,21 @@ struct find_reshaper ...@@ -89,38 +88,21 @@ struct find_reshaper
{ {
auto matcher() const auto matcher() const
{ {
return match::name(reshaper_names())( auto no_output_reshape = match::none_of[match::outputs()](match::name(reshaper_names()));
match::any_of[match::outputs()](match::name(reshaper_names()))); auto input_reshape = match::arg(0)(match::skip(match::name("contiguous"))(match::name(reshaper_names())));
auto input = match::skip(match::name(reshaper_names()), match::name("contiguous"))(match::arg(0).bind("x"));
return match::name(reshaper_names())(no_output_reshape, input_reshape, input);
} }
void apply(module& m, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
{ {
auto ins = mr.result; auto ins = mr.result;
std::vector<instruction_ref> reshapes{ins}; auto input = mr.instructions["x"];
while(is_reshaper(reshapes.back())) auto dims = ins->get_shape().lens();
{
assert(not reshapes.back()->inputs().empty());
assert(m.has_instruction(reshapes.back()->inputs().front()));
auto input = reshapes.back()->inputs().front();
reshapes.push_back(input);
}
std::pair<instruction_ref, instruction_ref> r{m.end(), m.end()}; if (not input->get_shape().standard())
for(auto start : iterator_for(reshapes)) input = m.insert_instruction(input, make_op("contiguous"), input);
{ m.replace_instruction(ins, make_op("reshape", {{"dims", dims}}), input);
auto last = std::find_if(reshapes.rbegin(), reshapes.rend(), [&](auto&& i) {
return i->get_shape() == (*start)->get_shape() and i != (*start);
});
if(last != reshapes.rend())
{
r = std::make_pair(*start, *last);
break;
}
}
if(r.first != r.second)
{
m.replace_instruction(r.first, r.second);
}
} }
}; };
...@@ -848,9 +830,9 @@ void simplify_reshapes::apply(module& m) const ...@@ -848,9 +830,9 @@ void simplify_reshapes::apply(module& m) const
match::find_matches(m, match::find_matches(m,
find_where_op{}, find_where_op{},
find_resize{}, find_resize{},
find_reshape_cont{},
find_nop_reshapes{}, find_nop_reshapes{},
find_reshaper{}, find_reshaper{},
find_reshape_cont{},
find_transpose{}, find_transpose{},
find_concat_transpose{}, find_concat_transpose{},
find_concat_multibroadcasts{}, find_concat_multibroadcasts{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment