Commit bb390b65 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'activationOperators' into addLogExpOperators

parents adfa1a93 682ad83a
#ifndef MIGRAPH_GUARD_RTGLIB_REFLECT_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_REFLECT_HPP
#define MIGRAPH_GUARD_RTGLIB_REFLECT_HPP #define MIGRAPHX_GUARD_RTGLIB_REFLECT_HPP
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/rank.hpp> #include <migraphx/rank.hpp>
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <functional> #include <functional>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace detail { namespace detail {
...@@ -47,7 +47,7 @@ void reflect_each(T& x, F f) ...@@ -47,7 +47,7 @@ void reflect_each(T& x, F f)
}); });
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_REQUIRES_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_REQUIRES_HPP
#include <type_traits> #include <type_traits>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <bool... Bs> template <bool... Bs>
struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
...@@ -24,29 +24,29 @@ struct requires_enum ...@@ -24,29 +24,29 @@ struct requires_enum
}; };
}; };
#define MIGRAPH_REQUIRES_CAT(x, y) x##y #define MIGRAPHX_REQUIRES_CAT(x, y) x##y
#ifdef CPPCHECK #ifdef CPPCHECK
#define MIGRAPH_REQUIRES(...) class = void #define MIGRAPHX_REQUIRES(...) class = void
#else #else
#if 0 #if 0
// TODO: This currently crashed on clang // TODO: This currently crashed on clang
#define MIGRAPH_REQUIRES(...) \ #define MIGRAPHX_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPH_REQUIRES_CAT( \ typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \
PrivateRequires, \ PrivateRequires, \
__LINE__) = migraphx::requires_enum<__LINE__>::a, \ __LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__, \ class = typename std::enable_if<and_<__VA_ARGS__, \
MIGRAPH_REQUIRES_CAT(PrivateRequires, __LINE__) == \ MIGRAPHX_REQUIRES_CAT(PrivateRequires, __LINE__) == \
migraphx::requires_enum<__LINE__>::a>{}>::type migraphx::requires_enum<__LINE__>::a>{}>::type
#else #else
#define MIGRAPH_REQUIRES(...) \ #define MIGRAPHX_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPH_REQUIRES_CAT( \ typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \
PrivateRequires, __LINE__) = migraphx::requires_enum<__LINE__>::a, \ PrivateRequires, __LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__>{}>::type class = typename std::enable_if<and_<__VA_ARGS__>{}>::type
#endif #endif
#endif #endif
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_HPP
#include <vector> #include <vector>
#include <cassert> #include <cassert>
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct shape_impl; struct shape_impl;
...@@ -21,7 +21,7 @@ struct shape ...@@ -21,7 +21,7 @@ struct shape
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPH_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
m(half_type, half) \ m(half_type, half) \
m(float_type, float) \ m(float_type, float) \
m(double_type, double) \ m(double_type, double) \
...@@ -35,22 +35,22 @@ struct shape ...@@ -35,22 +35,22 @@ struct shape
m(uint64_type, uint64_t) m(uint64_type, uint64_t)
// clang-format on // clang-format on
#define MIGRAPH_SHAPE_ENUM_TYPES(x, t) x, #define MIGRAPHX_SHAPE_ENUM_TYPES(x, t) x,
enum type_t enum type_t
{ {
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_ENUM_TYPES) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_ENUM_TYPES)
}; };
#undef MIGRAPH_SHAPE_ENUM_TYPES #undef MIGRAPHX_SHAPE_ENUM_TYPES
template <class T, class = void> template <class T, class = void>
struct get_type; struct get_type;
#define MIGRAPH_SHAPE_GET_TYPE(x, t) \ #define MIGRAPHX_SHAPE_GET_TYPE(x, t) \
template <class T> \ template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \ struct get_type<t, T> : std::integral_constant<type_t, x> \
{ \ { \
}; };
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_GET_TYPE) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GET_TYPE)
#undef MIGRAPH_SHAPE_GET_TYPE #undef MIGRAPHX_SHAPE_GET_TYPE
template <class T> template <class T>
struct get_type<const T> : get_type<T> struct get_type<const T> : get_type<T>
...@@ -148,12 +148,12 @@ struct shape ...@@ -148,12 +148,12 @@ struct shape
{ {
switch(this->type()) switch(this->type())
{ {
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \ #define MIGRAPHX_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return; case x: v(as<t>()); return;
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_VISITOR_CASE) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_VISITOR_CASE)
#undef MIGRAPH_SHAPE_VISITOR_CASE #undef MIGRAPHX_SHAPE_VISITOR_CASE
} }
MIGRAPH_THROW("Unknown type"); MIGRAPHX_THROW("Unknown type");
} }
private: private:
...@@ -163,7 +163,7 @@ struct shape ...@@ -163,7 +163,7 @@ struct shape
std::string type_string() const; std::string type_string() const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <algorithm> #include <algorithm>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class F> template <class F>
void shape_for_each(const migraphx::shape& s, F f) void shape_for_each(const migraphx::shape& s, F f)
...@@ -28,7 +28,7 @@ void shape_for_each(const migraphx::shape& s, F f) ...@@ -28,7 +28,7 @@ void shape_for_each(const migraphx::shape& s, F f)
} }
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP #define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#include <string> #include <string>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -18,7 +18,7 @@ struct simplify_algebra ...@@ -18,7 +18,7 @@ struct simplify_algebra
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP #define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#include <string> #include <string>
#include <migraphx/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -19,7 +19,7 @@ struct simplify_reshapes ...@@ -19,7 +19,7 @@ struct simplify_reshapes
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_STREAMUTILS_HPP #ifndef MIGRAPHX_GUARD_STREAMUTILS_HPP
#define MIGRAPH_GUARD_STREAMUTILS_HPP #define MIGRAPHX_GUARD_STREAMUTILS_HPP
#include <ostream> #include <ostream>
#include <algorithm> #include <algorithm>
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T> template <class T>
struct stream_range_container struct stream_range_container
...@@ -56,7 +56,7 @@ void stream_write_value(std::ostream& os, const T& x) ...@@ -56,7 +56,7 @@ void stream_write_value(std::ostream& os, const T& x)
detail::stream_write_value_impl(rank<1>{}, os, x); detail::stream_write_value_impl(rank<1>{}, os, x);
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
inline std::string inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace) replace_string(std::string subject, const std::string& search, const std::string& replace)
...@@ -87,7 +87,7 @@ inline std::string to_string(const T& x) ...@@ -87,7 +87,7 @@ inline std::string to_string(const T& x)
return ss.str(); return ss.str();
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_TARGET_HPP
#include <cassert> #include <cassert>
#include <string> #include <string>
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN #ifdef DOXYGEN
...@@ -244,7 +244,7 @@ inline const ValueType& any_cast(const target& x) ...@@ -244,7 +244,7 @@ inline const ValueType& any_cast(const target& x)
#endif #endif
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_TENSOR_VIEW_HPP #ifndef MIGRAPHX_GUARD_TENSOR_VIEW_HPP
#define MIGRAPH_GUARD_TENSOR_VIEW_HPP #define MIGRAPHX_GUARD_TENSOR_VIEW_HPP
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/float_equal.hpp> #include <migraphx/float_equal.hpp>
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T> template <class T>
struct tensor_view struct tensor_view
...@@ -29,7 +29,7 @@ struct tensor_view ...@@ -29,7 +29,7 @@ struct tensor_view
const T* data() const { return this->m_data; } const T* data() const { return this->m_data; }
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)> template <class... Ts, MIGRAPHX_REQUIRES(std::is_integral<Ts>{}...)>
const T& operator()(Ts... xs) const const T& operator()(Ts... xs) const
{ {
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens()); assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
...@@ -37,7 +37,7 @@ struct tensor_view ...@@ -37,7 +37,7 @@ struct tensor_view
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)> template <class... Ts, MIGRAPHX_REQUIRES(std::is_integral<Ts>{}...)>
T& operator()(Ts... xs) T& operator()(Ts... xs)
{ {
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens()); assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
...@@ -45,13 +45,13 @@ struct tensor_view ...@@ -45,13 +45,13 @@ struct tensor_view
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})> template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
const T& operator()(Iterator start, Iterator last) const const T& operator()(Iterator start, Iterator last) const
{ {
return m_data[m_shape.index(start, last)]; return m_data[m_shape.index(start, last)];
} }
template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})> template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
T& operator()(Iterator start, Iterator last) T& operator()(Iterator start, Iterator last)
{ {
return m_data[m_shape.index(start, last)]; return m_data[m_shape.index(start, last)];
...@@ -169,7 +169,7 @@ tensor_view<T> make_view(shape s, T* data) ...@@ -169,7 +169,7 @@ tensor_view<T> make_view(shape s, T* data)
return {s, data}; return {s, data};
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_TIME_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_TIME_HPP
#define MIGRAPH_GUARD_RTGLIB_TIME_HPP #define MIGRAPHX_GUARD_RTGLIB_TIME_HPP
#include <chrono> #include <chrono>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Duration, class F> template <class Duration, class F>
auto time(F f) auto time(F f)
...@@ -16,7 +16,7 @@ auto time(F f) ...@@ -16,7 +16,7 @@ auto time(F f)
return std::chrono::duration_cast<Duration>(finish - start).count(); return std::chrono::duration_cast<Duration>(finish - start).count();
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_TRACER_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP #define MIGRAPHX_GUARD_RTGLIB_TRACER_HPP
#include <ostream> #include <ostream>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct tracer struct tracer
{ {
...@@ -30,7 +30,7 @@ struct tracer ...@@ -30,7 +30,7 @@ struct tracer
std::ostream* os = nullptr; std::ostream* os = nullptr;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP #define MIGRAPHX_GUARD_RTGLIB_TYPE_NAME_HPP
#include <string> #include <string>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class PrivateMigraphTypeNameProbe> template <class PrivateMigraphTypeNameProbe>
const std::string& get_type_name() const std::string& get_type_name()
...@@ -41,7 +41,7 @@ const std::string& get_type_name(const T&) ...@@ -41,7 +41,7 @@ const std::string& get_type_name(const T&)
return migraphx::get_type_name<T>(); return migraphx::get_type_name<T>();
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -5,32 +5,32 @@ ...@@ -5,32 +5,32 @@
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/ ==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP #define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <type_traits> #include <type_traits>
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ #define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \ template <class X> \
struct trait : std::trait<X> \ struct trait : std::trait<X> \
{ \ { \
}; \ }; \
\ \
template <> \ template <> \
struct trait<T> : std::true_type \ struct trait<T> : std::true_type \
{ \ { \
}; };
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_VERIFY_HPP #ifndef MIGRAPHX_GUARD_VERIFY_HPP
#define MIGRAPH_GUARD_VERIFY_HPP #define MIGRAPHX_GUARD_VERIFY_HPP
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Compute the value of a range // Compute the value of a range
template <class R> template <class R>
...@@ -173,6 +173,6 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n ...@@ -173,6 +173,6 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n
return error <= threshold; return error <= threshold;
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP #define MIGRAPHX_GUARD_RTGLIB_VERIFY_ARGS_HPP
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
inline bool verify_args(const std::string& name, inline bool verify_args(const std::string& name,
const argument& cpu_arg, const argument& cpu_arg,
...@@ -84,7 +84,7 @@ inline bool verify_args(const std::string& name, ...@@ -84,7 +84,7 @@ inline bool verify_args(const std::string& name,
return passed; return passed;
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif #endif
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include <migraphx/erase.hpp> #include <migraphx/erase.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args) instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args)) : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
...@@ -183,5 +183,5 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg ...@@ -183,5 +183,5 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg
return op.compute_shape(compute_shapes(args)); return op.compute_shape(compute_shapes(args));
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -22,7 +22,7 @@ rocm_clang_tidy_check(read_onnx) ...@@ -22,7 +22,7 @@ rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx migraphx_onnx) target_link_libraries(read_onnx migraphx_onnx)
if(MIGRAPH_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
add_executable(mnist mnist.cpp) add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist) rocm_clang_tidy_check(mnist)
target_link_libraries(mnist migraphx_cpu migraphx_gpu migraphx_onnx) target_link_libraries(mnist migraphx_cpu migraphx_gpu migraphx_onnx)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct unknown struct unknown
{ {
std::string op; std::string op;
...@@ -43,7 +43,8 @@ struct onnx_parser ...@@ -43,7 +43,8 @@ struct onnx_parser
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>; using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
bool is_pytorch = false;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
...@@ -113,7 +114,7 @@ struct onnx_parser ...@@ -113,7 +114,7 @@ struct onnx_parser
{ {
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) { ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() != 2) if(args.size() != 2)
MIGRAPH_THROW("binary operators should have 2 operands"); MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast")) if(contains(attributes, "broadcast"))
{ {
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>(); uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
...@@ -152,8 +153,8 @@ struct onnx_parser ...@@ -152,8 +153,8 @@ struct onnx_parser
std::swap(s0, s1); std::swap(s0, s1);
// Copy the larger vector to output_lens // Copy the larger vector to output_lens
std::vector<std::size_t> output_lens(s1->size()); std::vector<std::size_t> output_lens = *s1;
auto offset = s1->size() - s0->size(); auto offset = s1->size() - s0->size();
std::transform(s0->begin(), std::transform(s0->begin(),
s0->end(), s0->end(),
s1->begin() + offset, s1->begin() + offset,
...@@ -195,7 +196,22 @@ struct onnx_parser ...@@ -195,7 +196,22 @@ struct onnx_parser
op::convolution op; op::convolution op;
if(contains(attributes, "pads")) if(contains(attributes, "pads"))
{ {
copy(attributes["pads"].ints(), op.padding.begin()); if(contains(attributes, "auto_pad"))
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
std::vector<std::size_t> padding(4);
copy(attributes["pads"].ints(), padding.begin());
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
} }
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
...@@ -205,6 +221,19 @@ struct onnx_parser ...@@ -205,6 +221,19 @@ struct onnx_parser
{ {
copy(attributes["dilations"].ints(), op.dilation.begin()); copy(attributes["dilations"].ints(), op.dilation.begin());
} }
if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
if(s.find("SAME") != std::string::npos)
{
op.padding_mode = op::convolution::same;
}
}
if(args.size() == 3) if(args.size() == 3)
{ {
uint64_t axis = 1; uint64_t axis = 1;
...@@ -227,7 +256,18 @@ struct onnx_parser ...@@ -227,7 +256,18 @@ struct onnx_parser
} }
if(contains(attributes, "pads")) if(contains(attributes, "pads"))
{ {
copy(attributes["pads"].ints(), op.padding.begin()); std::vector<std::size_t> padding(4);
copy(attributes["pads"].ints(), padding.begin());
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
} }
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
...@@ -237,6 +277,15 @@ struct onnx_parser ...@@ -237,6 +277,15 @@ struct onnx_parser
{ {
copy(attributes["kernel_shape"].ints(), op.lengths.begin()); copy(attributes["kernel_shape"].ints(), op.lengths.begin());
} }
if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("auto_pad is not supported for pooling");
}
}
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
...@@ -502,7 +551,7 @@ struct onnx_parser ...@@ -502,7 +551,7 @@ struct onnx_parser
void parse_node(const std::string& name) void parse_node(const std::string& name)
{ {
if(name.empty()) if(name.empty())
MIGRAPH_THROW("Onnx node must have a name"); MIGRAPHX_THROW("Onnx node must have a name");
if(instructions.count(name) == 0) if(instructions.count(name) == 0)
{ {
auto&& node = nodes.at(name); auto&& node = nodes.at(name);
...@@ -592,7 +641,7 @@ struct onnx_parser ...@@ -592,7 +641,7 @@ struct onnx_parser
case onnx::AttributeProto::TENSORS: return {}; case onnx::AttributeProto::TENSORS: return {};
case onnx::AttributeProto::GRAPHS: return {}; case onnx::AttributeProto::GRAPHS: return {};
} }
MIGRAPH_THROW("Invalid attribute type"); MIGRAPHX_THROW("Invalid attribute type");
} }
static literal parse_tensor(const onnx::TensorProto& t) static literal parse_tensor(const onnx::TensorProto& t)
...@@ -620,7 +669,7 @@ struct onnx_parser ...@@ -620,7 +669,7 @@ struct onnx_parser
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
} }
MIGRAPH_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
switch(t.data_type()) switch(t.data_type())
{ {
...@@ -651,7 +700,7 @@ struct onnx_parser ...@@ -651,7 +700,7 @@ struct onnx_parser
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
} }
MIGRAPH_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
static shape parse_type(const onnx::TypeProto& t) static shape parse_type(const onnx::TypeProto& t)
...@@ -720,5 +769,5 @@ program parse_onnx(const std::string& name) ...@@ -720,5 +769,5 @@ program parse_onnx(const std::string& name)
return std::move(parser.prog); return std::move(parser.prog);
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#ifndef MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP #define MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -14,17 +14,17 @@ ...@@ -14,17 +14,17 @@
#include <queue> #include <queue>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
//#define MIGRAPH_DEBUG_OPT //#define MIGRAPHX_DEBUG_OPT
#ifdef MIGRAPH_DEBUG_OPT #ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPH_DEBUG(s) s #define MIGRAPHX_DEBUG(s) s
#else #else
#define MIGRAPH_DEBUG(s) #define MIGRAPHX_DEBUG(s)
#endif // MIGRAPH_DEBUG_OPT #endif // MIGRAPHX_DEBUG_OPT
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP #endif // MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment