Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f0a9d415
Commit
f0a9d415
authored
Apr 23, 2018
by
Paul
Browse files
Add onnx reader to analyzer
parent
6dd3cc0e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
23 deletions
+41
-23
include/rtg/fallthrough.hpp
include/rtg/fallthrough.hpp
+14
-0
include/rtg/operators.hpp
include/rtg/operators.hpp
+6
-6
onnx/CMakeLists.txt
onnx/CMakeLists.txt
+7
-4
onnx/read_onnx.cpp
onnx/read_onnx.cpp
+14
-13
No files found.
include/rtg/fallthrough.hpp
0 → 100644
View file @
f0a9d415
#ifndef RTG_GUARD_FALLTHROUGH_HPP
#define RTG_GUARD_FALLTHROUGH_HPP
namespace
rtg
{
#ifdef __clang__
#define RTG_FALLTHROUGH [[clang::fallthrough]]
#else
#define RTG_FALLTHROUGH
#endif
}
// namespace rtg
#endif
include/rtg/operators.hpp
View file @
f0a9d415
...
...
@@ -14,9 +14,9 @@ struct not_computable
struct
convolution
{
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
dilation
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
{
0
,
0
}
}
;
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
{
1
,
1
}
}
;
std
::
array
<
std
::
size_t
,
2
>
dilation
=
{
{
1
,
1
}
}
;
std
::
string
name
()
const
{
return
"convolution[padding={"
+
to_string
(
padding
)
+
"}, stride={"
+
to_string
(
stride
)
+
...
...
@@ -61,9 +61,9 @@ struct convolution
struct
pooling
{
std
::
string
mode
;
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
0
,
0
};
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
lengths
=
{
1
,
1
};
std
::
array
<
std
::
size_t
,
2
>
padding
=
{
{
0
,
0
}
}
;
std
::
array
<
std
::
size_t
,
2
>
stride
=
{
{
1
,
1
}
}
;
std
::
array
<
std
::
size_t
,
2
>
lengths
=
{
{
1
,
1
}
}
;
std
::
string
name
()
const
{
return
"pooling:"
+
mode
+
"[padding={"
+
to_string
(
padding
)
+
"}, stride={"
+
...
...
onnx/CMakeLists.txt
View file @
f0a9d415
find_package
(
Protobuf REQUIRED
)
protobuf_generate_cpp
(
PROTO_SRCS PROTO_HDRS onnx.proto
)
include_directories
(
${
CMAKE_CURRENT_BINARY_DIR
}
)
add_library
(
onnx-proto STATIC
${
PROTO_SRCS
}
)
target_include_directories
(
onnx-proto SYSTEM PUBLIC
${
CMAKE_CURRENT_BINARY_DIR
}
${
PROTOBUF_INCLUDE_DIR
}
)
target_compile_options
(
onnx-proto PRIVATE -w
)
target_link_libraries
(
onnx-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
add_executable
(
read_onnx read_onnx.cpp
${
PROTO_SRCS
}
)
target_include_directories
(
read_onnx PUBLIC
${
PROTOBUF_INCLUDE_DIR
}
)
target_link_libraries
(
read_onnx
${
PROTOBUF_LIBRARY
}
rtg
)
add_executable
(
read_onnx read_onnx.cpp
)
rocm_clang_tidy_check
(
read_onnx
)
target_link_libraries
(
read_onnx
onnx-proto
rtg
)
onnx/read_onnx.cpp
View file @
f0a9d415
...
...
@@ -7,6 +7,7 @@
#include <unordered_map>
#include <functional>
#include <rtg/fallthrough.hpp>
#include <rtg/program.hpp>
#include <rtg/operators.hpp>
...
...
@@ -21,7 +22,7 @@ struct unknown
else
return
input
.
front
();
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
input
)
const
{
throw
"not computable"
;
}
rtg
::
argument
compute
(
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
)
;
}
};
template
<
class
C
,
class
T
>
...
...
@@ -84,7 +85,7 @@ struct onnx_parser
}
return
prog
->
add_instruction
(
op
,
args
);
});
add_op
(
"Relu"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
add_op
(
"Relu"
,
[
this
](
attribute_map
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
return
prog
->
add_instruction
(
rtg
::
activation
{
"relu"
},
args
);
});
add_op
(
"Reshape"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction
*>
args
)
{
...
...
@@ -126,7 +127,7 @@ struct onnx_parser
nodes
=
get_nodes
(
graph
);
for
(
auto
&&
input
:
graph
.
input
())
{
std
::
string
name
=
input
.
name
();
const
std
::
string
&
name
=
input
.
name
();
// TODO: Get shape of input parameter
rtg
::
shape
s
=
parse_type
(
input
.
type
());
instructions
[
name
]
=
prog
->
add_parameter
(
name
,
s
);
...
...
@@ -254,28 +255,28 @@ struct onnx_parser
static
rtg
::
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
{
rtg
::
shape
::
type_t
shape_type
;
rtg
::
shape
::
type_t
shape_type
{}
;
switch
(
t
.
tensor_type
().
elem_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
break
;
// throw std::runtime_error("Unsupported type UNDEFINED");
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
rtg
::
shape
::
float_type
;
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
rtg
::
shape
::
float_type
;
break
;
case
onnx
::
TensorProto
::
UINT8
:
break
;
// throw std::runtime_error("Unsupported type UINT8");
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
rtg
::
shape
::
int8_type
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
rtg
::
shape
::
uint16_type
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
rtg
::
shape
::
int16_type
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
rtg
::
shape
::
int32_type
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
rtg
::
shape
::
int64_type
;
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
rtg
::
shape
::
int8_type
;
break
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
rtg
::
shape
::
uint16_type
;
break
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
rtg
::
shape
::
int16_type
;
break
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
rtg
::
shape
::
int32_type
;
break
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
rtg
::
shape
::
int64_type
;
break
;
case
onnx
::
TensorProto
::
STRING
:
break
;
// throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
BOOL
:
break
;
// throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
FLOAT16
:
break
;
// throw std::runtime_error("Unsupported type FLOAT16");
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
rtg
::
shape
::
double_type
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
rtg
::
shape
::
uint32_type
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
rtg
::
shape
::
uint64_type
;
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
rtg
::
shape
::
double_type
;
break
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
rtg
::
shape
::
uint32_type
;
break
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
rtg
::
shape
::
uint64_type
;
break
;
case
onnx
::
TensorProto
::
COMPLEX64
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX64");
case
onnx
::
TensorProto
::
COMPLEX128
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment