Commit d2549384 authored by Khalique's avatar Khalique
Browse files

manual merge

parents 67048d04 ab6cd9d3
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_HPP
#include <vector> #include <vector>
#include <cassert> #include <cassert>
...@@ -7,12 +7,12 @@ ...@@ -7,12 +7,12 @@
#include <numeric> #include <numeric>
#include <memory> #include <memory>
#include <migraph/errors.hpp> #include <migraphx/errors.hpp>
#include <migraph/half.hpp> #include <migraphx/half.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct shape_impl; struct shape_impl;
...@@ -21,7 +21,7 @@ struct shape ...@@ -21,7 +21,7 @@ struct shape
// Add new types here // Add new types here
// clang-format off // clang-format off
#define MIGRAPH_SHAPE_VISIT_TYPES(m) \ #define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
m(half_type, half) \ m(half_type, half) \
m(float_type, float) \ m(float_type, float) \
m(double_type, double) \ m(double_type, double) \
...@@ -35,22 +35,22 @@ struct shape ...@@ -35,22 +35,22 @@ struct shape
m(uint64_type, uint64_t) m(uint64_type, uint64_t)
// clang-format on // clang-format on
#define MIGRAPH_SHAPE_ENUM_TYPES(x, t) x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t enum type_t
{ {
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_ENUM_TYPES) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
}; };
#undef MIGRAPH_SHAPE_ENUM_TYPES #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
template <class T, class = void> template <class T, class = void>
struct get_type; struct get_type;
#define MIGRAPH_SHAPE_GET_TYPE(x, t) \ #define MIGRAPHX_SHAPE_GENERATE_GET_TYPE(x, t) \
template <class T> \ template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \ struct get_type<t, T> : std::integral_constant<type_t, x> \
{ \ { \
}; };
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_GET_TYPE) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_GET_TYPE)
#undef MIGRAPH_SHAPE_GET_TYPE #undef MIGRAPHX_SHAPE_GENERATE_GET_TYPE
template <class T> template <class T>
struct get_type<const T> : get_type<T> struct get_type<const T> : get_type<T>
...@@ -148,12 +148,12 @@ struct shape ...@@ -148,12 +148,12 @@ struct shape
{ {
switch(this->type()) switch(this->type())
{ {
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \ #define MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return; case x: v(as<t>()); return;
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_VISITOR_CASE) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE)
#undef MIGRAPH_SHAPE_VISITOR_CASE #undef MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE
} }
MIGRAPH_THROW("Unknown type"); MIGRAPHX_THROW("Unknown type");
} }
private: private:
...@@ -163,7 +163,7 @@ struct shape ...@@ -163,7 +163,7 @@ struct shape
std::string type_string() const; std::string type_string() const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#include <migraph/shape.hpp> #include <migraphx/shape.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
#include <algorithm> #include <algorithm>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class F> template <class F>
void shape_for_each(const migraph::shape& s, F f) void shape_for_each(const migraphx::shape& s, F f)
{ {
// Ensure calls to f use const ref to vector // Ensure calls to f use const ref to vector
auto call = [&f](const std::vector<std::size_t>& i) { f(i); }; auto call = [&f](const std::vector<std::size_t>& i) { f(i); };
...@@ -28,7 +28,7 @@ void shape_for_each(const migraph::shape& s, F f) ...@@ -28,7 +28,7 @@ void shape_for_each(const migraph::shape& s, F f)
} }
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP #define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#include <string> #include <string>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -18,7 +18,7 @@ struct simplify_algebra ...@@ -18,7 +18,7 @@ struct simplify_algebra
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP #define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#include <string> #include <string>
#include <migraph/instruction_ref.hpp> #include <migraphx/instruction_ref.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct program; struct program;
...@@ -19,7 +19,7 @@ struct simplify_reshapes ...@@ -19,7 +19,7 @@ struct simplify_reshapes
void apply(program& p) const; void apply(program& p) const;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_STREAMUTILS_HPP #ifndef MIGRAPHX_GUARD_STREAMUTILS_HPP
#define MIGRAPH_GUARD_STREAMUTILS_HPP #define MIGRAPHX_GUARD_STREAMUTILS_HPP
#include <ostream> #include <ostream>
#include <algorithm> #include <algorithm>
#include <migraph/rank.hpp> #include <migraphx/rank.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T> template <class T>
struct stream_range_container struct stream_range_container
...@@ -56,7 +56,7 @@ void stream_write_value(std::ostream& os, const T& x) ...@@ -56,7 +56,7 @@ void stream_write_value(std::ostream& os, const T& x)
detail::stream_write_value_impl(rank<1>{}, os, x); detail::stream_write_value_impl(rank<1>{}, os, x);
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#include <algorithm> #include <algorithm>
#include <numeric> #include <numeric>
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
inline std::string inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace) replace_string(std::string subject, const std::string& search, const std::string& replace)
...@@ -87,7 +87,7 @@ inline std::string to_string(const T& x) ...@@ -87,7 +87,7 @@ inline std::string to_string(const T& x)
return ss.str(); return ss.str();
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_TARGET_HPP
#include <cassert> #include <cassert>
#include <string> #include <string>
...@@ -8,12 +8,12 @@ ...@@ -8,12 +8,12 @@
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <migraph/context.hpp> #include <migraphx/context.hpp>
#include <migraph/pass.hpp> #include <migraphx/pass.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN #ifdef DOXYGEN
...@@ -127,6 +127,12 @@ struct target ...@@ -127,6 +127,12 @@ struct target
return (*this).private_detail_te_get_handle().get_context(); return (*this).private_detail_te_get_handle().get_context();
} }
friend bool is_shared(const target& private_detail_x, const target& private_detail_y)
{
return private_detail_x.private_detail_te_handle_mem_var ==
private_detail_y.private_detail_te_handle_mem_var;
}
private: private:
struct private_detail_te_handle_base_type struct private_detail_te_handle_base_type
{ {
...@@ -244,7 +250,7 @@ inline const ValueType& any_cast(const target& x) ...@@ -244,7 +250,7 @@ inline const ValueType& any_cast(const target& x)
#endif #endif
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_TENSOR_VIEW_HPP #ifndef MIGRAPHX_GUARD_TENSOR_VIEW_HPP
#define MIGRAPH_GUARD_TENSOR_VIEW_HPP #define MIGRAPHX_GUARD_TENSOR_VIEW_HPP
#include <migraph/shape.hpp> #include <migraphx/shape.hpp>
#include <migraph/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraph/requires.hpp> #include <migraphx/requires.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
#include <iostream> #include <iostream>
#include <utility> #include <utility>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T> template <class T>
struct tensor_view struct tensor_view
...@@ -29,7 +29,7 @@ struct tensor_view ...@@ -29,7 +29,7 @@ struct tensor_view
const T* data() const { return this->m_data; } const T* data() const { return this->m_data; }
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)> template <class... Ts, MIGRAPHX_REQUIRES(std::is_integral<Ts>{}...)>
const T& operator()(Ts... xs) const const T& operator()(Ts... xs) const
{ {
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens()); assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
...@@ -37,7 +37,7 @@ struct tensor_view ...@@ -37,7 +37,7 @@ struct tensor_view
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)> template <class... Ts, MIGRAPHX_REQUIRES(std::is_integral<Ts>{}...)>
T& operator()(Ts... xs) T& operator()(Ts... xs)
{ {
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens()); assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
...@@ -45,13 +45,13 @@ struct tensor_view ...@@ -45,13 +45,13 @@ struct tensor_view
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})]; return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
} }
template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})> template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
const T& operator()(Iterator start, Iterator last) const const T& operator()(Iterator start, Iterator last) const
{ {
return m_data[m_shape.index(start, last)]; return m_data[m_shape.index(start, last)];
} }
template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})> template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
T& operator()(Iterator start, Iterator last) T& operator()(Iterator start, Iterator last)
{ {
return m_data[m_shape.index(start, last)]; return m_data[m_shape.index(start, last)];
...@@ -164,12 +164,12 @@ bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y) ...@@ -164,12 +164,12 @@ bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y)
} }
template <class T> template <class T>
tensor_view<T> make_view(shape s, T* data) tensor_view<T> make_view(const shape& s, T* data)
{ {
return {s, data}; return {s, data};
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_TIME_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_TIME_HPP
#define MIGRAPH_GUARD_RTGLIB_TIME_HPP #define MIGRAPHX_GUARD_RTGLIB_TIME_HPP
#include <chrono> #include <chrono>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Duration, class F> template <class Duration, class F>
auto time(F f) auto time(F f)
...@@ -16,7 +16,7 @@ auto time(F f) ...@@ -16,7 +16,7 @@ auto time(F f)
return std::chrono::duration_cast<Duration>(finish - start).count(); return std::chrono::duration_cast<Duration>(finish - start).count();
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_TRACER_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP #define MIGRAPHX_GUARD_RTGLIB_TRACER_HPP
#include <ostream> #include <ostream>
#include <migraph/functional.hpp> #include <migraphx/functional.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct tracer struct tracer
{ {
...@@ -30,7 +30,7 @@ struct tracer ...@@ -30,7 +30,7 @@ struct tracer
std::ostream* os = nullptr; std::ostream* os = nullptr;
}; };
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP #define MIGRAPHX_GUARD_RTGLIB_TYPE_NAME_HPP
#include <string> #include <string>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class PrivateMigraphTypeNameProbe> template <class PrivateMigraphTypeNameProbe>
const std::string& get_type_name() const std::string& get_type_name()
...@@ -18,7 +18,7 @@ const std::string& get_type_name() ...@@ -18,7 +18,7 @@ const std::string& get_type_name()
name = typeid(PrivateMigraphTypeNameProbe).name(); name = typeid(PrivateMigraphTypeNameProbe).name();
name = name.substr(7); name = name.substr(7);
#else #else
const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; const char parameter_name[] = "PrivateMigraphTypeNameProbe ="; // NOLINT
name = __PRETTY_FUNCTION__; name = __PRETTY_FUNCTION__;
...@@ -38,10 +38,10 @@ const std::string& get_type_name() ...@@ -38,10 +38,10 @@ const std::string& get_type_name()
template <class T> template <class T>
const std::string& get_type_name(const T&) const std::string& get_type_name(const T&)
{ {
return migraph::get_type_name<T>(); return migraphx::get_type_name<T>();
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
...@@ -5,32 +5,32 @@ ...@@ -5,32 +5,32 @@
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/ ==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP #define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <type_traits> #include <type_traits>
#include <migraph/half.hpp> #include <migraphx/half.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ #define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \ template <class X> \
struct trait : std::trait<X> \ struct trait : std::trait<X> \
{ \ { \
}; \ }; \
\ \
template <> \ template <> \
struct trait<T> : std::true_type \ struct trait<T> : std::true_type \
{ \ { \
}; };
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_VERIFY_HPP #ifndef MIGRAPHX_GUARD_VERIFY_HPP
#define MIGRAPH_GUARD_VERIFY_HPP #define MIGRAPHX_GUARD_VERIFY_HPP
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
...@@ -7,11 +7,11 @@ ...@@ -7,11 +7,11 @@
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <migraph/float_equal.hpp> #include <migraphx/float_equal.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
// Compute the value of a range // Compute the value of a range
template <class R> template <class R>
...@@ -173,6 +173,6 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n ...@@ -173,6 +173,6 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n
return error <= threshold; return error <= threshold;
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#ifndef MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP #define MIGRAPHX_GUARD_RTGLIB_VERIFY_ARGS_HPP
#include <migraph/verify.hpp> #include <migraphx/verify.hpp>
#include <migraph/argument.hpp> #include <migraphx/argument.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
inline bool verify_args(const std::string& name, inline bool verify_args(const std::string& name,
const argument& cpu_arg, const argument& cpu_arg,
...@@ -84,7 +84,7 @@ inline bool verify_args(const std::string& name, ...@@ -84,7 +84,7 @@ inline bool verify_args(const std::string& name,
return passed; return passed;
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#endif #endif
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/builtin.hpp> #include <migraphx/builtin.hpp>
#include <migraph/erase.hpp> #include <migraphx/erase.hpp>
namespace migraph { namespace migraphx {
inline namespace MIGRAPH_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args) instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args)) : op(std::move(o)), result(std::move(r)), arguments(std::move(args))
...@@ -70,7 +70,7 @@ bool instruction::valid() const ...@@ -70,7 +70,7 @@ bool instruction::valid() const
{ {
computed = compute_shape(op, arguments); computed = compute_shape(op, arguments);
} }
catch(migraph::exception&) catch(migraphx::exception&)
{ {
return false; return false;
} }
...@@ -162,26 +162,55 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins) ...@@ -162,26 +162,55 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old->remove_output(*this); old->remove_output(*this);
} }
std::vector<shape> compute_shapes(const std::vector<instruction_ref>& args) argument instruction::eval() const
{ {
std::vector<shape> shapes(args.size()); if(op.name() == "@literal")
std::transform( {
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); }); return this->get_literal().get_argument();
return shapes; }
if(is_context_free(op))
{
std::vector<argument> args;
for(auto&& arg : this->inputs())
{
argument a = arg->eval();
if(a.empty())
return {};
args.push_back(a);
}
return op.compute(result, args);
}
return {};
} }
instruction_ref instruction::get_output_alias(instruction_ref ins) void instruction::finalize(context& ctx)
{ {
auto i = ins->get_operator().output_alias(compute_shapes(ins->inputs())); if(has_finalize(this->op))
this->op.finalize(ctx, this->get_shape(), to_shapes(this->inputs()));
}
instruction_ref instruction::get_output_alias(instruction_ref ins, bool shallow)
{
auto i = ins->get_operator().output_alias(to_shapes(ins->inputs()));
if(i < 0) if(i < 0)
return ins; return ins;
if(shallow)
return ins->inputs().at(i);
return get_output_alias(ins->inputs().at(i)); return get_output_alias(ins->inputs().at(i));
} }
std::vector<shape> to_shapes(const std::vector<instruction_ref>& args)
{
std::vector<shape> shapes(args.size());
std::transform(
args.begin(), args.end(), shapes.begin(), [](instruction_ref i) { return i->get_shape(); });
return shapes;
}
shape compute_shape(const operation& op, const std::vector<instruction_ref>& args) shape compute_shape(const operation& op, const std::vector<instruction_ref>& args)
{ {
return op.compute_shape(compute_shapes(args)); return op.compute_shape(to_shapes(args));
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
...@@ -7,35 +7,35 @@ target_compile_options(onnx-proto PRIVATE -w) ...@@ -7,35 +7,35 @@ target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(onnx-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
add_library(migraph_onnx onnx.cpp) add_library(migraphx_onnx onnx.cpp)
set_target_properties(migraph_onnx PROPERTIES EXPORT_NAME onnx) set_target_properties(migraphx_onnx PROPERTIES EXPORT_NAME onnx)
rocm_clang_tidy_check(migraph_onnx) rocm_clang_tidy_check(migraphx_onnx)
target_link_libraries(migraph_onnx PRIVATE onnx-proto) target_link_libraries(migraphx_onnx PRIVATE onnx-proto)
target_link_libraries(migraph_onnx PUBLIC migraph) target_link_libraries(migraphx_onnx PUBLIC migraphx)
rocm_install_targets( rocm_install_targets(
TARGETS migraph_onnx TARGETS migraphx_onnx
) )
add_executable(read_onnx read_onnx.cpp) add_executable(read_onnx read_onnx.cpp)
rocm_clang_tidy_check(read_onnx) rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx migraph_onnx) target_link_libraries(read_onnx migraphx_onnx)
if(MIGRAPH_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
add_executable(mnist mnist.cpp) add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist) rocm_clang_tidy_check(mnist)
target_link_libraries(mnist migraph_cpu migraph_gpu migraph_onnx) target_link_libraries(mnist migraphx_cpu migraphx_gpu migraphx_onnx)
add_executable(cifar10 cifar10.cpp) add_executable(cifar10 cifar10.cpp)
rocm_clang_tidy_check(cifar10) rocm_clang_tidy_check(cifar10)
target_link_libraries(cifar10 migraph_cpu migraph_gpu migraph_onnx) target_link_libraries(cifar10 migraphx_cpu migraphx_gpu migraphx_onnx)
add_executable(verify_onnx verify_onnx.cpp) add_executable(verify_onnx verify_onnx.cpp)
rocm_clang_tidy_check(verify_onnx) rocm_clang_tidy_check(verify_onnx)
target_link_libraries(verify_onnx migraph_onnx migraph_cpu migraph_gpu) target_link_libraries(verify_onnx migraphx_onnx migraphx_cpu migraphx_gpu)
add_executable(perf_onnx perf_onnx.cpp) add_executable(perf_onnx perf_onnx.cpp)
rocm_clang_tidy_check(perf_onnx) rocm_clang_tidy_check(perf_onnx)
target_link_libraries(perf_onnx migraph_onnx migraph_cpu migraph_gpu) target_link_libraries(perf_onnx migraphx_onnx migraphx_cpu migraphx_gpu)
endif() endif()
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
#include <numeric> #include <numeric>
#include <stdexcept> #include <stdexcept>
#include <migraph/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraph/cpu/target.hpp> #include <migraphx/cpu/target.hpp>
#include <migraph/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraph/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraphx/generate.hpp>
#include "softmax.hpp" #include "softmax.hpp"
...@@ -53,19 +53,19 @@ int main(int argc, char const* argv[]) ...@@ -53,19 +53,19 @@ int main(int argc, char const* argv[])
std::string gpu_cpu = argv[1]; std::string gpu_cpu = argv[1];
std::string file = argv[2]; std::string file = argv[2];
std::string datafile = argv[3]; std::string datafile = argv[3];
auto prog = migraph::parse_onnx(file); auto prog = migraphx::parse_onnx(file);
std::cout << prog << std::endl; std::cout << prog << std::endl;
auto imageset = read_cifar10_images(datafile); auto imageset = read_cifar10_images(datafile);
if(gpu_cpu == "gpu") if(gpu_cpu == "gpu")
{ {
// GPU target // GPU target
prog.compile(migraph::gpu::target{}); prog.compile(migraphx::gpu::target{});
migraph::program::parameter_map m; migraphx::program::parameter_map m;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}};
for(auto&& x : prog.get_parameter_shapes()) for(auto&& x : prog.get_parameter_shapes())
{ {
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second)); m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second));
} }
auto labels = imageset.first; auto labels = imageset.first;
auto input = imageset.second; auto input = imageset.second;
...@@ -73,8 +73,8 @@ int main(int argc, char const* argv[]) ...@@ -73,8 +73,8 @@ int main(int argc, char const* argv[])
for(int i = 0; i < 10; i++) for(int i = 0; i < 10; i++)
{ {
std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> "; std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> ";
m["0"] = migraph::gpu::to_gpu(migraph::argument{s, &ptr[3072 * i]}); m["0"] = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[3072 * i]});
auto result = migraph::gpu::from_gpu(prog.eval(m)); auto result = migraphx::gpu::from_gpu(prog.eval(m));
std::vector<float> logits; std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax<float>(logits); std::vector<float> probs = softmax<float>(logits);
...@@ -86,15 +86,15 @@ int main(int argc, char const* argv[]) ...@@ -86,15 +86,15 @@ int main(int argc, char const* argv[])
else else
{ {
// CPU target // CPU target
prog.compile(migraph::cpu::target{}); prog.compile(migraphx::cpu::target{});
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 3, 32, 32}};
auto labels = imageset.first; auto labels = imageset.first;
auto input = imageset.second; auto input = imageset.second;
auto ptr = input.data(); auto ptr = input.data();
for(int i = 0; i < 10; i++) for(int i = 0; i < 10; i++)
{ {
std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> "; std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> ";
auto input3 = migraph::argument{s, &ptr[3072 * i]}; auto input3 = migraphx::argument{s, &ptr[3072 * i]};
auto result = prog.eval({{"0", input3}}); auto result = prog.eval({{"0", input3}});
std::vector<float> logits; std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
......
...@@ -4,17 +4,20 @@ ...@@ -4,17 +4,20 @@
#include <numeric> #include <numeric>
#include <stdexcept> #include <stdexcept>
#include <migraph/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraph/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraph/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraphx/generate.hpp>
#include "softmax.hpp" #include "softmax.hpp"
auto reverse_int(unsigned int i) auto reverse_int(unsigned int i)
{ {
unsigned char c1, c2, c3, c4; unsigned char c1;
unsigned char c2;
unsigned char c3;
unsigned char c4;
c1 = i & 255u; c1 = i & 255u;
c2 = (i >> 8u) & 255u; c2 = (i >> 8u) & 255u;
c3 = (i >> 16u) & 255u; c3 = (i >> 16u) & 255u;
...@@ -32,7 +35,9 @@ read_mnist_images(const std::string& full_path, int& number_of_images, int& imag ...@@ -32,7 +35,9 @@ read_mnist_images(const std::string& full_path, int& number_of_images, int& imag
if(file.is_open()) if(file.is_open())
{ {
int magic_number = 0, n_rows = 0, n_cols = 0; int magic_number = 0;
int n_rows = 0;
int n_cols = 0;
file.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number)); file.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));
magic_number = reverse_int(magic_number); magic_number = reverse_int(magic_number);
...@@ -113,20 +118,20 @@ int main(int argc, char const* argv[]) ...@@ -113,20 +118,20 @@ int main(int argc, char const* argv[])
std::vector<int32_t> labels = read_mnist_labels(labelfile, nlabels); std::vector<int32_t> labels = read_mnist_labels(labelfile, nlabels);
std::string file = argv[1]; std::string file = argv[1];
auto prog = migraph::parse_onnx(file); auto prog = migraphx::parse_onnx(file);
std::cout << prog << std::endl << std::endl; std::cout << prog << std::endl << std::endl;
prog.compile(migraph::gpu::target{}); prog.compile(migraphx::gpu::target{});
auto s = migraph::shape{migraph::shape::float_type, {1, 1, 28, 28}}; auto s = migraphx::shape{migraphx::shape::float_type, {1, 1, 28, 28}};
std::cout << s << std::endl; std::cout << s << std::endl;
auto ptr = input.data(); auto ptr = input.data();
migraph::program::parameter_map m; migraphx::program::parameter_map m;
m["output"] = m["output"] =
migraph::gpu::to_gpu(migraph::generate_argument(prog.get_parameter_shape("output"))); migraphx::gpu::to_gpu(migraphx::generate_argument(prog.get_parameter_shape("output")));
for(int i = 0; i < 20; i++) for(int i = 0; i < 20; i++)
{ {
std::cout << "label: " << labels[i] << " ----> "; std::cout << "label: " << labels[i] << " ----> ";
m["0"] = migraph::gpu::to_gpu(migraph::argument{s, &ptr[784 * i]}); m["0"] = migraphx::gpu::to_gpu(migraphx::argument{s, &ptr[784 * i]});
auto result = migraph::gpu::from_gpu(prog.eval(m)); auto result = migraphx::gpu::from_gpu(prog.eval(m));
std::vector<float> logits; std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax(logits); std::vector<float> probs = softmax(logits);
......
...@@ -9,41 +9,27 @@ ...@@ -9,41 +9,27 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <migraph/fallthrough.hpp> #include <migraphx/fallthrough.hpp>
#include <migraph/program.hpp> #include <migraphx/program.hpp>
#include <migraph/operators.hpp> #include <migraphx/operators.hpp>
#include <migraph/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraph/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraph/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/onnx.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS { namespace migraphx {
struct unknown inline namespace MIGRAPHX_INLINE_NS {
{
std::string op;
std::string name() const { return "unknown:" + op; }
shape compute_shape(std::vector<shape> input) const
{
if(input.empty())
return {};
else
return input.front();
}
friend std::ostream& operator<<(std::ostream& os, const unknown& x)
{
os << x.name();
return os;
}
};
struct onnx_parser struct onnx_parser
{ {
using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>; using attribute_map = std::unordered_map<std::string, onnx::AttributeProto>;
using node_map = std::unordered_map<std::string, onnx::NodeProto>; using node_map = std::unordered_map<std::string, onnx::NodeProto>;
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>; using op_func =
std::function<std::vector<instruction_ref>(attribute_map, std::vector<instruction_ref>)>;
node_map nodes; node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions; std::unordered_map<std::string, instruction_ref> instructions;
program prog = program(); program prog = program();
bool is_pytorch = false;
std::unordered_map<std::string, op_func> ops; std::unordered_map<std::string, op_func> ops;
...@@ -51,18 +37,36 @@ struct onnx_parser ...@@ -51,18 +37,36 @@ struct onnx_parser
{ {
add_generic_op("MatMul", op::dot{}); add_generic_op("MatMul", op::dot{});
add_generic_op("Relu", op::relu{}); add_generic_op("Relu", op::relu{});
add_generic_op("Sigmoid", op::sigmoid{});
add_generic_op("Abs", op::abs{});
add_generic_op("Exp", op::exp{});
add_generic_op("Log", op::log{});
// disable dropout for inference // disable dropout for inference
add_generic_op("Dropout", op::identity{}); add_generic_op("Dropout", op::identity{});
add_generic_op("Identity", op::identity{});
add_broadcastable_binary_op("Add", op::add{}); add_generic_op("Sin", op::sin{});
add_broadcastable_binary_op("Div", op::div{}); add_generic_op("Cos", op::cos{});
add_broadcastable_binary_op("Mul", op::mul{}); add_generic_op("Tan", op::tan{});
add_broadcastable_binary_op("Sub", op::sub{}); add_generic_op("Sinh", op::sinh{});
add_broadcastable_binary_op("Sum", op::add{}); add_generic_op("Cosh", op::cosh{});
add_generic_op("Tanh", op::tanh{});
add_generic_op("Asin", op::asin{});
add_generic_op("Acos", op::acos{});
add_generic_op("Atan", op::atan{});
add_binary_op("Add", op::add{});
add_binary_op("Div", op::div{});
add_binary_op("Mul", op::mul{});
add_binary_op("Sub", op::sub{});
add_variadic_op("Sum", op::add{});
add_variadic_op("Max", op::max{});
add_variadic_op("Min", op::min{});
add_mem_op("LRN", &onnx_parser::parse_lrn); add_mem_op("LRN", &onnx_parser::parse_lrn);
add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler); add_mem_op("ImageScaler", &onnx_parser::parse_imagescaler);
add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu); add_mem_op("LeakyRelu", &onnx_parser::parse_leaky_relu);
add_mem_op("Elu", &onnx_parser::parse_elu);
add_mem_op("Constant", &onnx_parser::parse_constant); add_mem_op("Constant", &onnx_parser::parse_constant);
add_mem_op("Conv", &onnx_parser::parse_conv); add_mem_op("Conv", &onnx_parser::parse_conv);
add_mem_op("MaxPool", &onnx_parser::parse_pooling); add_mem_op("MaxPool", &onnx_parser::parse_pooling);
...@@ -78,11 +82,24 @@ struct onnx_parser ...@@ -78,11 +82,24 @@ struct onnx_parser
add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze); add_mem_op("Unsqueeze", &onnx_parser::parse_unsqueeze);
add_mem_op("Slice", &onnx_parser::parse_slice); add_mem_op("Slice", &onnx_parser::parse_slice);
add_mem_op("Concat", &onnx_parser::parse_concat); add_mem_op("Concat", &onnx_parser::parse_concat);
add_mem_op("Gather", &onnx_parser::parse_gather);
add_mem_op("Shape", &onnx_parser::parse_shape);
add_mem_op("ConstantFill", &onnx_parser::parse_constant_fill);
add_mem_op("Transpose", &onnx_parser::parse_transpose); add_mem_op("Transpose", &onnx_parser::parse_transpose);
add_mem_op("Pad", &onnx_parser::parse_pad);
} }
template <class F> template <class F>
void add_op(std::string name, F f) void add_op(std::string name, F f)
{
ops.emplace(name, [=](auto&&... xs) {
return std::vector<instruction_ref>{f(std::forward<decltype(xs)>(xs)...)};
});
}
// Multi output op
template <class F>
void add_multi_op(std::string name, F f)
{ {
ops.emplace(name, f); ops.emplace(name, f);
} }
...@@ -90,24 +107,23 @@ struct onnx_parser ...@@ -90,24 +107,23 @@ struct onnx_parser
template <class F> template <class F>
void add_mem_op(std::string name, F f) void add_mem_op(std::string name, F f)
{ {
ops.emplace(name, [=](auto&&... xs) { add_op(name, [=](auto&&... xs) {
return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...); return std::mem_fn(f)(*this, name, std::forward<decltype(xs)>(xs)...);
}); });
} }
template <class T> template <class T>
void add_broadcastable_binary_op(std::string name, T x) void add_binary_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) { add_op(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() != 2) if(args.size() != 2)
MIGRAPH_THROW("binary operators should have 2 operands"); MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast")) if(contains(attributes, "broadcast") and contains(attributes, "axis"))
{ {
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>(); uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
if(broadcasted != 0) if(broadcasted != 0)
{ {
uint64_t axis = (contains(attributes, "axis")) uint64_t axis = parse_value(attributes.at("axis")).at<uint64_t>();
? parse_value(attributes.at("axis")).at<uint64_t>()
: 0;
auto l = auto l =
prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]); prog.add_instruction(op::broadcast{axis, args[0]->get_shape()}, args[1]);
return prog.add_instruction(x, args[0], l); return prog.add_instruction(x, args[0], l);
...@@ -116,51 +132,76 @@ struct onnx_parser ...@@ -116,51 +132,76 @@ struct onnx_parser
} }
else else
{ {
// Example: return add_broadcastable_binary_op(args[0], args[1], x);
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<std::size_t>* s0 = &args[0]->get_shape().lens();
const std::vector<std::size_t>* s1 = &args[1]->get_shape().lens();
// Make sure s0 is the smaller size
if(s0->size() > s1->size())
std::swap(s0, s1);
// Copy the larger vector to output_lens
std::vector<std::size_t> output_lens(s1->size());
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, args[0]);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, args[1]);
return prog.add_instruction(x, l0, l1);
} }
}); });
} }
template <class T>
instruction_ref add_broadcastable_binary_op(instruction_ref arg0, instruction_ref arg1, T x)
{
if(arg0->get_shape() != arg1->get_shape())
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const std::vector<std::size_t>* s0 = &arg0->get_shape().lens();
const std::vector<std::size_t>* s1 = &arg1->get_shape().lens();
// Make sure s0 is the smaller size
if(s0->size() > s1->size())
std::swap(s0, s1);
std::vector<std::size_t> output_lens(*s1);
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
output_lens.begin() + offset,
[](auto a, auto b) { return std::max(a, b); });
auto l0 = prog.add_instruction(op::multibroadcast{output_lens}, arg0);
auto l1 = prog.add_instruction(op::multibroadcast{output_lens}, arg1);
return prog.add_instruction(x, l0, l1);
}
else
{
return prog.add_instruction(x, {arg0, arg1});
}
}
template <class T> template <class T>
void add_generic_op(std::string name, T x) void add_generic_op(std::string name, T x)
{ {
ops.emplace(name, [this, x](attribute_map, std::vector<instruction_ref> args) { add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return prog.add_instruction(x, args); return prog.add_instruction(x, args);
}); });
} }
template <class T>
void add_variadic_op(std::string name, T x)
{
add_op(name, [this, x](attribute_map, std::vector<instruction_ref> args) {
return std::accumulate(std::next(args.begin()),
args.end(),
args.front(),
[this, x](instruction_ref a, instruction_ref b) {
return add_broadcastable_binary_op(a, b, x);
});
});
}
instruction_ref instruction_ref
parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args) parse_softmax(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{ {
...@@ -175,9 +216,30 @@ struct onnx_parser ...@@ -175,9 +216,30 @@ struct onnx_parser
parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_conv(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
op::convolution op; op::convolution op;
auto l0 = args[0];
if(contains(attributes, "pads")) if(contains(attributes, "pads"))
{ {
copy(attributes["pads"].ints(), op.padding.begin()); if(contains(attributes, "auto_pad"))
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
std::vector<std::int64_t> padding;
copy(attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
// insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding}, l0);
}
else
{
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
} }
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
...@@ -187,6 +249,23 @@ struct onnx_parser ...@@ -187,6 +249,23 @@ struct onnx_parser
{ {
copy(attributes["dilations"].ints(), op.dilation.begin()); copy(attributes["dilations"].ints(), op.dilation.begin());
} }
if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
if(s.find("SAME") != std::string::npos)
{
op.padding_mode = op::padding_mode_t::same;
}
}
if(contains(attributes, "group"))
{
op.group = parse_value(attributes.at("group")).at<int>();
}
if(args.size() == 3) if(args.size() == 3)
{ {
uint64_t axis = 1; uint64_t axis = 1;
...@@ -194,7 +273,7 @@ struct onnx_parser ...@@ -194,7 +273,7 @@ struct onnx_parser
auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape()}, args[2]); auto l2 = prog.add_instruction(op::broadcast{axis, l1->get_shape()}, args[2]);
return prog.add_instruction(op::add{}, l1, l2); return prog.add_instruction(op::add{}, l1, l2);
} }
return prog.add_instruction(op, args); return prog.add_instruction(op, l0, args[1]);
} }
instruction_ref parse_pooling(const std::string& name, instruction_ref parse_pooling(const std::string& name,
...@@ -202,6 +281,7 @@ struct onnx_parser ...@@ -202,6 +281,7 @@ struct onnx_parser
std::vector<instruction_ref> args) std::vector<instruction_ref> args)
{ {
op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"}; op::pooling op{ends_with(name, "MaxPool") ? "max" : "average"};
auto l0 = args[0];
if(starts_with(name, "Global")) if(starts_with(name, "Global"))
{ {
auto lens = args.front()->get_shape().lens(); auto lens = args.front()->get_shape().lens();
...@@ -209,7 +289,23 @@ struct onnx_parser ...@@ -209,7 +289,23 @@ struct onnx_parser
} }
if(contains(attributes, "pads")) if(contains(attributes, "pads"))
{ {
copy(attributes["pads"].ints(), op.padding.begin()); std::vector<std::int64_t> padding;
copy(attributes["pads"].ints(), std::back_inserter(padding));
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
// insert zeros for pad op (args[0] has 4 dims)
padding = {0, 0, padding[0], padding[1], 0, 0, padding[2], padding[3]};
l0 = prog.add_instruction(op::pad{padding}, l0);
}
else
{
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
} }
if(contains(attributes, "strides")) if(contains(attributes, "strides"))
{ {
...@@ -219,7 +315,17 @@ struct onnx_parser ...@@ -219,7 +315,17 @@ struct onnx_parser
{ {
copy(attributes["kernel_shape"].ints(), op.lengths.begin()); copy(attributes["kernel_shape"].ints(), op.lengths.begin());
} }
return prog.add_instruction(op, std::move(args)); if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(s.find("SAME_UPPER") == std::string::npos)
{
MIGRAPHX_THROW("auto_pad only supports SAME_UPPER for pooling");
}
op.padding_mode = op::padding_mode_t::same;
}
return prog.add_instruction(op, l0);
} }
instruction_ref instruction_ref
...@@ -242,7 +348,7 @@ struct onnx_parser ...@@ -242,7 +348,7 @@ struct onnx_parser
instruction_ref instruction_ref
parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_flatten(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
uint64_t axis = 0; uint64_t axis = 1;
if(contains(attributes, "axis")) if(contains(attributes, "axis"))
{ {
axis = parse_value(attributes.at("axis")).at<int>(); axis = parse_value(attributes.at("axis")).at<int>();
...@@ -276,6 +382,18 @@ struct onnx_parser ...@@ -276,6 +382,18 @@ struct onnx_parser
return prog.add_instruction(op, std::move(args)); return prog.add_instruction(op, std::move(args));
} }
instruction_ref
parse_gather(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
int axis = 0;
if(contains(attributes, "axis"))
{
axis = parse_value(attributes.at("axis")).at<int>();
}
op::gather op{axis};
return prog.add_instruction(op, std::move(args));
}
instruction_ref instruction_ref
parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_slice(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -308,7 +426,7 @@ struct onnx_parser ...@@ -308,7 +426,7 @@ struct onnx_parser
parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_gemm(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
float alpha = 1.0f; float alpha = 1.0f;
float beta = 0.0f; float beta = 1.0f;
bool transa = false; bool transa = false;
bool transb = false; bool transb = false;
if(contains(attributes, "alpha")) if(contains(attributes, "alpha"))
...@@ -317,7 +435,7 @@ struct onnx_parser ...@@ -317,7 +435,7 @@ struct onnx_parser
} }
if(contains(attributes, "beta")) if(contains(attributes, "beta"))
{ {
alpha = parse_value(attributes.at("beta")).at<float>(); beta = parse_value(attributes.at("beta")).at<float>();
} }
if(contains(attributes, "transA")) if(contains(attributes, "transA"))
{ {
...@@ -332,10 +450,20 @@ struct onnx_parser ...@@ -332,10 +450,20 @@ struct onnx_parser
auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1]; auto l2 = (transb) ? prog.add_instruction(op::transpose{perm}, args[1]) : args[1];
if(args.size() == 3) if(args.size() == 3)
{ {
uint64_t axis = 1; if(beta != 0.f)
auto l3 = prog.add_instruction(op::dot{alpha, beta}, l1, l2); {
auto l4 = prog.add_instruction(op::broadcast{axis, l3->get_shape()}, args[2]); auto l3 = prog.add_instruction(op::dot{alpha}, l1, l2);
return prog.add_instruction(op::add{}, l3, l4); auto l4 = args[2];
if(l4->get_shape().scalar()) // ignore args[2] (no C value added to alpha*A*B)
return l3;
if(beta != 1.f)
{
auto beta_val = prog.add_literal(beta);
auto l5 = prog.add_instruction(op::scalar{args[2]->get_shape()}, beta_val);
l4 = prog.add_instruction(op::mul{}, args[2], l5);
}
return add_broadcastable_binary_op(l3, l4, op::add{});
}
} }
return prog.add_instruction(op::dot{alpha, beta}, l1, l2); return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
} }
...@@ -383,6 +511,18 @@ struct onnx_parser ...@@ -383,6 +511,18 @@ struct onnx_parser
return prog.add_instruction(op, args.front()); return prog.add_instruction(op, args.front());
} }
instruction_ref
parse_elu(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
float alpha = 1.0; // default alpha val for elu
if(contains(attributes, "alpha"))
{
alpha = parse_value(attributes.at("alpha")).at<float>();
}
op::elu op{alpha};
return prog.add_instruction(op, args.front());
}
instruction_ref instruction_ref
parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args) parse_lrn(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{ {
...@@ -422,12 +562,12 @@ struct onnx_parser ...@@ -422,12 +562,12 @@ struct onnx_parser
auto scale_val = prog.add_literal(scale); auto scale_val = prog.add_literal(scale);
auto bias_vals = prog.add_literal( auto bias_vals = prog.add_literal(
migraph::literal{migraph::shape{migraph::shape::float_type, {bias.size()}}, bias}); migraphx::literal{migraphx::shape{migraphx::shape::float_type, {bias.size()}}, bias});
auto scale_tensor = prog.add_instruction(migraph::op::scalar{input_shape}, scale_val); auto scale_tensor = prog.add_instruction(migraphx::op::scalar{input_shape}, scale_val);
auto img_scaled = prog.add_instruction(migraph::op::mul{}, args.front(), scale_tensor); auto img_scaled = prog.add_instruction(migraphx::op::mul{}, args.front(), scale_tensor);
auto bias_bcast = prog.add_instruction(migraph::op::broadcast{1, input_shape}, bias_vals); auto bias_bcast = prog.add_instruction(migraphx::op::broadcast{1, input_shape}, bias_vals);
return prog.add_instruction(migraph::op::add{}, img_scaled, bias_bcast); return prog.add_instruction(migraphx::op::add{}, img_scaled, bias_bcast);
} }
instruction_ref instruction_ref
...@@ -439,7 +579,122 @@ struct onnx_parser ...@@ -439,7 +579,122 @@ struct onnx_parser
auto&& perm_vals = attributes["perm"].ints(); auto&& perm_vals = attributes["perm"].ints();
perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end()); perm = std::vector<int64_t>(perm_vals.begin(), perm_vals.end());
} }
return prog.add_instruction(migraph::op::transpose{perm}, args.front()); return prog.add_instruction(migraphx::op::transpose{perm}, args.front());
}
instruction_ref
parse_pad(const std::string&, attribute_map attributes, std::vector<instruction_ref> args)
{
std::vector<int64_t> pads{};
float value = 0.0f;
if(contains(attributes, "pads"))
{
auto&& pad_vals = attributes["pads"].ints();
pads = std::vector<int64_t>(pad_vals.begin(), pad_vals.end());
}
if(contains(attributes, "value"))
{
value = parse_value(attributes.at("value")).at<float>();
}
if(contains(attributes, "mode"))
{
auto mode = attributes.at("mode").s();
if(mode != "constant")
MIGRAPHX_THROW("migraphx currently only supports constant padding");
}
return prog.add_instruction(migraphx::op::pad{pads, value}, args.front());
}
// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
instruction_ref
parse_shape(const std::string&, const attribute_map&, std::vector<instruction_ref> args)
{
if(args.size() != 1)
MIGRAPHX_THROW("Shape: operator should have 1 operand");
std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
std::vector<int64_t> vec_shape(arg_shape.size());
migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
return int64_t(i);
});
return prog.add_literal(migraphx::literal{s, vec_shape});
}
// Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill
// operator
instruction_ref parse_constant_fill(const std::string&,
attribute_map attributes,
std::vector<instruction_ref> args)
{
int input_as_shape = 0;
int dtype = 1;
float value = 0.0f;
if(contains(attributes, "dtype"))
{
dtype = parse_value(attributes.at("dtype")).at<int>();
}
migraphx::shape::type_t type = get_type(dtype);
if(contains(attributes, "input_as_shape"))
{
input_as_shape = parse_value(attributes.at("input_as_shape")).at<int>();
}
if(contains(attributes, "value"))
{
value = parse_value(attributes.at("value")).at<float>();
}
if(contains(attributes, "extra_shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot handle extra shape attribute");
}
if(input_as_shape == 1)
{
if(args.size() != 1)
{
MIGRAPHX_THROW("ConstantFill: need an input argument as output shape");
}
if(contains(attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: cannot set the shape argument and pass in an input "
"at the same time");
}
migraphx::argument in = args[0]->eval();
if(in.empty())
{
MIGRAPHX_THROW("ConstantFill: cannot handle dynamic shape as input");
}
std::vector<std::size_t> dims;
in.visit([&](auto input) { dims.assign(input.begin(), input.end()); });
migraphx::shape s(type, dims);
std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
}
else if(input_as_shape == 0)
{
if(!contains(attributes, "shape"))
{
MIGRAPHX_THROW("ConstantFill: attribute output shape is needed");
}
literal ls = parse_value(attributes.at("shape"));
std::vector<std::size_t> dims;
ls.visit([&](auto s) { dims.assign(s.begin(), s.end()); });
migraphx::shape s{type, dims};
std::vector<float> values(s.elements(), value);
return prog.add_literal(migraphx::literal(s, values));
}
else
{
MIGRAPHX_THROW("ConstantFill: wrong value of attribute input_as_shape");
}
} }
void parse_from(std::istream& is) void parse_from(std::istream& is)
...@@ -454,7 +709,7 @@ struct onnx_parser ...@@ -454,7 +709,7 @@ struct onnx_parser
} }
else else
{ {
throw std::runtime_error("Failed reading"); MIGRAPHX_THROW("Failed reading onnx file.");
} }
} }
...@@ -484,14 +739,14 @@ struct onnx_parser ...@@ -484,14 +739,14 @@ struct onnx_parser
} }
for(auto&& p : nodes) for(auto&& p : nodes)
{ {
this->parse_node(get_name(p.second)); this->parse_node(p.first);
} }
} }
void parse_node(const std::string& name) void parse_node(const std::string& name)
{ {
if(name.empty()) if(name.empty())
MIGRAPH_THROW("Onnx node must have a name"); MIGRAPHX_THROW("Onnx node must have a name");
if(instructions.count(name) == 0) if(instructions.count(name) == 0)
{ {
auto&& node = nodes.at(name); auto&& node = nodes.at(name);
...@@ -500,23 +755,37 @@ struct onnx_parser ...@@ -500,23 +755,37 @@ struct onnx_parser
{ {
if(nodes.count(input) > 0) if(nodes.count(input) > 0)
{ {
auto&& iname = get_name(nodes.at(input)); assert(name != input);
assert(name != iname); this->parse_node(input);
this->parse_node(iname); args.push_back(instructions.at(input));
args.push_back(instructions.at(iname));
} }
else else
{ {
args.push_back(instructions.at(input)); args.push_back(instructions.at(input));
} }
} }
std::vector<instruction_ref> result;
if(ops.count(node.op_type()) == 0) if(ops.count(node.op_type()) == 0)
{ {
instructions[name] = prog.add_instruction(unknown{node.op_type()}, args); result.push_back(prog.add_instruction(unknown{node.op_type()}, args));
}
else
{
result = ops[node.op_type()](get_attributes(node), args);
}
// Even no output nodes produce output in migraphx
if(node.output().empty() and result.size() == 1)
{
instructions[name] = result.front();
} }
else else
{ {
instructions[name] = ops[node.op_type()](get_attributes(node), args); assert(node.output().size() >= result.size());
std::transform(result.begin(),
result.end(),
node.output().begin(),
std::inserter(instructions, instructions.end()),
[](auto&& x, auto&& y) { return std::make_pair(y, x); });
} }
} }
} }
...@@ -531,25 +800,24 @@ struct onnx_parser ...@@ -531,25 +800,24 @@ struct onnx_parser
return result; return result;
} }
static std::string get_name(const onnx::NodeProto& node)
{
if(node.name().empty())
{
std::string generated = "migraph_unnamed_node";
return std::accumulate(node.output().begin(),
node.output().end(),
generated,
[](auto x, auto y) { return x + "_" + y; });
}
return node.name();
}
static node_map get_nodes(const onnx::GraphProto& graph) static node_map get_nodes(const onnx::GraphProto& graph)
{ {
std::unordered_map<std::string, onnx::NodeProto> result; std::unordered_map<std::string, onnx::NodeProto> result;
std::size_t n = 0;
for(auto&& node : graph.node()) for(auto&& node : graph.node())
{ {
result[get_name(node)] = node; if(node.output().empty())
{
if(node.name().empty())
{
result["migraphx_unamed_node_" + std::to_string(n)] = node;
n++;
}
else
{
result[node.name()] = node;
}
}
for(auto&& output : node.output()) for(auto&& output : node.output())
{ {
result[output] = node; result[output] = node;
...@@ -581,12 +849,17 @@ struct onnx_parser ...@@ -581,12 +849,17 @@ struct onnx_parser
case onnx::AttributeProto::TENSORS: return {}; case onnx::AttributeProto::TENSORS: return {};
case onnx::AttributeProto::GRAPHS: return {}; case onnx::AttributeProto::GRAPHS: return {};
} }
MIGRAPH_THROW("Invalid attribute type"); MIGRAPHX_THROW("Invalid attribute type");
} }
static literal parse_tensor(const onnx::TensorProto& t) static literal parse_tensor(const onnx::TensorProto& t)
{ {
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end()); std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty())
{
dims = {1};
}
if(t.has_raw_data()) if(t.has_raw_data())
{ {
const std::string& s = t.raw_data(); const std::string& s = t.raw_data();
...@@ -609,7 +882,7 @@ struct onnx_parser ...@@ -609,7 +882,7 @@ struct onnx_parser
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
} }
MIGRAPH_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
switch(t.data_type()) switch(t.data_type())
{ {
...@@ -640,7 +913,7 @@ struct onnx_parser ...@@ -640,7 +913,7 @@ struct onnx_parser
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
} }
MIGRAPH_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
static shape parse_type(const onnx::TypeProto& t) static shape parse_type(const onnx::TypeProto& t)
...@@ -686,6 +959,28 @@ struct onnx_parser ...@@ -686,6 +959,28 @@ struct onnx_parser
}); });
return {shape_type, dims}; return {shape_type, dims};
} }
shape::type_t get_type(int dtype)
{
switch(dtype)
{
case 1: return shape::float_type;
case 2: return shape::uint8_type;
case 3: return shape::int8_type;
case 4: return shape::uint16_type;
case 5: return shape::int16_type;
case 6: return shape::int32_type;
case 7: return shape::int64_type;
case 10: return shape::half_type;
case 11: return shape::double_type;
case 12: return shape::uint32_type;
case 13: return shape::uint64_type;
default:
{
MIGRAPHX_THROW("Prototensor data type " + std::to_string(dtype) + " not supported");
}
}
}
}; };
program parse_onnx(const std::string& name) program parse_onnx(const std::string& name)
...@@ -709,5 +1004,5 @@ program parse_onnx(const std::string& name) ...@@ -709,5 +1004,5 @@ program parse_onnx(const std::string& name)
return std::move(parser.prog); return std::move(parser.prog);
} }
} // namespace MIGRAPH_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraph } // namespace migraphx
#include <migraph/onnx.hpp> #include <migraphx/onnx.hpp>
#include <migraph/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraph/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraphx/generate.hpp>
#include <migraph/verify.hpp> #include <migraphx/verify.hpp>
migraph::program::parameter_map create_param_map(const migraph::program& p, bool gpu = true) migraphx::program::parameter_map create_param_map(const migraphx::program& p, bool gpu = true)
{ {
migraph::program::parameter_map m; migraphx::program::parameter_map m;
for(auto&& x : p.get_parameter_shapes()) for(auto&& x : p.get_parameter_shapes())
{ {
if(gpu) if(gpu)
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second)); m[x.first] = migraphx::gpu::to_gpu(migraphx::generate_argument(x.second));
else else
m[x.first] = migraph::generate_argument(x.second); m[x.first] = migraphx::generate_argument(x.second);
} }
return m; return m;
} }
...@@ -25,9 +25,9 @@ int main(int argc, char const* argv[]) ...@@ -25,9 +25,9 @@ int main(int argc, char const* argv[])
{ {
std::string file = argv[1]; std::string file = argv[1];
std::size_t n = argc > 2 ? std::stoul(argv[2]) : 50; std::size_t n = argc > 2 ? std::stoul(argv[2]) : 50;
auto p = migraph::parse_onnx(file); auto p = migraphx::parse_onnx(file);
std::cout << "Compiling ... " << std::endl; std::cout << "Compiling ... " << std::endl;
p.compile(migraph::gpu::target{}); p.compile(migraphx::gpu::target{});
std::cout << "Allocating params ... " << std::endl; std::cout << "Allocating params ... " << std::endl;
auto m = create_param_map(p); auto m = create_param_map(p);
std::cout << "Running performance report ... " << std::endl; std::cout << "Running performance report ... " << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment