Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f0a9d415
Commit
f0a9d415
authored
Apr 23, 2018
by
Paul
Browse files
Add onnx reader to analyzer
parent
6dd3cc0e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
23 deletions
+41
-23
include/rtg/fallthrough.hpp
include/rtg/fallthrough.hpp
+14
-0
include/rtg/operators.hpp
include/rtg/operators.hpp
+6
-6
onnx/CMakeLists.txt
onnx/CMakeLists.txt
+7
-4
onnx/read_onnx.cpp
onnx/read_onnx.cpp
+14
-13
No files found.
include/rtg/fallthrough.hpp
0 → 100644
View file @
f0a9d415
#ifndef RTG_GUARD_FALLTHROUGH_HPP
#define RTG_GUARD_FALLTHROUGH_HPP
namespace
rtg
{
#ifdef __clang__
#define RTG_FALLTHROUGH [[clang::fallthrough]]
#else
#define RTG_FALLTHROUGH
#endif
}
// namespace rtg
#endif
include/rtg/operators.hpp
View file @
f0a9d415
...
...
@@ -14,9 +14,9 @@ struct not_computable
struct
convolution
{
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
dilation
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
{
0
,
0
}
}
;
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
{
1
,
1
}
}
;
std
::
array
<
std
::
size_t
,
2
>
dilation
=
{
{
1
,
1
}
}
;
std
::
string
name
()
const
{
return
"convolution[padding={"
+
to_string
(
padding
)
+
"}, stride={"
+
to_string
(
stride
)
+
...
...
@@ -61,9 +61,9 @@ struct convolution
struct
pooling
{
std
::
string
mode
;
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
lengths
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
{
0
,
0
}
}
;
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
{
1
,
1
}
}
;
std
::
array
<
std
::
size_t
,
2
>
lengths
=
{
{
1
,
1
}
}
;
std
::
string
name
()
const
{
return
"pooling:"
+
mode
+
"[padding={"
+
to_string
(
padding
)
+
"}, stride={"
+
...
...
onnx/CMakeLists.txt
View file @
f0a9d415
find_package
(
Protobuf REQUIRED
)
protobuf_generate_cpp
(
PROTO_SRCS PROTO_HDRS onnx.proto
)
include_directories
(
${
CMAKE_CURRENT_BINARY_DIR
}
)
add_library
(
onnx-proto STATIC
${
PROTO_SRCS
}
)
target_include_directories
(
onnx-proto SYSTEM PUBLIC
${
CMAKE_CURRENT_BINARY_DIR
}
${
PROTOBUF_INCLUDE_DIR
}
)
target_compile_options
(
onnx-proto PRIVATE -w
)
target_link_libraries
(
onnx-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
add_executable
(
read_onnx read_onnx.cpp
${
PROTO_SRCS
}
)
target_include_directories
(
read_onnx PUBLIC
${
PROTOBUF_INCLUDE_DIR
}
)
target_link_libraries
(
read_onnx
${
PROTOBUF_LIBRARY
}
rtg
)
add_executable
(
read_onnx read_onnx.cpp
)
rocm_clang_tidy_check
(
read_onnx
)
target_link_libraries
(
read_onnx
onnx-proto
rtg
)
onnx/read_onnx.cpp
View file @
f0a9d415
...
...
@@ -7,6 +7,7 @@
#include <unordered_map>
#include <functional>
#include <rtg/fallthrough.hpp>
#include <rtg/program.hpp>
#include <rtg/operators.hpp>
...
...
@@ -21,7 +22,7 @@ struct unknown
else
return
input
.
front
();
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
input
)
const
{
throw
"not computable"
;
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
)
;
}
};
template
<
class
C
,
class
T
>
...
...
@@ -84,7 +85,7 @@ struct onnx_parser
}
return
prog
->
add_instruction
(
op
,
args
);
});
add_op
(
"Relu"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
add_op
(
"Relu"
,
[
this
](
attribute_map
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
return
prog
->
add_instruction
(
rtg
::
activation
{
"relu"
},
args
);
});
add_op
(
"Reshape"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
...
...
@@ -126,7 +127,7 @@ struct onnx_parser
nodes
=
get_nodes
(
graph
);
for
(
auto
&&
input
:
graph
.
input
())
{
std
::
string
name
=
input
.
name
();
const
std
::
string
&
name
=
input
.
name
();
// TODO: Get shape of input parameter
rtg
::
shape
s
=
parse_type
(
input
.
type
());
instructions
[
name
]
=
prog
->
add_parameter
(
name
,
s
);
...
...
@@ -254,28 +255,28 @@ struct onnx_parser
static
rtg
::
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
{
rtg
::
shape
::
type_t
shape_type
;
rtg
::
shape
::
type_t
shape_type
{}
;
switch
(
t
.
tensor_type
().
elem_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
break
;
// throw std::runtime_error("Unsupported type UNDEFINED");
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
rtg
::
shape
::
float_type
;
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
rtg
::
shape
::
float_type
;
break
;
case
onnx
::
TensorProto
::
UINT8
:
break
;
// throw std::runtime_error("Unsupported type UINT8");
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
rtg
::
shape
::
int8_type
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
rtg
::
shape
::
uint16_type
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
rtg
::
shape
::
int16_type
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
rtg
::
shape
::
int32_type
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
rtg
::
shape
::
int64_type
;
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
rtg
::
shape
::
int8_type
;
break
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
rtg
::
shape
::
uint16_type
;
break
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
rtg
::
shape
::
int16_type
;
break
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
rtg
::
shape
::
int32_type
;
break
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
rtg
::
shape
::
int64_type
;
break
;
case
onnx
::
TensorProto
::
STRING
:
break
;
// throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
BOOL
:
break
;
// throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
FLOAT16
:
break
;
// throw std::runtime_error("Unsupported type FLOAT16");
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
rtg
::
shape
::
double_type
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
rtg
::
shape
::
uint32_type
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
rtg
::
shape
::
uint64_type
;
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
rtg
::
shape
::
double_type
;
break
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
rtg
::
shape
::
uint32_type
;
break
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
rtg
::
shape
::
uint64_type
;
break
;
case
onnx
::
TensorProto
::
COMPLEX64
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX64");
case
onnx
::
TensorProto
::
COMPLEX128
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment