Commit a851f699 authored by Paul's avatar Paul
Browse files

Hash modules for quicker lookup of modules

parent 56584fa2
......@@ -7,6 +7,20 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator, class Output, class Predicate, class F>
void transform_if(Iterator start, Iterator last, Output out, Predicate pred, F f)
{
while (start != last)
{
if (pred(*start))
{
*out = f(*start);
++out;
}
++start;
}
}
template <class Iterator, class Output, class Predicate>
void group_by(Iterator start, Iterator last, Output out, Predicate pred)
{
......
......@@ -9,6 +9,7 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/make_op.hpp>
#include <iostream>
#include <sstream>
......@@ -26,13 +27,12 @@ inline namespace MIGRAPHX_INLINE_NS {
struct program_impl
{
// A map is used to keep references to modules of the program
// all the modules are store in the depth-first order
std::list<module> modules;
std::unordered_map<std::string, module> modules;
context ctx;
std::string target_name;
};
program::program() : impl(std::make_unique<program_impl>()) { impl->modules.push_back({"main"}); }
program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); }
program::program(program&&) noexcept = default;
program::~program() noexcept = default;
......@@ -67,9 +67,8 @@ void program::assign(const program& p)
std::unordered_map<module_ref, module_ref> mod_map;
std::transform(impl->modules.begin(),
impl->modules.end(),
p.impl->modules.begin(),
std::inserter(mod_map, mod_map.begin()),
[](auto&& x, auto&& y) { return std::make_pair(&y, &x); });
[&](auto&& xp) { return std::make_pair(&p.impl->modules.at(xp.first), &xp.second); });
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto&& pp : mod_map)
......@@ -86,7 +85,7 @@ void program::assign(const program& p)
// Update all references from all modules
for(auto&& mp : impl->modules)
{
for(auto ins : iterator_for(mp))
for(auto ins : iterator_for(mp.second))
instruction::replace_refs(ins, ins_map, mod_map);
}
}
......@@ -316,14 +315,14 @@ value program::to_value() const
if(not this->impl->target_name.empty())
result["context"] = this->impl->ctx.to_value();
value module_vals = value::array{};
value module_vals = value::object{};
std::unordered_map<instruction_ref, std::string> names;
for(auto& mod : this->impl->modules)
for(auto& mod : this->get_modules())
{
value mod_val;
value nodes;
mod_val["name"] = mod.name();
names = mod.print(
mod_val["name"] = mod->name();
names = mod->print(
[&](auto ins, auto ins_names) {
value node;
node["output"] = ins_names.at(ins);
......@@ -358,7 +357,7 @@ value program::to_value() const
names);
mod_val["nodes"] = nodes;
module_vals.push_back(mod_val);
module_vals[mod->name()] = mod_val;
}
result["modules"] = module_vals;
......@@ -371,12 +370,7 @@ static void mod_from_val(module_ref mod,
std::unordered_map<std::string, instruction_ref>& instructions,
const std::unordered_map<std::string, module_ref>& map_mods)
{
const auto* it = std::find_if(v.begin(), v.end(), [&](auto& mv) {
return mv.at("name").template to<std::string>() == mod->name();
});
assert(it != v.end());
const auto& module_val = *it;
const auto& module_val = v.at(mod->name());
for(const value& node : module_val.at("nodes"))
{
instruction_ref output;
......@@ -455,15 +449,17 @@ void program::from_value(const value& v)
}
auto module_vals = v.at("modules");
std::unordered_map<std::string, module_ref> map_mods;
for(const auto& vv : module_vals)
{
const auto& name = vv.at("name").to<std::string>();
const auto& name = vv.get_key();
if(name == "main")
continue;
impl->modules.push_back({name});
map_mods[name] = &impl->modules.back();
impl->modules.emplace(name, name);
}
std::unordered_map<std::string, module_ref> map_mods;
std::transform(impl->modules.begin(), impl->modules.end(), std::inserter(map_mods, map_mods.end()), [&](auto&& pp) {
return std::make_pair(pp.first, &pp.second);
});
std::unordered_map<std::string, instruction_ref> map_insts;
auto* mm = get_main_module();
......@@ -585,8 +581,8 @@ void program::debug_print() const { std::cout << *this << std::endl; }
void program::debug_print(instruction_ref ins) const
{
std::unordered_map<instruction_ref, std::string> names;
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& it) {
return (it.end() == ins);
if(std::any_of(this->impl->modules.begin(), this->impl->modules.end(), [&](const auto& pp) {
return (pp.second.end() == ins);
}))
{
std::cout << "End instruction" << std::endl;
......@@ -594,7 +590,7 @@ void program::debug_print(instruction_ref ins) const
}
else if(std::none_of(this->impl->modules.begin(),
this->impl->modules.end(),
[&](const auto& it) { return it.has_instruction(ins); }))
[&](const auto& pp) { return pp.second.has_instruction(ins); }))
{
std::cout << "Instruction not part of program" << std::endl;
return;
......@@ -615,9 +611,9 @@ void program::print(
const std::function<void(instruction_ref, std::unordered_map<instruction_ref, std::string>)>&
print_func) const
{
for(const auto& mod : this->impl->modules)
for(const auto& pp : this->impl->modules)
{
names = mod.print(print_func, names);
names = pp.second.print(print_func, names);
}
}
......@@ -647,74 +643,72 @@ void program::dry_run(std::unordered_map<std::string, argument> params) const
void program::annotate(std::ostream& os, const std::function<void(instruction_ref)>& a) const
{
for(auto& mod : this->impl->modules)
for(auto& pp : this->impl->modules)
{
std::cout << mod.name() << ":" << std::endl;
mod.annotate(os, a);
std::cout << pp.first << ":" << std::endl;
pp.second.annotate(os, a);
}
}
const module* program::get_module(const std::string& name) const
{
auto it = std::find_if(
impl->modules.begin(), impl->modules.end(), [&](auto& m) { return (m.name() == name); });
if(it == impl->modules.end())
{
return nullptr;
}
return &(*it);
return &impl->modules.at(name);
}
module* program::create_module(const std::string& name)
{
auto it = impl->modules.insert(impl->modules.end(), {name});
return &(*it);
auto r = impl->modules.emplace(name, name);
return &(r.first->second);
}
module* program::get_module(const std::string& name)
{
auto it = std::find_if(
impl->modules.begin(), impl->modules.end(), [&](auto& m) { return (m.name() == name); });
if(it == impl->modules.end())
{
return nullptr;
}
return &(*it);
return &impl->modules.at(name);
}
module* program::get_main_module() { return get_module("main"); }
const module* program::get_main_module() const { return get_module("main"); }
std::vector<const module*> program::get_modules() const
template<class T>
std::vector<T*> generic_get_modules(T* mm)
{
const module* mm = this->get_main_module();
std::vector<const module*> vec_modules;
std::vector<T*> vec_modules;
vec_modules.push_back(mm);
auto sub_modules = mm->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_modules.begin(), sub_modules.end());
return vec_modules;
}
std::vector<module*> program::get_modules()
template<class Map, class T, class OutputIterator>
void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputIterator out)
{
module* mm = this->get_main_module();
std::vector<module*> vec_modules;
vec_modules.push_back(mm);
auto sub_modules = mm->get_sub_modules();
vec_modules.insert(vec_modules.end(), sub_modules.begin(), sub_modules.end());
std::unordered_set<std::string> used;
std::transform(mods.begin(), mods.end(), std::inserter(used, used.end()), [](auto&& mod) {
return mod->name();
});
transform_if(m.begin(), m.end(), out, [&](auto&& pp){ return not contains(used, pp.first); }, [](auto&& pp) {return &pp.second; });
}
return vec_modules;
std::vector<const module*> program::get_modules() const
{
auto result = generic_get_modules(this->get_main_module());
generic_get_unused_modules(impl->modules, result, std::back_inserter(result));
return result;
}
std::vector<module*> program::get_modules()
{
auto result = generic_get_modules(this->get_main_module());
generic_get_unused_modules(impl->modules, result, std::back_inserter(result));
return result;
}
program& program::sort()
{
for(auto& mod : this->impl->modules)
for(auto& pp : this->impl->modules)
{
mod.sort();
pp.second.sort();
}
return *this;
......
......@@ -89,17 +89,17 @@ struct invert_pass
{
std::string name() const { return "invert_pass"; }
void apply(migraphx::module& p) const
void apply(migraphx::module& m) const
{
for(auto ins : migraphx::iterator_for(p))
for(auto ins : migraphx::iterator_for(m))
{
if(ins->name() == "sum")
{
p.replace_instruction(ins, minus_op{}, ins->inputs());
m.replace_instruction(ins, minus_op{}, ins->inputs());
}
else if(ins->name() == "minus")
{
p.replace_instruction(ins, sum_op{}, ins->inputs());
m.replace_instruction(ins, sum_op{}, ins->inputs());
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment