Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7ac2d6a5
Commit
7ac2d6a5
authored
Jun 21, 2018
by
Paul
Browse files
Move onnx parser to seperaate cpp file
parent
7e522c7e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
384 additions
and
352 deletions
+384
-352
src/include/rtg/onnx.hpp
src/include/rtg/onnx.hpp
+12
-0
src/onnx/CMakeLists.txt
src/onnx/CMakeLists.txt
+5
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+359
-0
src/onnx/read_onnx.cpp
src/onnx/read_onnx.cpp
+8
-351
No files found.
src/include/rtg/onnx.hpp
0 → 100644
View file @
7ac2d6a5
#ifndef GUARD_RTGLIB_ONNX_HPP
#define GUARD_RTGLIB_ONNX_HPP
#include <rtg/program.hpp>
namespace
rtg
{
program
parse_onnx
(
const
std
::
string
&
name
);
}
// namespace rtg
#endif
src/onnx/CMakeLists.txt
View file @
7ac2d6a5
...
...
@@ -6,6 +6,10 @@ target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR}
target_compile_options
(
onnx-proto PRIVATE -w
)
target_link_libraries
(
onnx-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
add_library
(
rtg_onnx onnx.cpp
)
rocm_clang_tidy_check
(
rtg_onnx
)
target_link_libraries
(
rtg_onnx onnx-proto rtg
)
add_executable
(
read_onnx read_onnx.cpp
)
rocm_clang_tidy_check
(
read_onnx
)
target_link_libraries
(
read_onnx onnx
-proto rtg
rtg_cpu
)
target_link_libraries
(
read_onnx
rtg_
onnx rtg_cpu
)
src/onnx/onnx.cpp
0 → 100644
View file @
7ac2d6a5
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
#include <rtg/fallthrough.hpp>
#include <rtg/program.hpp>
#include <rtg/operators.hpp>
namespace
rtg
{
struct
unknown
{
std
::
string
op
;
std
::
string
name
()
const
{
return
"unknown:"
+
op
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
if
(
input
.
empty
())
return
{};
else
return
input
.
front
();
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
return
os
;
}
};
template
<
class
C
,
class
T
>
bool
contains
(
C
&&
c
,
T
&&
x
)
{
return
c
.
find
(
x
)
!=
c
.
end
();
}
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
{
std
::
copy
(
r
.
begin
(),
r
.
end
(),
it
);
}
struct
onnx_parser
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
program
prog
=
program
();
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
onnx_parser
()
{
add_op
(
"Conv"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
convolution
op
;
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"dilations"
))
{
copy
(
attributes
[
"dilations"
].
ints
(),
op
.
dilation
.
begin
());
}
return
prog
.
add_instruction
(
op
,
args
);
});
add_op
(
"MatMul"
,
[
this
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
gemm
{},
args
);
});
add_op
(
"MaxPool"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
pooling
op
{
"max"
};
// for(auto&& p:attributes) std::cout << p.first << std::endl;
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"kernel_shape"
))
{
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
return
prog
.
add_instruction
(
op
,
args
);
});
add_op
(
"Relu"
,
[
this
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
activation
{
"relu"
},
args
);
});
add_op
(
"Reshape"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
reshape
op
;
literal
s
=
parse_value
(
attributes
.
at
(
"shape"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
return
prog
.
add_instruction
(
op
,
args
);
});
add_op
(
"Constant"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
)
{
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
return
prog
.
add_literal
(
v
);
});
add_op
(
"Add"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
broadcast
{
axis
},
args
);
return
prog
.
add_instruction
(
add
{},
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
add
{},
args
);
});
add_op
(
"Sub"
,
[
this
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
sub
{},
args
);
});
add_op
(
"Mul"
,
[
this
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
mul
{},
args
);
});
add_op
(
"Div"
,
[
this
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
div
{},
args
);
});
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
void
parse_from
(
std
::
istream
&
is
)
{
onnx
::
ModelProto
model
;
if
(
model
.
ParseFromIstream
(
&
is
))
{
if
(
model
.
has_graph
())
{
this
->
parse_graph
(
model
.
graph
());
}
}
else
{
throw
std
::
runtime_error
(
"Failed reading"
);
}
}
void
parse_graph
(
const
onnx
::
GraphProto
&
graph
)
{
nodes
=
get_nodes
(
graph
);
for
(
auto
&&
input
:
graph
.
input
())
{
const
std
::
string
&
name
=
input
.
name
();
// TODO: Get shape of input parameter
shape
s
=
parse_type
(
input
.
type
());
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
}
for
(
auto
&&
p
:
nodes
)
{
this
->
parse_node
(
p
.
second
.
name
());
}
}
void
parse_node
(
std
::
string
name
)
{
if
(
instructions
.
count
(
name
)
==
0
)
{
auto
&&
node
=
nodes
.
at
(
name
);
std
::
vector
<
instruction_ref
>
args
;
for
(
auto
&&
input
:
node
.
input
())
{
if
(
nodes
.
count
(
input
)
>
0
)
{
auto
&&
iname
=
nodes
.
at
(
input
).
name
();
this
->
parse_node
(
iname
);
args
.
push_back
(
instructions
.
at
(
iname
));
}
else
{
args
.
push_back
(
instructions
.
at
(
input
));
}
}
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
{
instructions
[
name
]
=
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
}
else
{
instructions
[
name
]
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
}
}
}
static
attribute_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
{
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
result
;
for
(
auto
&&
attr
:
node
.
attribute
())
{
result
[
attr
.
name
()]
=
attr
;
}
return
result
;
}
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
{
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
for
(
auto
&&
node
:
graph
.
node
())
{
result
[
node
.
name
()]
=
node
;
for
(
auto
&&
output
:
node
.
output
())
{
result
[
output
]
=
node
;
}
}
return
result
;
}
template
<
class
T
>
static
literal
from_repeated
(
shape
::
type_t
t
,
const
T
&
r
)
{
std
::
size_t
size
=
r
.
size
();
return
literal
{{
t
,
{
size
}},
r
.
begin
(),
r
.
end
()};
}
static
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
{
switch
(
attr
.
type
())
{
case
onnx
::
AttributeProto
::
UNDEFINED
:
return
{};
case
onnx
::
AttributeProto
::
FLOAT
:
return
literal
{
attr
.
f
()};
case
onnx
::
AttributeProto
::
INT
:
return
literal
{
attr
.
i
()};
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
FLOATS
:
return
from_repeated
(
shape
::
float_type
,
attr
.
floats
());
case
onnx
::
AttributeProto
::
INTS
:
return
from_repeated
(
shape
::
int64_type
,
attr
.
ints
());
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
}
RTG_THROW
(
"Invalid attribute type"
);
}
static
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
switch
(
t
.
data_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
literal
{
{
shape
::
float_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
UINT16
:
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT16
:
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT32
:
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT64
:
return
literal
{
{
shape
::
int64_type
,
dims
},
t
.
int64_data
().
begin
(),
t
.
int64_data
().
end
()};
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
DOUBLE
:
return
literal
{
{
shape
::
double_type
,
dims
},
t
.
double_data
().
begin
(),
t
.
double_data
().
end
()};
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
RTG_THROW
(
"Invalid tensor type"
);
}
static
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
{
shape
::
type_t
shape_type
{};
switch
(
t
.
tensor_type
().
elem_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
break
;
// throw std::runtime_error("Unsupported type UNDEFINED");
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
shape
::
float_type
;
break
;
case
onnx
::
TensorProto
::
UINT8
:
break
;
// throw std::runtime_error("Unsupported type UINT8");
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
shape
::
int8_type
;
break
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
shape
::
uint16_type
;
break
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
shape
::
int16_type
;
break
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
shape
::
int32_type
;
break
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
shape
::
int64_type
;
break
;
case
onnx
::
TensorProto
::
STRING
:
break
;
// throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
BOOL
:
break
;
// throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
FLOAT16
:
break
;
// throw std::runtime_error("Unsupported type FLOAT16");
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
shape
::
double_type
;
break
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
case
onnx
::
TensorProto
::
COMPLEX64
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX64");
case
onnx
::
TensorProto
::
COMPLEX128
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX128");
}
std
::
vector
<
std
::
size_t
>
dims
;
// TODO: USe std::transform
for
(
auto
&&
d
:
t
.
tensor_type
().
shape
().
dim
())
{
dims
.
push_back
(
d
.
dim_value
());
}
return
{
shape_type
,
dims
};
}
};
program
parse_onnx
(
const
std
::
string
&
name
)
{
std
::
fstream
input
(
name
.
c_str
(),
std
::
ios
::
in
|
std
::
ios
::
binary
);
onnx_parser
parser
;
#ifndef NDEBUG
// Log the program when it can't be parsed
try
{
parser
.
parse_from
(
input
);
}
catch
(...)
{
std
::
cerr
<<
parser
.
prog
<<
std
::
endl
;
throw
;
}
#else
parser
.
parse_from
(
input
);
#endif
return
std
::
move
(
parser
.
prog
);
}
}
// namespace rtg
src/onnx/read_onnx.cpp
View file @
7ac2d6a5
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
#include <rtg/fallthrough.hpp>
#include <rtg/program.hpp>
#include <rtg/operators.hpp>
#include <rtg/onnx.hpp>
#include <rtg/cpu/cpu_target.hpp>
#include <random>
struct
unknown
{
std
::
string
op
;
std
::
string
name
()
const
{
return
"unknown:"
+
op
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
input
)
const
{
if
(
input
.
empty
())
return
{};
else
return
input
.
front
();
}
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
return
os
;
}
};
template
<
class
C
,
class
T
>
bool
contains
(
C
&&
c
,
T
&&
x
)
{
return
c
.
find
(
x
)
!=
c
.
end
();
}
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
{
std
::
copy
(
r
.
begin
(),
r
.
end
(),
it
);
}
struct
onnx_parser
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
op_func
=
std
::
function
<
rtg
::
instruction_ref
(
attribute_map
,
std
::
vector
<
rtg
::
instruction_ref
>
)
>
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
rtg
::
instruction_ref
>
instructions
;
rtg
::
program
prog
=
rtg
::
program
();
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
onnx_parser
()
{
add_op
(
"Conv"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
rtg
::
convolution
op
;
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"dilations"
))
{
copy
(
attributes
[
"dilations"
].
ints
(),
op
.
dilation
.
begin
());
}
return
prog
.
add_instruction
(
op
,
args
);
});
add_op
(
"MatMul"
,
[
this
](
attribute_map
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
rtg
::
gemm
{},
args
);
});
add_op
(
"MaxPool"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
rtg
::
pooling
op
{
"max"
};
// for(auto&& p:attributes) std::cout << p.first << std::endl;
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"kernel_shape"
))
{
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
return
prog
.
add_instruction
(
op
,
args
);
});
add_op
(
"Relu"
,
[
this
](
attribute_map
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
rtg
::
activation
{
"relu"
},
args
);
});
add_op
(
"Reshape"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
rtg
::
reshape
op
;
rtg
::
literal
s
=
parse_value
(
attributes
.
at
(
"shape"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
return
prog
.
add_instruction
(
op
,
args
);
});
add_op
(
"Constant"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction_ref
>
)
{
rtg
::
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
return
prog
.
add_literal
(
v
);
});
add_op
(
"Add"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
if
(
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcast
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcast
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
rtg
::
broadcast
{
axis
},
args
);
return
prog
.
add_instruction
(
rtg
::
add
{},
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
rtg
::
add
{},
args
);
});
add_op
(
"Sub"
,
[
this
](
attribute_map
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
rtg
::
sub
{},
args
);
});
add_op
(
"Mul"
,
[
this
](
attribute_map
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
rtg
::
mul
{},
args
);
});
add_op
(
"Div"
,
[
this
](
attribute_map
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
rtg
::
div
{},
args
);
});
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
void
parse_from
(
std
::
istream
&
is
)
{
onnx
::
ModelProto
model
;
if
(
model
.
ParseFromIstream
(
&
is
))
{
if
(
model
.
has_graph
())
{
this
->
parse_graph
(
model
.
graph
());
}
}
else
{
throw
std
::
runtime_error
(
"Failed reading"
);
}
}
void
parse_graph
(
const
onnx
::
GraphProto
&
graph
)
{
nodes
=
get_nodes
(
graph
);
for
(
auto
&&
input
:
graph
.
input
())
{
const
std
::
string
&
name
=
input
.
name
();
// TODO: Get shape of input parameter
rtg
::
shape
s
=
parse_type
(
input
.
type
());
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
}
for
(
auto
&&
p
:
nodes
)
{
this
->
parse_node
(
p
.
second
.
name
());
}
}
void
parse_node
(
std
::
string
name
)
{
if
(
instructions
.
count
(
name
)
==
0
)
{
auto
&&
node
=
nodes
.
at
(
name
);
std
::
vector
<
rtg
::
instruction_ref
>
args
;
for
(
auto
&&
input
:
node
.
input
())
{
if
(
nodes
.
count
(
input
)
>
0
)
{
auto
&&
iname
=
nodes
.
at
(
input
).
name
();
this
->
parse_node
(
iname
);
args
.
push_back
(
instructions
.
at
(
iname
));
}
else
{
args
.
push_back
(
instructions
.
at
(
input
));
}
}
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
{
instructions
[
name
]
=
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
}
else
{
instructions
[
name
]
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
}
}
}
static
attribute_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
{
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
result
;
for
(
auto
&&
attr
:
node
.
attribute
())
{
result
[
attr
.
name
()]
=
attr
;
}
return
result
;
}
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
{
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
for
(
auto
&&
node
:
graph
.
node
())
{
result
[
node
.
name
()]
=
node
;
for
(
auto
&&
output
:
node
.
output
())
{
result
[
output
]
=
node
;
}
}
return
result
;
}
template
<
class
T
>
static
rtg
::
literal
from_repeated
(
rtg
::
shape
::
type_t
t
,
const
T
&
r
)
{
std
::
size_t
size
=
r
.
size
();
return
rtg
::
literal
{{
t
,
{
size
}},
r
.
begin
(),
r
.
end
()};
}
static
rtg
::
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
{
switch
(
attr
.
type
())
{
case
onnx
::
AttributeProto
::
UNDEFINED
:
return
{};
case
onnx
::
AttributeProto
::
FLOAT
:
return
rtg
::
literal
{
attr
.
f
()};
case
onnx
::
AttributeProto
::
INT
:
return
rtg
::
literal
{
attr
.
i
()};
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
FLOATS
:
return
from_repeated
(
rtg
::
shape
::
float_type
,
attr
.
floats
());
case
onnx
::
AttributeProto
::
INTS
:
return
from_repeated
(
rtg
::
shape
::
int64_type
,
attr
.
ints
());
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
}
RTG_THROW
(
"Invalid attribute type"
);
}
static
rtg
::
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
switch
(
t
.
data_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
rtg
::
literal
{
{
rtg
::
shape
::
float_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
UINT16
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT16
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT32
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT64
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int64_type
,
dims
},
t
.
int64_data
().
begin
(),
t
.
int64_data
().
end
()};
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
DOUBLE
:
return
rtg
::
literal
{
{
rtg
::
shape
::
double_type
,
dims
},
t
.
double_data
().
begin
(),
t
.
double_data
().
end
()};
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
RTG_THROW
(
"Invalid tensor type"
);
}
static
rtg
::
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
{
rtg
::
shape
::
type_t
shape_type
{};
switch
(
t
.
tensor_type
().
elem_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
break
;
// throw std::runtime_error("Unsupported type UNDEFINED");
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
rtg
::
shape
::
float_type
;
break
;
case
onnx
::
TensorProto
::
UINT8
:
break
;
// throw std::runtime_error("Unsupported type UINT8");
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
rtg
::
shape
::
int8_type
;
break
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
rtg
::
shape
::
uint16_type
;
break
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
rtg
::
shape
::
int16_type
;
break
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
rtg
::
shape
::
int32_type
;
break
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
rtg
::
shape
::
int64_type
;
break
;
case
onnx
::
TensorProto
::
STRING
:
break
;
// throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
BOOL
:
break
;
// throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
FLOAT16
:
break
;
// throw std::runtime_error("Unsupported type FLOAT16");
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
rtg
::
shape
::
double_type
;
break
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
rtg
::
shape
::
uint32_type
;
break
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
rtg
::
shape
::
uint64_type
;
break
;
case
onnx
::
TensorProto
::
COMPLEX64
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX64");
case
onnx
::
TensorProto
::
COMPLEX128
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX128");
}
std
::
vector
<
std
::
size_t
>
dims
;
// TODO: USe std::transform
for
(
auto
&&
d
:
t
.
tensor_type
().
shape
().
dim
())
{
dims
.
push_back
(
d
.
dim_value
());
}
return
{
shape_type
,
dims
};
}
};
// TODO: Move this to a seperate header
std
::
vector
<
float
>
get_tensor_data
(
rtg
::
shape
s
)
{
...
...
@@ -358,22 +25,12 @@ int main(int argc, char const* argv[])
if
(
argc
>
1
)
{
std
::
string
file
=
argv
[
1
];
std
::
fstream
input
(
file
.
c_str
(),
std
::
ios
::
in
|
std
::
ios
::
binary
);
onnx_parser
parser
;
try
{
parser
.
parse_from
(
input
);
parser
.
prog
.
compile
(
rtg
::
cpu
::
cpu_target
{});
auto
s
=
parser
.
prog
.
get_parameter_shape
(
"Input3"
);
auto
input3
=
get_tensor_argument
(
s
);
auto
out
=
parser
.
prog
.
eval
({{
"Input3"
,
input3
}});
(
void
)
out
;
}
catch
(...)
{
std
::
cout
<<
parser
.
prog
<<
std
::
endl
;
throw
;
}
std
::
cout
<<
parser
.
prog
<<
std
::
endl
;
auto
prog
=
rtg
::
parse_onnx
(
file
);
prog
.
compile
(
rtg
::
cpu
::
cpu_target
{});
auto
s
=
prog
.
get_parameter_shape
(
"Input3"
);
auto
input3
=
get_tensor_argument
(
s
);
auto
out
=
prog
.
eval
({{
"Input3"
,
input3
}});
(
void
)
out
;
std
::
cout
<<
prog
<<
std
::
endl
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment