Unverified Commit bf603a76 authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

Parallelize evaluations in propagate_constant (#1220)

Addressing issue #1166 - propagate_constant pass currently uses a recursive approach to find all instructions in a module that can be evaluated to a literal and performs the replacement in the same call.

New approach:

Perform single pass though instructions in the module to determine which instructions can be evaluated
Evaluate selected instructions in parallel
Replace the selected instructions with the corresponding literal
parent a401e72a
......@@ -3,6 +3,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/par_for.hpp>
#include <unordered_set>
namespace migraphx {
......@@ -20,33 +21,42 @@ bool skip_propogate(instruction_ref ins)
return false;
}
bool is_const(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); }
void propagate_constant::apply(module& m) const
{
std::unordered_set<instruction_ref> const_instrs;
auto last = std::prev(m.end());
// Find instructions that can be evaluated to a literal
for(auto i : iterator_for(m))
{
if(i->name() != "@literal")
continue;
if(i->outputs().empty())
continue;
fix([&](auto self, auto ins) {
std::unordered_set<instruction_ref> children(ins->outputs().begin(),
ins->outputs().end());
for(auto child : children)
{
if(child->name() == "@literal" or skip_propogate(child))
{
self(child);
if(is_const(i) and i != last)
continue;
std::copy_if(
i->inputs().begin(),
i->inputs().end(),
std::inserter(const_instrs, const_instrs.begin()),
[&](const instruction_ref ins) { return is_const(ins) and ins->name() != "@literal"; });
}
auto r = child->eval();
if(not r.empty())
// Compute literals in parallel
std::vector<instruction_ref> const_instrs_vec{const_instrs.begin(), const_instrs.end()};
std::vector<argument> literals(const_instrs_vec.size());
par_for(const_instrs_vec.size(), 1, [&](const auto i) {
literals[i] = const_instrs_vec[i]->eval();
});
// Replace instructions in m
for(size_t i = 0; i < const_instrs_vec.size(); i++)
{
assert(r.get_shape() == child->get_shape());
auto l = m.add_literal(r.get_shape(), r.data());
self(m.replace_instruction(child, l));
}
if(not literals[i].empty())
{
assert(literals[i].get_shape() == const_instrs_vec[i]->get_shape());
auto l = m.add_literal(literals[i].get_shape(), literals[i].data());
m.replace_instruction(const_instrs_vec[i], l);
}
})(i);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment