Commit bb390b65 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'activationOperators' into addLogExpOperators

parents adfa1a93 682ad83a
#ifndef MIGRAPH_GUARD_RTGLIB_REFLECT_HPP
#define MIGRAPH_GUARD_RTGLIB_REFLECT_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_REFLECT_HPP
#define MIGRAPHX_GUARD_RTGLIB_REFLECT_HPP
#include <migraphx/functional.hpp>
#include <migraphx/rank.hpp>
......@@ -7,7 +7,7 @@
#include <functional>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
namespace detail {
......@@ -47,7 +47,7 @@ void reflect_each(T& x, F f)
});
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_REQUIRES_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_REQUIRES_HPP
#include <type_traits>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
template <bool... Bs>
struct and_ : std::is_same<and_<Bs...>, and_<(Bs || true)...>> // NOLINT
......@@ -24,29 +24,29 @@ struct requires_enum
};
};
#define MIGRAPH_REQUIRES_CAT(x, y) x##y
#define MIGRAPHX_REQUIRES_CAT(x, y) x##y
#ifdef CPPCHECK
#define MIGRAPH_REQUIRES(...) class = void
#define MIGRAPHX_REQUIRES(...) class = void
#else
#if 0
// TODO: This currently crashed on clang
#define MIGRAPH_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPH_REQUIRES_CAT( \
PrivateRequires, \
__LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__, \
MIGRAPH_REQUIRES_CAT(PrivateRequires, __LINE__) == \
#define MIGRAPHX_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \
PrivateRequires, \
__LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__, \
MIGRAPHX_REQUIRES_CAT(PrivateRequires, __LINE__) == \
migraphx::requires_enum<__LINE__>::a>{}>::type
#else
#define MIGRAPH_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPH_REQUIRES_CAT( \
#define MIGRAPHX_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPHX_REQUIRES_CAT( \
PrivateRequires, __LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__>{}>::type
#endif
#endif
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_HPP
#include <vector>
#include <cassert>
......@@ -12,7 +12,7 @@
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
struct shape_impl;
......@@ -21,7 +21,7 @@ struct shape
// Add new types here
// clang-format off
#define MIGRAPH_SHAPE_VISIT_TYPES(m) \
#define MIGRAPHX_SHAPE_VISIT_TYPES(m) \
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
......@@ -35,22 +35,22 @@ struct shape
m(uint64_type, uint64_t)
// clang-format on
#define MIGRAPH_SHAPE_ENUM_TYPES(x, t) x,
#define MIGRAPHX_SHAPE_ENUM_TYPES(x, t) x,
enum type_t
{
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_ENUM_TYPES)
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_ENUM_TYPES)
};
#undef MIGRAPH_SHAPE_ENUM_TYPES
#undef MIGRAPHX_SHAPE_ENUM_TYPES
template <class T, class = void>
struct get_type;
#define MIGRAPH_SHAPE_GET_TYPE(x, t) \
#define MIGRAPHX_SHAPE_GET_TYPE(x, t) \
template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{ \
};
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_GET_TYPE)
#undef MIGRAPH_SHAPE_GET_TYPE
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GET_TYPE)
#undef MIGRAPHX_SHAPE_GET_TYPE
template <class T>
struct get_type<const T> : get_type<T>
......@@ -148,12 +148,12 @@ struct shape
{
switch(this->type())
{
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
#define MIGRAPHX_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
MIGRAPH_SHAPE_VISIT_TYPES(MIGRAPH_SHAPE_VISITOR_CASE)
#undef MIGRAPH_SHAPE_VISITOR_CASE
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_VISITOR_CASE)
#undef MIGRAPHX_SHAPE_VISITOR_CASE
}
MIGRAPH_THROW("Unknown type");
MIGRAPHX_THROW("Unknown type");
}
private:
......@@ -163,7 +163,7 @@ struct shape
std::string type_string() const;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
template <class F>
void shape_for_each(const migraphx::shape& s, F f)
......@@ -28,7 +28,7 @@ void shape_for_each(const migraphx::shape& s, F f)
}
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
......@@ -18,7 +18,7 @@ struct simplify_algebra
void apply(program& p) const;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
......@@ -19,7 +19,7 @@ struct simplify_reshapes
void apply(program& p) const;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_STREAMUTILS_HPP
#define MIGRAPH_GUARD_STREAMUTILS_HPP
#ifndef MIGRAPHX_GUARD_STREAMUTILS_HPP
#define MIGRAPHX_GUARD_STREAMUTILS_HPP
#include <ostream>
#include <algorithm>
......@@ -7,7 +7,7 @@
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
struct stream_range_container
......@@ -56,7 +56,7 @@ void stream_write_value(std::ostream& os, const T& x)
detail::stream_write_value_impl(rank<1>{}, os, x);
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#include <algorithm>
#include <numeric>
......@@ -8,7 +8,7 @@
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
inline std::string
replace_string(std::string subject, const std::string& search, const std::string& replace)
......@@ -87,7 +87,7 @@ inline std::string to_string(const T& x)
return ss.str();
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_TARGET_HPP
#include <cassert>
#include <string>
......@@ -13,7 +13,7 @@
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
#ifdef DOXYGEN
......@@ -244,7 +244,7 @@ inline const ValueType& any_cast(const target& x)
#endif
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_TENSOR_VIEW_HPP
#define MIGRAPH_GUARD_TENSOR_VIEW_HPP
#ifndef MIGRAPHX_GUARD_TENSOR_VIEW_HPP
#define MIGRAPHX_GUARD_TENSOR_VIEW_HPP
#include <migraphx/shape.hpp>
#include <migraphx/float_equal.hpp>
......@@ -10,7 +10,7 @@
#include <utility>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
struct tensor_view
......@@ -29,7 +29,7 @@ struct tensor_view
const T* data() const { return this->m_data; }
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
template <class... Ts, MIGRAPHX_REQUIRES(std::is_integral<Ts>{}...)>
const T& operator()(Ts... xs) const
{
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
......@@ -37,7 +37,7 @@ struct tensor_view
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
}
template <class... Ts, MIGRAPH_REQUIRES(std::is_integral<Ts>{}...)>
template <class... Ts, MIGRAPHX_REQUIRES(std::is_integral<Ts>{}...)>
T& operator()(Ts... xs)
{
assert(std::vector<std::size_t>{static_cast<std::size_t>(xs)...} < m_shape.lens());
......@@ -45,13 +45,13 @@ struct tensor_view
return m_data[m_shape.index({static_cast<std::size_t>(xs)...})];
}
template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})>
template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
const T& operator()(Iterator start, Iterator last) const
{
return m_data[m_shape.index(start, last)];
}
template <class Iterator, MIGRAPH_REQUIRES(not std::is_integral<Iterator>{})>
template <class Iterator, MIGRAPHX_REQUIRES(not std::is_integral<Iterator>{})>
T& operator()(Iterator start, Iterator last)
{
return m_data[m_shape.index(start, last)];
......@@ -169,7 +169,7 @@ tensor_view<T> make_view(shape s, T* data)
return {s, data};
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_TIME_HPP
#define MIGRAPH_GUARD_RTGLIB_TIME_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_TIME_HPP
#define MIGRAPHX_GUARD_RTGLIB_TIME_HPP
#include <chrono>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
template <class Duration, class F>
auto time(F f)
......@@ -16,7 +16,7 @@ auto time(F f)
return std::chrono::duration_cast<Duration>(finish - start).count();
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPHX_GUARD_RTGLIB_TRACER_HPP
#include <ostream>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
struct tracer
{
......@@ -30,7 +30,7 @@ struct tracer
std::ostream* os = nullptr;
};
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_NAME_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
template <class PrivateMigraphTypeNameProbe>
const std::string& get_type_name()
......@@ -41,7 +41,7 @@ const std::string& get_type_name(const T&)
return migraphx::get_type_name<T>();
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -5,32 +5,32 @@
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPHX_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
#define MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
};
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half)
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_VERIFY_HPP
#define MIGRAPH_GUARD_VERIFY_HPP
#ifndef MIGRAPHX_GUARD_VERIFY_HPP
#define MIGRAPHX_GUARD_VERIFY_HPP
#include <algorithm>
#include <cmath>
......@@ -11,7 +11,7 @@
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
// Compute the value of a range
template <class R>
......@@ -173,6 +173,6 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n
return error <= threshold;
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPHX_GUARD_RTGLIB_VERIFY_ARGS_HPP
#include <migraphx/verify.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
inline bool verify_args(const std::string& name,
const argument& cpu_arg,
......@@ -84,7 +84,7 @@ inline bool verify_args(const std::string& name,
return passed;
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -3,7 +3,7 @@
#include <migraphx/erase.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
instruction::instruction(operation o, shape r, std::vector<instruction_ref> args)
: op(std::move(o)), result(std::move(r)), arguments(std::move(args))
......@@ -183,5 +183,5 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg
return op.compute_shape(compute_shapes(args));
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -22,7 +22,7 @@ rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx migraphx_onnx)
if(MIGRAPH_ENABLE_GPU)
if(MIGRAPHX_ENABLE_GPU)
add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist)
target_link_libraries(mnist migraphx_cpu migraphx_gpu migraphx_onnx)
......
......@@ -17,7 +17,7 @@
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
struct unknown
{
std::string op;
......@@ -43,7 +43,8 @@ struct onnx_parser
using op_func = std::function<instruction_ref(attribute_map, std::vector<instruction_ref>)>;
node_map nodes;
std::unordered_map<std::string, instruction_ref> instructions;
program prog = program();
program prog = program();
bool is_pytorch = false;
std::unordered_map<std::string, op_func> ops;
......@@ -113,7 +114,7 @@ struct onnx_parser
{
ops.emplace(name, [this, x](attribute_map attributes, std::vector<instruction_ref> args) {
if(args.size() != 2)
MIGRAPH_THROW("binary operators should have 2 operands");
MIGRAPHX_THROW("binary operators should have 2 operands");
if(contains(attributes, "broadcast"))
{
uint64_t broadcasted = parse_value(attributes.at("broadcast")).at<uint64_t>();
......@@ -152,8 +153,8 @@ struct onnx_parser
std::swap(s0, s1);
// Copy the larger vector to output_lens
std::vector<std::size_t> output_lens(s1->size());
auto offset = s1->size() - s0->size();
std::vector<std::size_t> output_lens = *s1;
auto offset = s1->size() - s0->size();
std::transform(s0->begin(),
s0->end(),
s1->begin() + offset,
......@@ -195,7 +196,22 @@ struct onnx_parser
op::convolution op;
if(contains(attributes, "pads"))
{
copy(attributes["pads"].ints(), op.padding.begin());
if(contains(attributes, "auto_pad"))
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
std::vector<std::size_t> padding(4);
copy(attributes["pads"].ints(), padding.begin());
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
if(contains(attributes, "strides"))
{
......@@ -205,6 +221,19 @@ struct onnx_parser
{
copy(attributes["dilations"].ints(), op.dilation.begin());
}
if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(contains(attributes, "pads") and to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("auto_pad and padding cannot be specified simultaneously");
}
if(s.find("SAME") != std::string::npos)
{
op.padding_mode = op::convolution::same;
}
}
if(args.size() == 3)
{
uint64_t axis = 1;
......@@ -227,7 +256,18 @@ struct onnx_parser
}
if(contains(attributes, "pads"))
{
copy(attributes["pads"].ints(), op.padding.begin());
std::vector<std::size_t> padding(4);
copy(attributes["pads"].ints(), padding.begin());
if(padding.size() != 4)
{
MIGRAPHX_THROW("padding should have 4 values");
}
if(padding[0] != padding[2] || padding[1] != padding[3])
{
MIGRAPHX_THROW("migraphx does not support asymetric padding");
}
op.padding[0] = padding[0];
op.padding[1] = padding[1];
}
if(contains(attributes, "strides"))
{
......@@ -237,6 +277,15 @@ struct onnx_parser
{
copy(attributes["kernel_shape"].ints(), op.lengths.begin());
}
if(contains(attributes, "auto_pad"))
{
auto s = attributes["auto_pad"].s();
if(to_upper(s) != "NOTSET")
{
MIGRAPHX_THROW("auto_pad is not supported for pooling");
}
}
return prog.add_instruction(op, std::move(args));
}
......@@ -502,7 +551,7 @@ struct onnx_parser
void parse_node(const std::string& name)
{
if(name.empty())
MIGRAPH_THROW("Onnx node must have a name");
MIGRAPHX_THROW("Onnx node must have a name");
if(instructions.count(name) == 0)
{
auto&& node = nodes.at(name);
......@@ -592,7 +641,7 @@ struct onnx_parser
case onnx::AttributeProto::TENSORS: return {};
case onnx::AttributeProto::GRAPHS: return {};
}
MIGRAPH_THROW("Invalid attribute type");
MIGRAPHX_THROW("Invalid attribute type");
}
static literal parse_tensor(const onnx::TensorProto& t)
......@@ -620,7 +669,7 @@ struct onnx_parser
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
}
MIGRAPH_THROW("Invalid tensor type");
MIGRAPHX_THROW("Invalid tensor type");
}
switch(t.data_type())
{
......@@ -651,7 +700,7 @@ struct onnx_parser
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
}
MIGRAPH_THROW("Invalid tensor type");
MIGRAPHX_THROW("Invalid tensor type");
}
static shape parse_type(const onnx::TypeProto& t)
......@@ -720,5 +769,5 @@ program parse_onnx(const std::string& name)
return std::move(parser.prog);
}
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
......@@ -14,17 +14,17 @@
#include <queue>
namespace migraphx {
inline namespace MIGRAPH_INLINE_NS {
inline namespace MIGRAPHX_INLINE_NS {
//#define MIGRAPH_DEBUG_OPT
//#define MIGRAPHX_DEBUG_OPT
#ifdef MIGRAPH_DEBUG_OPT
#define MIGRAPH_DEBUG(s) s
#ifdef MIGRAPHX_DEBUG_OPT
#define MIGRAPHX_DEBUG(s) s
#else
#define MIGRAPH_DEBUG(s)
#endif // MIGRAPH_DEBUG_OPT
#define MIGRAPHX_DEBUG(s)
#endif // MIGRAPHX_DEBUG_OPT
} // namespace MIGRAPH_INLINE_NS
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#endif // MIGRAPHX_GUARD_RTGLIB_COMMON_HEADER_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment