Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
bc5d7f75
Commit
bc5d7f75
authored
Feb 15, 2019
by
Paul
Browse files
Merge from develop
parents
47c0854d
a5b0afa0
Changes
337
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1033 additions
and
292 deletions
+1033
-292
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+53
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+43
-20
src/include/migraphx/shape_for_each.hpp
src/include/migraphx/shape_for_each.hpp
+9
-9
src/include/migraphx/simplify_algebra.hpp
src/include/migraphx/simplify_algebra.hpp
+24
-0
src/include/migraphx/simplify_reshapes.hpp
src/include/migraphx/simplify_reshapes.hpp
+25
-0
src/include/migraphx/streamutils.hpp
src/include/migraphx/streamutils.hpp
+8
-8
src/include/migraphx/stringutils.hpp
src/include/migraphx/stringutils.hpp
+7
-7
src/include/migraphx/target.hpp
src/include/migraphx/target.hpp
+15
-9
src/include/migraphx/tensor_view.hpp
src/include/migraphx/tensor_view.hpp
+15
-15
src/include/migraphx/time.hpp
src/include/migraphx/time.hpp
+7
-7
src/include/migraphx/tracer.hpp
src/include/migraphx/tracer.hpp
+8
-8
src/include/migraphx/type_name.hpp
src/include/migraphx/type_name.hpp
+9
-9
src/include/migraphx/type_traits.hpp
src/include/migraphx/type_traits.hpp
+36
-0
src/include/migraphx/verify.hpp
src/include/migraphx/verify.hpp
+8
-8
src/include/migraphx/verify_args.hpp
src/include/migraphx/verify_args.hpp
+9
-9
src/instruction.cpp
src/instruction.cpp
+46
-17
src/onnx/CMakeLists.txt
src/onnx/CMakeLists.txt
+12
-12
src/onnx/cifar10.cpp
src/onnx/cifar10.cpp
+15
-15
src/onnx/mnist.cpp
src/onnx/mnist.cpp
+18
-13
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+666
-126
No files found.
src/include/migraphx/rewrite_rnn.hpp
0 → 100644
View file @
bc5d7f75
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_RNN_HPP
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
program
;
/**
* Rewrite rnn to gemm and add.
*/
struct
rewrite_rnn
{
std
::
string
name
()
const
{
return
"rewrite_rnn"
;
}
void
apply
(
program
&
prog
)
const
;
private:
// for vanilla rnn operators
void
apply_vanilla_rnn
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
vanilla_rnn_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
w
,
instruction_ref
r
,
instruction_ref
bias
,
instruction_ref
ih
,
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
vanilla_rnn_actv_funcs
(
instruction_ref
ins
)
const
;
// for gru operators
void
apply_gru
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
gru_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
)
const
;
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraph/shape.hpp
→
src/include/migraph
x
/shape.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_SHAPE_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_SHAPE_HPP
#include <vector>
#include <cassert>
...
...
@@ -7,12 +7,12 @@
#include <numeric>
#include <memory>
#include <migraph/errors.hpp>
#include <migraph/half.hpp>
#include <migraph/config.hpp>
#include <migraph
x
/errors.hpp>
#include <migraph
x
/half.hpp>
#include <migraph
x
/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
shape_impl
;
...
...
@@ -21,7 +21,7 @@ struct shape
// Add new types here
// clang-format off
#define MIGRAPH_SHAPE_VISIT_TYPES(m) \
#define MIGRAPH
X
_SHAPE_VISIT_TYPES(m) \
m(half_type, half) \
m(float_type, float) \
m(double_type, double) \
...
...
@@ -35,22 +35,22 @@ struct shape
m(uint64_type, uint64_t)
// clang-format on
#define MIGRAPH_SHAPE_ENUM_TYPES(x, t) x,
#define MIGRAPH
X
_SHAPE_
GENERATE_
ENUM_TYPES(x, t) x,
enum
type_t
{
MIGRAPH_SHAPE_VISIT_TYPES
(
MIGRAPH_SHAPE_ENUM_TYPES
)
MIGRAPH
X
_SHAPE_VISIT_TYPES
(
MIGRAPH
X
_SHAPE_
GENERATE_
ENUM_TYPES
)
};
#undef MIGRAPH_SHAPE_ENUM_TYPES
#undef MIGRAPH
X
_SHAPE_
GENERATE_
ENUM_TYPES
template
<
class
T
,
class
=
void
>
struct
get_type
;
#define MIGRAPH_SHAPE_GET_TYPE(x, t)
\
#define MIGRAPH
X
_SHAPE_
GENERATE_
GET_TYPE(x, t) \
template <class T> \
struct get_type<t, T> : std::integral_constant<type_t, x> \
{ \
};
MIGRAPH_SHAPE_VISIT_TYPES
(
MIGRAPH_SHAPE_GET_TYPE
)
#undef MIGRAPH_SHAPE_GET_TYPE
MIGRAPH
X
_SHAPE_VISIT_TYPES
(
MIGRAPH
X
_SHAPE_
GENERATE_
GET_TYPE
)
#undef MIGRAPH
X
_SHAPE_
GENERATE_
GET_TYPE
template
<
class
T
>
struct
get_type
<
const
T
>
:
get_type
<
T
>
...
...
@@ -62,6 +62,19 @@ struct shape
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
);
shape
(
type_t
t
,
std
::
vector
<
std
::
size_t
>
l
,
std
::
vector
<
std
::
size_t
>
s
);
template
<
class
Range
>
shape
(
type_t
t
,
const
Range
&
l
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()))
{
}
template
<
class
Range1
,
class
Range2
>
shape
(
type_t
t
,
const
Range1
&
l
,
const
Range2
&
s
)
:
shape
(
t
,
std
::
vector
<
std
::
size_t
>
(
l
.
begin
(),
l
.
end
()),
std
::
vector
<
std
::
size_t
>
(
s
.
begin
(),
s
.
end
()))
{
}
type_t
type
()
const
;
const
std
::
vector
<
std
::
size_t
>&
lens
()
const
;
const
std
::
vector
<
std
::
size_t
>&
strides
()
const
;
...
...
@@ -141,6 +154,8 @@ struct shape
{
return
reinterpret_cast
<
const
T
*>
(
buffer
)
+
n
;
}
type_t
type_enum
()
const
{
return
get_type
<
T
>
{};
}
};
template
<
class
Visitor
>
...
...
@@ -148,12 +163,20 @@ struct shape
{
switch
(
this
->
type
())
{
#define MIGRAPH_SHAPE_VISITOR_CASE(x, t) \
#define MIGRAPH
X
_SHAPE_
GENERATE_
VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
MIGRAPH_SHAPE_VISIT_TYPES
(
MIGRAPH_SHAPE_VISITOR_CASE
)
#undef MIGRAPH_SHAPE_VISITOR_CASE
MIGRAPH
X
_SHAPE_VISIT_TYPES
(
MIGRAPH
X
_SHAPE_
GENERATE_
VISITOR_CASE
)
#undef MIGRAPH
X
_SHAPE_
GENERATE_
VISITOR_CASE
}
MIGRAPH_THROW
(
"Unknown type"
);
MIGRAPHX_THROW
(
"Unknown type"
);
}
template
<
class
Visitor
>
static
void
visit_types
(
Visitor
v
)
{
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL(x, t) v(as<t>());
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
)
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
}
private:
...
...
@@ -163,7 +186,7 @@ struct shape
std
::
string
type_string
()
const
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/include/migraph/shape_for_each.hpp
→
src/include/migraph
x
/shape_for_each.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_SHAPE_FOR_EACH_HPP
#include <migraph/shape.hpp>
#include <migraph/config.hpp>
#include <migraph
x
/shape.hpp>
#include <migraph
x
/config.hpp>
#include <algorithm>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
F
>
void
shape_for_each
(
const
migraph
::
shape
&
s
,
F
f
)
void
shape_for_each
(
const
migraph
x
::
shape
&
s
,
F
f
)
{
// Ensure calls to f use const ref to vector
auto
call
=
[
&
f
](
const
std
::
vector
<
std
::
size_t
>&
i
)
{
f
(
i
);
};
...
...
@@ -28,7 +28,7 @@ void shape_for_each(const migraph::shape& s, F f)
}
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/include/migraph/simplify_algebra.hpp
→
src/include/migraph
x
/simplify_algebra.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_SIMPLIFY_ALGEBRA_HPP
#include <string>
#include <migraph/config.hpp>
#include <migraph
x
/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
program
;
...
...
@@ -18,7 +18,7 @@ struct simplify_algebra
void
apply
(
program
&
p
)
const
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/include/migraph/simplify_reshapes.hpp
→
src/include/migraph
x
/simplify_reshapes.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_SIMPLIFY_RESHAPES_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
#include <migraph/config.hpp>
#include <migraph
x
/instruction_ref.hpp>
#include <migraph
x
/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
program
;
...
...
@@ -19,7 +19,7 @@ struct simplify_reshapes
void
apply
(
program
&
p
)
const
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/include/migraph/streamutils.hpp
→
src/include/migraph
x
/streamutils.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_STREAMUTILS_HPP
#define MIGRAPH_GUARD_STREAMUTILS_HPP
#ifndef MIGRAPH
X
_GUARD_STREAMUTILS_HPP
#define MIGRAPH
X
_GUARD_STREAMUTILS_HPP
#include <ostream>
#include <algorithm>
#include <migraph/rank.hpp>
#include <migraph/config.hpp>
#include <migraph
x
/rank.hpp>
#include <migraph
x
/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
T
>
struct
stream_range_container
...
...
@@ -56,7 +56,7 @@ void stream_write_value(std::ostream& os, const T& x)
detail
::
stream_write_value_impl
(
rank
<
1
>
{},
os
,
x
);
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/include/migraph/stringutils.hpp
→
src/include/migraph
x
/stringutils.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_STRINGUTILS_HPP
#include <algorithm>
#include <numeric>
#include <string>
#include <sstream>
#include <migraph/config.hpp>
#include <migraph
x
/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
inline
std
::
string
replace_string
(
std
::
string
subject
,
const
std
::
string
&
search
,
const
std
::
string
&
replace
)
...
...
@@ -87,7 +87,7 @@ inline std::string to_string(const T& x)
return
ss
.
str
();
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/include/migraph/target.hpp
→
src/include/migraph
x
/target.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH_GUARD_MIGRAPHLIB_TARGET_HPP
#ifndef MIGRAPH
X
_GUARD_MIGRAPHLIB_TARGET_HPP
#define MIGRAPH
X
_GUARD_MIGRAPHLIB_TARGET_HPP
#include <cassert>
#include <string>
...
...
@@ -8,12 +8,12 @@
#include <type_traits>
#include <utility>
#include <vector>
#include <migraph/context.hpp>
#include <migraph/pass.hpp>
#include <migraph/config.hpp>
#include <migraph
x
/context.hpp>
#include <migraph
x
/pass.hpp>
#include <migraph
x
/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
#ifdef DOXYGEN
...
...
@@ -127,6 +127,12 @@ struct target
return
(
*
this
).
private_detail_te_get_handle
().
get_context
();
}
friend
bool
is_shared
(
const
target
&
private_detail_x
,
const
target
&
private_detail_y
)
{
return
private_detail_x
.
private_detail_te_handle_mem_var
==
private_detail_y
.
private_detail_te_handle_mem_var
;
}
private:
struct
private_detail_te_handle_base_type
{
...
...
@@ -244,7 +250,7 @@ inline const ValueType& any_cast(const target& x)
#endif
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/include/migraph/tensor_view.hpp
→
src/include/migraph
x
/tensor_view.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_TENSOR_VIEW_HPP
#define MIGRAPH_GUARD_TENSOR_VIEW_HPP
#ifndef MIGRAPH
X
_GUARD_TENSOR_VIEW_HPP
#define MIGRAPH
X
_GUARD_TENSOR_VIEW_HPP
#include <migraph/shape.hpp>
#include <migraph/float_equal.hpp>
#include <migraph/requires.hpp>
#include <migraph/config.hpp>
#include <migraph
x
/shape.hpp>
#include <migraph
x
/float_equal.hpp>
#include <migraph
x
/requires.hpp>
#include <migraph
x
/config.hpp>
#include <iostream>
#include <utility>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
T
>
struct
tensor_view
...
...
@@ -29,7 +29,7 @@ struct tensor_view
const
T
*
data
()
const
{
return
this
->
m_data
;
}
template
<
class
...
Ts
,
MIGRAPH_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
template
<
class
...
Ts
,
MIGRAPH
X
_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
const
T
&
operator
()(
Ts
...
xs
)
const
{
assert
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
xs
)...}
<
m_shape
.
lens
());
...
...
@@ -37,7 +37,7 @@ struct tensor_view
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
}
template
<
class
...
Ts
,
MIGRAPH_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
template
<
class
...
Ts
,
MIGRAPH
X
_REQUIRES
(
std
::
is_integral
<
Ts
>{}...)
>
T
&
operator
()(
Ts
...
xs
)
{
assert
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
xs
)...}
<
m_shape
.
lens
());
...
...
@@ -45,13 +45,13 @@ struct tensor_view
return
m_data
[
m_shape
.
index
({
static_cast
<
std
::
size_t
>
(
xs
)...})];
}
template
<
class
Iterator
,
MIGRAPH_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
template
<
class
Iterator
,
MIGRAPH
X
_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
const
T
&
operator
()(
Iterator
start
,
Iterator
last
)
const
{
return
m_data
[
m_shape
.
index
(
start
,
last
)];
}
template
<
class
Iterator
,
MIGRAPH_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
template
<
class
Iterator
,
MIGRAPH
X
_REQUIRES
(
not
std
::
is_integral
<
Iterator
>{})
>
T
&
operator
()(
Iterator
start
,
Iterator
last
)
{
return
m_data
[
m_shape
.
index
(
start
,
last
)];
...
...
@@ -164,12 +164,12 @@ bool operator!=(const tensor_view<T>& x, const tensor_view<U>& y)
}
template
<
class
T
>
tensor_view
<
T
>
make_view
(
shape
s
,
T
*
data
)
tensor_view
<
T
>
make_view
(
const
shape
&
s
,
T
*
data
)
{
return
{
s
,
data
};
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/include/migraph/time.hpp
→
src/include/migraph
x
/time.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_RTGLIB_TIME_HPP
#define MIGRAPH_GUARD_RTGLIB_TIME_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_TIME_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_TIME_HPP
#include <chrono>
#include <migraph/config.hpp>
#include <migraph
x
/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
Duration
,
class
F
>
auto
time
(
F
f
)
...
...
@@ -16,7 +16,7 @@ auto time(F f)
return
std
::
chrono
::
duration_cast
<
Duration
>
(
finish
-
start
).
count
();
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/include/migraph/tracer.hpp
→
src/include/migraph
x
/tracer.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_TRACER_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_TRACER_HPP
#include <ostream>
#include <migraph/functional.hpp>
#include <migraph/config.hpp>
#include <migraph
x
/functional.hpp>
#include <migraph
x
/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
struct
tracer
{
...
...
@@ -30,7 +30,7 @@ struct tracer
std
::
ostream
*
os
=
nullptr
;
};
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/include/migraph/type_name.hpp
→
src/include/migraph
x
/type_name.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_NAME_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_TYPE_NAME_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_TYPE_NAME_HPP
#include <string>
#include <migraph/config.hpp>
#include <migraph
x
/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
template
<
class
PrivateMigraphTypeNameProbe
>
const
std
::
string
&
get_type_name
()
...
...
@@ -18,7 +18,7 @@ const std::string& get_type_name()
name
=
typeid
(
PrivateMigraphTypeNameProbe
).
name
();
name
=
name
.
substr
(
7
);
#else
const
char
parameter_name
[]
=
"PrivateMigraphTypeNameProbe ="
;
const
char
parameter_name
[]
=
"PrivateMigraphTypeNameProbe ="
;
// NOLINT
name
=
__PRETTY_FUNCTION__
;
...
...
@@ -38,10 +38,10 @@ const std::string& get_type_name()
template
<
class
T
>
const
std
::
string
&
get_type_name
(
const
T
&
)
{
return
migraph
::
get_type_name
<
T
>
();
return
migraph
x
::
get_type_name
<
T
>
();
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/include/migraph/type_traits.hpp
→
src/include/migraph
x
/type_traits.hpp
View file @
bc5d7f75
...
...
@@ -5,32 +5,32 @@
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPH_GUARD_RTGLIB_TYPE_TRAITS_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_TYPE_TRAITS_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_TYPE_TRAITS_HPP
#include <type_traits>
#include <migraph/half.hpp>
#include <migraph/config.hpp>
#include <migraph
x
/half.hpp>
#include <migraph
x
/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
#define MIGRAPH_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
#define MIGRAPH
X
_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X>
\
struct trait : std::trait<X>
\
{
\
};
\
\
template <>
\
struct trait<T> : std::true_type
\
{
\
};
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPH_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
MIGRAPH
X
_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
half
)
MIGRAPH
X
_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
half
)
MIGRAPH
X
_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
half
)
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/include/migraph/verify.hpp
→
src/include/migraph
x
/verify.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_VERIFY_HPP
#define MIGRAPH_GUARD_VERIFY_HPP
#ifndef MIGRAPH
X
_GUARD_VERIFY_HPP
#define MIGRAPH
X
_GUARD_VERIFY_HPP
#include <algorithm>
#include <cmath>
...
...
@@ -7,11 +7,11 @@
#include <iostream>
#include <numeric>
#include <migraph/float_equal.hpp>
#include <migraph/config.hpp>
#include <migraph
x
/float_equal.hpp>
#include <migraph
x
/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
// Compute the value of a range
template
<
class
R
>
...
...
@@ -173,6 +173,6 @@ bool verify_range(R1&& r1, R2&& r2, double tolerance = 80, double* out_error = n
return
error
<=
threshold
;
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/include/migraph/verify_args.hpp
→
src/include/migraph
x
/verify_args.hpp
View file @
bc5d7f75
#ifndef MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#ifndef MIGRAPH
X
_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPH
X
_GUARD_RTGLIB_VERIFY_ARGS_HPP
#include <migraph/verify.hpp>
#include <migraph/argument.hpp>
#include <migraph/config.hpp>
#include <migraph
x
/verify.hpp>
#include <migraph
x
/argument.hpp>
#include <migraph
x
/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
inline
bool
verify_args
(
const
std
::
string
&
name
,
const
argument
&
cpu_arg
,
...
...
@@ -84,7 +84,7 @@ inline bool verify_args(const std::string& name,
return
passed
;
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
#endif
src/instruction.cpp
View file @
bc5d7f75
#include <migraph/instruction.hpp>
#include <migraph/builtin.hpp>
#include <migraph/erase.hpp>
#include <migraph
x
/instruction.hpp>
#include <migraph
x
/builtin.hpp>
#include <migraph
x
/erase.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
migraph
x
{
inline
namespace
MIGRAPH
X
_INLINE_NS
{
instruction
::
instruction
(
operation
o
,
shape
r
,
std
::
vector
<
instruction_ref
>
args
)
:
op
(
std
::
move
(
o
)),
result
(
std
::
move
(
r
)),
arguments
(
std
::
move
(
args
))
...
...
@@ -70,7 +70,7 @@ bool instruction::valid() const
{
computed
=
compute_shape
(
op
,
arguments
);
}
catch
(
migraph
::
exception
&
)
catch
(
migraph
x
::
exception
&
)
{
return
false
;
}
...
...
@@ -97,7 +97,7 @@ const std::vector<instruction_ref>& instruction::outputs() const { return output
bool
operator
==
(
const
instruction
&
x
,
const
instruction
&
y
)
{
if
(
not
(
x
.
result
==
y
.
result
and
x
.
op
==
y
.
op
and
x
.
arguments
==
y
.
arguments
))
if
(
std
::
tie
(
x
.
result
,
x
.
op
,
x
.
arguments
)
!=
std
::
tie
(
y
.
result
,
y
.
op
,
y
.
arguments
))
return
false
;
if
(
x
.
name
()
==
"@literal"
)
return
x
.
lit
==
y
.
lit
;
...
...
@@ -162,26 +162,55 @@ void instruction::replace_argument(instruction_ref old, instruction_ref new_ins)
old
->
remove_output
(
*
this
);
}
std
::
vector
<
shape
>
compute_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
)
argument
instruction
::
eval
()
const
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction_ref
i
)
{
return
i
->
get_shape
();
});
return
shapes
;
if
(
op
.
name
()
==
"@literal"
)
{
return
this
->
get_literal
().
get_argument
();
}
if
(
is_context_free
(
op
))
{
std
::
vector
<
argument
>
args
;
for
(
auto
&&
arg
:
this
->
inputs
())
{
argument
a
=
arg
->
eval
();
if
(
a
.
empty
())
return
{};
args
.
push_back
(
a
);
}
return
op
.
compute
(
result
,
args
);
}
return
{};
}
instruction_ref
instruction
::
get_output_alias
(
instruction_ref
ins
)
void
instruction
::
finalize
(
context
&
ctx
)
{
auto
i
=
ins
->
get_operator
().
output_alias
(
compute_shapes
(
ins
->
inputs
()));
if
(
has_finalize
(
this
->
op
))
this
->
op
.
finalize
(
ctx
,
this
->
get_shape
(),
to_shapes
(
this
->
inputs
()));
}
instruction_ref
instruction
::
get_output_alias
(
instruction_ref
ins
,
bool
shallow
)
{
auto
i
=
ins
->
get_operator
().
output_alias
(
to_shapes
(
ins
->
inputs
()));
if
(
i
<
0
)
return
ins
;
if
(
shallow
)
return
ins
->
inputs
().
at
(
i
);
return
get_output_alias
(
ins
->
inputs
().
at
(
i
));
}
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
shape
>
shapes
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
shapes
.
begin
(),
[](
instruction_ref
i
)
{
return
i
->
get_shape
();
});
return
shapes
;
}
shape
compute_shape
(
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
)
{
return
op
.
compute_shape
(
compute
_shapes
(
args
));
return
op
.
compute_shape
(
to
_shapes
(
args
));
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
src/onnx/CMakeLists.txt
View file @
bc5d7f75
...
...
@@ -7,35 +7,35 @@ target_compile_options(onnx-proto PRIVATE -w)
target_link_libraries
(
onnx-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
set_target_properties
(
onnx-proto PROPERTIES POSITION_INDEPENDENT_CODE On
)
add_library
(
migraph_onnx onnx.cpp
)
set_target_properties
(
migraph_onnx PROPERTIES EXPORT_NAME onnx
)
rocm_clang_tidy_check
(
migraph_onnx
)
target_link_libraries
(
migraph_onnx PRIVATE onnx-proto
)
target_link_libraries
(
migraph_onnx PUBLIC migraph
)
add_library
(
migraph
x
_onnx onnx.cpp
)
set_target_properties
(
migraph
x
_onnx PROPERTIES EXPORT_NAME onnx
)
rocm_clang_tidy_check
(
migraph
x
_onnx
)
target_link_libraries
(
migraph
x
_onnx PRIVATE onnx-proto
)
target_link_libraries
(
migraph
x
_onnx PUBLIC migraph
x
)
rocm_install_targets
(
TARGETS migraph_onnx
TARGETS migraph
x
_onnx
)
add_executable
(
read_onnx read_onnx.cpp
)
rocm_clang_tidy_check
(
read_onnx
)
target_link_libraries
(
read_onnx migraph_onnx
)
target_link_libraries
(
read_onnx migraph
x
_onnx
)
if
(
MIGRAPH_ENABLE_GPU
)
if
(
MIGRAPH
X
_ENABLE_GPU
)
add_executable
(
mnist mnist.cpp
)
rocm_clang_tidy_check
(
mnist
)
target_link_libraries
(
mnist migraph_cpu migraph_gpu migraph_onnx
)
target_link_libraries
(
mnist migraph
x
_cpu migraph
x
_gpu migraph
x
_onnx
)
add_executable
(
cifar10 cifar10.cpp
)
rocm_clang_tidy_check
(
cifar10
)
target_link_libraries
(
cifar10 migraph_cpu migraph_gpu migraph_onnx
)
target_link_libraries
(
cifar10 migraph
x
_cpu migraph
x
_gpu migraph
x
_onnx
)
add_executable
(
verify_onnx verify_onnx.cpp
)
rocm_clang_tidy_check
(
verify_onnx
)
target_link_libraries
(
verify_onnx migraph_onnx migraph_cpu migraph_gpu
)
target_link_libraries
(
verify_onnx migraph
x
_onnx migraph
x
_cpu migraph
x
_gpu
)
add_executable
(
perf_onnx perf_onnx.cpp
)
rocm_clang_tidy_check
(
perf_onnx
)
target_link_libraries
(
perf_onnx migraph_onnx migraph_cpu migraph_gpu
)
target_link_libraries
(
perf_onnx migraph
x
_onnx migraph
x
_cpu migraph
x
_gpu
)
endif
()
src/onnx/cifar10.cpp
View file @
bc5d7f75
...
...
@@ -4,12 +4,12 @@
#include <numeric>
#include <stdexcept>
#include <migraph/onnx.hpp>
#include <migraph
x
/onnx.hpp>
#include <migraph/cpu/target.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
#include <migraph
x
/cpu/target.hpp>
#include <migraph
x
/gpu/target.hpp>
#include <migraph
x
/gpu/hip.hpp>
#include <migraph
x
/generate.hpp>
#include "softmax.hpp"
...
...
@@ -53,19 +53,19 @@ int main(int argc, char const* argv[])
std
::
string
gpu_cpu
=
argv
[
1
];
std
::
string
file
=
argv
[
2
];
std
::
string
datafile
=
argv
[
3
];
auto
prog
=
migraph
::
parse_onnx
(
file
);
auto
prog
=
migraph
x
::
parse_onnx
(
file
);
std
::
cout
<<
prog
<<
std
::
endl
;
auto
imageset
=
read_cifar10_images
(
datafile
);
if
(
gpu_cpu
==
"gpu"
)
{
// GPU target
prog
.
compile
(
migraph
::
gpu
::
target
{});
migraph
::
program
::
parameter_map
m
;
auto
s
=
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}};
prog
.
compile
(
migraph
x
::
gpu
::
target
{});
migraph
x
::
program
::
parameter_map
m
;
auto
s
=
migraph
x
::
shape
{
migraph
x
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}};
for
(
auto
&&
x
:
prog
.
get_parameter_shapes
())
{
m
[
x
.
first
]
=
migraph
::
gpu
::
to_gpu
(
migraph
::
generate_argument
(
x
.
second
));
m
[
x
.
first
]
=
migraph
x
::
gpu
::
to_gpu
(
migraph
x
::
generate_argument
(
x
.
second
));
}
auto
labels
=
imageset
.
first
;
auto
input
=
imageset
.
second
;
...
...
@@ -73,8 +73,8 @@ int main(int argc, char const* argv[])
for
(
int
i
=
0
;
i
<
10
;
i
++
)
{
std
::
cout
<<
"label: "
<<
static_cast
<
uint32_t
>
(
labels
[
i
])
<<
" ----> "
;
m
[
"0"
]
=
migraph
::
gpu
::
to_gpu
(
migraph
::
argument
{
s
,
&
ptr
[
3072
*
i
]});
auto
result
=
migraph
::
gpu
::
from_gpu
(
prog
.
eval
(
m
));
m
[
"0"
]
=
migraph
x
::
gpu
::
to_gpu
(
migraph
x
::
argument
{
s
,
&
ptr
[
3072
*
i
]});
auto
result
=
migraph
x
::
gpu
::
from_gpu
(
prog
.
eval
(
m
));
std
::
vector
<
float
>
logits
;
result
.
visit
([
&
](
auto
output
)
{
logits
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
probs
=
softmax
<
float
>
(
logits
);
...
...
@@ -86,15 +86,15 @@ int main(int argc, char const* argv[])
else
{
// CPU target
prog
.
compile
(
migraph
::
cpu
::
target
{});
auto
s
=
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}};
prog
.
compile
(
migraph
x
::
cpu
::
target
{});
auto
s
=
migraph
x
::
shape
{
migraph
x
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}};
auto
labels
=
imageset
.
first
;
auto
input
=
imageset
.
second
;
auto
ptr
=
input
.
data
();
for
(
int
i
=
0
;
i
<
10
;
i
++
)
{
std
::
cout
<<
"label: "
<<
static_cast
<
uint32_t
>
(
labels
[
i
])
<<
" ----> "
;
auto
input3
=
migraph
::
argument
{
s
,
&
ptr
[
3072
*
i
]};
auto
input3
=
migraph
x
::
argument
{
s
,
&
ptr
[
3072
*
i
]};
auto
result
=
prog
.
eval
({{
"0"
,
input3
}});
std
::
vector
<
float
>
logits
;
result
.
visit
([
&
](
auto
output
)
{
logits
.
assign
(
output
.
begin
(),
output
.
end
());
});
...
...
src/onnx/mnist.cpp
View file @
bc5d7f75
...
...
@@ -4,17 +4,20 @@
#include <numeric>
#include <stdexcept>
#include <migraph/onnx.hpp>
#include <migraph
x
/onnx.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
#include <migraph
x
/gpu/target.hpp>
#include <migraph
x
/gpu/hip.hpp>
#include <migraph
x
/generate.hpp>
#include "softmax.hpp"
auto
reverse_int
(
unsigned
int
i
)
{
unsigned
char
c1
,
c2
,
c3
,
c4
;
unsigned
char
c1
;
unsigned
char
c2
;
unsigned
char
c3
;
unsigned
char
c4
;
c1
=
i
&
255u
;
c2
=
(
i
>>
8u
)
&
255u
;
c3
=
(
i
>>
16u
)
&
255u
;
...
...
@@ -32,7 +35,9 @@ read_mnist_images(const std::string& full_path, int& number_of_images, int& imag
if
(
file
.
is_open
())
{
int
magic_number
=
0
,
n_rows
=
0
,
n_cols
=
0
;
int
magic_number
=
0
;
int
n_rows
=
0
;
int
n_cols
=
0
;
file
.
read
(
reinterpret_cast
<
char
*>
(
&
magic_number
),
sizeof
(
magic_number
));
magic_number
=
reverse_int
(
magic_number
);
...
...
@@ -113,20 +118,20 @@ int main(int argc, char const* argv[])
std
::
vector
<
int32_t
>
labels
=
read_mnist_labels
(
labelfile
,
nlabels
);
std
::
string
file
=
argv
[
1
];
auto
prog
=
migraph
::
parse_onnx
(
file
);
auto
prog
=
migraph
x
::
parse_onnx
(
file
);
std
::
cout
<<
prog
<<
std
::
endl
<<
std
::
endl
;
prog
.
compile
(
migraph
::
gpu
::
target
{});
auto
s
=
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
1
,
28
,
28
}};
prog
.
compile
(
migraph
x
::
gpu
::
target
{});
auto
s
=
migraph
x
::
shape
{
migraph
x
::
shape
::
float_type
,
{
1
,
1
,
28
,
28
}};
std
::
cout
<<
s
<<
std
::
endl
;
auto
ptr
=
input
.
data
();
migraph
::
program
::
parameter_map
m
;
migraph
x
::
program
::
parameter_map
m
;
m
[
"output"
]
=
migraph
::
gpu
::
to_gpu
(
migraph
::
generate_argument
(
prog
.
get_parameter_shape
(
"output"
)));
migraph
x
::
gpu
::
to_gpu
(
migraph
x
::
generate_argument
(
prog
.
get_parameter_shape
(
"output"
)));
for
(
int
i
=
0
;
i
<
20
;
i
++
)
{
std
::
cout
<<
"label: "
<<
labels
[
i
]
<<
" ----> "
;
m
[
"0"
]
=
migraph
::
gpu
::
to_gpu
(
migraph
::
argument
{
s
,
&
ptr
[
784
*
i
]});
auto
result
=
migraph
::
gpu
::
from_gpu
(
prog
.
eval
(
m
));
m
[
"0"
]
=
migraph
x
::
gpu
::
to_gpu
(
migraph
x
::
argument
{
s
,
&
ptr
[
784
*
i
]});
auto
result
=
migraph
x
::
gpu
::
from_gpu
(
prog
.
eval
(
m
));
std
::
vector
<
float
>
logits
;
result
.
visit
([
&
](
auto
output
)
{
logits
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
probs
=
softmax
(
logits
);
...
...
src/onnx/onnx.cpp
View file @
bc5d7f75
...
...
@@ -9,59 +9,65 @@
#include <utility>
#include <vector>
#include <migraph/fallthrough.hpp>
#include <migraph/program.hpp>
#include <migraph/operators.hpp>
#include <migraph/ranges.hpp>
#include <migraph/instruction.hpp>
#include <migraph/config.hpp>
namespace
migraph
{
inline
namespace
MIGRAPH_INLINE_NS
{
struct
unknown
{
std
::
string
op
;
std
::
string
name
()
const
{
return
"unknown:"
+
op
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
if
(
input
.
empty
())
return
{};
else
return
input
.
front
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
return
os
;
}
};
#include <migraphx/fallthrough.hpp>
#include <migraphx/program.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/onnx.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
onnx_parser
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
using
op_func
=
std
::
function
<
std
::
vector
<
instruction_ref
>
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
program
prog
=
program
();
program
prog
=
program
();
bool
is_pytorch
=
false
;
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
std
::
unordered_map
<
std
::
string
,
operation
>
map_actv_funcs
;
onnx_parser
()
{
add_generic_op
(
"MatMul"
,
op
::
dot
{});
add_generic_op
(
"Relu"
,
op
::
relu
{});
add_generic_op
(
"Sigmoid"
,
op
::
sigmoid
{});
add_generic_op
(
"Abs"
,
op
::
abs
{});
add_generic_op
(
"Exp"
,
op
::
exp
{});
add_generic_op
(
"Log"
,
op
::
log
{});
// disable dropout for inference
add_generic_op
(
"Dropout"
,
op
::
identity
{});
add_generic_op
(
"Identity"
,
op
::
identity
{});
add_generic_op
(
"Sin"
,
op
::
sin
{});
add_generic_op
(
"Cos"
,
op
::
cos
{});
add_generic_op
(
"Tan"
,
op
::
tan
{});
add_generic_op
(
"Sinh"
,
op
::
sinh
{});
add_generic_op
(
"Cosh"
,
op
::
cosh
{});
add_generic_op
(
"Tanh"
,
op
::
tanh
{});
add_generic_op
(
"Asin"
,
op
::
asin
{});
add_generic_op
(
"Acos"
,
op
::
acos
{});
add_generic_op
(
"Atan"
,
op
::
atan
{});
add_binary_op
(
"Add"
,
op
::
add
{});
add_binary_op
(
"Div"
,
op
::
div
{});
add_binary_op
(
"Mul"
,
op
::
mul
{});
add_binary_op
(
"Sub"
,
op
::
sub
{});
add_broadcastable_binary_op
(
"Add"
,
op
::
add
{});
add_broadcastable_binary_op
(
"Div"
,
op
::
div
{});
add_broadcastable_binary_op
(
"Mul"
,
op
::
mul
{});
add_broadcastable_binary_op
(
"Sub"
,
op
::
sub
{});
add_broadcastable_binary_op
(
"Sum"
,
op
::
add
{});
add_variadic_op
(
"Sum"
,
op
::
add
{});
add_variadic_op
(
"Max"
,
op
::
max
{});
add_variadic_op
(
"Min"
,
op
::
min
{});
add_mem_op
(
"LRN"
,
&
onnx_parser
::
parse_lrn
);
add_mem_op
(
"ImageScaler"
,
&
onnx_parser
::
parse_imagescaler
);
add_mem_op
(
"LeakyRelu"
,
&
onnx_parser
::
parse_leaky_relu
);
add_mem_op
(
"Elu"
,
&
onnx_parser
::
parse_elu
);
add_mem_op
(
"Constant"
,
&
onnx_parser
::
parse_constant
);
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_pooling
);
...
...
@@ -77,11 +83,38 @@ struct onnx_parser
add_mem_op
(
"Unsqueeze"
,
&
onnx_parser
::
parse_unsqueeze
);
add_mem_op
(
"Slice"
,
&
onnx_parser
::
parse_slice
);
add_mem_op
(
"Concat"
,
&
onnx_parser
::
parse_concat
);
add_mem_op
(
"Gather"
,
&
onnx_parser
::
parse_gather
);
add_mem_op
(
"Shape"
,
&
onnx_parser
::
parse_shape
);
add_mem_op
(
"ConstantFill"
,
&
onnx_parser
::
parse_constant_fill
);
add_mem_op
(
"Transpose"
,
&
onnx_parser
::
parse_transpose
);
add_mem_op
(
"RNN"
,
&
onnx_parser
::
parse_rnn
);
add_mem_op
(
"GRU"
,
&
onnx_parser
::
parse_gru
);
add_mem_op
(
"Pad"
,
&
onnx_parser
::
parse_pad
);
// init the activation function map
init_actv_func
();
}
void
init_actv_func
()
{
map_actv_funcs
.
insert
(
std
::
make_pair
(
"tanh"
,
op
::
tanh
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"relu"
,
op
::
relu
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"sigmoid"
,
op
::
sigmoid
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"leakyrelu"
,
op
::
leaky_relu
{}));
map_actv_funcs
.
insert
(
std
::
make_pair
(
"elu"
,
op
::
elu
{}));
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
vector
<
instruction_ref
>
{
f
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...)};
});
}
// Multi output op
template
<
class
F
>
void
add_multi_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
...
...
@@ -89,81 +122,101 @@ struct onnx_parser
template
<
class
F
>
void
add_mem_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
[
=
](
auto
&&
...
xs
)
{
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
}
template
<
class
T
>
void
add_
broadcastable_
binary_op
(
std
::
string
name
,
T
x
)
void
add_binary_op
(
std
::
string
name
,
T
x
)
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
MIGRAPH_THROW
(
"binary operators should have 2 operands"
);
if
(
contains
(
attributes
,
"broadcast"
))
MIGRAPH
X
_THROW
(
"binary operators should have 2 operands"
);
if
(
contains
(
attributes
,
"broadcast"
)
and
contains
(
attributes
,
"axis"
)
)
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
uint64_t
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
();
auto
l
=
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
args
[
0
]
->
get_shape
()},
args
[
1
]);
return
prog
.
add_instruction
(
x
,
args
[
0
],
l
);
}
return
prog
.
add_instruction
(
x
,
args
);
}
else
if
(
args
[
0
]
->
get_shape
()
!=
args
[
1
]
->
get_shape
())
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const
std
::
vector
<
std
::
size_t
>*
s0
=
&
args
[
0
]
->
get_shape
().
lens
();
const
std
::
vector
<
std
::
size_t
>*
s1
=
&
args
[
1
]
->
get_shape
().
lens
();
// Make sure s0 is the smaller size
if
(
s0
->
size
()
>
s1
->
size
())
std
::
swap
(
s0
,
s1
);
// Copy the larger vector to output_lens
std
::
vector
<
std
::
size_t
>
output_lens
(
s1
->
size
());
auto
offset
=
s1
->
size
()
-
s0
->
size
();
std
::
transform
(
s0
->
begin
(),
s0
->
end
(),
s1
->
begin
()
+
offset
,
output_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
args
[
0
]);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
args
[
1
]);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
}
else
{
return
prog
.
add_instruction
(
x
,
args
);
return
add_broadcastable_binary_op
(
args
[
0
],
args
[
1
],
x
);
}
});
}
template
<
class
T
>
instruction_ref
add_broadcastable_binary_op
(
instruction_ref
arg0
,
instruction_ref
arg1
,
T
x
)
{
if
(
arg0
->
get_shape
()
!=
arg1
->
get_shape
())
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
//
// Get lengths for both arguments
const
std
::
vector
<
std
::
size_t
>*
s0
=
&
arg0
->
get_shape
().
lens
();
const
std
::
vector
<
std
::
size_t
>*
s1
=
&
arg1
->
get_shape
().
lens
();
// Make sure s0 is the smaller size
if
(
s0
->
size
()
>
s1
->
size
())
std
::
swap
(
s0
,
s1
);
std
::
vector
<
std
::
size_t
>
output_lens
(
*
s1
);
auto
offset
=
s1
->
size
()
-
s0
->
size
();
std
::
transform
(
s0
->
begin
(),
s0
->
end
(),
s1
->
begin
()
+
offset
,
output_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg0
);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg1
);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
}
else
{
return
prog
.
add_instruction
(
x
,
{
arg0
,
arg1
});
}
}
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
});
}
template
<
class
T
>
void
add_variadic_op
(
std
::
string
name
,
T
x
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
std
::
accumulate
(
std
::
next
(
args
.
begin
()),
args
.
end
(),
args
.
front
(),
[
this
,
x
](
instruction_ref
a
,
instruction_ref
b
)
{
return
add_broadcastable_binary_op
(
a
,
b
,
x
);
});
});
}
instruction_ref
parse_softmax
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
@@ -178,9 +231,30 @@ struct onnx_parser
parse_conv
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
op
::
convolution
op
;
auto
l0
=
args
[
0
];
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
if
(
contains
(
attributes
,
"auto_pad"
))
{
MIGRAPHX_THROW
(
"auto_pad and padding cannot be specified simultaneously"
);
}
std
::
vector
<
std
::
int64_t
>
padding
;
copy
(
attributes
[
"pads"
].
ints
(),
std
::
back_inserter
(
padding
));
if
(
padding
.
size
()
!=
4
)
{
MIGRAPHX_THROW
(
"padding should have 4 values"
);
}
if
(
padding
[
0
]
!=
padding
[
2
]
||
padding
[
1
]
!=
padding
[
3
])
{
// insert zeros for pad op (args[0] has 4 dims)
padding
=
{
0
,
0
,
padding
[
0
],
padding
[
1
],
0
,
0
,
padding
[
2
],
padding
[
3
]};
l0
=
prog
.
add_instruction
(
op
::
pad
{
padding
},
l0
);
}
else
{
op
.
padding
[
0
]
=
padding
[
0
];
op
.
padding
[
1
]
=
padding
[
1
];
}
}
if
(
contains
(
attributes
,
"strides"
))
{
...
...
@@ -190,6 +264,23 @@ struct onnx_parser
{
copy
(
attributes
[
"dilations"
].
ints
(),
op
.
dilation
.
begin
());
}
if
(
contains
(
attributes
,
"auto_pad"
))
{
auto
s
=
attributes
[
"auto_pad"
].
s
();
if
(
contains
(
attributes
,
"pads"
)
and
to_upper
(
s
)
!=
"NOTSET"
)
{
MIGRAPHX_THROW
(
"auto_pad and padding cannot be specified simultaneously"
);
}
if
(
s
.
find
(
"SAME"
)
!=
std
::
string
::
npos
)
{
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
}
}
if
(
contains
(
attributes
,
"group"
))
{
op
.
group
=
parse_value
(
attributes
.
at
(
"group"
)).
at
<
int
>
();
}
if
(
args
.
size
()
==
3
)
{
uint64_t
axis
=
1
;
...
...
@@ -197,7 +288,7 @@ struct onnx_parser
auto
l2
=
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
l1
->
get_shape
()},
args
[
2
]);
return
prog
.
add_instruction
(
op
::
add
{},
l1
,
l2
);
}
return
prog
.
add_instruction
(
op
,
args
);
return
prog
.
add_instruction
(
op
,
l0
,
args
[
1
]
);
}
instruction_ref
parse_pooling
(
const
std
::
string
&
name
,
...
...
@@ -205,6 +296,7 @@ struct onnx_parser
std
::
vector
<
instruction_ref
>
args
)
{
op
::
pooling
op
{
ends_with
(
name
,
"MaxPool"
)
?
"max"
:
"average"
};
auto
l0
=
args
[
0
];
if
(
starts_with
(
name
,
"Global"
))
{
auto
lens
=
args
.
front
()
->
get_shape
().
lens
();
...
...
@@ -212,7 +304,23 @@ struct onnx_parser
}
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
std
::
vector
<
std
::
int64_t
>
padding
;
copy
(
attributes
[
"pads"
].
ints
(),
std
::
back_inserter
(
padding
));
if
(
padding
.
size
()
!=
4
)
{
MIGRAPHX_THROW
(
"padding should have 4 values"
);
}
if
(
padding
[
0
]
!=
padding
[
2
]
||
padding
[
1
]
!=
padding
[
3
])
{
// insert zeros for pad op (args[0] has 4 dims)
padding
=
{
0
,
0
,
padding
[
0
],
padding
[
1
],
0
,
0
,
padding
[
2
],
padding
[
3
]};
l0
=
prog
.
add_instruction
(
op
::
pad
{
padding
},
l0
);
}
else
{
op
.
padding
[
0
]
=
padding
[
0
];
op
.
padding
[
1
]
=
padding
[
1
];
}
}
if
(
contains
(
attributes
,
"strides"
))
{
...
...
@@ -222,7 +330,17 @@ struct onnx_parser
{
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
if
(
contains
(
attributes
,
"auto_pad"
))
{
auto
s
=
attributes
[
"auto_pad"
].
s
();
if
(
s
.
find
(
"SAME_UPPER"
)
==
std
::
string
::
npos
)
{
MIGRAPHX_THROW
(
"auto_pad only supports SAME_UPPER for pooling"
);
}
op
.
padding_mode
=
op
::
padding_mode_t
::
same
;
}
return
prog
.
add_instruction
(
op
,
l0
);
}
instruction_ref
...
...
@@ -245,7 +363,7 @@ struct onnx_parser
instruction_ref
parse_flatten
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
uint64_t
axis
=
0
;
uint64_t
axis
=
1
;
if
(
contains
(
attributes
,
"axis"
))
{
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
...
...
@@ -279,6 +397,18 @@ struct onnx_parser
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
}
instruction_ref
parse_gather
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
int
axis
=
0
;
if
(
contains
(
attributes
,
"axis"
))
{
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
}
op
::
gather
op
{
axis
};
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
}
instruction_ref
parse_slice
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
@@ -311,7 +441,7 @@ struct onnx_parser
parse_gemm
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
float
alpha
=
1.0
f
;
float
beta
=
0
.0
f
;
float
beta
=
1
.0
f
;
bool
transa
=
false
;
bool
transb
=
false
;
if
(
contains
(
attributes
,
"alpha"
))
...
...
@@ -320,7 +450,7 @@ struct onnx_parser
}
if
(
contains
(
attributes
,
"beta"
))
{
alph
a
=
parse_value
(
attributes
.
at
(
"beta"
)).
at
<
float
>
();
bet
a
=
parse_value
(
attributes
.
at
(
"beta"
)).
at
<
float
>
();
}
if
(
contains
(
attributes
,
"transA"
))
{
...
...
@@ -335,10 +465,20 @@ struct onnx_parser
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
op
::
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
if
(
args
.
size
()
==
3
)
{
uint64_t
axis
=
1
;
auto
l3
=
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l2
);
auto
l4
=
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
l3
->
get_shape
()},
args
[
2
]);
return
prog
.
add_instruction
(
op
::
add
{},
l3
,
l4
);
if
(
beta
!=
0.
f
)
{
auto
l3
=
prog
.
add_instruction
(
op
::
dot
{
alpha
},
l1
,
l2
);
auto
l4
=
args
[
2
];
if
(
l4
->
get_shape
().
scalar
())
// ignore args[2] (no C value added to alpha*A*B)
return
l3
;
if
(
beta
!=
1.
f
)
{
auto
beta_val
=
prog
.
add_literal
(
beta
);
auto
l5
=
prog
.
add_instruction
(
op
::
scalar
{
args
[
2
]
->
get_shape
()},
beta_val
);
l4
=
prog
.
add_instruction
(
op
::
mul
{},
args
[
2
],
l5
);
}
return
add_broadcastable_binary_op
(
l3
,
l4
,
op
::
add
{});
}
}
return
prog
.
add_instruction
(
op
::
dot
{
alpha
,
beta
},
l1
,
l2
);
}
...
...
@@ -386,6 +526,37 @@ struct onnx_parser
return
prog
.
add_instruction
(
op
,
args
.
front
());
}
instruction_ref
parse_elu
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
float
alpha
=
1.0
;
// default alpha val for elu
if
(
contains
(
attributes
,
"alpha"
))
{
alpha
=
parse_value
(
attributes
.
at
(
"alpha"
)).
at
<
float
>
();
}
op
::
elu
op
{
alpha
};
return
prog
.
add_instruction
(
op
,
args
.
front
());
}
instruction_ref
parse_lrn
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
float
alpha
=
0.0001
;
float
beta
=
0.75
;
float
bias
=
1.0
;
int
size
=
1
;
if
(
contains
(
attributes
,
"alpha"
))
alpha
=
parse_value
(
attributes
.
at
(
"alpha"
)).
at
<
float
>
();
if
(
contains
(
attributes
,
"beta"
))
beta
=
parse_value
(
attributes
.
at
(
"beta"
)).
at
<
float
>
();
if
(
contains
(
attributes
,
"bias"
))
bias
=
parse_value
(
attributes
.
at
(
"bias"
)).
at
<
float
>
();
if
(
contains
(
attributes
,
"size"
))
size
=
parse_value
(
attributes
.
at
(
"size"
)).
at
<
int
>
();
op
::
lrn
op
{
alpha
,
beta
,
bias
,
size
};
return
prog
.
add_instruction
(
op
,
args
.
front
());
}
instruction_ref
parse_imagescaler
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
...
...
@@ -406,12 +577,12 @@ struct onnx_parser
auto
scale_val
=
prog
.
add_literal
(
scale
);
auto
bias_vals
=
prog
.
add_literal
(
migraph
::
literal
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
bias
.
size
()}},
bias
});
migraph
x
::
literal
{
migraph
x
::
shape
{
migraph
x
::
shape
::
float_type
,
{
bias
.
size
()}},
bias
});
auto
scale_tensor
=
prog
.
add_instruction
(
migraph
::
op
::
scalar
{
input_shape
},
scale_val
);
auto
img_scaled
=
prog
.
add_instruction
(
migraph
::
op
::
mul
{},
args
.
front
(),
scale_tensor
);
auto
bias_bcast
=
prog
.
add_instruction
(
migraph
::
op
::
broadcast
{
1
,
input_shape
},
bias_vals
);
return
prog
.
add_instruction
(
migraph
::
op
::
add
{},
img_scaled
,
bias_bcast
);
auto
scale_tensor
=
prog
.
add_instruction
(
migraph
x
::
op
::
scalar
{
input_shape
},
scale_val
);
auto
img_scaled
=
prog
.
add_instruction
(
migraph
x
::
op
::
mul
{},
args
.
front
(),
scale_tensor
);
auto
bias_bcast
=
prog
.
add_instruction
(
migraph
x
::
op
::
broadcast
{
1
,
input_shape
},
bias_vals
);
return
prog
.
add_instruction
(
migraph
x
::
op
::
add
{},
img_scaled
,
bias_bcast
);
}
instruction_ref
...
...
@@ -423,7 +594,330 @@ struct onnx_parser
auto
&&
perm_vals
=
attributes
[
"perm"
].
ints
();
perm
=
std
::
vector
<
int64_t
>
(
perm_vals
.
begin
(),
perm_vals
.
end
());
}
return
prog
.
add_instruction
(
migraph
::
op
::
transpose
{
perm
},
args
.
front
());
return
prog
.
add_instruction
(
migraphx
::
op
::
transpose
{
perm
},
args
.
front
());
}
instruction_ref
parse_pad
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
std
::
vector
<
int64_t
>
pads
{};
float
value
=
0.0
f
;
if
(
contains
(
attributes
,
"pads"
))
{
auto
&&
pad_vals
=
attributes
[
"pads"
].
ints
();
pads
=
std
::
vector
<
int64_t
>
(
pad_vals
.
begin
(),
pad_vals
.
end
());
}
if
(
contains
(
attributes
,
"value"
))
{
value
=
parse_value
(
attributes
.
at
(
"value"
)).
at
<
float
>
();
}
if
(
contains
(
attributes
,
"mode"
))
{
auto
mode
=
attributes
.
at
(
"mode"
).
s
();
if
(
mode
!=
"constant"
)
MIGRAPHX_THROW
(
"migraphx currently only supports constant padding"
);
}
return
prog
.
add_instruction
(
migraphx
::
op
::
pad
{
pads
,
value
},
args
.
front
());
}
// Use a literal instruction to replace the shape since, output of
// shape operator are literals in migraphx
instruction_ref
parse_shape
(
const
std
::
string
&
,
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"Shape: operator should have 1 operand"
);
std
::
vector
<
std
::
size_t
>
arg_shape
=
args
[
0
]
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
vec_shape
(
arg_shape
.
size
());
migraphx
::
shape
s
(
migraphx
::
shape
::
int64_type
,
{
arg_shape
.
size
()});
std
::
transform
(
arg_shape
.
begin
(),
arg_shape
.
end
(),
vec_shape
.
begin
(),
[](
auto
i
)
{
return
int64_t
(
i
);
});
return
prog
.
add_literal
(
migraphx
::
literal
{
s
,
vec_shape
});
}
// Use a literal instruction to replace the constantFill operator. In RNN, input shape
// and value are fixed, so no need to do the actual computation for the constantFill
// operator
instruction_ref
parse_constant_fill
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
int
input_as_shape
=
0
;
int
dtype
=
1
;
float
value
=
0.0
f
;
if
(
contains
(
attributes
,
"dtype"
))
{
dtype
=
parse_value
(
attributes
.
at
(
"dtype"
)).
at
<
int
>
();
}
migraphx
::
shape
::
type_t
type
=
get_type
(
dtype
);
if
(
contains
(
attributes
,
"input_as_shape"
))
{
input_as_shape
=
parse_value
(
attributes
.
at
(
"input_as_shape"
)).
at
<
int
>
();
}
if
(
contains
(
attributes
,
"value"
))
{
value
=
parse_value
(
attributes
.
at
(
"value"
)).
at
<
float
>
();
}
if
(
contains
(
attributes
,
"extra_shape"
))
{
MIGRAPHX_THROW
(
"ConstantFill: cannot handle extra shape attribute"
);
}
if
(
input_as_shape
==
1
)
{
if
(
args
.
size
()
!=
1
)
{
MIGRAPHX_THROW
(
"ConstantFill: need an input argument as output shape"
);
}
if
(
contains
(
attributes
,
"shape"
))
{
MIGRAPHX_THROW
(
"ConstantFill: cannot set the shape argument and pass in an input "
"at the same time"
);
}
migraphx
::
argument
in
=
args
[
0
]
->
eval
();
if
(
in
.
empty
())
{
MIGRAPHX_THROW
(
"ConstantFill: cannot handle dynamic shape as input"
);
}
std
::
vector
<
std
::
size_t
>
dims
;
in
.
visit
([
&
](
auto
input
)
{
dims
.
assign
(
input
.
begin
(),
input
.
end
());
});
migraphx
::
shape
s
(
type
,
dims
);
std
::
vector
<
float
>
values
(
s
.
elements
(),
value
);
return
prog
.
add_literal
(
migraphx
::
literal
(
s
,
values
));
}
else
if
(
input_as_shape
==
0
)
{
if
(
!
contains
(
attributes
,
"shape"
))
{
MIGRAPHX_THROW
(
"ConstantFill: attribute output shape is needed"
);
}
literal
ls
=
parse_value
(
attributes
.
at
(
"shape"
));
std
::
vector
<
std
::
size_t
>
dims
;
ls
.
visit
([
&
](
auto
s
)
{
dims
.
assign
(
s
.
begin
(),
s
.
end
());
});
migraphx
::
shape
s
{
type
,
dims
};
std
::
vector
<
float
>
values
(
s
.
elements
(),
value
);
return
prog
.
add_literal
(
migraphx
::
literal
(
s
,
values
));
}
else
{
MIGRAPHX_THROW
(
"ConstantFill: wrong value of attribute input_as_shape"
);
}
}
std
::
vector
<
instruction_ref
>
parse_rnn
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
1
]
->
get_shape
().
lens
()[
1
];
if
(
contains
(
attributes
,
"hidden_size"
))
{
std
::
size_t
hidden_size_att
=
parse_value
(
attributes
.
at
(
"hidden_size"
)).
at
<
int
>
();
if
(
hidden_size
!=
hidden_size_att
)
{
MIGRAPHX_THROW
(
"RNN: hidden size mismatch in input and attribute"
);
}
}
// Handling of direction to be added later
std
::
string
direction
{
"forward"
};
if
(
contains
(
attributes
,
"direction"
))
{
direction
=
attributes
.
at
(
"direction"
).
s
();
}
op
::
rnn_direction
dirct
=
op
::
rnn_direction
::
forward
;
if
(
direction
==
"bidirectional"
)
{
dirct
=
op
::
rnn_direction
::
bidirectional
;
}
else
if
(
direction
==
"reverse"
)
{
dirct
=
op
::
rnn_direction
::
reverse
;
}
std
::
vector
<
std
::
string
>
vec_names
{
"tanh"
};
if
(
contains
(
attributes
,
"activations"
))
{
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
for_each
(
names
.
begin
(),
names
.
end
(),
[
&
](
auto
&
fn
)
{
vec_names
.
push_back
(
fn
);
});
}
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
fn
)
{
if
(
map_actv_funcs
.
count
(
fn
)
==
0
)
{
MIGRAPHX_THROW
(
"RNN: activation function "
+
std
::
string
(
fn
)
+
" not supported"
);
}
});
// bidirectional case should have two activation functions.
// one is for forward, and the other is for reverse.
// if only one actv function is provided, we use it in both
// forward and reverse direction
if
(
dirct
==
op
::
rnn_direction
::
bidirectional
)
{
if
(
vec_names
.
size
()
==
1
)
{
vec_names
.
push_back
(
vec_names
.
at
(
0
));
}
}
std
::
vector
<
operation
>
vec_actv_funcs
(
vec_names
.
size
());
std
::
transform
(
vec_names
.
begin
(),
vec_names
.
end
(),
vec_actv_funcs
.
begin
(),
[
&
](
auto
&
fn
)
{
return
map_actv_funcs
[
fn
];
});
// To be added later
float
clip
=
0.0
;
if
(
contains
(
attributes
,
"clip"
))
{
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
}
// if the number of arguments is less than 6, append
// undefined operator to have 6 arguments
if
(
args
.
size
()
<
6
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
undefined
{});
args
.
insert
(
args
.
end
(),
(
6
-
args
.
size
()),
ins
);
}
// first output for the concatenation of hidden states
auto
hidden_states
=
prog
.
add_instruction
(
op
::
rnn
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
},
std
::
move
(
args
));
// second output for the last hidden state
auto
last_output
=
prog
.
add_instruction
(
op
::
rnn_last_output
{},
hidden_states
);
return
{
hidden_states
,
last_output
};
}
std
::
vector
<
instruction_ref
>
parse_gru
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
if
(
contains
(
attributes
,
"hidden_size"
))
{
std
::
size_t
hidden_size_att
=
parse_value
(
attributes
.
at
(
"hidden_size"
)).
at
<
int
>
();
if
(
hidden_size
!=
hidden_size_att
)
{
MIGRAPHX_THROW
(
"GRU: hidden size mismatch in input and attribute"
);
}
}
// Handling of direction to be added later
std
::
string
direction
{
"forward"
};
if
(
contains
(
attributes
,
"direction"
))
{
direction
=
attributes
.
at
(
"direction"
).
s
();
}
op
::
rnn_direction
dirct
=
op
::
rnn_direction
::
forward
;
if
(
direction
==
"bidirectional"
)
{
dirct
=
op
::
rnn_direction
::
bidirectional
;
}
else
if
(
direction
==
"reverse"
)
{
dirct
=
op
::
rnn_direction
::
reverse
;
}
std
::
vector
<
std
::
string
>
vec_names
=
{
"sigmoid"
,
"tanh"
};
if
(
contains
(
attributes
,
"activations"
))
{
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
&
str
)
{
return
str
;
});
}
// need 4 activation functions
if
(
dirct
==
op
::
rnn_direction
::
bidirectional
)
{
// 4 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1 four times. If 2 actv functins are provided,
// assume forward and reverse use the same pair of actv
// functions. For the case of 3 actv functions provided,
// assume the 3rd one is repeated once and used by the
// reverse direction.
// This may need change later
if
(
vec_names
.
size
()
==
1
)
{
vec_names
.
insert
(
vec_names
.
end
(),
3
,
vec_names
.
at
(
0
));
}
else
if
(
vec_names
.
size
()
==
2
)
{
// repeat the activation functions
vec_names
.
push_back
(
vec_names
.
at
(
0
));
vec_names
.
push_back
(
vec_names
.
at
(
1
));
}
else
if
(
vec_names
.
size
()
==
3
)
{
vec_names
.
push_back
(
vec_names
.
at
(
2
));
}
}
else
{
if
(
vec_names
.
size
()
==
1
)
{
vec_names
.
push_back
(
vec_names
.
at
(
0
));
}
}
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
if
(
map_actv_funcs
.
count
(
name
)
==
0
)
{
MIGRAPHX_THROW
(
"GRU: activation function "
+
std
::
string
(
name
)
+
" not supported"
);
}
});
std
::
vector
<
operation
>
vec_actv_funcs
(
vec_names
.
size
());
std
::
transform
(
vec_names
.
begin
(),
vec_names
.
end
(),
vec_actv_funcs
.
begin
(),
[
&
](
auto
&
name
)
{
return
map_actv_funcs
[
name
];
});
float
clip
=
0.0
;
if
(
contains
(
attributes
,
"clip"
))
{
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
}
int
linear_before_reset
=
0
;
if
(
contains
(
attributes
,
"linear_before_reset"
))
{
linear_before_reset
=
parse_value
(
attributes
.
at
(
"linear_before_reset"
)).
at
<
int
>
();
}
// append undefined opeator to make 6 arguments
if
(
args
.
size
()
<
6
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
undefined
{});
args
.
insert
(
args
.
end
(),
6
-
args
.
size
(),
ins
);
}
// first output for concatenation of hidden states
auto
hidden_states
=
prog
.
add_instruction
(
op
::
gru
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
,
linear_before_reset
},
std
::
move
(
args
));
// second output for last gru output
auto
last_output
=
prog
.
add_instruction
(
op
::
rnn_last_output
{},
hidden_states
);
return
{
hidden_states
,
last_output
};
}
void
parse_from
(
std
::
istream
&
is
)
...
...
@@ -438,7 +932,7 @@ struct onnx_parser
}
else
{
throw
std
::
runtime_error
(
"Failed reading"
);
MIGRAPHX_THROW
(
"Failed reading
onnx file.
"
);
}
}
...
...
@@ -468,14 +962,20 @@ struct onnx_parser
}
for
(
auto
&&
p
:
nodes
)
{
this
->
parse_node
(
get_name
(
p
.
second
)
);
this
->
parse_node
(
p
.
first
);
}
}
void
parse_undefined
(
const
std
::
string
&
name
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
undefined
{});
instructions
[
name
]
=
ins
;
}
void
parse_node
(
const
std
::
string
&
name
)
{
if
(
name
.
empty
())
MIGRAPH_THROW
(
"Onnx node must have a name"
);
MIGRAPH
X
_THROW
(
"Onnx node must have a name"
);
if
(
instructions
.
count
(
name
)
==
0
)
{
auto
&&
node
=
nodes
.
at
(
name
);
...
...
@@ -484,23 +984,37 @@ struct onnx_parser
{
if
(
nodes
.
count
(
input
)
>
0
)
{
auto
&&
iname
=
get_name
(
nodes
.
at
(
input
));
assert
(
name
!=
iname
);
this
->
parse_node
(
iname
);
args
.
push_back
(
instructions
.
at
(
iname
));
assert
(
name
!=
input
);
this
->
parse_node
(
input
);
}
else
else
if
(
input
.
empty
())
{
args
.
push_back
(
instructions
.
at
(
input
)
)
;
this
->
parse_undefined
(
input
);
}
args
.
push_back
(
instructions
.
at
(
input
));
}
std
::
vector
<
instruction_ref
>
result
;
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
{
instructions
[
name
]
=
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
result
.
push_back
(
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
)
)
;
}
else
{
instructions
[
name
]
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
result
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
}
// Even no output nodes produce output in migraphx
if
(
node
.
output
().
empty
()
and
result
.
size
()
==
1
)
{
instructions
[
name
]
=
result
.
front
();
}
else
{
assert
(
node
.
output
().
size
()
>=
result
.
size
());
std
::
transform
(
result
.
begin
(),
result
.
end
(),
node
.
output
().
begin
(),
std
::
inserter
(
instructions
,
instructions
.
end
()),
[](
auto
&&
x
,
auto
&&
y
)
{
return
std
::
make_pair
(
y
,
x
);
});
}
}
}
...
...
@@ -515,25 +1029,24 @@ struct onnx_parser
return
result
;
}
static
std
::
string
get_name
(
const
onnx
::
NodeProto
&
node
)
{
if
(
node
.
name
().
empty
())
{
std
::
string
generated
=
"migraph_unnamed_node"
;
return
std
::
accumulate
(
node
.
output
().
begin
(),
node
.
output
().
end
(),
generated
,
[](
auto
x
,
auto
y
)
{
return
x
+
"_"
+
y
;
});
}
return
node
.
name
();
}
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
{
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
std
::
size_t
n
=
0
;
for
(
auto
&&
node
:
graph
.
node
())
{
result
[
get_name
(
node
)]
=
node
;
if
(
node
.
output
().
empty
())
{
if
(
node
.
name
().
empty
())
{
result
[
"migraphx_unamed_node_"
+
std
::
to_string
(
n
)]
=
node
;
n
++
;
}
else
{
result
[
node
.
name
()]
=
node
;
}
}
for
(
auto
&&
output
:
node
.
output
())
{
result
[
output
]
=
node
;
...
...
@@ -565,12 +1078,17 @@ struct onnx_parser
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
}
MIGRAPH_THROW
(
"Invalid attribute type"
);
MIGRAPH
X
_THROW
(
"Invalid attribute type"
);
}
static
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if
(
dims
.
empty
())
{
dims
=
{
1
};
}
if
(
t
.
has_raw_data
())
{
const
std
::
string
&
s
=
t
.
raw_data
();
...
...
@@ -593,7 +1111,7 @@ struct onnx_parser
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
MIGRAPH_THROW
(
"Invalid tensor type"
);
MIGRAPH
X
_THROW
(
"Invalid tensor type"
);
}
switch
(
t
.
data_type
())
{
...
...
@@ -624,7 +1142,7 @@ struct onnx_parser
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
MIGRAPH_THROW
(
"Invalid tensor type"
);
MIGRAPH
X
_THROW
(
"Invalid tensor type"
);
}
static
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
...
...
@@ -670,6 +1188,28 @@ struct onnx_parser
});
return
{
shape_type
,
dims
};
}
shape
::
type_t
get_type
(
int
dtype
)
{
switch
(
dtype
)
{
case
1
:
return
shape
::
float_type
;
case
2
:
return
shape
::
uint8_type
;
case
3
:
return
shape
::
int8_type
;
case
4
:
return
shape
::
uint16_type
;
case
5
:
return
shape
::
int16_type
;
case
6
:
return
shape
::
int32_type
;
case
7
:
return
shape
::
int64_type
;
case
10
:
return
shape
::
half_type
;
case
11
:
return
shape
::
double_type
;
case
12
:
return
shape
::
uint32_type
;
case
13
:
return
shape
::
uint64_type
;
default:
{
MIGRAPHX_THROW
(
"Prototensor data type "
+
std
::
to_string
(
dtype
)
+
" not supported"
);
}
}
}
};
program
parse_onnx
(
const
std
::
string
&
name
)
...
...
@@ -693,5 +1233,5 @@ program parse_onnx(const std::string& name)
return
std
::
move
(
parser
.
prog
);
}
}
// namespace MIGRAPH_INLINE_NS
}
// namespace migraph
}
// namespace MIGRAPH
X
_INLINE_NS
}
// namespace migraph
x
Prev
1
2
3
4
5
6
7
8
9
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment