Commit e1f448cf authored by Paul's avatar Paul
Browse files

Insert identity instructions

parent 5b3cded5
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <set> #include <set>
namespace migraphx { namespace migraphx {
...@@ -92,10 +94,11 @@ struct stream_info ...@@ -92,10 +94,11 @@ struct stream_info
return not std::all_of(v.begin(), v.end(), [&](std::size_t x) { return x == v.front(); }); return not std::all_of(v.begin(), v.end(), [&](std::size_t x) { return x == v.front(); });
} }
std::vector<std::size_t> get_input_streams(instruction_ref ins) const template<class Selector>
std::vector<std::size_t> get_streams(instruction_ref ins, Selector select) const
{ {
std::vector<std::size_t> result; std::vector<std::size_t> result;
for(auto i : ins->inputs()) for(auto i : select(ins))
{ {
if(weights.at(i) == 0) if(weights.at(i) == 0)
{ {
...@@ -110,7 +113,23 @@ struct stream_info ...@@ -110,7 +113,23 @@ struct stream_info
return result; return result;
} }
std::vector<std::size_t> get_input_streams(instruction_ref ins) const
{
return get_streams(ins, [](auto i) {
return i->inputs();
});
}
std::vector<std::size_t> get_output_streams(instruction_ref ins) const
{
return get_streams(ins, [](auto i) {
return i->outputs();
});
}
bool is_merge_point(instruction_ref ins) const { return different(get_input_streams(ins)); } bool is_merge_point(instruction_ref ins) const { return different(get_input_streams(ins)); }
bool is_split_point(instruction_ref ins) const { return different(get_output_streams(ins)); }
std::vector<std::size_t> wait_for(instruction_ref ins) const std::vector<std::size_t> wait_for(instruction_ref ins) const
{ {
...@@ -119,11 +138,33 @@ struct stream_info ...@@ -119,11 +138,33 @@ struct stream_info
result.erase(std::unique(result.begin(), result.end()), result.end()); result.erase(std::unique(result.begin(), result.end()), result.end());
return result; return result;
} }
template<class F>
void find_concurrent_instructions(program& p, F f)
{
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> split_from;
for(auto ins : iterator_for(p))
{
if (weights[ins] == 0)
continue;
for(auto&& arg : ins->inputs())
{
if (is_split_point(arg))
split_from[ins].insert(arg);
split_from[ins].insert(split_from[arg].begin(), split_from[arg].end());
}
// Collect concur instructions for each split point.
for(auto& split : split_from[ins])
{
f(ins, split);
}
}
}
}; };
void schedule::apply(program& p) const void schedule::apply(program& p) const
{ {
stream_info si; stream_info si;
auto last = std::prev(p.end()); auto last = std::prev(p.end());
si.accumulate_weights(last, model); si.accumulate_weights(last, model);
...@@ -148,6 +189,10 @@ void schedule::apply(program& p) const ...@@ -148,6 +189,10 @@ void schedule::apply(program& p) const
else else
model.schedule_instruction(p, ins, si.get_stream(ins)); model.schedule_instruction(p, ins, si.get_stream(ins));
} }
si.find_concurrent_instructions(p, [&](auto x, auto y) {
p.insert_instruction(std::next(x), op::identity{}, x, y);
});
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment