Unverified Commit ee80cee9 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'master' into gpu_slice_test

parents 6d06226d f958d56f
#ifndef MIGRAPH_GUARD_CONCAT_OPT_HPP
#define MIGRAPH_GUARD_CONCAT_OPT_HPP
#include <cassert>
#include <string>
#include <functional>
#include <memory>
#include <type_traits>
#include <utility>
#include <migraph/operation.hpp>
#include <migraph/operators.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct program;
#ifdef DOXYGEN
/// An interface for target-dependent optimization for the concat instruction
struct concat_optimization
{
/// The name of the target-dependent concat operator
std::string name() const;
/// A name of the target-dependent allocate operator
std::string allocate() const;
/// Return the target-independent concat operator
op::concat get_concat(const operation& op) const;
};
#else
/*
* Type-erased interface for:
*
* struct concat_optimization
* {
* std::string name() const;
* std::string allocate() const;
* op::concat get_concat(const operation& op) const;
* };
*
*/
struct concat_optimization
{
// Constructors
concat_optimization() = default;
template <typename PrivateDetailTypeErasedT>
concat_optimization(PrivateDetailTypeErasedT value)
: private_detail_te_handle_mem_var(
std::make_shared<private_detail_te_handle_type<
typename std::remove_reference<PrivateDetailTypeErasedT>::type>>(
std::forward<PrivateDetailTypeErasedT>(value)))
{
}
// Assignment
template <typename PrivateDetailTypeErasedT>
concat_optimization& operator=(PrivateDetailTypeErasedT value)
{
if(private_detail_te_handle_mem_var.unique())
*private_detail_te_handle_mem_var = std::forward<PrivateDetailTypeErasedT>(value);
else if(!private_detail_te_handle_mem_var)
private_detail_te_handle_mem_var = std::make_shared<PrivateDetailTypeErasedT>(
std::forward<PrivateDetailTypeErasedT>(value));
return *this;
}
// Cast
template <typename PrivateDetailTypeErasedT>
PrivateDetailTypeErasedT* any_cast()
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
template <typename PrivateDetailTypeErasedT>
const typename std::remove_cv<PrivateDetailTypeErasedT>::type* any_cast() const
{
return private_detail_te_get_handle().type() == typeid(PrivateDetailTypeErasedT)
? std::addressof(static_cast<const private_detail_te_handle_type<
typename std::remove_cv<PrivateDetailTypeErasedT>::type>&>(
private_detail_te_get_handle())
.private_detail_te_value)
: nullptr;
}
const std::type_info& type_id() const
{
if(private_detail_te_handle_empty())
return typeid(std::nullptr_t);
else
return private_detail_te_get_handle().type();
}
std::string name() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().name();
}
std::string allocate() const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().allocate();
}
op::concat get_concat(const operation& op) const
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_concat(op);
}
private:
struct private_detail_te_handle_base_type
{
virtual ~private_detail_te_handle_base_type() {}
virtual std::shared_ptr<private_detail_te_handle_base_type> clone() const = 0;
virtual const std::type_info& type() const = 0;
virtual std::string name() const = 0;
virtual std::string allocate() const = 0;
virtual op::concat get_concat(const operation& op) const = 0;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type
{
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<std::is_reference<PrivateDetailTypeErasedU>::value>::type* =
nullptr)
: private_detail_te_value(value)
{
}
template <typename PrivateDetailTypeErasedU = PrivateDetailTypeErasedT>
private_detail_te_handle_type(
PrivateDetailTypeErasedT value,
typename std::enable_if<!std::is_reference<PrivateDetailTypeErasedU>::value,
int>::type* = nullptr) noexcept
: private_detail_te_value(std::move(value))
{
}
std::shared_ptr<private_detail_te_handle_base_type> clone() const override
{
return std::make_shared<private_detail_te_handle_type>(private_detail_te_value);
}
const std::type_info& type() const override { return typeid(private_detail_te_value); }
std::string name() const override { return private_detail_te_value.name(); }
std::string allocate() const override { return private_detail_te_value.allocate(); }
op::concat get_concat(const operation& op) const override
{
return private_detail_te_value.get_concat(op);
}
PrivateDetailTypeErasedT private_detail_te_value;
};
template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type<std::reference_wrapper<PrivateDetailTypeErasedT>>
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>
{
private_detail_te_handle_type(std::reference_wrapper<PrivateDetailTypeErasedT> ref)
: private_detail_te_handle_type<PrivateDetailTypeErasedT&>(ref.get())
{
}
};
bool private_detail_te_handle_empty() const
{
return private_detail_te_handle_mem_var == nullptr;
}
const private_detail_te_handle_base_type& private_detail_te_get_handle() const
{
assert(private_detail_te_handle_mem_var != nullptr);
return *private_detail_te_handle_mem_var;
}
private_detail_te_handle_base_type& private_detail_te_get_handle()
{
assert(private_detail_te_handle_mem_var != nullptr);
if(!private_detail_te_handle_mem_var.unique())
private_detail_te_handle_mem_var = private_detail_te_handle_mem_var->clone();
return *private_detail_te_handle_mem_var;
}
std::shared_ptr<private_detail_te_handle_base_type> private_detail_te_handle_mem_var;
};
template <typename ValueType>
inline const ValueType* any_cast(const concat_optimization* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType* any_cast(concat_optimization* x)
{
return x->any_cast<ValueType>();
}
template <typename ValueType>
inline ValueType& any_cast(concat_optimization& x)
{
auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
template <typename ValueType>
inline const ValueType& any_cast(const concat_optimization& x)
{
const auto* y = x.any_cast<typename std::remove_reference<ValueType>::type>();
if(y == nullptr)
throw std::bad_cast();
return *y;
}
#endif
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_CONFIG_HPP
#define MIGRAPH_GUARD_CONFIG_HPP
namespace migraph {
#ifndef MIGRAPH_USE_CLANG_TIDY
#define MIGRAPH_INLINE_NS version_1
#endif
} // namespace migraph
#endif
......@@ -2,8 +2,10 @@
#define MIGRAPH_GUARD_RTGLIB_CONSTANT_PROPAGATE_HPP
#include <string>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct program;
......@@ -13,6 +15,7 @@ struct constant_propagate
void apply(program& p) const;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -7,8 +7,10 @@
#include <memory>
#include <type_traits>
#include <utility>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
#ifdef DOXYGEN
......@@ -203,6 +205,7 @@ inline const ValueType& any_cast(const context& x)
#endif
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -3,8 +3,10 @@
#include <string>
#include <migraph/instruction_ref.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct program;
......@@ -14,6 +16,7 @@ struct dead_code_elimination
void apply(program& p) const;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_DFOR_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_DFOR_HPP
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
// Multidimensional for loop
inline auto dfor()
......@@ -20,6 +23,7 @@ auto dfor(T x, Ts... xs)
};
}
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -3,8 +3,11 @@
#include <string>
#include <migraph/instruction_ref.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct program;
struct eliminate_allocation
......@@ -14,6 +17,8 @@ struct eliminate_allocation
std::string name() const { return "eliminate_allocation"; }
void apply(program& p) const;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP
#define MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
#include <migraph/concat_opt.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct program;
struct eliminate_concat
{
concat_optimization concat_opt;
std::string name() const { return "eliminate_concat"; }
void apply(program& p) const;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -3,8 +3,10 @@
#include <string>
#include <migraph/instruction_ref.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct program;
......@@ -14,6 +16,7 @@ struct eliminate_contiguous
void apply(program& p) const;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -3,8 +3,10 @@
#include <vector>
#include <string>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
// Declare a cached environment variable
#define MIGRAPH_DECLARE_ENV_VAR(x) \
......@@ -31,6 +33,7 @@ bool disabled(T)
return result;
}
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -2,8 +2,10 @@
#define MIGRAPH_GUARD_ERASE_HPP
#include <algorithm>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
/**
* @brief Erase all elements from a container
......@@ -31,6 +33,7 @@ auto erase_if(R&& r, P&& pred)
return r.erase(std::remove_if(r.begin(), r.end(), pred), r.end());
}
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -4,8 +4,10 @@
#include <exception>
#include <stdexcept>
#include <string>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
/// Represents exceptions that can be thrown by migraphlib
struct exception : std::runtime_error
......@@ -44,6 +46,7 @@ inline std::string make_source_context(const std::string& file, int line)
#define MIGRAPH_THROW(...) \
throw migraph::make_exception(migraph::make_source_context(__FILE__, __LINE__), __VA_ARGS__)
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
#ifndef MIGRAPH_GUARD_FALLTHROUGH_HPP
#define MIGRAPH_GUARD_FALLTHROUGH_HPP
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
#ifdef __clang__
#define MIGRAPH_FALLTHROUGH [[clang::fallthrough]]
......@@ -9,6 +12,7 @@ namespace migraph {
#define MIGRAPH_FALLTHROUGH
#endif
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -9,8 +9,10 @@
#endif
#include <migraph/requires.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
template <class... Ts>
using common_type = typename std::common_type<Ts...>::type;
......@@ -40,6 +42,7 @@ struct float_equal_fn
static constexpr float_equal_fn float_equal{};
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -2,8 +2,10 @@
#define MIGRAPH_GUARD_RTGLIB_FUNCTIONAL_HPP
#include <utility>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct swallow
{
......@@ -129,6 +131,7 @@ auto fold(F f)
return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); };
}
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -3,8 +3,10 @@
#include <string>
#include <migraph/instruction_ref.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct program;
......@@ -14,6 +16,7 @@ struct fwd_conv_batchnorm_rewrite
void apply(program& p) const;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -4,9 +4,11 @@
#include <migraph/argument.hpp>
#include <migraph/literal.hpp>
#include <migraph/type_traits.hpp>
#include <migraph/config.hpp>
#include <random>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
template <class T, MIGRAPH_REQUIRES(is_floating_point<T>{})>
constexpr T normalize(unsigned long z)
......@@ -91,6 +93,7 @@ literal generate_literal(shape s, unsigned long seed = 0);
literal abs(literal l);
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -9,8 +9,10 @@
#define MIGRAPH_GUARD_RTGLIB_HALF_HPP
#include <half.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
using half = half_float::half;
......@@ -31,6 +33,7 @@ struct deduce<half_float::detail::expr>
template <class T>
using deduce = typename detail::deduce<T>::type;
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -6,10 +6,12 @@
#include <migraph/instruction_ref.hpp>
#include <migraph/operation.hpp>
#include <migraph/erase.hpp>
#include <migraph/config.hpp>
#include <string>
#include <utility>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
......@@ -69,6 +71,8 @@ struct instruction
static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
static instruction_ref get_output_alias(instruction_ref ins);
private:
// internal
void replace(operation o, const shape& r, std::vector<instruction_ref> args);
......@@ -86,6 +90,7 @@ struct instruction
std::vector<instruction_ref> arguments;
literal lit;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
namespace std {
......@@ -99,6 +104,7 @@ struct hash<migraph::instruction_ref>
return std::hash<migraph::instruction*>{}(&*x);
}
};
} // namespace std
#endif
......@@ -3,12 +3,15 @@
#include <list>
#include <functional>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct instruction;
using instruction_ref = std::list<instruction>::iterator;
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment