Commit cf86db72 authored by Paul's avatar Paul
Browse files

Merge branch 'master' into fp16

parents af454aeb 414e2fac
...@@ -18,7 +18,7 @@ CheckOptions: ...@@ -18,7 +18,7 @@ CheckOptions:
- key: readability-identifier-naming.NamespaceCase - key: readability-identifier-naming.NamespaceCase
value: lower_case value: lower_case
- key: readability-identifier-naming.InlineNamespaceCase - key: readability-identifier-naming.InlineNamespaceCase
value: lower_case value: UPPER_CASE
- key: readability-identifier-naming.EnumConstantCase - key: readability-identifier-naming.EnumConstantCase
value: lower_case value: lower_case
- key: readability-identifier-naming.ConstexprVariableCase - key: readability-identifier-naming.ConstexprVariableCase
......
...@@ -6,6 +6,7 @@ add_library(migraph ...@@ -6,6 +6,7 @@ add_library(migraph
dead_code_elimination.cpp dead_code_elimination.cpp
eliminate_allocation.cpp eliminate_allocation.cpp
eliminate_contiguous.cpp eliminate_contiguous.cpp
eliminate_concat.cpp
fwd_conv_batchnorm_rewrite.cpp fwd_conv_batchnorm_rewrite.cpp
env.cpp env.cpp
generate.cpp generate.cpp
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <migraph/iterator_for.hpp> #include <migraph/iterator_for.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
void auto_contiguous::apply(program& p) const void auto_contiguous::apply(program& p) const
{ {
...@@ -19,4 +20,5 @@ void auto_contiguous::apply(program& p) const ...@@ -19,4 +20,5 @@ void auto_contiguous::apply(program& p) const
} }
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <unordered_set> #include <unordered_set>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
template <class Range> template <class Range>
void cse_range(program& p, Range&& r) void cse_range(program& p, Range&& r)
...@@ -34,4 +35,5 @@ void cse_range(program& p, Range&& r) ...@@ -34,4 +35,5 @@ void cse_range(program& p, Range&& r)
void common_subexpression_elimination::apply(program& p) const { cse_range(p, iterator_for(p)); } void common_subexpression_elimination::apply(program& p) const { cse_range(p, iterator_for(p)); }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraph/literal.hpp> #include <migraph/literal.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct match_const_add struct match_const_add
{ {
...@@ -25,4 +26,5 @@ struct match_const_add ...@@ -25,4 +26,5 @@ struct match_const_add
void constant_propagate::apply(program& p) const { match::find_matches(p, match_const_add{}); } void constant_propagate::apply(program& p) const { match::find_matches(p, match_const_add{}); }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraph/ranges.hpp> #include <migraph/ranges.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
template <class Range, class Iterator> template <class Range, class Iterator>
std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last) std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last)
...@@ -61,4 +62,5 @@ void dead_code_elimination::apply(program& p) const ...@@ -61,4 +62,5 @@ void dead_code_elimination::apply(program& p) const
p.remove_instructions(std::next(last), p.end()); p.remove_instructions(std::next(last), p.end());
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <migraph/pass_config.hpp> #include <migraph/pass_config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
void eliminate_allocation::apply(program& p) const void eliminate_allocation::apply(program& p) const
{ {
...@@ -35,4 +36,6 @@ void eliminate_allocation::apply(program& p) const ...@@ -35,4 +36,6 @@ void eliminate_allocation::apply(program& p) const
p.replace_instruction(ins, op::load{s, offset}, mem); p.replace_instruction(ins, op::load{s, offset}, mem);
} }
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#include <iterator>
#include <migraph/eliminate_concat.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/dfor.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
void eliminate_concat::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
// Look for the concat operator
if(ins->name() != concat_opt.name())
continue;
// If any inputs are literals then abort
if(std::any_of(ins->inputs().begin() + 1, ins->inputs().end(), [](auto arg) {
return arg->name() == "@literal";
}))
continue;
// We can only do this optimization when concat axis is either the leftmost
// axis OR the sizes to the left of this axis are all equal to 1
// Since we've already checked that the non-axis dimensions are identical
// we only need to check the first input
auto lens = ins->inputs().front()->get_shape().lens();
auto concat_op = concat_opt.get_concat(ins->get_operator());
if(concat_op.axis == 0 ||
std::all_of(lens.begin(), lens.begin() + concat_op.axis, [](auto x) { return x == 1; }))
{
// Last input should be an allocation
auto last = ins->inputs().back();
if(last->name() != concat_opt.allocate())
continue;
// Where are the allocations for the tensors to be concatenated?
std::vector<instruction_ref> allocations;
for(auto ins2 = ins->inputs().begin(); ins2 != ins->inputs().end() - 1; ins2++)
{
auto last2 = (*ins2)->inputs().back();
if(last2->name() == concat_opt.allocate())
{
allocations.push_back(last2);
}
}
// Need to sort the allocations, so that we know where to
// insert the "super"-allocation
std::sort(
allocations.begin(), allocations.end(), [&](instruction_ref x, instruction_ref y) {
return std::distance(p.begin(), x) < std::distance(p.begin(), y);
});
// Move "super" allocation to the front
auto first = allocations.front();
auto super = p.move_instruction(last, first);
std::size_t offset = 0;
for(auto x : allocations)
{
migraph::op::load op{x->get_shape(), offset};
// migraph::op::load op{x->get_shape(), 0};
p.replace_instruction(x, op, {super});
offset += x->get_shape().bytes();
}
std::vector<instruction_ref> args = {super};
std::copy(ins->inputs().begin(), ins->inputs().end() - 1, std::back_inserter(args));
p.replace_instruction(ins, migraph::op::identity{}, args);
}
}
}
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <utility> #include <utility>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args) bool try_compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{ {
...@@ -46,4 +47,5 @@ void eliminate_contiguous::apply(program& p) const ...@@ -46,4 +47,5 @@ void eliminate_contiguous::apply(program& p) const
} }
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <cstdlib> #include <cstdlib>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
bool enabled(const char* name) bool enabled(const char* name)
{ {
...@@ -29,4 +30,5 @@ std::vector<std::string> env(const char* name) ...@@ -29,4 +30,5 @@ std::vector<std::string> env(const char* name)
return {{p}}; return {{p}};
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <migraph/dfor.hpp> #include <migraph/dfor.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
void fwd_conv_batchnorm_rewrite::apply(program& p) const void fwd_conv_batchnorm_rewrite::apply(program& p) const
{ {
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
...@@ -64,4 +66,6 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const ...@@ -64,4 +66,6 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
p.replace_instruction(ins, op::add{}, {c, b}); p.replace_instruction(ins, op::add{}, {c, b});
} }
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#include <migraph/generate.hpp> #include <migraph/generate.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
argument generate_argument(shape s, unsigned long seed) argument generate_argument(shape s, unsigned long seed)
{ {
...@@ -30,4 +31,5 @@ literal abs(literal l) ...@@ -30,4 +31,5 @@ literal abs(literal l)
return transform(std::move(l), [](auto x) { return std::fabs(x); }); return transform(std::move(l), [](auto x) { return std::fabs(x); });
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
...@@ -3,10 +3,12 @@ ...@@ -3,10 +3,12 @@
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/raw_data.hpp> #include <migraph/raw_data.hpp>
#include <migraph/config.hpp>
#include <functional> #include <functional>
#include <utility> #include <utility>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
/** /**
* @brief Arguments passed to instructions * @brief Arguments passed to instructions
...@@ -45,6 +47,7 @@ struct argument : raw_data<argument> ...@@ -45,6 +47,7 @@ struct argument : raw_data<argument>
shape m_shape; shape m_shape;
}; };
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_AUTO_ANY_CAST_HPP #ifndef MIGRAPH_GUARD_RTGLIB_AUTO_ANY_CAST_HPP
#define MIGRAPH_GUARD_RTGLIB_AUTO_ANY_CAST_HPP #define MIGRAPH_GUARD_RTGLIB_AUTO_ANY_CAST_HPP
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace detail { namespace detail {
...@@ -32,6 +34,7 @@ detail::auto_any_caster<T> auto_any_cast(T& x) ...@@ -32,6 +34,7 @@ detail::auto_any_caster<T> auto_any_cast(T& x)
return {x}; return {x};
} }
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
#include <string> #include <string>
#include <migraph/instruction_ref.hpp> #include <migraph/instruction_ref.hpp>
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct program; struct program;
...@@ -14,6 +16,7 @@ struct auto_contiguous ...@@ -14,6 +16,7 @@ struct auto_contiguous
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -5,8 +5,10 @@ ...@@ -5,8 +5,10 @@
#include <migraph/errors.hpp> #include <migraph/errors.hpp>
#include <migraph/argument.hpp> #include <migraph/argument.hpp>
#include <migraph/reflect.hpp> #include <migraph/reflect.hpp>
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace builtin { namespace builtin {
...@@ -62,7 +64,7 @@ struct param ...@@ -62,7 +64,7 @@ struct param
}; };
} // namespace builtin } // namespace builtin
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
#define MIGRAPH_GUARD_RTGLIB_CHECK_CONTEXT_HPP #define MIGRAPH_GUARD_RTGLIB_CHECK_CONTEXT_HPP
#include <migraph/program.hpp> #include <migraph/program.hpp>
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
template <class T> template <class T>
struct check_context struct check_context
...@@ -25,6 +27,7 @@ struct check_context ...@@ -25,6 +27,7 @@ struct check_context
void apply(program& p) const { p.insert_instruction(p.begin(), op{}); } void apply(program& p) const { p.insert_instruction(p.begin(), op{}); }
}; };
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
#define MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP #define MIGRAPH_GUARD_RTGLIB_CHECK_SHAPES_HPP
#include <migraph/shape.hpp> #include <migraph/shape.hpp>
#include <migraph/config.hpp>
#include <algorithm> #include <algorithm>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct check_shapes struct check_shapes
{ {
...@@ -154,6 +156,7 @@ struct check_shapes ...@@ -154,6 +156,7 @@ struct check_shapes
check_shapes slice(long start, long last) { return {get(start), get(last), name}; } check_shapes slice(long start, long last) { return {get(start), get(last), name}; }
}; };
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
...@@ -3,8 +3,10 @@ ...@@ -3,8 +3,10 @@
#include <string> #include <string>
#include <migraph/instruction_ref.hpp> #include <migraph/instruction_ref.hpp>
#include <migraph/config.hpp>
namespace migraph { namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct program; struct program;
...@@ -14,6 +16,7 @@ struct common_subexpression_elimination ...@@ -14,6 +16,7 @@ struct common_subexpression_elimination
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace migraph
#endif #endif
#ifndef MIGRAPH_GUARD_CONCAT_OPT_HPP
#define MIGRAPH_GUARD_CONCAT_OPT_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraph/operation.hpp>
#include <migraph/operators.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct program;
#ifdef DOXYGEN
/// An interface for target-dependent optimization for the concat instruction
struct concat_optimization
{
/// The name of the target-dependent concat operator
std::string name() const;
/// A name of the target-dependent allocate operator
std::string allocate() const;
/// Return the target-independent concat operator
op::concat get_concat(const operation& op) const;
};
#else
/*
* Type-erased interface for:
*
* struct concat_optimization
* {
* std::string name() const;
* std::string allocate() const;
* op::concat get_concat(const operation& op) const;
* };
*
*/
struct concat_optimization
{
// Constructors
concat_optimization() = default;
template <typename PrivateDetailTypeErasedT>
concat_optimization(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
concat_optimization& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().name();
}
std::string allocate() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().allocate();
}
op::concat get_concat(const operation& op) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_concat(op);
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual std::string allocate() const = 0;
virtual op::concat get_concat(const operation& op) const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
std::string name() const override { return private_detail_te_value.name(); }
std::string allocate() const override { return private_detail_te_value.allocate(); }
op::concat get_concat(const operation& op) const override
{
return private_detail_te_value.get_concat(op);
}
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const concat_optimization* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(concat_optimization* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(concat_optimization& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const concat_optimization& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment