"vscode:/vscode.git/clone" did not exist on "76285fdeea2cd533d2ca7e88eaf0a1f32c97f63d"
Commit a9e5d73c authored by Paul's avatar Paul
Browse files

Update get_passes for the cpu

parent cff16121
#ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace migraph {
struct program;
/*
* Type-erased interface for:
*
* struct pass
* {
* std::string name() const;
* void apply(program & p) const;
* };
*
*/
struct pass
{
// Constructors
pass() = default;
template <typename PrivateDetailTypeErasedT>
pass(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
pass& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().name();
}
void apply(program& p) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().apply(p);
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual void apply(program& p) const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
std::string name() const override { return private_detail_te_value.name(); }
void apply(program& p) const override { return private_detail_te_value.apply(p); }
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const pass* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(pass* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(pass& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const pass& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
} // namespace migraph
#endif
...@@ -6,7 +6,9 @@ ...@@ -6,7 +6,9 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <vector>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/pass.hpp>
namespace migraph { namespace migraph {
...@@ -18,7 +20,7 @@ struct program; ...@@ -18,7 +20,7 @@ struct program;
* struct target * struct target
* { * {
* std::string name() const; * std::string name() const;
* void apply(program & p) const; * std::vector<pass> get_passes(context& ctx) const;
* context get_context() const; * context get_context() const;
* }; * };
* *
...@@ -87,10 +89,10 @@ struct target ...@@ -87,10 +89,10 @@ struct target
return (*this).private_detail_te_get_handle().name(); return (*this).private_detail_te_get_handle().name();
} }
void apply(program& p) const std::vector<pass> get_passes(context& ctx) const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().apply(p); return (*this).private_detail_te_get_handle().get_passes(ctx);
} }
context get_context() const context get_context() const
...@@ -106,9 +108,9 @@ struct target ...@@ -106,9 +108,9 @@ struct target
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0; virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0; virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0; virtual std::string name() const = 0;
virtual void apply(program& p) const = 0; virtual std::vector<pass> get_passes(context& ctx) const = 0;
virtual context get_context() const = 0; virtual context get_context() const = 0;
}; };
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
...@@ -141,7 +143,11 @@ struct target ...@@ -141,7 +143,11 @@ struct target
std::string name() const override { return private_detail_te_value.name(); } std::string name() const override { return private_detail_te_value.name(); }
void apply(program& p) const override { return private_detail_te_value.apply(p); } std::vector<pass> get_passes(context& ctx) const override
{
return private_detail_te_value.get_passes(ctx);
}
context get_context() const override { return private_detail_te_value.get_context(); } context get_context() const override { return private_detail_te_value.get_context(); }
......
...@@ -111,7 +111,14 @@ void program::compile(const target& t) ...@@ -111,7 +111,14 @@ void program::compile(const target& t)
{ {
assert(this->validate() != impl->instructions.end()); assert(this->validate() != impl->instructions.end());
this->impl->ctx = t.get_context(); this->impl->ctx = t.get_context();
t.apply(*this); for(auto&& p:t.get_passes(this->impl->ctx))
{
p.apply(*this);
#ifndef NDEBUG
if(this->validate() == impl->instructions.end())
MIGRAPH_THROW(p.name() + " pass produces invalid program");
#endif
}
if(this->validate() == impl->instructions.end()) if(this->validate() == impl->instructions.end())
MIGRAPH_THROW("Invalid program from compilation"); MIGRAPH_THROW("Invalid program from compilation");
} }
......
...@@ -538,9 +538,25 @@ struct cpu_apply ...@@ -538,9 +538,25 @@ struct cpu_apply
} }
}; };
struct cpu_pass
{
std::string name() const { return "cpu::pass"; }
void apply(program & p) const
{
cpu_apply{&p}.apply();
}
};
std::string cpu_target::name() const { return "cpu"; } std::string cpu_target::name() const { return "cpu"; }
void cpu_target::apply(program& p) const { cpu_apply{&p}.apply(); } std::vector<pass> cpu_target::get_passes(context&) const
{
return
{
cpu_pass{}
};
}
} // namespace cpu } // namespace cpu
......
...@@ -9,7 +9,7 @@ namespace cpu { ...@@ -9,7 +9,7 @@ namespace cpu {
struct cpu_target struct cpu_target
{ {
std::string name() const; std::string name() const;
void apply(program& p) const; std::vector<pass> get_passes(context& ctx) const;
context get_context() const { return {}; } context get_context() const { return {}; }
}; };
......
...@@ -68,7 +68,10 @@ struct minus_op ...@@ -68,7 +68,10 @@ struct minus_op
struct id_target struct id_target
{ {
std::string name() const { return "id"; } std::string name() const { return "id"; }
void apply(migraph::program&) const {} std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {};
}
migraph::context get_context() const { return {}; } migraph::context get_context() const { return {}; }
}; };
......
#ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace migraph {
struct program;
<%
interface('pass',
virtual('name', returns='std::string', const=True),
virtual('apply', returns='void', p='program &', const=True)
)
%>
} // namespace migraph
#endif
...@@ -6,7 +6,9 @@ ...@@ -6,7 +6,9 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <vector>
#include <migraph/context.hpp> #include <migraph/context.hpp>
#include <migraph/pass.hpp>
namespace migraph { namespace migraph {
...@@ -15,7 +17,7 @@ struct program; ...@@ -15,7 +17,7 @@ struct program;
<% <%
interface('target', interface('target',
virtual('name', returns='std::string', const=True), virtual('name', returns='std::string', const=True),
virtual('apply', returns='void', p='program &', const=True), virtual('get_passes', ctx='context&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True) virtual('get_context', returns='context', const=True)
) )
%> %>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment