Commit baee674c authored by Paul's avatar Paul
Browse files

Support commutative slices as well

parent 1dd8a1fd
...@@ -36,6 +36,15 @@ using instruction_ref = std::list<instruction>::iterator; ...@@ -36,6 +36,15 @@ using instruction_ref = std::list<instruction>::iterator;
migraphx::instruction* as_address(const instruction_ref& ins) noexcept; migraphx::instruction* as_address(const instruction_ref& ins) noexcept;
struct compare_instruction_ref
{
bool operator()(const instruction_ref& x,
const instruction_ref& y) const noexcept
{
return as_address(x) < as_address(y);
}
};
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -559,7 +559,7 @@ struct find_splits ...@@ -559,7 +559,7 @@ struct find_splits
return true; return true;
} }
static std::vector<instruction_ref> split_nary(const std::vector<instruction_ref>& group) static std::vector<instruction_ref> split_nary(const std::vector<instruction_ref>& group, bool commutative)
{ {
// All inputs have the same slices // All inputs have the same slices
if(not std::all_of(group.begin(), group.end(), [](auto ins) { if(not std::all_of(group.begin(), group.end(), [](auto ins) {
...@@ -574,20 +574,38 @@ struct find_splits ...@@ -574,20 +574,38 @@ struct find_splits
})) }))
return {}; return {};
auto start = group.front(); auto start = group.front();
std::vector<instruction_ref> inputs; auto get_inputs = [](auto ins) {
std::transform(start->inputs().begin(), std::vector<instruction_ref> result;
start->inputs().end(), std::transform(ins->inputs().begin(),
std::back_inserter(inputs), ins->inputs().end(),
[](auto ins) { return ins->inputs().front(); }); std::back_inserter(result),
if(not std::all_of(group.begin(), group.end(), [&](auto ins) { [](auto slice) { return slice->inputs().front(); });
return std::equal( return result;
ins->inputs().begin(), };
ins->inputs().end(), auto inputs = get_inputs(start);
inputs.begin(), if (commutative and inputs.size() > 1)
inputs.end(), {
[](auto slice, auto input) { return slice->inputs().front() == input; }); std::sort(inputs.begin(), inputs.end(), compare_instruction_ref{});
})) if(not std::all_of(group.begin(), group.end(), [&](auto ins) {
return {}; auto inputs2 = get_inputs(ins);
std::sort(inputs2.begin(), inputs2.end(), compare_instruction_ref{});
return inputs == inputs2;
}))
return {};
}
else
{
if(not std::all_of(group.begin(), group.end(), [&](auto ins) {
return std::equal(
ins->inputs().begin(),
ins->inputs().end(),
inputs.begin(),
inputs.end(),
[](auto slice, auto input) { return slice->inputs().front() == input; });
}))
return {};
}
return inputs; return inputs;
} }
...@@ -632,7 +650,7 @@ struct find_splits ...@@ -632,7 +650,7 @@ struct find_splits
return i->name() == "slice"; return i->name() == "slice";
}) && "one argument must be a split"); }) && "one argument must be a split");
auto split_inputs = split_nary(group); auto split_inputs = split_nary(group, op.attributes().get("commutative", false));
if(not split_inputs.empty()) if(not split_inputs.empty())
{ {
auto last = find_last_instruction(m, split_inputs); auto last = find_last_instruction(m, split_inputs);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment