Commit 583c76f2 authored by Scott Thornton's avatar Scott Thornton
Browse files

Merge branch 'master' into mnist2

parents 06fb0905 603adbe6
#ifndef MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
#define MIGRAPH_GUARD_RTGLIB_ITERATOR_FOR_HPP
namespace migraph {
template <class T>
struct iterator_for_range
{
T* base;
using base_iterator = decltype(base->begin());
struct iterator
{
base_iterator i;
base_iterator operator*() { return i; }
base_iterator operator++() { return ++i; }
bool operator!=(const iterator& rhs) { return i != rhs.i; }
};
iterator begin() { return {base->begin()}; }
iterator end() { return {base->end()}; }
};
template <class T>
iterator_for_range<T> iterator_for(T& x)
{
return {&x};
}
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace migraph {
struct program;
/*
* Type-erased interface for:
*
* struct pass
* {
* std::string name() const;
* void apply(program & p) const;
* };
*
*/
struct pass
{
// Constructors
pass() = default;
template <typename PrivateDetailTypeErasedT>
pass(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
pass& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().name();
}
void apply(program& p) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().apply(p);
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual void apply(program& p) const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
std::string name() const override { return private_detail_te_value.name(); }
void apply(program& p) const override { return private_detail_te_value.apply(p); }
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const pass* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(pass* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(pass& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const pass& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
} // namespace migraph
#endif
......@@ -6,7 +6,9 @@
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include <migraph/context.hpp>
#include <migraph/pass.hpp>
namespace migraph {
......@@ -18,7 +20,7 @@ struct program;
* struct target
* {
* std::string name() const;
* void apply(program & p) const;
* std::vector<pass> get_passes(context& ctx) const;
* context get_context() const;
* };
*
......@@ -87,10 +89,10 @@ struct target
return (*this).private_detail_te_get_handle().name();
}
void apply(program& p) const
std::vector<pass> get_passes(context& ctx) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().apply(p);
return (*this).private_detail_te_get_handle().get_passes(ctx);
}
context get_context() const
......@@ -106,9 +108,9 @@ struct target
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual void apply(program& p) const = 0;
virtual context get_context() const = 0;
virtual std::string name() const = 0;
virtual std::vector<pass> get_passes(context& ctx) const = 0;
virtual context get_context() const = 0;
};
template <typename PrivateDetailTypeErasedT>
......@@ -141,7 +143,11 @@ struct target
std::string name() const override { return private_detail_te_value.name(); }
void apply(program& p) const override { return private_detail_te_value.apply(p); }
std::vector<pass> get_passes(context& ctx) const override
{
return private_detail_te_value.get_passes(ctx);
}
context get_context() const override { return private_detail_te_value.get_context(); }
......
......@@ -111,7 +111,14 @@ void program::compile(const target& t)
{
assert(this->validate() != impl->instructions.end());
this->impl->ctx = t.get_context();
t.apply(*this);
for(auto&& p : t.get_passes(this->impl->ctx))
{
p.apply(*this);
#ifndef NDEBUG
if(this->validate() == impl->instructions.end())
MIGRAPH_THROW(p.name() + " pass produces invalid program");
#endif
}
if(this->validate() == impl->instructions.end())
MIGRAPH_THROW("Invalid program from compilation");
}
......
......@@ -4,6 +4,7 @@
#include <migraph/dfor.hpp>
#include <migraph/operators.hpp>
#include <migraph/shape_for_each.hpp>
#include <migraph/iterator_for.hpp>
namespace migraph {
namespace cpu {
......@@ -491,7 +492,7 @@ struct cpu_apply
void apply()
{
init();
for(auto it = prog->begin(); it != prog->end(); it++)
for(auto it : iterator_for(*prog))
{
if(it->op.name() == "activation")
{
......@@ -538,9 +539,16 @@ struct cpu_apply
}
};
struct cpu_pass
{
std::string name() const { return "cpu::pass"; }
void apply(program& p) const { cpu_apply{&p}.apply(); }
};
std::string cpu_target::name() const { return "cpu"; }
void cpu_target::apply(program& p) const { cpu_apply{&p}.apply(); }
std::vector<pass> cpu_target::get_passes(context&) const { return {cpu_pass{}}; }
} // namespace cpu
......
......@@ -9,7 +9,7 @@ namespace cpu {
struct cpu_target
{
std::string name() const;
void apply(program& p) const;
std::vector<pass> get_passes(context& ctx) const;
context get_context() const { return {}; }
};
......
......@@ -9,7 +9,7 @@ namespace miopen {
struct miopen_target
{
std::string name() const;
void apply(program& p) const;
std::vector<pass> get_passes(context& ctx) const;
context get_context() const;
};
......
......@@ -299,9 +299,16 @@ struct miopen_apply
}
};
std::string miopen_target::name() const { return "miopen"; }
struct miopen_pass
{
std::string name() const { return "miopen::pass"; }
void apply(program& p) const { miopen_apply{&p}.apply(); }
};
std::vector<pass> miopen_target::get_passes(context&) const { return {miopen_pass{}}; }
void miopen_target::apply(program& p) const { miopen_apply{&p}.apply(); }
std::string miopen_target::name() const { return "miopen"; }
context miopen_target::get_context() const
{
......
......@@ -2,6 +2,8 @@
#include <migraph/program.hpp>
#include <migraph/argument.hpp>
#include <migraph/shape.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/instruction.hpp>
#include <sstream>
#include "test.hpp"
......@@ -68,7 +70,44 @@ struct minus_op
struct id_target
{
std::string name() const { return "id"; }
void apply(migraph::program&) const {}
std::vector<migraph::pass> get_passes(migraph::context&) const { return {}; }
migraph::context get_context() const { return {}; }
};
struct reverse_pass
{
std::string name() const { return "reverse_pass"; }
void apply(migraph::program& p) const
{
for(auto ins : migraph::iterator_for(p))
{
if(ins->op.name() == "sum")
{
p.replace_instruction(ins, minus_op{}, ins->arguments);
}
else if(ins->op.name() == "minus")
{
p.replace_instruction(ins, sum_op{}, ins->arguments);
}
}
}
};
struct reverse_target
{
std::string name() const { return "reverse"; }
std::vector<migraph::pass> get_passes(migraph::context&) const { return {reverse_pass{}}; }
migraph::context get_context() const { return {}; }
};
struct double_reverse_target
{
std::string name() const { return "double_reverse"; }
std::vector<migraph::pass> get_passes(migraph::context&) const
{
return {reverse_pass{}, reverse_pass{}};
}
migraph::context get_context() const { return {}; }
};
......@@ -170,6 +209,32 @@ void target_test()
EXPECT(result != migraph::literal{4});
}
void reverse_target_test()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, two, one);
p.compile(reverse_target{});
auto result = p.eval({});
EXPECT(result == migraph::literal{1});
EXPECT(result != migraph::literal{4});
}
void double_reverse_target_test()
{
migraph::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, two, one);
p.compile(double_reverse_target{});
auto result = p.eval({});
EXPECT(result == migraph::literal{3});
EXPECT(result != migraph::literal{4});
}
int main()
{
literal_test1();
......@@ -179,4 +244,5 @@ int main()
replace_test();
insert_replace_test();
target_test();
reverse_target_test();
}
#ifndef MIGRAPH_GUARD_PASS_HPP
#define MIGRAPH_GUARD_PASS_HPP
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
namespace migraph {
struct program;
<%
interface('pass',
virtual('name', returns='std::string', const=True),
virtual('apply', returns='void', p='program &', const=True)
)
%>
} // namespace migraph
#endif
......@@ -6,7 +6,9 @@
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include <migraph/context.hpp>
#include <migraph/pass.hpp>
namespace migraph {
......@@ -15,7 +17,7 @@ struct program;
<%
interface('target',
virtual('name', returns='std::string', const=True),
virtual('apply', returns='void', p='program &', const=True),
virtual('get_passes', ctx='context&', returns='std::vector<pass>', const=True),
virtual('get_context', returns='context', const=True)
)
%>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment