Unverified Commit 3a5c4306 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Get parent module in the pass manager (#1181)

* Add function to get a module tree
* Get parent module in the pass manager
parent 3b0a9116
......@@ -183,7 +183,7 @@ struct module
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
std::vector<module_ref> get_sub_modules() const;
std::vector<module_ref> get_sub_modules(bool shallow = false) const;
module& sort();
ins_dep_map calc_implicit_deps() const;
......
......@@ -38,6 +38,7 @@ struct module_pass_manager
module_pass_manager(const module_pass_manager&) = delete;
virtual module& get_module() = 0;
virtual module* create_module(const std::string& name) = 0;
virtual module* get_common_parent() = 0;
virtual void run_pass(const pass& p) = 0;
protected:
......
......@@ -132,6 +132,8 @@ struct program
std::vector<const module*> get_modules() const;
std::vector<module*> get_modules();
std::unordered_multimap<module_ref, module_ref> get_module_tree();
void remove_module(const std::string& name);
void remove_unused_modules();
......
......@@ -216,6 +216,12 @@ bool equal(R1&& r1, R2&& r2, Predicate... pred)
return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end(), pred...);
}
template <class Range>
auto distance(Range&& r)
{
return std::distance(r.begin(), r.end());
}
template <class R>
using range_value = std::decay_t<decltype(*std::declval<R>().begin())>;
......
......@@ -819,19 +819,22 @@ void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a)
});
}
std::vector<module_ref> module::get_sub_modules() const
std::vector<module_ref> module::get_sub_modules(bool shallow) const
{
std::vector<module_ref> vec_modules;
for(auto ins : iterator_for(*this))
{
const auto& mod_args = ins->module_inputs();
vec_modules.insert(vec_modules.end(), mod_args.begin(), mod_args.end());
if(not shallow)
{
for(const auto& smod : mod_args)
{
auto sub_mods = smod->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_mods.begin(), sub_mods.end());
}
}
}
return vec_modules;
}
......
......@@ -66,14 +66,12 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct module_pm : module_pass_manager
{
module* mod;
program* prog;
tracer* t;
module* mod = nullptr;
tracer* t = nullptr;
module* common_parent = nullptr;
program* prog = nullptr;
module_pm(module* pmod = nullptr, program* pprog = nullptr, tracer* pt = nullptr)
: mod(pmod), prog(pprog), t(pt)
{
}
module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {}
template <class... Ts>
void trace(Ts&&... xs) const
......@@ -92,6 +90,7 @@ struct module_pm : module_pass_manager
assert(prog);
return prog->create_module(name);
}
virtual module* get_common_parent() override { return common_parent; }
virtual void run_pass(const pass& p) override
{
assert(mod);
......@@ -111,7 +110,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
trace = tracer{std::cout};
for(const auto& p : passes)
{
module_pm{&mod, nullptr, &trace}.run_pass(p);
module_pm{&mod, &trace}.run_pass(p);
}
}
......@@ -119,14 +118,31 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{
if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout};
std::unordered_set<module_ref> visited;
for(const auto& p : passes)
{
auto mods = prog.get_modules();
auto tree = prog.get_module_tree();
visited.clear();
for(const auto& mod : reverse(mods))
{
if(mod->bypass())
continue;
module_pm{mod, &prog, &trace}.run_pass(p);
if(not visited.insert(mod).second)
continue;
module_pm mpm{mod, &trace};
mpm.prog = &prog;
auto parents = range(tree.equal_range(mod));
auto nparents = distance(parents);
if(nparents == 0)
mpm.common_parent = nullptr;
else if(nparents == 1)
mpm.common_parent = parents.begin()->second;
else
// Just set common parent to main module when there is muliple parents for now
// TODO: Compute the common parent
mpm.common_parent = prog.get_main_module();
mpm.run_pass(p);
}
run_pass(prog, p, trace);
}
......
......@@ -869,6 +869,23 @@ std::vector<module*> program::get_modules()
return result;
}
template <class Module, class Map>
void generic_insert_module_tree(Module* pm, Map& m)
{
for(auto* sm : pm->get_sub_modules(true))
{
m.insert(std::make_pair(sm, pm));
generic_insert_module_tree(sm, m);
}
}
std::unordered_multimap<module_ref, module_ref> program::get_module_tree()
{
std::unordered_multimap<module_ref, module_ref> result;
generic_insert_module_tree(this->get_main_module(), result);
return result;
}
template <class Map, class T>
bool is_unused_module(Map& m, const std::vector<T*>& mods, const std::string& name)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment