Commit bc5d7f75 authored by Paul's avatar Paul
Browse files

Merge from develop

parents 47c0854d a5b0afa0
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
/**
* Rewrite rnn to gemm and add.
*/
struct rewrite_rnn
{
std::string name() const { return "rewrite_rnn"; }
void apply(program& prog) const;
private:
// for vanilla rnn operators
void apply_vanilla_rnn(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> vanilla_rnn_cell(bool is_forward,
program& prog,
instruction_ref ins,
instruction_ref input,
instruction_ref w,
instruction_ref r,
instruction_ref bias,
instruction_ref ih,
operation& actv_func) const;
std::vector<operation> vanilla_rnn_actv_funcs(instruction_ref ins) const;
// for gru operators
void apply_gru(program& prog, instruction_ref ins) const;
std::vector<instruction_ref> gru_cell(bool is_forward,
program& prog,
instruction_ref ins,
std::vector<instruction_ref> inputs,
int linear_before_reset,
const operation& actv_func1,
const operation& actv_func2) const;
std::vector<operation> gru_actv_funcs(instruction_ref ins) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_HPP
#include <vector> #include <vector>
#include <cassert> #include <cassert>
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include <numeric> #include <numeric>
#include <memory> #include <memory>
#include <migraph/errors.hpp> #include <migraphx/errors.hpp>
#include <migraph/half.hpp> #include <migraphx/half.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct shape_impl; struct shape_impl;
...@@ -21,7 +21,7 @@ struct shape ...@@ -21,7 +21,7 @@ struct shape
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPH_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
m(half_type, half) \ m(half_type, half) \
m(float_type, float) \ m(float_type, float) \
m(double_type, double) \ m(double_type, double) \
...@@ -35,22 +35,22 @@ struct shape ...@@ -35,22 +35,22 @@ struct shape
m(uint64_type, uint64_t) m(uint64_type, uint64_t)
// clang-format on // clang-format on
#define MIGRAPH_SHAPE_ENUM_TYPES(x, t) x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t enum type_t
{ {
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_ENUM_TYPES) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
}; };
#undef MIGRAPH_SHAPE_ENUM_TYPES #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
template <class T, class = void> template <class T, class = void>
struct get_type; struct get_type;
#define MIGRAPH_SHAPE_GET_TYPE(x, t) \ #define MIGRAPHX_SHAPE_GENERATE_GET_TYPE(x, t) \
template <class T> \ template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \ struct get_type<t, T> : std::integral_constant<type_t, x> \
{ \ { \
}; };
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_GET_TYPE) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_GET_TYPE)
#undef MIGRAPH_SHAPE_GET_TYPE #undef MIGRAPHX_SHAPE_GENERATE_GET_TYPE
template <class T> template <class T>
struct get_type<const T> : get_type<T> struct get_type<const T> : get_type<T>
...@@ -62,6 +62,19 @@ struct shape ...@@ -62,6 +62,19 @@ struct shape
shape(type_t t, std::vector<std::size_t> l); shape(type_t t, std::vector<std::size_t> l);
shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s); shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s);
template <class Range>
shape(type_t t, const Range& l) : shape(t, std::vector<std::size_t>(l.begin(), l.end()))
{
}
template <class Range1, class Range2>
shape(type_t t, const Range1& l, const Range2& s)
: shape(t,
std::vector<std::size_t>(l.begin(), l.end()),
std::vector<std::size_t>(s.begin(), s.end()))
{
}
type_t type() const; type_t type() const;
const std::vector<std::size_t>& lens() const; const std::vector<std::size_t>& lens() const;
const std::vector<std::size_t>& strides() const; const std::vector<std::size_t>& strides() const;
...@@ -141,6 +154,8 @@ struct shape ...@@ -141,6 +154,8 @@ struct shape
{ {
return reinterpret_cast<const T*>(buffer) + n; return reinterpret_cast<const T*>(buffer) + n;
} }
type_t type_enum() const { return get_type<T>{}; }
}; };
template <class Visitor> template <class Visitor>
...@@ -148,12 +163,20 @@ struct shape ...@@ -148,12 +163,20 @@ struct shape
{ {
switch(this->type()) switch(this->type())
{ {
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \ #define MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return; case x: v(as<t>()); return;
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_VISITOR_CASE) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE)
#undef MIGRAPH_SHAPE_VISITOR_CASE #undef MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE
} }
MIGRAPH_THROW("Unknown type"); MIGRAPHX_THROW("Unknown type");
}
template <class Visitor>
static void visit_types(Visitor v)
{
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL(x, t) v(as<t>());
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL)
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
} }
private: private:
...@@ -163,7 +186,7 @@ struct shape ...@@ -163,7 +186,7 @@ struct shape
std::string type_string() const; std::string type_string() const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#include <migraph/shape.hpp> #include <migraphx/shape.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
#include <algorithm> #include <algorithm>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class F> template <class F>
void shape_for_each(const migraph::shape& s, F f) void shape_for_each(const migraphx::shape& s, F f)
{ {
// Ensure calls to f use const ref to vector // Ensure calls to f use const ref to vector
auto call = [&f](const std::vector<std::size_t>& i) { f(i); }; auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
...@@ -28,7 +28,7 @@ void shape_for_each(const migraph::shape& s, F f) ...@@ -28,7 +28,7 @@ void shape_for_each(const migraph::shape& s, F f)
} }
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP #define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#include <string> #include <string>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -18,7 +18,7 @@ struct simplify_algebra ...@@ -18,7 +18,7 @@ struct simplify_algebra
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP #define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#include <string> #include <string>
#include <migraph/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -19,7 +19,7 @@ struct simplify_reshapes ...@@ -19,7 +19,7 @@ struct simplify_reshapes
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_STREAMUTILS_HPP #ifndef MIGRAPHX_GUARD_STREAMUTILS_HPP
#define MIGRAPH_GUARD_STREAMUTILS_HPP #define MIGRAPHX_GUARD_STREAMUTILS_HPP
#include <ostream> #include <ostream>
#include <algorithm> #include <algorithm>
#include <migraph/rank.hpp> #include <migraphx/rank.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T> template <class T>
struct stream_range_container struct stream_range_container
...@@ -56,7 +56,7 @@ void stream_write_value(std::ostream& os, const T& x) ...@@ -56,7 +56,7 @@ void stream_write_value(std::ostream& os, const T& x)
detail::stream_write_value_impl(rank<1>{}, os, x); detail::stream_write_value_impl(rank<1>{}, os, x);
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
inline std::string inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace) replace_string(std::string subject, const std::string& search, const std::string& replace)
...@@ -87,7 +87,7 @@ inline std::string to_string(const T& x) ...@@ -87,7 +87,7 @@ inline std::string to_string(const T& x)
return ss.str(); return ss.str();
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_TARGET_HPP
#include <cassert> #include <cassert>
#include <string> #include <string>
...@@ -8,12 +8,12 @@ ...@@ -8,12 +8,12 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <migraph/context.hpp> #include <migraphx/context.hpp>
#include <migraph/pass.hpp> #include <migraphx/pass.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN #ifdef DOXYGEN
...@@ -127,6 +127,12 @@ struct target ...@@ -127,6 +127,12 @@ struct target
return (*this).private_detail_te_get_handle().get_context(); return (*this).private_detail_te_get_handle().get_context();
} }
friend bool is_shared(const target& private_detail_x, const target& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private: private:
struct private_detail_te_handle_base_type struct private_detail_te_handle_base_type
{ {
...@@ -244,7 +250,7 @@ inline const ValueType& any_cast(const target& x) ...@@ -244,7 +250,7 @@ inline const ValueType& any_cast(const target& x)
#endif #endif
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_TENSOR_VIEW_HPP #ifndef MIGRAPHX_GUARD_TENSOR_VIEW_HPP
#define MIGRAPH_GUARD_TENSOR_VIEW_HPP #define MIGRAPHX_GUARD_TENSOR_VIEW_HPP
#include <migraph/shape.hpp> #include <migraphx/shape.hpp>
#include <migraph/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraph/requires.hpp> #include <migraphx/requires.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
#include <iostream> #include <iostream>
#include <utility> #include <utility>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T> template <class T>
struct tensor_view struct tensor_view
...@@ -29,7 +29,7 @@ struct tensor_view ...@@ -29,7 +29,7 @@ struct tensor_view
const T* data() const { return this->m_data; } const T* data() const { return this->m_data; }
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)> template <class... Ts, MIGRAPHX_REQUIRES(std::is_integral<Ts>{}...)>
const T& operator()(Ts... xs) const const T& operator()(Ts... xs) const
{ {
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens()); assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
...@@ -37,7 +37,7 @@ struct tensor_view ...@@ -37,7 +37,7 @@ struct tensor_view
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)> template <class... Ts, MIGRAPHX_REQUIRES(std::is_integral<Ts>{}...)>
T& operator()(Ts... xs) T& operator()(Ts... xs)
{ {
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens()); assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
...@@ -45,13 +45,13 @@ struct tensor_view ...@@ -45,13 +45,13 @@ struct tensor_view
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})> template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
const T& operator()(Iterator start, Iterator last) const const T& operator()(Iterator start, Iterator last) const
{ {
return m_data[m_shape.index(start, last)]; return m_data[m_shape.index(start, last)];
} }
template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})> template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
T& operator()(Iterator start, Iterator last) T& operator()(Iterator start, Iterator last)
{ {
return m_data[m_shape.index(start, last)]; return m_data[m_shape.index(start, last)];
...@@ -164,12 +164,12 @@ bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y) ...@@ -164,12 +164,12 @@ bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y)
} }
template <class T> template <class T>
tensor_view<T> make_view(shape s, T* data) tensor_view<T> make_view(const shape& s, T* data)
{ {
return {s, data}; return {s, data};
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_TIME_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_TIME_HPP
#define MIGRAPH_GUARD_RTGLIB_TIME_HPP #define MIGRAPHX_GUARD_RTGLIB_TIME_HPP
#include <chrono> #include <chrono>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Duration, class F> template <class Duration, class F>
auto time(F f) auto time(F f)
...@@ -16,7 +16,7 @@ auto time(F f) ...@@ -16,7 +16,7 @@ auto time(F f)
return std::chrono::duration_cast<Duration>(finish - start).count(); return std::chrono::duration_cast<Duration>(finish - start).count();
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_TRACER_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP #define MIGRAPHX_GUARD_RTGLIB_TRACER_HPP
#include <ostream> #include <ostream>
#include <migraph/functional.hpp> #include <migraphx/functional.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct tracer struct tracer
{ {
...@@ -30,7 +30,7 @@ struct tracer ...@@ -30,7 +30,7 @@ struct tracer
std::ostream* os = nullptr; std::ostream* os = nullptr;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP #define MIGRAPHX_GUARD_RTGLIB_TYPE_NAME_HPP
#include <string> #include <string>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class PrivateMigraphTypeNameProbe> template <class PrivateMigraphTypeNameProbe>
const std::string& get_type_name() const std::string& get_type_name()
...@@ -18,7 +18,7 @@ const std::string& get_type_name() ...@@ -18,7 +18,7 @@ const std::string& get_type_name()
name = typeid(PrivateMigraphTypeNameProbe).name(); name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7); name = name.substr(7);
#else #else
const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT
name = __PRETTY_FUNCTION__; name = __PRETTY_FUNCTION__;
...@@ -38,10 +38,10 @@ const std::string& get_type_name() ...@@ -38,10 +38,10 @@ const std::string& get_type_name()
template <class T> template <class T>
const std::string& get_type_name(const T&) const std::string& get_type_name(const T&)
{ {
return migraph::get_type_name<T>(); return migraphx::get_type_name<T>();
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
...@@ -5,32 +5,32 @@ ...@@ -5,32 +5,32 @@
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/ ==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP #define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <type_traits> #include <type_traits>
#include <migraph/half.hpp> #include <migraphx/half.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ #define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \ template <class X> \
struct trait : std::trait<X> \ struct trait : std::trait<X> \
{ \ { \
}; \ }; \
\ \
template <> \ template <> \
struct trait<T> : std::true_type \ struct trait<T> : std::true_type \
{ \ { \
}; };
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_VERIFY_HPP #ifndef MIGRAPHX_GUARD_VERIFY_HPP
#define MIGRAPH_GUARD_VERIFY_HPP #define MIGRAPHX_GUARD_VERIFY_HPP
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
...@@ -7,11 +7,11 @@ ...@@ -7,11 +7,11 @@
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <migraph/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Compute the value of a range // Compute the value of a range
template <class R> template <class R>
...@@ -173,6 +173,6 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n ...@@ -173,6 +173,6 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n
return error <= threshold; return error <= threshold;
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP #define MIGRAPHX_GUARD_RTGLIB_VERIFY_ARGS_HPP
#include <migraph/verify.hpp> #include <migraphx/verify.hpp>
#include <migraph/argument.hpp> #include <migraphx/argument.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
inline bool verify_args(const std::string& name, inline bool verify_args(const std::string& name,
const argument& cpu_arg, const argument& cpu_arg,
...@@ -84,7 +84,7 @@ inline bool verify_args(const std::string& name, ...@@ -84,7 +84,7 @@ inline bool verify_args(const std::string& name,
return passed; return passed;
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/builtin.hpp> #include <migraphx/builtin.hpp>
#include <migraph/erase.hpp> #include <migraphx/erase.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args) instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args)) : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
...@@ -70,7 +70,7 @@ bool instruction::valid() const ...@@ -70,7 +70,7 @@ bool instruction::valid() const
{ {
computed = compute_shape(op, arguments); computed = compute_shape(op, arguments);
} }
catch(migraph::exception&) catch(migraphx::exception&)
{ {
return false; return false;
} }
...@@ -97,7 +97,7 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output ...@@ -97,7 +97,7 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output
bool operator==(const instruction& x, const instruction& y) bool operator==(const instruction& x, const instruction& y)
{ {
if(not(x.result == y.result and x.op == y.op and x.arguments == y.arguments)) if(std::tie(x.result, x.op, x.arguments) != std::tie(y.result, y.op, y.arguments))
return false; return false;
if(x.name() == "@literal") if(x.name() == "@literal")
return x.lit == y.lit; return x.lit == y.lit;
...@@ -162,26 +162,55 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) ...@@ -162,26 +162,55 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this); old->remove_output(*this);
} }
std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args) argument instruction::eval() const
{ {
std::vector<shape> shapes(args.size()); if(op.name() == "@literal")
std::transform( {
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); }); return this->get_literal().get_argument();
return shapes; }
if(is_context_free(op))
{
std::vector<argument> args;
for(auto&& arg : this->inputs())
{
argument a = arg->eval();
if(a.empty())
return {};
args.push_back(a);
}
return op.compute(result, args);
}
return {};
} }
instruction_ref instruction::get_output_alias(instruction_ref ins) void instruction::finalize(context& ctx)
{ {
auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs())); if(has_finalize(this->op))
this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
}
instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
{
auto i = ins->get_operator().output_alias(to_shapes(ins->inputs()));
if(i < 0) if(i < 0)
return ins; return ins;
if(shallow)
return ins->inputs().at(i);
return get_output_alias(ins->inputs().at(i)); return get_output_alias(ins->inputs().at(i));
} }
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return shapes;
}
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args) shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{ {
return op.compute_shape(compute_shapes(args)); return op.compute_shape(to_shapes(args));
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
...@@ -7,35 +7,35 @@ target_compile_options(onnx-proto PRIVATE -w) ...@@ -7,35 +7,35 @@ target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library(migraph_onnx onnx.cpp) add_library(migraphx_onnx onnx.cpp)
set_target_properties(migraph_onnx PROPERTIES EXPORT_NAME onnx) set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
rocm_clang_tidy_check(migraph_onnx) rocm_clang_tidy_check(migraphx_onnx)
target_link_libraries(migraph_onnx PRIVATE onnx-proto) target_link_libraries(migraphx_onnx PRIVATE onnx-proto)
target_link_libraries(migraph_onnx PUBLIC migraph) target_link_libraries(migraphx_onnx PUBLIC migraphx)
rocm_install_targets( rocm_install_targets(
TARGETS migraph_onnx TARGETS migraphx_onnx
) )
add_executable(read_onnx read_onnx.cpp) add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx) rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx migraph_onnx) target_link_libraries(read_onnx migraphx_onnx)
if(MIGRAPH_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
add_executable(mnist mnist.cpp) add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist) rocm_clang_tidy_check(mnist)
target_link_libraries(mnist migraph_cpu migraph_gpu migraph_onnx) target_link_libraries(mnist migraphx_cpu migraphx_gpu migraphx_onnx)
add_executable(cifar10 cifar10.cpp) add_executable(cifar10 cifar10.cpp)
rocm_clang_tidy_check(cifar10) rocm_clang_tidy_check(cifar10)
target_link_libraries(cifar10 migraph_cpu migraph_gpu migraph_onnx) target_link_libraries(cifar10 migraphx_cpu migraphx_gpu migraphx_onnx)
add_executable(verify_onnx verify_onnx.cpp) add_executable(verify_onnx verify_onnx.cpp)
rocm_clang_tidy_check(verify_onnx) rocm_clang_tidy_check(verify_onnx)
target_link_libraries(verify_onnx migraph_onnx migraph_cpu migraph_gpu) target_link_libraries(verify_onnx migraphx_onnx migraphx_cpu migraphx_gpu)
add_executable(perf_onnx perf_onnx.cpp) add_executable(perf_onnx perf_onnx.cpp)
rocm_clang_tidy_check(perf_onnx) rocm_clang_tidy_check(perf_onnx)
target_link_libraries(perf_onnx migraph_onnx migraph_cpu migraph_gpu) target_link_libraries(perf_onnx migraphx_onnx migraphx_cpu migraphx_gpu)
endif() endif()
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
#include <numeric> #include <numeric>
#include <stdexcept> #include <stdexcept>
#include <migraph/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraph/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraph/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraph/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraphx/generate.hpp>
#include "softmax.hpp" #include "softmax.hpp"
...@@ -53,19 +53,19 @@ int main(int argc, char const* argv[]) ...@@ -53,19 +53,19 @@ int main(int argc, char const* argv[])
std::string gpu_cpu = argv[1]; std::string gpu_cpu = argv[1];
std::string file = argv[2]; std::string file = argv[2];
std::string datafile = argv[3]; std::string datafile = argv[3];
auto prog = migraph::parse_onnx(file); auto prog = migraphx::parse_onnx(file);
std::cout << prog << std::endl; std::cout << prog << std::endl;
auto imageset = read_cifar10_images(datafile); auto imageset = read_cifar10_images(datafile);
if(gpu_cpu == "gpu") if(gpu_cpu == "gpu")
{ {
// GPU target // GPU target
prog.compile(migraph::gpu::target{}); prog.compile(migraphx::gpu::target{});
migraph::program::parameter_map m; migraphx::program::parameter_map m;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}};
for(auto&& x : prog.get_parameter_shapes()) for(auto&& x : prog.get_parameter_shapes())
{ {
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second)); m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second));
} }
auto labels = imageset.first; auto labels = imageset.first;
auto input = imageset.second; auto input = imageset.second;
...@@ -73,8 +73,8 @@ int main(int argc, char const* argv[]) ...@@ -73,8 +73,8 @@ int main(int argc, char const* argv[])
for(int i = 0; i < 10; i++) for(int i = 0; i < 10; i++)
{ {
std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> "; std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> ";
m["0"] = migraph::gpu::to_gpu(migraph::argument{s, &ptr[3072 * i]}); m["0"] = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[3072 * i]});
auto result = migraph::gpu::from_gpu(prog.eval(m)); auto result = migraphx::gpu::from_gpu(prog.eval(m));
std::vector<float> logits; std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax<float>(logits); std::vector<float> probs = softmax<float>(logits);
...@@ -86,15 +86,15 @@ int main(int argc, char const* argv[]) ...@@ -86,15 +86,15 @@ int main(int argc, char const* argv[])
else else
{ {
// CPU target // CPU target
prog.compile(migraph::cpu::target{}); prog.compile(migraphx::cpu::target{});
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}};
auto labels = imageset.first; auto labels = imageset.first;
auto input = imageset.second; auto input = imageset.second;
auto ptr = input.data(); auto ptr = input.data();
for(int i = 0; i < 10; i++) for(int i = 0; i < 10; i++)
{ {
std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> "; std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> ";
auto input3 = migraph::argument{s, &ptr[3072 * i]}; auto input3 = migraphx::argument{s, &ptr[3072 * i]};
auto result = prog.eval({{"0", input3}}); auto result = prog.eval({{"0", input3}});
std::vector<float> logits; std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
......
...@@ -4,17 +4,20 @@ ...@@ -4,17 +4,20 @@
#include <numeric> #include <numeric>
#include <stdexcept> #include <stdexcept>
#include <migraph/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraph/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraph/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraphx/generate.hpp>
#include "softmax.hpp" #include "softmax.hpp"
auto reverse_int(unsigned int i) auto reverse_int(unsigned int i)
{ {
unsigned char c1, c2, c3, c4; unsigned char c1;
unsigned char c2;
unsigned char c3;
unsigned char c4;
c1 = i & 255u; c1 = i & 255u;
c2 = (i >> 8u) & 255u; c2 = (i >> 8u) & 255u;
c3 = (i >> 16u) & 255u; c3 = (i >> 16u) & 255u;
...@@ -32,7 +35,9 @@ read_mnist_images(const std::string& full_path, int& number_of_images, int& imag ...@@ -32,7 +35,9 @@ read_mnist_images(const std::string& full_path, int& number_of_images, int& imag
if(file.is_open()) if(file.is_open())
{ {
int magic_number = 0, n_rows = 0, n_cols = 0; int magic_number = 0;
int n_rows = 0;
int n_cols = 0;
file.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number)); file.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
magic_number = reverse_int(magic_number); magic_number = reverse_int(magic_number);
...@@ -113,20 +118,20 @@ int main(int argc, char const* argv[]) ...@@ -113,20 +118,20 @@ int main(int argc, char const* argv[])
std::vector<int32_t> labels = read_mnist_labels(labelfile, nlabels); std::vector<int32_t> labels = read_mnist_labels(labelfile, nlabels);
std::string file = argv[1]; std::string file = argv[1];
auto prog = migraph::parse_onnx(file); auto prog = migraphx::parse_onnx(file);
std::cout << prog << std::endl << std::endl; std::cout << prog << std::endl << std::endl;
prog.compile(migraph::gpu::target{}); prog.compile(migraphx::gpu::target{});
auto s = migraph::shape{migraph::shape::float_type, {1, 1, 28, 28}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 1, 28, 28}};
std::cout << s << std::endl; std::cout << s << std::endl;
auto ptr = input.data(); auto ptr = input.data();
migraph::program::parameter_map m; migraphx::program::parameter_map m;
m["output"] = m["output"] =
migraph::gpu::to_gpu(migraph::generate_argument(prog.get_parameter_shape("output"))); migraphx::gpu::to_gpu(migraphx::generate_argument(prog.get_parameter_shape("output")));
for(int i = 0; i < 20; i++) for(int i = 0; i < 20; i++)
{ {
std::cout << "label: " << labels[i] << " ----> "; std::cout << "label: " << labels[i] << " ----> ";
m["0"] = migraph::gpu::to_gpu(migraph::argument{s, &ptr[784 * i]}); m["0"] = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[784 * i]});
auto result = migraph::gpu::from_gpu(prog.eval(m)); auto result = migraphx::gpu::from_gpu(prog.eval(m));
std::vector<float> logits; std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax(logits); std::vector<float> probs = softmax(logits);
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment