Commit bc5d7f75 authored by Paul's avatar Paul
Browse files

Merge from develop

parents 47c0854d a5b0afa0
#ifndef MIGRAPH_GUARD_RTGLIB_COMMON_SUBEXPRESSION_ELIMINATION_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_COMMON_SUBEXPRESSION_ELIMINATION_HPP
#define MIGRAPH_GUARD_RTGLIB_COMMON_SUBEXPRESSION_ELIMINATION_HPP #define MIGRAPHX_GUARD_RTGLIB_COMMON_SUBEXPRESSION_ELIMINATION_HPP
#include <string> #include <string>
#include <migraph/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -19,7 +19,7 @@ struct common_subexpression_elimination ...@@ -19,7 +19,7 @@ struct common_subexpression_elimination
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_CONCAT_OPT_HPP #ifndef MIGRAPHX_GUARD_CONCAT_OPT_HPP
#define MIGRAPH_GUARD_CONCAT_OPT_HPP #define MIGRAPHX_GUARD_CONCAT_OPT_HPP
#include <cassert> #include <cassert>
#include <string> #include <string>
...@@ -8,12 +8,12 @@ ...@@ -8,12 +8,12 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraph/operation.hpp> #include <migraphx/operation.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -119,6 +119,13 @@ struct concat_optimization ...@@ -119,6 +119,13 @@ struct concat_optimization
return (*this).private_detail_te_get_handle().get_concat(op); return (*this).private_detail_te_get_handle().get_concat(op);
} }
friend bool is_shared(const concat_optimization& private_detail_x,
const concat_optimization& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private: private:
struct private_detail_te_handle_base_type struct private_detail_te_handle_base_type
{ {
...@@ -235,7 +242,8 @@ inline const ValueType& any_cast(const concat_optimization& x) ...@@ -235,7 +242,8 @@ inline const ValueType& any_cast(const concat_optimization& x)
} }
#endif #endif
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif #endif
#ifndef MIGRAPHX_GUARD_CONFIG_HPP
#define MIGRAPHX_GUARD_CONFIG_HPP
namespace migraphx {
#if !defined(MIGRAPHX_USE_CLANG_TIDY) && !defined(DOXYGEN)
#define MIGRAPHX_INLINE_NS version_1
#endif
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_CONSTANT_PROPAGATE_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_CONSTANT_PROPAGATE_HPP
#define MIGRAPH_GUARD_RTGLIB_CONSTANT_PROPAGATE_HPP #define MIGRAPHX_GUARD_RTGLIB_CONSTANT_PROPAGATE_HPP
#include <string> #include <string>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -18,7 +18,7 @@ struct constant_propagate ...@@ -18,7 +18,7 @@ struct constant_propagate
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_CONTEXT_HPP #ifndef MIGRAPHX_GUARD_CONTEXT_HPP
#define MIGRAPH_GUARD_CONTEXT_HPP #define MIGRAPHX_GUARD_CONTEXT_HPP
#include <cassert> #include <cassert>
#include <string> #include <string>
...@@ -7,10 +7,10 @@ ...@@ -7,10 +7,10 @@
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN #ifdef DOXYGEN
...@@ -95,7 +95,13 @@ struct context ...@@ -95,7 +95,13 @@ struct context
void finish() const void finish() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().finish(); (*this).private_detail_te_get_handle().finish();
}
friend bool is_shared(const context& private_detail_x, const context& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
} }
private: private:
...@@ -136,7 +142,7 @@ struct context ...@@ -136,7 +142,7 @@ struct context
const std::type_info& type() const override { return typeid(private_detail_te_value); } const std::type_info& type() const override { return typeid(private_detail_te_value); }
void finish() const override { return private_detail_te_value.finish(); } void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
}; };
...@@ -205,7 +211,7 @@ inline const ValueType& any_cast(const context& x) ...@@ -205,7 +211,7 @@ inline const ValueType& any_cast(const context& x)
#endif #endif
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_DEAD_CODE_ELIMINATION_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_DEAD_CODE_ELIMINATION_HPP
#define MIGRAPH_GUARD_RTGLIB_DEAD_CODE_ELIMINATION_HPP #define MIGRAPHX_GUARD_RTGLIB_DEAD_CODE_ELIMINATION_HPP
#include <string> #include <string>
#include <migraph/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -19,7 +19,7 @@ struct dead_code_elimination ...@@ -19,7 +19,7 @@ struct dead_code_elimination
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_DFOR_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_DFOR_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_DFOR_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_DFOR_HPP
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Multidimensional for loop // Multidimensional for loop
inline auto dfor() inline auto dfor()
...@@ -23,7 +23,7 @@ auto dfor(T x, Ts... xs) ...@@ -23,7 +23,7 @@ auto dfor(T x, Ts... xs)
}; };
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_ELIMINATE_ALLOCATION_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_ALLOCATION_HPP
#define MIGRAPH_GUARD_RTGLIB_ELIMINATE_ALLOCATION_HPP #define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_ALLOCATION_HPP
#include <string> #include <string>
#include <migraph/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -22,7 +22,7 @@ struct eliminate_allocation ...@@ -22,7 +22,7 @@ struct eliminate_allocation
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP
#define MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP #define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_CONCAT_HPP
#include <string> #include <string>
#include <migraph/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraph/concat_opt.hpp> #include <migraphx/concat_opt.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -21,7 +21,7 @@ struct eliminate_concat ...@@ -21,7 +21,7 @@ struct eliminate_concat
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONTIGUOUS_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_ELIMINATE_CONTIGUOUS_HPP
#define MIGRAPH_GUARD_RTGLIB_ELIMINATE_CONTIGUOUS_HPP #define MIGRAPHX_GUARD_RTGLIB_ELIMINATE_CONTIGUOUS_HPP
#include <string> #include <string>
#include <migraph/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -19,7 +19,7 @@ struct eliminate_contiguous ...@@ -19,7 +19,7 @@ struct eliminate_contiguous
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_ENV_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_ENV_HPP
#define MIGRAPH_GUARD_RTGLIB_ENV_HPP #define MIGRAPHX_GUARD_RTGLIB_ENV_HPP
#include <vector> #include <vector>
#include <string> #include <string>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Declare a cached environment variable // Declare a cached environment variable
#define MIGRAPH_DECLARE_ENV_VAR(x) \ #define MIGRAPHX_DECLARE_ENV_VAR(x) \
struct x \ struct x \
{ \ { \
static const char* value() { return #x; } \ static const char* value() { return #x; } \
...@@ -33,7 +33,7 @@ bool disabled(T) ...@@ -33,7 +33,7 @@ bool disabled(T)
return result; return result;
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_ERASE_HPP #ifndef MIGRAPHX_GUARD_ERASE_HPP
#define MIGRAPH_GUARD_ERASE_HPP #define MIGRAPHX_GUARD_ERASE_HPP
#include <algorithm> #include <algorithm>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/** /**
* @brief Erase all elements from a container * @brief Erase all elements from a container
...@@ -33,7 +33,7 @@ auto erase_if(R&& r, P&& pred) ...@@ -33,7 +33,7 @@ auto erase_if(R&& r, P&& pred)
return r.erase(std::remove_if(r.begin(), r.end(), pred), r.end()); return r.erase(std::remove_if(r.begin(), r.end(), pred), r.end());
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_ERRORS_HPP #ifndef MIGRAPHX_GUARD_ERRORS_HPP
#define MIGRAPH_GUARD_ERRORS_HPP #define MIGRAPHX_GUARD_ERRORS_HPP
#include <exception> #include <exception>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
/// Represents exceptions that can be thrown by migraphlib /// Represents exceptions that can be thrown by migraphxlib
struct exception : std::runtime_error struct exception : std::runtime_error
{ {
exception(const std::string& msg = "") : std::runtime_error(msg) {} exception(const std::string& msg = "") : std::runtime_error(msg) {}
...@@ -43,10 +43,10 @@ inline std::string make_source_context(const std::string& file, int line) ...@@ -43,10 +43,10 @@ inline std::string make_source_context(const std::string& file, int line)
/** /**
* @brief Throw an exception with context information * @brief Throw an exception with context information
*/ */
#define MIGRAPH_THROW(...) \ #define MIGRAPHX_THROW(...) \
throw migraph::make_exception(migraph::make_source_context(__FILE__, __LINE__), __VA_ARGS__) throw migraphx::make_exception(migraphx::make_source_context(__FILE__, __LINE__), __VA_ARGS__)
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPHX_GUARD_FALLTHROUGH_HPP
#define MIGRAPHX_GUARD_FALLTHROUGH_HPP
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef __clang__
#define MIGRAPHX_FALLTHROUGH [[clang::fallthrough]]
#else
#define MIGRAPHX_FALLTHROUGH
#endif
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_FLOAT_EQUAL_HPP
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
...@@ -8,18 +8,18 @@ ...@@ -8,18 +8,18 @@
#include <iso646.h> #include <iso646.h>
#endif #endif
#include <migraph/requires.hpp> #include <migraphx/requires.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class... Ts> template <class... Ts>
using common_type = typename std::common_type<Ts...>::type; using common_type = typename std::common_type<Ts...>::type;
struct float_equal_fn struct float_equal_fn
{ {
template <class T, MIGRAPH_REQUIRES(std::is_floating_point<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_floating_point<T>{})>
static bool apply(T x, T y) static bool apply(T x, T y)
{ {
return std::isfinite(x) and std::isfinite(y) and return std::isfinite(x) and std::isfinite(y) and
...@@ -27,7 +27,7 @@ struct float_equal_fn ...@@ -27,7 +27,7 @@ struct float_equal_fn
std::nextafter(x, std::numeric_limits<T>::max()) >= y; std::nextafter(x, std::numeric_limits<T>::max()) >= y;
} }
template <class T, MIGRAPH_REQUIRES(not std::is_floating_point<T>{})> template <class T, MIGRAPHX_REQUIRES(not std::is_floating_point<T>{})>
static bool apply(T x, T y) static bool apply(T x, T y)
{ {
return x == y; return x == y;
...@@ -42,7 +42,7 @@ struct float_equal_fn ...@@ -42,7 +42,7 @@ struct float_equal_fn
static constexpr float_equal_fn float_equal{}; static constexpr float_equal_fn float_equal{};
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_FUNCTIONAL_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_FUNCTIONAL_HPP
#define MIGRAPH_GUARD_RTGLIB_FUNCTIONAL_HPP #define MIGRAPHX_GUARD_RTGLIB_FUNCTIONAL_HPP
#include <utility> #include <utility>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct swallow struct swallow
{ {
...@@ -94,6 +94,12 @@ constexpr void each_args(F) ...@@ -94,6 +94,12 @@ constexpr void each_args(F)
{ {
} }
template <class F, class T>
auto unpack(F f, T& x)
{
return sequence_c<std::tuple_size<T>{}>([&](auto... is) { f(std::get<is>(x)...); });
}
/// Implements a fix-point combinator /// Implements a fix-point combinator
template <class R, class F> template <class R, class F>
detail::fix_f<R, F> fix(F f) detail::fix_f<R, F> fix(F f)
...@@ -131,7 +137,7 @@ auto fold(F f) ...@@ -131,7 +137,7 @@ auto fold(F f)
return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); }; return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); };
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP #define MIGRAPHX_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#include <string> #include <string>
#include <migraph/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -19,7 +19,7 @@ struct fwd_conv_batchnorm_rewrite ...@@ -19,7 +19,7 @@ struct fwd_conv_batchnorm_rewrite
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_GENERATE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_GENERATE_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_GENERATE_HPP
#include <migraph/argument.hpp> #include <migraphx/argument.hpp>
#include <migraph/literal.hpp> #include <migraphx/literal.hpp>
#include <migraph/type_traits.hpp> #include <migraphx/type_traits.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
#include <random> #include <random>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T, MIGRAPH_REQUIRES(is_floating_point<T>{})> template <class T, MIGRAPHX_REQUIRES(is_floating_point<T>{})>
constexpr T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
if(z == 0) if(z == 0)
...@@ -22,7 +22,7 @@ constexpr T normalize(unsigned long z) ...@@ -22,7 +22,7 @@ constexpr T normalize(unsigned long z)
return T(result); return T(result);
} }
template <class T, MIGRAPH_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})> template <class T, MIGRAPHX_REQUIRES(is_signed<T>{} and not is_floating_point<T>{})>
constexpr T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
const auto max = std::numeric_limits<T>::max(); const auto max = std::numeric_limits<T>::max();
...@@ -30,7 +30,7 @@ constexpr T normalize(unsigned long z) ...@@ -30,7 +30,7 @@ constexpr T normalize(unsigned long z)
return half_max - (z % max); return half_max - (z % max);
} }
template <class T, MIGRAPH_REQUIRES(not is_signed<T>{} and std::is_integral<T>{})> template <class T, MIGRAPHX_REQUIRES(not is_signed<T>{} and std::is_integral<T>{})>
constexpr T normalize(unsigned long z) constexpr T normalize(unsigned long z)
{ {
const auto max = std::numeric_limits<T>::max(); const auto max = std::numeric_limits<T>::max();
...@@ -78,7 +78,7 @@ struct xorshift_generator ...@@ -78,7 +78,7 @@ struct xorshift_generator
}; };
template <class T> template <class T>
std::vector<T> generate_tensor_data(const migraph::shape& s, unsigned long seed = 0) std::vector<T> generate_tensor_data(const migraphx::shape& s, unsigned long seed = 0)
{ {
std::vector<T> result(s.elements()); std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed}); std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
...@@ -93,7 +93,7 @@ literal generate_literal(shape s, unsigned long seed = 0); ...@@ -93,7 +93,7 @@ literal generate_literal(shape s, unsigned long seed = 0);
literal abs(literal l); literal abs(literal l);
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/ ==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_HALF_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_HALF_HPP
#define MIGRAPH_GUARD_RTGLIB_HALF_HPP #define MIGRAPHX_GUARD_RTGLIB_HALF_HPP
#include <half.hpp> #include <half.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
using half = half_float::half; using half = half_float::half;
...@@ -33,7 +33,7 @@ struct deduce<half_float::detail::expr> ...@@ -33,7 +33,7 @@ struct deduce<half_float::detail::expr>
template <class T> template <class T>
using deduce = typename detail::deduce<T>::type; using deduce = typename detail::deduce<T>::type;
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_INSTRUCTION_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_INSTRUCTION_HPP
#include <migraph/literal.hpp> #include <migraphx/literal.hpp>
#include <migraph/shape.hpp> #include <migraphx/shape.hpp>
#include <migraph/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraph/operation.hpp> #include <migraphx/operation.hpp>
#include <migraph/erase.hpp> #include <migraphx/erase.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
#include <string> #include <string>
#include <utility> #include <utility>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args); shape compute_shape(const operation& op, const std::vector<instruction_ref>& args);
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args);
struct instruction struct instruction
{ {
...@@ -61,7 +62,7 @@ struct instruction ...@@ -61,7 +62,7 @@ struct instruction
template <class T> template <class T>
void remove_output(const T& ins) void remove_output(const T& ins)
{ {
migraph::erase(output, ins); migraphx::erase(output, ins);
} }
static void backreference(instruction_ref ref); static void backreference(instruction_ref ref);
...@@ -71,7 +72,11 @@ struct instruction ...@@ -71,7 +72,11 @@ struct instruction
static void static void
replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args); replace(instruction_ref ins, operation o, const shape& r, std::vector<instruction_ref> args);
static instruction_ref get_output_alias(instruction_ref ins); argument eval() const;
void finalize(context& ctx);
static instruction_ref get_output_alias(instruction_ref ins, bool shallow = false);
private: private:
// internal // internal
...@@ -90,18 +95,18 @@ struct instruction ...@@ -90,18 +95,18 @@ struct instruction
std::vector<instruction_ref> arguments; std::vector<instruction_ref> arguments;
literal lit; literal lit;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
namespace std { namespace std {
template <> template <>
struct hash<migraph::instruction_ref> struct hash<migraphx::instruction_ref>
{ {
using argument_type = migraph::instruction_ref; using argument_type = migraphx::instruction_ref;
using result_type = std::size_t; using result_type = std::size_t;
result_type operator()(const argument_type& x) const noexcept result_type operator()(const argument_type& x) const noexcept
{ {
return std::hash<migraph::instruction*>{}(&*x); return std::hash<migraphx::instruction*>{}(&*x);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment