Commit 4c185e65 authored by Paul's avatar Paul
Browse files

Fuse across binary slices

parent 40118191
...@@ -559,6 +559,42 @@ struct find_splits ...@@ -559,6 +559,42 @@ struct find_splits
return true; return true;
} }
static std::vector<instruction_ref> split_nary(const std::vector<instruction_ref>& group)
{
// All inputs have the same slices
if (not std::all_of(group.begin(), group.end(), [](auto ins) {
if (ins->inputs().empty())
return false;
auto first = ins->inputs().front();
if (first->name() != "slice")
return false;
return std::all_of(ins->inputs().begin()+1, ins->inputs().end(), [&](auto input) {
return input->get_operator() == first->get_operator();
});
}))
return {};
auto start = group.front();
std::vector<instruction_ref> inputs;
std::transform(start->inputs().begin(), start->inputs().end(), std::back_inserter(inputs), [](auto ins) {
return ins->inputs().front();
});
if (not std::all_of(group.begin(), group.end(), [&](auto ins) {
return std::equal(ins->inputs().begin(), ins->inputs().end(), inputs.begin(), inputs.end(), [](auto slice, auto input) {
return slice->inputs().front() == input;
});
}))
return {};
return inputs;
}
template<class Range>
static instruction_ref find_last_instruction(const module& m, const Range& r)
{
auto rm = reverse(m);
auto it = std::find_first_of(rm.begin(), rm.end(), r.begin(), r.end());
return std::prev(it.base());
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -591,6 +627,15 @@ struct find_splits ...@@ -591,6 +627,15 @@ struct find_splits
assert(not std::none_of(start->inputs().begin(), start->inputs().end(), [](auto i) { assert(not std::none_of(start->inputs().begin(), start->inputs().end(), [](auto i) {
return i->name() == "slice"; return i->name() == "slice";
}) && "one argument must be a split"); }) && "one argument must be a split");
auto split_inputs = split_nary(group);
if (not split_inputs.empty())
{
auto last = find_last_instruction(m, split_inputs);
c = m.insert_instruction(std::next(last), op, split_inputs);
}
else
{
auto data_idx = 1; auto data_idx = 1;
if(start->inputs().back()->name() == "slice") if(start->inputs().back()->name() == "slice")
{ {
...@@ -628,6 +673,8 @@ struct find_splits ...@@ -628,6 +673,8 @@ struct find_splits
args[data_idx] = concat; args[data_idx] = concat;
c = m.insert_instruction(std::next(ins), op, args); c = m.insert_instruction(std::next(ins), op, args);
} }
}
if(c != m.end()) if(c != m.end())
{ {
for(auto i : group) for(auto i : group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment