Commit 4c185e65 authored by Paul's avatar Paul
Browse files

Fuse across binary slices

parent 40118191
......@@ -559,6 +559,42 @@ struct find_splits
return true;
}
static std::vector<instruction_ref> split_nary(const std::vector<instruction_ref>& group)
{
// All inputs have the same slices
if (not std::all_of(group.begin(), group.end(), [](auto ins) {
if (ins->inputs().empty())
return false;
auto first = ins->inputs().front();
if (first->name() != "slice")
return false;
return std::all_of(ins->inputs().begin()+1, ins->inputs().end(), [&](auto input) {
return input->get_operator() == first->get_operator();
});
}))
return {};
auto start = group.front();
std::vector<instruction_ref> inputs;
std::transform(start->inputs().begin(), start->inputs().end(), std::back_inserter(inputs), [](auto ins) {
return ins->inputs().front();
});
if (not std::all_of(group.begin(), group.end(), [&](auto ins) {
return std::equal(ins->inputs().begin(), ins->inputs().end(), inputs.begin(), inputs.end(), [](auto slice, auto input) {
return slice->inputs().front() == input;
});
}))
return {};
return inputs;
}
template<class Range>
static instruction_ref find_last_instruction(const module& m, const Range& r)
{
auto rm = reverse(m);
auto it = std::find_first_of(rm.begin(), rm.end(), r.begin(), r.end());
return std::prev(it.base());
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -591,42 +627,53 @@ struct find_splits
assert(not std::none_of(start->inputs().begin(), start->inputs().end(), [](auto i) {
return i->name() == "slice";
}) && "one argument must be a split");
auto data_idx = 1;
if(start->inputs().back()->name() == "slice")
auto split_inputs = split_nary(group);
if (not split_inputs.empty())
{
split_idx = 1;
data_idx = 0;
auto last = find_last_instruction(m, split_inputs);
c = m.insert_instruction(std::next(last), op, split_inputs);
}
else
{
auto data_idx = 1;
if(start->inputs().back()->name() == "slice")
{
split_idx = 1;
data_idx = 0;
}
std::vector<instruction_ref> data_args;
std::transform(group.begin(),
group.end(),
std::back_inserter(data_args),
[&](auto i) { return i->inputs()[data_idx]; });
std::vector<instruction_ref> data_args;
std::transform(group.begin(),
group.end(),
std::back_inserter(data_args),
[&](auto i) { return i->inputs()[data_idx]; });
// Data arguments must be a constant
if(std::any_of(data_args.begin(), data_args.end(), [](auto i) {
return not i->can_eval();
}))
return;
// Data arguments must be a constant
if(std::any_of(data_args.begin(), data_args.end(), [](auto i) {
return not i->can_eval();
}))
return;
for(auto data : data_args)
m.move_instructions(data, ins);
for(auto data : data_args)
m.move_instructions(data, ins);
auto slice_op = any_cast<op::slice>(splits.front()->get_operator());
assert(not slice_op.axes.empty());
if(slice_op.axes.size() > 1)
return;
auto concat_axis = slice_op.axes.front();
// TODO: Check if axises match
auto concat = m.insert_instruction(
ins, make_op("concat", {{"axis", concat_axis}}), data_args);
std::vector<instruction_ref> args;
args.resize(2);
args[split_idx] = ins;
args[data_idx] = concat;
c = m.insert_instruction(std::next(ins), op, args);
}
auto slice_op = any_cast<op::slice>(splits.front()->get_operator());
assert(not slice_op.axes.empty());
if(slice_op.axes.size() > 1)
return;
auto concat_axis = slice_op.axes.front();
// TODO: Check if axises match
auto concat = m.insert_instruction(
ins, make_op("concat", {{"axis", concat_axis}}), data_args);
std::vector<instruction_ref> args;
args.resize(2);
args[split_idx] = ins;
args[data_idx] = concat;
c = m.insert_instruction(std::next(ins), op, args);
}
if(c != m.end())
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment