Unverified Commit ac0f79aa authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Allow creating modules in a module pass (#931)

* Add module pass manage
parent ebbaf8fc
......@@ -8,12 +8,14 @@
#include <utility>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
struct module_pass_manager;
#ifdef DOXYGEN
......@@ -24,6 +26,7 @@ struct pass
/// A unique name used to identify the pass
std::string name() const;
/// Run the pass on the module
void apply(module_pass_manager& mpm) const;
void apply(module& m) const;
/// Run the pass on the program
void apply(program& p) const;
......@@ -31,13 +34,37 @@ struct pass
#else
module& get_module(module_pass_manager& mpm);
namespace detail {
template <class T>
auto module_pass_manager_apply(rank<1>, const T& x, module_pass_manager& mpm)
-> decltype(x.apply(get_module(mpm)))
{
return x.apply(get_module(mpm));
}
template <class T>
void module_pass_manager_apply(rank<0>, const T&, module_pass_manager&)
{
}
template <class T>
void module_pass_manager_apply(const T& x, module_pass_manager& mpm)
{
module_pass_manager_apply(rank<1>{}, x, mpm);
}
} // namespace detail
/*
* Type-erased interface for:
*
* struct pass
* {
* std::string name() const;
* void apply(module & m) const;
* void apply(module_pass_manager & mpm) const;
* void apply(program & p) const;
* };
*
......@@ -112,10 +139,10 @@ struct pass
return (*this).private_detail_te_get_handle().name();
}
void apply(module& m) const
void apply(module_pass_manager& mpm) const
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().apply(m);
(*this).private_detail_te_get_handle().apply(mpm);
}
void apply(program& p) const
......@@ -137,22 +164,24 @@ struct pass
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual void apply(module& m) const = 0;
virtual void apply(program& p) const = 0;
virtual std::string name() const = 0;
virtual void apply(module_pass_manager& mpm) const = 0;
virtual void apply(program& p) const = 0;
};
template <class T>
static auto private_detail_te_default_apply(char, T&& private_detail_te_self, module& m)
-> decltype(private_detail_te_self.apply(m))
static auto
private_detail_te_default_apply(char, T&& private_detail_te_self, module_pass_manager& mpm)
-> decltype(private_detail_te_self.apply(mpm))
{
private_detail_te_self.apply(m);
private_detail_te_self.apply(mpm);
}
template <class T>
static void private_detail_te_default_apply(float, T&& private_detail_te_self, module& m)
static void
private_detail_te_default_apply(float, T&& private_detail_te_self, module_pass_manager& mpm)
{
migraphx::nop(private_detail_te_self, m);
migraphx::detail::module_pass_manager_apply(private_detail_te_self, mpm);
}
template <class T>
......@@ -198,10 +227,10 @@ struct pass
std::string name() const override { return private_detail_te_value.name(); }
void apply(module& m) const override
void apply(module_pass_manager& mpm) const override
{
private_detail_te_default_apply(char(0), private_detail_te_value, m);
private_detail_te_default_apply(char(0), private_detail_te_value, mpm);
}
void apply(program& p) const override
......
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_PASS_MANAGER_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_PASS_MANAGER_HPP
#include <list>
#include <unordered_map>
#include <migraphx/operation.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/builtin.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/target.hpp>
#include <migraphx/tracer.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
#include <iostream>
#include <migraphx/pass.hpp>
#include <migraphx/tracer.hpp>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager
{
module_pass_manager() = default;
module_pass_manager(const module_pass_manager&) = delete;
virtual module& get_module() = 0;
virtual module* create_module(const std::string& name) = 0;
virtual void run_pass(const pass& p) = 0;
protected:
virtual ~module_pass_manager() {}
};
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace = tracer{});
void run_passes(program& prog, const std::vector<pass>& passes, tracer trace = tracer{});
......
......@@ -32,14 +32,6 @@ void validate_pass(module& mod, const pass& p, tracer trace)
trace();
#endif
}
void run_pass(module& mod, const pass& p, tracer trace)
{
trace("Module: ", mod.name(), ", Pass: ", p.name());
assert(mod.validate() == mod.end());
p.apply(mod);
trace(mod);
validate_pass(mod, p, trace);
}
void run_pass(program& prog, const pass& p, tracer trace)
{
trace("Pass: ", p.name());
......@@ -47,11 +39,52 @@ void run_pass(program& prog, const pass& p, tracer trace)
trace(prog);
}
struct module_pm : module_pass_manager
{
module* mod;
program* prog;
tracer* t;
module_pm(module* pmod = nullptr, program* pprog = nullptr, tracer* pt = nullptr)
: mod(pmod), prog(pprog), t(pt)
{
}
template <class... Ts>
void trace(Ts&&... xs) const
{
assert(t);
(*t)(xs...);
}
virtual module& get_module() override
{
assert(mod);
return *mod;
}
virtual module* create_module(const std::string& name) override
{
assert(prog);
return prog->create_module(name);
}
virtual void run_pass(const pass& p) override
{
assert(mod);
trace("Module: ", mod->name(), ", Pass: ", p.name());
assert(mod->validate() == mod->end());
p.apply(*this);
trace(*mod);
validate_pass(*mod, p, *t);
}
};
module& get_module(module_pass_manager& mpm) { return mpm.get_module(); }
void run_passes(module& mod, const std::vector<pass>& passes, tracer trace)
{
for(const auto& p : passes)
{
run_pass(mod, p, trace);
module_pm{&mod, nullptr, &trace}.run_pass(p);
}
}
......@@ -62,7 +95,7 @@ void run_passes(program& prog, const std::vector<pass>& passes, tracer trace)
auto mods = prog.get_modules();
for(const auto& mod : reverse(mods))
{
run_pass(*mod, p, trace);
module_pm{mod, &prog, &trace}.run_pass(p);
}
run_pass(prog, p, trace);
}
......
......@@ -8,12 +8,14 @@
#include <utility>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
#include <migraphx/rank.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
struct module;
struct module_pass_manager;
#ifdef DOXYGEN
......@@ -24,6 +26,7 @@ struct pass
/// A unique name used to identify the pass
std::string name() const;
/// Run the pass on the module
void apply(module_pass_manager& mpm) const;
void apply(module& m) const;
/// Run the pass on the program
void apply(program& p) const;
......@@ -31,10 +34,34 @@ struct pass
#else
module& get_module(module_pass_manager& mpm);
namespace detail {
template <class T>
auto module_pass_manager_apply(rank<1>, const T& x, module_pass_manager& mpm)
-> decltype(x.apply(get_module(mpm)))
{
return x.apply(get_module(mpm));
}
template <class T>
void module_pass_manager_apply(rank<0>, const T&, module_pass_manager&)
{
}
template <class T>
void module_pass_manager_apply(const T& x, module_pass_manager& mpm)
{
module_pass_manager_apply(rank<1>{}, x, mpm);
}
} // namespace detail
<%
interface('pass',
virtual('name', returns='std::string', const=True),
virtual('apply', returns='void', m='module &', const=True, default='migraphx::nop'),
virtual('apply', returns='void', mpm='module_pass_manager &', const=True, default='migraphx::detail::module_pass_manager_apply'),
virtual('apply', returns='void', p='program &', const=True, default='migraphx::nop')
)
%>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment