Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
87d35010
"vllm_flash_attn/utils/generation.py" did not exist on "605655bc66bc0fb11ac10dad2e656a97f3729b5b"
Commit
87d35010
authored
Nov 29, 2018
by
Khalique
Browse files
manual merge
parents
6aebef15
84e7335e
Changes
159
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
212 additions
and
124 deletions
+212
-124
src/include/migraphx/reflect.hpp
src/include/migraphx/reflect.hpp
+4
-4
src/include/migraphx/requires.hpp
src/include/migraphx/requires.hpp
+14
-14
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+15
-15
src/include/migraphx/shape_for_each.hpp
src/include/migraphx/shape_for_each.hpp
+4
-4
src/include/migraphx/simplify_algebra.hpp
src/include/migraphx/simplify_algebra.hpp
+4
-4
src/include/migraphx/simplify_reshapes.hpp
src/include/migraphx/simplify_reshapes.hpp
+4
-4
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+4
-4
src/include/migraphx/stringutils.hpp
src/include/migraphx/stringutils.hpp
+4
-4
src/include/migraphx/target.hpp
src/include/migraphx/target.hpp
+4
-4
src/include/migraphx/tensor_view.hpp
src/include/migraphx/tensor_view.hpp
+8
-8
src/include/migraphx/time.hpp
src/include/migraphx/time.hpp
+4
-4
src/include/migraphx/tracer.hpp
src/include/migraphx/tracer.hpp
+4
-4
src/include/migraphx/type_name.hpp
src/include/migraphx/type_name.hpp
+4
-4
src/include/migraphx/type_traits.hpp
src/include/migraphx/type_traits.hpp
+16
-16
src/include/migraphx/verify.hpp
src/include/migraphx/verify.hpp
+4
-4
src/include/migraphx/verify_args.hpp
src/include/migraphx/verify_args.hpp
+4
-4
src/instruction.cpp
src/instruction.cpp
+2
-2
src/onnx/CMakeLists.txt
src/onnx/CMakeLists.txt
+1
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+98
-10
src/opt/common_header.hpp
src/opt/common_header.hpp
+10
-10
No files found.
src/include/migraphx/reflect.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_RTGLIB_REFLECT_HPP
#define MIGRAPH_GUARD_RTGLIB_REFLECT_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_REFLECT_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_REFLECT_HPP
#include <migraphx/functional.hpp>
#include <migraphx/rank.hpp>
...
...
@@ -7,7 +7,7 @@
#include <functional>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
namespace
detail
{
...
...
@@ -47,7 +47,7 @@ void reflect_each(T& x, F f)
});
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/requires.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_REQUIRES_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_REQUIRES_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_REQUIRES_HPP
#include <type_traits>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
bool
...
Bs
>
struct
and_
:
std
::
is_same
<
and_
<
Bs
...
>
,
and_
<
(
Bs
||
true
)...
>>
// NOLINT
...
...
@@ -24,29 +24,29 @@ struct requires_enum
};
};
#define MIGRAPH_REQUIRES_CAT(x, y) x##y
#define MIGRAPH
X
_REQUIRES_CAT(x, y) x##y
#ifdef CPPCHECK
#define MIGRAPH_REQUIRES(...) class = void
#define MIGRAPH
X
_REQUIRES(...) class = void
#else
#if 0
// TODO: This currently crashed on clang
#define MIGRAPH_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPH_REQUIRES_CAT( \
PrivateRequires, \
__LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__, \
MIGRAPH_REQUIRES_CAT(PrivateRequires, __LINE__) == \
#define MIGRAPH
X
_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPH
X
_REQUIRES_CAT( \
PrivateRequires,
\
__LINE__) = migraphx::requires_enum<__LINE__>::a,
\
class = typename std::enable_if<and_<__VA_ARGS__,
\
MIGRAPH
X
_REQUIRES_CAT(PrivateRequires, __LINE__) == \
migraphx::requires_enum<__LINE__>::a>{}>::type
#else
#define MIGRAPH_REQUIRES(...)
\
typename migraphx::requires_enum<__LINE__>::e MIGRAPH_REQUIRES_CAT(
\
#define MIGRAPH
X
_REQUIRES(...) \
typename migraphx::requires_enum<__LINE__>::e MIGRAPH
X
_REQUIRES_CAT( \
PrivateRequires, __LINE__) = migraphx::requires_enum<__LINE__>::a, \
class = typename std::enable_if<and_<__VA_ARGS__>{}>::type
#endif
#endif
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/shape.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_SHAPE_HPP
#include <vector>
#include <cassert>
...
...
@@ -12,7 +12,7 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
shape_impl
;
...
...
@@ -21,7 +21,7 @@ struct shape
// Add new types here
// clang-format off
#define MIGRAPH_SHAPE_VISIT_TYPES(m) \
#define MIGRAPH
X
_SHAPE_VISIT_TYPES(m) \
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
...
...
@@ -35,22 +35,22 @@ struct shape
m(uint64_type, uint64_t)
// clang-format on
#define MIGRAPH_SHAPE_ENUM_TYPES(x, t) x,
#define MIGRAPH
X
_SHAPE_ENUM_TYPES(x, t) x,
enum
type_t
{
MIGRAPH_SHAPE_VISIT_TYPES
(
MIGRAPH_SHAPE_ENUM_TYPES
)
MIGRAPH
X
_SHAPE_VISIT_TYPES
(
MIGRAPH
X
_SHAPE_ENUM_TYPES
)
};
#undef MIGRAPH_SHAPE_ENUM_TYPES
#undef MIGRAPH
X
_SHAPE_ENUM_TYPES
template
<
class
T
,
class
=
void
>
struct
get_type
;
#define MIGRAPH_SHAPE_GET_TYPE(x, t)
\
#define MIGRAPH
X
_SHAPE_GET_TYPE(x, t) \
template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{ \
};
MIGRAPH_SHAPE_VISIT_TYPES
(
MIGRAPH_SHAPE_GET_TYPE
)
#undef MIGRAPH_SHAPE_GET_TYPE
MIGRAPH
X
_SHAPE_VISIT_TYPES
(
MIGRAPH
X
_SHAPE_GET_TYPE
)
#undef MIGRAPH
X
_SHAPE_GET_TYPE
template
<
class
T
>
struct
get_type
<
const
T
>
:
get_type
<
T
>
...
...
@@ -148,12 +148,12 @@ struct shape
{
switch
(
this
->
type
())
{
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
#define MIGRAPH
X
_SHAPE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
MIGRAPH_SHAPE_VISIT_TYPES
(
MIGRAPH_SHAPE_VISITOR_CASE
)
#undef MIGRAPH_SHAPE_VISITOR_CASE
MIGRAPH
X
_SHAPE_VISIT_TYPES
(
MIGRAPH
X
_SHAPE_VISITOR_CASE
)
#undef MIGRAPH
X
_SHAPE_VISITOR_CASE
}
MIGRAPH_THROW
(
"Unknown type"
);
MIGRAPH
X
_THROW
(
"Unknown type"
);
}
private:
...
...
@@ -163,7 +163,7 @@ struct shape
std
::
string
type_string
()
const
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/shape_for_each.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp>
#include <algorithm>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
F
>
void
shape_for_each
(
const
migraphx
::
shape
&
s
,
F
f
)
...
...
@@ -28,7 +28,7 @@ void shape_for_each(const migraphx::shape& s, F f)
}
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/simplify_algebra.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#include <string>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
program
;
...
...
@@ -18,7 +18,7 @@ struct simplify_algebra
void
apply
(
program
&
p
)
const
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/simplify_reshapes.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
program
;
...
...
@@ -19,7 +19,7 @@ struct simplify_reshapes
void
apply
(
program
&
p
)
const
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/streamutils.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_STREAMUTILS_HPP
#define MIGRAPH_GUARD_STREAMUTILS_HPP
#ifndef MIGRAPH
X
_GUARD_STREAMUTILS_HPP
#define MIGRAPH
X
_GUARD_STREAMUTILS_HPP
#include <ostream>
#include <algorithm>
...
...
@@ -7,7 +7,7 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
T
>
struct
stream_range_container
...
...
@@ -56,7 +56,7 @@ void stream_write_value(std::ostream& os, const T& x)
detail
::
stream_write_value_impl
(
rank
<
1
>
{},
os
,
x
);
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/stringutils.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#include <algorithm>
#include <numeric>
...
...
@@ -8,7 +8,7 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
inline
std
::
string
replace_string
(
std
::
string
subject
,
const
std
::
string
&
search
,
const
std
::
string
&
replace
)
...
...
@@ -87,7 +87,7 @@ inline std::string to_string(const T& x)
return
ss
.
str
();
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/target.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_TARGET_HPP
#include <cassert>
#include <string>
...
...
@@ -13,7 +13,7 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
#ifdef DOXYGEN
...
...
@@ -244,7 +244,7 @@ inline const ValueType& any_cast(const target& x)
#endif
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/tensor_view.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_TENSOR_VIEW_HPP
#define MIGRAPH_GUARD_TENSOR_VIEW_HPP
#ifndef MIGRAPH
X
_GUARD_TENSOR_VIEW_HPP
#define MIGRAPH
X
_GUARD_TENSOR_VIEW_HPP
#include <migraphx/shape.hpp>
#include <migraphx/float_equal.hpp>
...
...
@@ -10,7 +10,7 @@
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
T
>
struct
tensor_view
...
...
@@ -29,7 +29,7 @@ struct tensor_view
const
T
*
data
()
const
{
return
this
->
m_data
;
}
template
<
class
...
Ts
,
MIGRAPH_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
template
<
class
...
Ts
,
MIGRAPH
X
_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
const
T
&
operator
()(
Ts
...
xs
)
const
{
assert
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
xs
)...}
<
m_shape
.
lens
());
...
...
@@ -37,7 +37,7 @@ struct tensor_view
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
}
template
<
class
...
Ts
,
MIGRAPH_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
template
<
class
...
Ts
,
MIGRAPH
X
_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
T
&
operator
()(
Ts
...
xs
)
{
assert
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
xs
)...}
<
m_shape
.
lens
());
...
...
@@ -45,13 +45,13 @@ struct tensor_view
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
}
template
<
class
Iterator
,
MIGRAPH_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
template
<
class
Iterator
,
MIGRAPH
X
_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
const
T
&
operator
()(
Iterator
start
,
Iterator
last
)
const
{
return
m_data
[
m_shape
.
index
(
start
,
last
)];
}
template
<
class
Iterator
,
MIGRAPH_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
template
<
class
Iterator
,
MIGRAPH
X
_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
T
&
operator
()(
Iterator
start
,
Iterator
last
)
{
return
m_data
[
m_shape
.
index
(
start
,
last
)];
...
...
@@ -169,7 +169,7 @@ tensor_view<T> make_view(shape s, T* data)
return
{
s
,
data
};
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/time.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_RTGLIB_TIME_HPP
#define MIGRAPH_GUARD_RTGLIB_TIME_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_TIME_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_TIME_HPP
#include <chrono>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
Duration
,
class
F
>
auto
time
(
F
f
)
...
...
@@ -16,7 +16,7 @@ auto time(F f)
return
std
::
chrono
::
duration_cast
<
Duration
>
(
finish
-
start
).
count
();
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/tracer.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_TRACER_HPP
#include <ostream>
#include <migraphx/functional.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
tracer
{
...
...
@@ -30,7 +30,7 @@ struct tracer
std
::
ostream
*
os
=
nullptr
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/type_name.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_TYPE_NAME_HPP
#include <string>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
PrivateMigraphTypeNameProbe
>
const
std
::
string
&
get_type_name
()
...
...
@@ -41,7 +41,7 @@ const std::string& get_type_name(const T&)
return
migraphx
::
get_type_name
<
T
>
();
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/type_traits.hpp
View file @
87d35010
...
...
@@ -5,32 +5,32 @@
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <type_traits>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
#define MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
#define MIGRAPH
X
_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X>
\
struct trait : std::trait<X>
\
{
\
};
\
\
template <>
\
struct trait<T> : std::true_type
\
{
\
};
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPH
X
_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPH
X
_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPH
X
_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/verify.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_VERIFY_HPP
#define MIGRAPH_GUARD_VERIFY_HPP
#ifndef MIGRAPH
X
_GUARD_VERIFY_HPP
#define MIGRAPH
X
_GUARD_VERIFY_HPP
#include <algorithm>
#include <cmath>
...
...
@@ -11,7 +11,7 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
// Compute the value of a range
template
<
class
R
>
...
...
@@ -173,6 +173,6 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n
return
error
<=
threshold
;
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/verify_args.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_VERIFY_ARGS_HPP
#include <migraphx/verify.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
inline
bool
verify_args
(
const
std
::
string
&
name
,
const
argument
&
cpu_arg
,
...
...
@@ -84,7 +84,7 @@ inline bool verify_args(const std::string& name,
return
passed
;
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif
src/instruction.cpp
View file @
87d35010
...
...
@@ -3,7 +3,7 @@
#include <migraphx/erase.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
instruction
::
instruction
(
operation
o
,
shape
r
,
std
::
vector
<
instruction_ref
>
args
)
:
op
(
std
::
move
(
o
)),
result
(
std
::
move
(
r
)),
arguments
(
std
::
move
(
args
))
...
...
@@ -183,5 +183,5 @@ shape compute_shape(const operation& op, const std::vector<instruction_ref>& arg
return
op
.
compute_shape
(
compute_shapes
(
args
));
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
src/onnx/CMakeLists.txt
View file @
87d35010
...
...
@@ -22,7 +22,7 @@ rocm_clang_tidy_check(read_onnx)
target_link_libraries
(
read_onnx migraphx_onnx
)
if
(
MIGRAPH_ENABLE_GPU
)
if
(
MIGRAPH
X
_ENABLE_GPU
)
add_executable
(
mnist mnist.cpp
)
rocm_clang_tidy_check
(
mnist
)
target_link_libraries
(
mnist migraphx_cpu migraphx_gpu migraphx_onnx
)
...
...
src/onnx/onnx.cpp
View file @
87d35010
...
...
@@ -17,7 +17,7 @@
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
unknown
{
std
::
string
op
;
...
...
@@ -43,7 +43,8 @@ struct onnx_parser
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
program
prog
=
program
();
program
prog
=
program
();
bool
is_pytorch
=
false
;
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
...
...
@@ -107,7 +108,7 @@ struct onnx_parser
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
MIGRAPH_THROW
(
"binary operators should have 2 operands"
);
MIGRAPH
X
_THROW
(
"binary operators should have 2 operands"
);
if
(
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
...
...
@@ -122,6 +123,45 @@ struct onnx_parser
}
return
prog
.
add_instruction
(
x
,
args
);
}
<<<<<<<
HEAD
=======
else
if
(
args
[
0
]
->
get_shape
()
!=
args
[
1
]
->
get_shape
())
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const
std
::
vector
<
std
::
size_t
>*
s0
=
&
args
[
0
]
->
get_shape
().
lens
();
const
std
::
vector
<
std
::
size_t
>*
s1
=
&
args
[
1
]
->
get_shape
().
lens
();
// Make sure s0 is the smaller size
if
(
s0
->
size
()
>
s1
->
size
())
std
::
swap
(
s0
,
s1
);
// Copy the larger vector to output_lens
std
::
vector
<
std
::
size_t
>
output_lens
=
*
s1
;
auto
offset
=
s1
->
size
()
-
s0
->
size
();
std
::
transform
(
s0
->
begin
(),
s0
->
end
(),
s1
->
begin
()
+
offset
,
output_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
args
[
0
]);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
args
[
1
]);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
}
>>>>>>>
84e7335
eb6088f9918dcf86f9fc1b58ef27c3360
else
{
return
add_broadcastable_binary_op
(
args
[
0
],
args
[
1
],
x
);
...
...
@@ -239,7 +279,22 @@ struct onnx_parser
op
::
convolution
op
;
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
if
(
contains
(
attributes
,
"auto_pad"
))
{
MIGRAPHX_THROW
(
"auto_pad and padding cannot be specified simultaneously"
);
}
std
::
vector
<
std
::
size_t
>
padding
(
4
);
copy
(
attributes
[
"pads"
].
ints
(),
padding
.
begin
());
if
(
padding
.
size
()
!=
4
)
{
MIGRAPHX_THROW
(
"padding should have 4 values"
);
}
if
(
padding
[
0
]
!=
padding
[
2
]
||
padding
[
1
]
!=
padding
[
3
])
{
MIGRAPHX_THROW
(
"migraphx does not support asymetric padding"
);
}
op
.
padding
[
0
]
=
padding
[
0
];
op
.
padding
[
1
]
=
padding
[
1
];
}
if
(
contains
(
attributes
,
"strides"
))
{
...
...
@@ -249,6 +304,19 @@ struct onnx_parser
{
copy
(
attributes
[
"dilations"
].
ints
(),
op
.
dilation
.
begin
());
}
if
(
contains
(
attributes
,
"auto_pad"
))
{
auto
s
=
attributes
[
"auto_pad"
].
s
();
if
(
contains
(
attributes
,
"pads"
)
and
to_upper
(
s
)
!=
"NOTSET"
)
{
MIGRAPHX_THROW
(
"auto_pad and padding cannot be specified simultaneously"
);
}
if
(
s
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
op
.
padding_mode
=
op
::
convolution
::
same
;
}
}
if
(
args
.
size
()
==
3
)
{
uint64_t
axis
=
1
;
...
...
@@ -271,7 +339,18 @@ struct onnx_parser
}
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
std
::
vector
<
std
::
size_t
>
padding
(
4
);
copy
(
attributes
[
"pads"
].
ints
(),
padding
.
begin
());
if
(
padding
.
size
()
!=
4
)
{
MIGRAPHX_THROW
(
"padding should have 4 values"
);
}
if
(
padding
[
0
]
!=
padding
[
2
]
||
padding
[
1
]
!=
padding
[
3
])
{
MIGRAPHX_THROW
(
"migraphx does not support asymetric padding"
);
}
op
.
padding
[
0
]
=
padding
[
0
];
op
.
padding
[
1
]
=
padding
[
1
];
}
if
(
contains
(
attributes
,
"strides"
))
{
...
...
@@ -281,6 +360,15 @@ struct onnx_parser
{
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
if
(
contains
(
attributes
,
"auto_pad"
))
{
auto
s
=
attributes
[
"auto_pad"
].
s
();
if
(
to_upper
(
s
)
!=
"NOTSET"
)
{
MIGRAPHX_THROW
(
"auto_pad is not supported for pooling"
);
}
}
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
}
...
...
@@ -546,7 +634,7 @@ struct onnx_parser
void
parse_node
(
const
std
::
string
&
name
)
{
if
(
name
.
empty
())
MIGRAPH_THROW
(
"Onnx node must have a name"
);
MIGRAPH
X
_THROW
(
"Onnx node must have a name"
);
if
(
instructions
.
count
(
name
)
==
0
)
{
auto
&&
node
=
nodes
.
at
(
name
);
...
...
@@ -636,7 +724,7 @@ struct onnx_parser
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
}
MIGRAPH_THROW
(
"Invalid attribute type"
);
MIGRAPH
X
_THROW
(
"Invalid attribute type"
);
}
static
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
...
...
@@ -669,7 +757,7 @@ struct onnx_parser
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
MIGRAPH_THROW
(
"Invalid tensor type"
);
MIGRAPH
X
_THROW
(
"Invalid tensor type"
);
}
switch
(
t
.
data_type
())
{
...
...
@@ -700,7 +788,7 @@ struct onnx_parser
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
MIGRAPH_THROW
(
"Invalid tensor type"
);
MIGRAPH
X
_THROW
(
"Invalid tensor type"
);
}
static
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
...
...
@@ -769,5 +857,5 @@ program parse_onnx(const std::string& name)
return
std
::
move
(
parser
.
prog
);
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
src/opt/common_header.hpp
View file @
87d35010
#ifndef MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_COMMON_HEADER_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_COMMON_HEADER_HPP
#include <migraphx/program.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
...
...
@@ -14,17 +14,17 @@
#include <queue>
namespace
migraphx
{
inline
namespace
MIGRAPH_INLINE_NS
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
//#define MIGRAPH_DEBUG_OPT
//#define MIGRAPH
X
_DEBUG_OPT
#ifdef MIGRAPH_DEBUG_OPT
#define MIGRAPH_DEBUG(s) s
#ifdef MIGRAPH
X
_DEBUG_OPT
#define MIGRAPH
X
_DEBUG(s) s
#else
#define MIGRAPH_DEBUG(s)
#endif // MIGRAPH_DEBUG_OPT
#define MIGRAPH
X
_DEBUG(s)
#endif // MIGRAPH
X
_DEBUG_OPT
}
// namespace MIGRAPH_INLINE_NS
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPH_GUARD_RTGLIB_COMMON_HEADER_HPP
#endif // MIGRAPH
X
_GUARD_RTGLIB_COMMON_HEADER_HPP
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment