Commit c94a436d authored by Paul's avatar Paul
Browse files

Reduce number of identiy ops added

parent d2198a2a
......@@ -256,15 +256,13 @@ void schedule::apply(program& p) const
return;
for(auto ins1 : split.second[i])
{
auto idx1 = std::distance(split.first, ins1);
for(auto ins2 : split.second[j])
{
if(ins1 == ins2)
continue;
auto idx2 = std::distance(split.first, ins2);
auto point = idx1 > idx2 ? ins1 : ins2;
p.insert_instruction(std::next(point), op::identity{}, ins1, ins2);
}
auto args = split.second[j];
args.push_back(ins1);
auto point = std::max_element(args.begin(), args.end(), [&](auto x, auto y) {
return std::distance(split.first, x) < std::distance(split.first, y);
});
p.insert_instruction(std::next(*point), op::identity{}, args);
}
});
}
......
......@@ -3,6 +3,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -110,11 +111,10 @@ bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx
{
if(ins->name() != "identity")
continue;
if(ins->inputs().size() != 2)
if (not migraphx::contains(ins->inputs(), x))
continue;
if (not migraphx::contains(ins->inputs(), y))
continue;
if(ins->inputs() == std::vector<migraphx::instruction_ref>{x, y})
return true;
if(ins->inputs() == std::vector<migraphx::instruction_ref>{y, x})
return true;
}
return false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment