Commit baee674c authored by Paul's avatar Paul
Browse files

Support commutative slices as well

parent 1dd8a1fd
......@@ -36,6 +36,15 @@ using instruction_ref = std::list<instruction>::iterator;
migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
struct compare_instruction_ref
{
bool operator()(const instruction_ref& x,
const instruction_ref& y) const noexcept
{
return as_address(x) < as_address(y);
}
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -559,7 +559,7 @@ struct find_splits
return true;
}
static std::vector<instruction_ref> split_nary(const std::vector<instruction_ref>& group)
static std::vector<instruction_ref> split_nary(const std::vector<instruction_ref>& group, bool commutative)
{
// All inputs have the same slices
if(not std::all_of(group.begin(), group.end(), [](auto ins) {
......@@ -574,20 +574,38 @@ struct find_splits
}))
return {};
auto start = group.front();
std::vector<instruction_ref> inputs;
std::transform(start->inputs().begin(),
start->inputs().end(),
std::back_inserter(inputs),
[](auto ins) { return ins->inputs().front(); });
if(not std::all_of(group.begin(), group.end(), [&](auto ins) {
return std::equal(
ins->inputs().begin(),
ins->inputs().end(),
inputs.begin(),
inputs.end(),
[](auto slice, auto input) { return slice->inputs().front() == input; });
}))
return {};
auto get_inputs = [](auto ins) {
std::vector<instruction_ref> result;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(result),
[](auto slice) { return slice->inputs().front(); });
return result;
};
auto inputs = get_inputs(start);
if (commutative and inputs.size() > 1)
{
std::sort(inputs.begin(), inputs.end(), compare_instruction_ref{});
if(not std::all_of(group.begin(), group.end(), [&](auto ins) {
auto inputs2 = get_inputs(ins);
std::sort(inputs2.begin(), inputs2.end(), compare_instruction_ref{});
return inputs == inputs2;
}))
return {};
}
else
{
if(not std::all_of(group.begin(), group.end(), [&](auto ins) {
return std::equal(
ins->inputs().begin(),
ins->inputs().end(),
inputs.begin(),
inputs.end(),
[](auto slice, auto input) { return slice->inputs().front() == input; });
}))
return {};
}
return inputs;
}
......@@ -632,7 +650,7 @@ struct find_splits
return i->name() == "slice";
}) && "one argument must be a split");
auto split_inputs = split_nary(group);
auto split_inputs = split_nary(group, op.attributes().get("commutative", false));
if(not split_inputs.empty())
{
auto last = find_last_instruction(m, split_inputs);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment