"official/projects/simclr/multitask_train.py" did not exist on "63620f4ce55a608e76c0ff4a76d64d90507bf997"
Commit a9e5d73c authored by Paul's avatar Paul
Browse files

Update get_passes for the cpu

parent cff16121
#ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace migraph {
struct program;
/*
* Type-erased interface for:
*
* struct pass
* {
* std::string name() const;
* void apply(program & p) const;
* };
*
*/
struct pass
{
// Constructors
pass() = default;
template <typename PrivateDetailTypeErasedT>
pass(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
pass& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().name();
}
void apply(program& p) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().apply(p);
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual void apply(program& p) const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
std::string name() const override { return private_detail_te_value.name(); }
void apply(program& p) const override { return private_detail_te_value.apply(p); }
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const pass* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(pass* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(pass& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const pass& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
} // namespace migraph
#endif
......@@ -6,7 +6,9 @@
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include <migraph/context.hpp>
#include <migraph/pass.hpp>
namespace migraph {
......@@ -18,7 +20,7 @@ struct program;
* struct target
* {
* std::string name() const;
* void apply(program & p) const;
* std::vector<pass> get_passes(context& ctx) const;
* context get_context() const;
* };
*
......@@ -87,10 +89,10 @@ struct target
return (*this).private_detail_te_get_handle().name();
}
void apply(program& p) const
std::vector<pass> get_passes(context& ctx) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().apply(p);
return (*this).private_detail_te_get_handle().get_passes(ctx);
}
context get_context() const
......@@ -106,9 +108,9 @@ struct target
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual void apply(program& p) const = 0;
virtual context get_context() const = 0;
virtual std::string name() const = 0;
virtual std::vector<pass> get_passes(context& ctx) const = 0;
virtual context get_context() const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -141,7 +143,11 @@ struct target
std::string name() const override { return private_detail_te_value.name(); }
void apply(program& p) const override { return private_detail_te_value.apply(p); }
std::vector<pass> get_passes(context& ctx) const override
{
return private_detail_te_value.get_passes(ctx);
}
context get_context() const override { return private_detail_te_value.get_context(); }
......
......@@ -111,7 +111,14 @@ void program::compile(const target& t)
{
assert(this->validate() != impl->instructions.end());
this->impl->ctx = t.get_context();
t.apply(*this);
for(auto&& p:t.get_passes(this->impl->ctx))
{
p.apply(*this);
#ifndef NDEBUG
if(this->validate() == impl->instructions.end())
MIGRAPH_THROW(p.name() + " pass produces invalid program");
#endif
}
if(this->validate() == impl->instructions.end())
MIGRAPH_THROW("Invalid program from compilation");
}
......
......@@ -538,9 +538,25 @@ struct cpu_apply
}
};
struct cpu_pass
{
std::string name() const { return "cpu::pass"; }
void apply(program & p) const
{
cpu_apply{&p}.apply();
}
};
std::string cpu_target::name() const { return "cpu"; }
void cpu_target::apply(program& p) const { cpu_apply{&p}.apply(); }
std::vector<pass> cpu_target::get_passes(context&) const
{
return
{
cpu_pass{}
};
}
} // namespace cpu
......
......@@ -9,7 +9,7 @@ namespace cpu {
struct cpu_target
{
std::string name() const;
void apply(program& p) const;
std::vector<pass> get_passes(context& ctx) const;
context get_context() const { return {}; }
};
......
......@@ -68,7 +68,10 @@ struct minus_op
struct id_target
{
std::string name() const { return "id"; }
void apply(migraph::program&) const {}
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {};
}
migraph::context get_context() const { return {}; }
};
......
#ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace migraph {
struct program;
<%
interface('pass',
virtual('name', returns='std::string', const=True),
virtual('apply', returns='void', p='program &', const=True)
)
%>
} // namespace migraph
#endif
......@@ -6,7 +6,9 @@
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include <migraph/context.hpp>
#include <migraph/pass.hpp>
namespace migraph {
......@@ -15,7 +17,7 @@ struct program;
<%
interface('target',
virtual('name', returns='std::string', const=True),
virtual('apply', returns='void', p='program &', const=True),
virtual('get_passes', ctx='context&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True)
)
%>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment