Unverified Commit 3a5c4306 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Get parent module in the pass manager (#1181)

* Add function to get a module tree
* Get parent module in the pass manager
parent 3b0a9116
...@@ -183,7 +183,7 @@ struct module ...@@ -183,7 +183,7 @@ struct module
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const; void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
std::vector<module_ref> get_sub_modules() const; std::vector<module_ref> get_sub_modules(bool shallow = false) const;
module& sort(); module& sort();
ins_dep_map calc_implicit_deps() const; ins_dep_map calc_implicit_deps() const;
......
...@@ -38,6 +38,7 @@ struct module_pass_manager ...@@ -38,6 +38,7 @@ struct module_pass_manager
module_pass_manager(const module_pass_manager&) = delete; module_pass_manager(const module_pass_manager&) = delete;
virtual module& get_module() = 0; virtual module& get_module() = 0;
virtual module* create_module(const std::string& name) = 0; virtual module* create_module(const std::string& name) = 0;
virtual module* get_common_parent() = 0;
virtual void run_pass(const pass& p) = 0; virtual void run_pass(const pass& p) = 0;
protected: protected:
......
...@@ -132,6 +132,8 @@ struct program ...@@ -132,6 +132,8 @@ struct program
std::vector<const module*> get_modules() const; std::vector<const module*> get_modules() const;
std::vector<module*> get_modules(); std::vector<module*> get_modules();
std::unordered_multimap<module_ref, module_ref> get_module_tree();
void remove_module(const std::string& name); void remove_module(const std::string& name);
void remove_unused_modules(); void remove_unused_modules();
......
...@@ -216,6 +216,12 @@ bool equal(R1&& r1, R2&& r2, Predicate... pred) ...@@ -216,6 +216,12 @@ bool equal(R1&& r1, R2&& r2, Predicate... pred)
return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end(), pred...); return std::equal(r1.begin(), r1.end(), r2.begin(), r2.end(), pred...);
} }
template <class Range>
auto distance(Range&& r)
{
return std::distance(r.begin(), r.end());
}
template <class R> template <class R>
using range_value = std::decay_t<decltype(*std::declval<R>().begin())>; using range_value = std::decay_t<decltype(*std::declval<R>().begin())>;
......
...@@ -819,19 +819,22 @@ void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a) ...@@ -819,19 +819,22 @@ void module::annotate(std::ostream& os, std::function<void(instruction_ref)> a)
}); });
} }
std::vector<module_ref> module::get_sub_modules() const std::vector<module_ref> module::get_sub_modules(bool shallow) const
{ {
std::vector<module_ref> vec_modules; std::vector<module_ref> vec_modules;
for(auto ins : iterator_for(*this)) for(auto ins : iterator_for(*this))
{ {
const auto& mod_args = ins->module_inputs(); const auto& mod_args = ins->module_inputs();
vec_modules.insert(vec_modules.end(), mod_args.begin(), mod_args.end()); vec_modules.insert(vec_modules.end(), mod_args.begin(), mod_args.end());
if(not shallow)
{
for(const auto& smod : mod_args) for(const auto& smod : mod_args)
{ {
auto sub_mods = smod->get_sub_modules(); auto sub_mods = smod->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_mods.begin(), sub_mods.end()); vec_modules.insert(vec_modules.end(), sub_mods.begin(), sub_mods.end());
} }
} }
}
return vec_modules; return vec_modules;
} }
......
...@@ -66,14 +66,12 @@ void run_pass(program& prog, const pass& p, tracer trace) ...@@ -66,14 +66,12 @@ void run_pass(program& prog, const pass& p, tracer trace)
struct module_pm : module_pass_manager struct module_pm : module_pass_manager
{ {
module* mod; module* mod = nullptr;
program* prog; tracer* t = nullptr;
tracer* t; module* common_parent = nullptr;
program* prog = nullptr;
module_pm(module* pmod = nullptr, program* pprog = nullptr, tracer* pt = nullptr) module_pm(module* pmod = nullptr, tracer* pt = nullptr) : mod(pmod), t(pt) {}
: mod(pmod), prog(pprog), t(pt)
{
}
template <class... Ts> template <class... Ts>
void trace(Ts&&... xs) const void trace(Ts&&... xs) const
...@@ -92,6 +90,7 @@ struct module_pm : module_pass_manager ...@@ -92,6 +90,7 @@ struct module_pm : module_pass_manager
assert(prog); assert(prog);
return prog->create_module(name); return prog->create_module(name);
} }
virtual module* get_common_parent() override { return common_parent; }
virtual void run_pass(const pass& p) override virtual void run_pass(const pass& p) override
{ {
assert(mod); assert(mod);
...@@ -111,7 +110,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace) ...@@ -111,7 +110,7 @@ void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
trace = tracer{std::cout}; trace = tracer{std::cout};
for(const auto& p : passes) for(const auto& p : passes)
{ {
module_pm{&mod, nullptr, &trace}.run_pass(p); module_pm{&mod, &trace}.run_pass(p);
} }
} }
...@@ -119,14 +118,31 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace) ...@@ -119,14 +118,31 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
{ {
if(enabled(MIGRAPHX_TRACE_PASSES{})) if(enabled(MIGRAPHX_TRACE_PASSES{}))
trace = tracer{std::cout}; trace = tracer{std::cout};
std::unordered_set<module_ref> visited;
for(const auto& p : passes) for(const auto& p : passes)
{ {
auto mods = prog.get_modules(); auto mods = prog.get_modules();
auto tree = prog.get_module_tree();
visited.clear();
for(const auto& mod : reverse(mods)) for(const auto& mod : reverse(mods))
{ {
if(mod->bypass()) if(mod->bypass())
continue; continue;
module_pm{mod, &prog, &trace}.run_pass(p); if(not visited.insert(mod).second)
continue;
module_pm mpm{mod, &trace};
mpm.prog = &prog;
auto parents = range(tree.equal_range(mod));
auto nparents = distance(parents);
if(nparents == 0)
mpm.common_parent = nullptr;
else if(nparents == 1)
mpm.common_parent = parents.begin()->second;
else
// Just set common parent to main module when there is muliple parents for now
// TODO: Compute the common parent
mpm.common_parent = prog.get_main_module();
mpm.run_pass(p);
} }
run_pass(prog, p, trace); run_pass(prog, p, trace);
} }
......
...@@ -869,6 +869,23 @@ std::vector<module*> program::get_modules() ...@@ -869,6 +869,23 @@ std::vector<module*> program::get_modules()
return result; return result;
} }
template <class Module, class Map>
void generic_insert_module_tree(Module* pm, Map& m)
{
for(auto* sm : pm->get_sub_modules(true))
{
m.insert(std::make_pair(sm, pm));
generic_insert_module_tree(sm, m);
}
}
std::unordered_multimap<module_ref, module_ref> program::get_module_tree()
{
std::unordered_multimap<module_ref, module_ref> result;
generic_insert_module_tree(this->get_main_module(), result);
return result;
}
template <class Map, class T> template <class Map, class T>
bool is_unused_module(Map& m, const std::vector<T*>& mods, const std::string& name) bool is_unused_module(Map& m, const std::vector<T*>& mods, const std::string& name)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment