Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7ac2d6a5
Commit
7ac2d6a5
authored
Jun 21, 2018
by
Paul
Browse files
Move onnx parser to seperaate cpp file
parent
7e522c7e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
384 additions
and
352 deletions
+384
-352
src/include/rtg/onnx.hpp
src/include/rtg/onnx.hpp
+12
-0
src/onnx/CMakeLists.txt
src/onnx/CMakeLists.txt
+5
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+359
-0
src/onnx/read_onnx.cpp
src/onnx/read_onnx.cpp
+8
-351
No files found.
src/include/rtg/onnx.hpp
0 → 100644
View file @
7ac2d6a5
#ifndef GUARD_RTGLIB_ONNX_HPP
#define GUARD_RTGLIB_ONNX_HPP
#include <rtg/program.hpp>
namespace
rtg
{
program
parse_onnx
(
const
std
::
string
&
name
);
}
// namespace rtg
#endif
src/onnx/CMakeLists.txt
View file @
7ac2d6a5
...
@@ -6,6 +6,10 @@ target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR}
...
@@ -6,6 +6,10 @@ target_include_directories(onnx-proto SYSTEM PUBLIC ${CMAKE_CURRENT_BINARY_DIR}
target_compile_options
(
onnx-proto PRIVATE -w
)
target_compile_options
(
onnx-proto PRIVATE -w
)
target_link_libraries
(
onnx-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
target_link_libraries
(
onnx-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
add_library
(
rtg_onnx onnx.cpp
)
rocm_clang_tidy_check
(
rtg_onnx
)
target_link_libraries
(
rtg_onnx onnx-proto rtg
)
add_executable
(
read_onnx read_onnx.cpp
)
add_executable
(
read_onnx read_onnx.cpp
)
rocm_clang_tidy_check
(
read_onnx
)
rocm_clang_tidy_check
(
read_onnx
)
target_link_libraries
(
read_onnx onnx
-proto rtg
rtg_cpu
)
target_link_libraries
(
read_onnx
rtg_
onnx rtg_cpu
)
src/onnx/onnx.cpp
0 → 100644
View file @
7ac2d6a5
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
#include <rtg/fallthrough.hpp>
#include <rtg/program.hpp>
#include <rtg/operators.hpp>
namespace
rtg
{
struct
unknown
{
std
::
string
op
;
std
::
string
name
()
const
{
return
"unknown:"
+
op
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
input
)
const
{
if
(
input
.
empty
())
return
{};
else
return
input
.
front
();
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
return
os
;
}
};
template
<
class
C
,
class
T
>
bool
contains
(
C
&&
c
,
T
&&
x
)
{
return
c
.
find
(
x
)
!=
c
.
end
();
}
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
{
std
::
copy
(
r
.
begin
(),
r
.
end
(),
it
);
}
struct
onnx_parser
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
program
prog
=
program
();
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
onnx_parser
()
{
add_op
(
"Conv"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
convolution
op
;
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"dilations"
))
{
copy
(
attributes
[
"dilations"
].
ints
(),
op
.
dilation
.
begin
());
}
return
prog
.
add_instruction
(
op
,
args
);
});
add_op
(
"MatMul"
,
[
this
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
gemm
{},
args
);
});
add_op
(
"MaxPool"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
pooling
op
{
"max"
};
// for(auto&& p:attributes) std::cout << p.first << std::endl;
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"kernel_shape"
))
{
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
return
prog
.
add_instruction
(
op
,
args
);
});
add_op
(
"Relu"
,
[
this
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
activation
{
"relu"
},
args
);
});
add_op
(
"Reshape"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
reshape
op
;
literal
s
=
parse_value
(
attributes
.
at
(
"shape"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
return
prog
.
add_instruction
(
op
,
args
);
});
add_op
(
"Constant"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
)
{
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
return
prog
.
add_literal
(
v
);
});
add_op
(
"Add"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcasted
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcasted
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
broadcast
{
axis
},
args
);
return
prog
.
add_instruction
(
add
{},
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
add
{},
args
);
});
add_op
(
"Sub"
,
[
this
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
sub
{},
args
);
});
add_op
(
"Mul"
,
[
this
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
mul
{},
args
);
});
add_op
(
"Div"
,
[
this
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
div
{},
args
);
});
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
void
parse_from
(
std
::
istream
&
is
)
{
onnx
::
ModelProto
model
;
if
(
model
.
ParseFromIstream
(
&
is
))
{
if
(
model
.
has_graph
())
{
this
->
parse_graph
(
model
.
graph
());
}
}
else
{
throw
std
::
runtime_error
(
"Failed reading"
);
}
}
void
parse_graph
(
const
onnx
::
GraphProto
&
graph
)
{
nodes
=
get_nodes
(
graph
);
for
(
auto
&&
input
:
graph
.
input
())
{
const
std
::
string
&
name
=
input
.
name
();
// TODO: Get shape of input parameter
shape
s
=
parse_type
(
input
.
type
());
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
}
for
(
auto
&&
p
:
nodes
)
{
this
->
parse_node
(
p
.
second
.
name
());
}
}
void
parse_node
(
std
::
string
name
)
{
if
(
instructions
.
count
(
name
)
==
0
)
{
auto
&&
node
=
nodes
.
at
(
name
);
std
::
vector
<
instruction_ref
>
args
;
for
(
auto
&&
input
:
node
.
input
())
{
if
(
nodes
.
count
(
input
)
>
0
)
{
auto
&&
iname
=
nodes
.
at
(
input
).
name
();
this
->
parse_node
(
iname
);
args
.
push_back
(
instructions
.
at
(
iname
));
}
else
{
args
.
push_back
(
instructions
.
at
(
input
));
}
}
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
{
instructions
[
name
]
=
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
}
else
{
instructions
[
name
]
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
}
}
}
static
attribute_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
{
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
result
;
for
(
auto
&&
attr
:
node
.
attribute
())
{
result
[
attr
.
name
()]
=
attr
;
}
return
result
;
}
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
{
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
for
(
auto
&&
node
:
graph
.
node
())
{
result
[
node
.
name
()]
=
node
;
for
(
auto
&&
output
:
node
.
output
())
{
result
[
output
]
=
node
;
}
}
return
result
;
}
template
<
class
T
>
static
literal
from_repeated
(
shape
::
type_t
t
,
const
T
&
r
)
{
std
::
size_t
size
=
r
.
size
();
return
literal
{{
t
,
{
size
}},
r
.
begin
(),
r
.
end
()};
}
static
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
{
switch
(
attr
.
type
())
{
case
onnx
::
AttributeProto
::
UNDEFINED
:
return
{};
case
onnx
::
AttributeProto
::
FLOAT
:
return
literal
{
attr
.
f
()};
case
onnx
::
AttributeProto
::
INT
:
return
literal
{
attr
.
i
()};
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
FLOATS
:
return
from_repeated
(
shape
::
float_type
,
attr
.
floats
());
case
onnx
::
AttributeProto
::
INTS
:
return
from_repeated
(
shape
::
int64_type
,
attr
.
ints
());
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
}
RTG_THROW
(
"Invalid attribute type"
);
}
static
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
switch
(
t
.
data_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
literal
{
{
shape
::
float_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
UINT16
:
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT16
:
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT32
:
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT64
:
return
literal
{
{
shape
::
int64_type
,
dims
},
t
.
int64_data
().
begin
(),
t
.
int64_data
().
end
()};
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
DOUBLE
:
return
literal
{
{
shape
::
double_type
,
dims
},
t
.
double_data
().
begin
(),
t
.
double_data
().
end
()};
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
RTG_THROW
(
"Invalid tensor type"
);
}
static
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
{
shape
::
type_t
shape_type
{};
switch
(
t
.
tensor_type
().
elem_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
break
;
// throw std::runtime_error("Unsupported type UNDEFINED");
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
shape
::
float_type
;
break
;
case
onnx
::
TensorProto
::
UINT8
:
break
;
// throw std::runtime_error("Unsupported type UINT8");
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
shape
::
int8_type
;
break
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
shape
::
uint16_type
;
break
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
shape
::
int16_type
;
break
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
shape
::
int32_type
;
break
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
shape
::
int64_type
;
break
;
case
onnx
::
TensorProto
::
STRING
:
break
;
// throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
BOOL
:
break
;
// throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
FLOAT16
:
break
;
// throw std::runtime_error("Unsupported type FLOAT16");
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
shape
::
double_type
;
break
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
case
onnx
::
TensorProto
::
COMPLEX64
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX64");
case
onnx
::
TensorProto
::
COMPLEX128
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX128");
}
std
::
vector
<
std
::
size_t
>
dims
;
// TODO: USe std::transform
for
(
auto
&&
d
:
t
.
tensor_type
().
shape
().
dim
())
{
dims
.
push_back
(
d
.
dim_value
());
}
return
{
shape_type
,
dims
};
}
};
program
parse_onnx
(
const
std
::
string
&
name
)
{
std
::
fstream
input
(
name
.
c_str
(),
std
::
ios
::
in
|
std
::
ios
::
binary
);
onnx_parser
parser
;
#ifndef NDEBUG
// Log the program when it can't be parsed
try
{
parser
.
parse_from
(
input
);
}
catch
(...)
{
std
::
cerr
<<
parser
.
prog
<<
std
::
endl
;
throw
;
}
#else
parser
.
parse_from
(
input
);
#endif
return
std
::
move
(
parser
.
prog
);
}
}
// namespace rtg
src/onnx/read_onnx.cpp
View file @
7ac2d6a5
#include <google/protobuf/text_format.h>
#include <rtg/onnx.hpp>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <onnx.pb.h>
#include <iostream>
#include <fstream>
#include <unordered_map>
#include <functional>
#include <array>
#include <rtg/fallthrough.hpp>
#include <rtg/program.hpp>
#include <rtg/operators.hpp>
#include <rtg/cpu/cpu_target.hpp>
#include <rtg/cpu/cpu_target.hpp>
#include <random>
#include <random>
struct
unknown
{
std
::
string
op
;
std
::
string
name
()
const
{
return
"unknown:"
+
op
;
}
rtg
::
shape
compute_shape
(
std
::
vector
<
rtg
::
shape
>
input
)
const
{
if
(
input
.
empty
())
return
{};
else
return
input
.
front
();
}
rtg
::
argument
compute
(
rtg
::
shape
,
std
::
vector
<
rtg
::
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
os
<<
x
.
name
();
return
os
;
}
};
template
<
class
C
,
class
T
>
bool
contains
(
C
&&
c
,
T
&&
x
)
{
return
c
.
find
(
x
)
!=
c
.
end
();
}
template
<
class
Range
,
class
Iterator
>
void
copy
(
Range
&&
r
,
Iterator
it
)
{
std
::
copy
(
r
.
begin
(),
r
.
end
(),
it
);
}
struct
onnx_parser
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
op_func
=
std
::
function
<
rtg
::
instruction_ref
(
attribute_map
,
std
::
vector
<
rtg
::
instruction_ref
>
)
>
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
rtg
::
instruction_ref
>
instructions
;
rtg
::
program
prog
=
rtg
::
program
();
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
onnx_parser
()
{
add_op
(
"Conv"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
rtg
::
convolution
op
;
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"dilations"
))
{
copy
(
attributes
[
"dilations"
].
ints
(),
op
.
dilation
.
begin
());
}
return
prog
.
add_instruction
(
op
,
args
);
});
add_op
(
"MatMul"
,
[
this
](
attribute_map
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
rtg
::
gemm
{},
args
);
});
add_op
(
"MaxPool"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
rtg
::
pooling
op
{
"max"
};
// for(auto&& p:attributes) std::cout << p.first << std::endl;
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"kernel_shape"
))
{
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
return
prog
.
add_instruction
(
op
,
args
);
});
add_op
(
"Relu"
,
[
this
](
attribute_map
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
rtg
::
activation
{
"relu"
},
args
);
});
add_op
(
"Reshape"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
rtg
::
reshape
op
;
rtg
::
literal
s
=
parse_value
(
attributes
.
at
(
"shape"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
return
prog
.
add_instruction
(
op
,
args
);
});
add_op
(
"Constant"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction_ref
>
)
{
rtg
::
literal
v
=
parse_value
(
attributes
.
at
(
"value"
));
return
prog
.
add_literal
(
v
);
});
add_op
(
"Add"
,
[
this
](
attribute_map
attributes
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
if
(
contains
(
attributes
,
"broadcast"
))
{
uint64_t
broadcast
=
parse_value
(
attributes
.
at
(
"broadcast"
)).
at
<
uint64_t
>
();
if
(
broadcast
!=
0
)
{
uint64_t
axis
=
(
contains
(
attributes
,
"axis"
))
?
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
uint64_t
>
()
:
0
;
auto
l
=
prog
.
add_instruction
(
rtg
::
broadcast
{
axis
},
args
);
return
prog
.
add_instruction
(
rtg
::
add
{},
args
[
0
],
l
);
}
}
return
prog
.
add_instruction
(
rtg
::
add
{},
args
);
});
add_op
(
"Sub"
,
[
this
](
attribute_map
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
rtg
::
sub
{},
args
);
});
add_op
(
"Mul"
,
[
this
](
attribute_map
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
rtg
::
mul
{},
args
);
});
add_op
(
"Div"
,
[
this
](
attribute_map
,
std
::
vector
<
rtg
::
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
rtg
::
div
{},
args
);
});
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
void
parse_from
(
std
::
istream
&
is
)
{
onnx
::
ModelProto
model
;
if
(
model
.
ParseFromIstream
(
&
is
))
{
if
(
model
.
has_graph
())
{
this
->
parse_graph
(
model
.
graph
());
}
}
else
{
throw
std
::
runtime_error
(
"Failed reading"
);
}
}
void
parse_graph
(
const
onnx
::
GraphProto
&
graph
)
{
nodes
=
get_nodes
(
graph
);
for
(
auto
&&
input
:
graph
.
input
())
{
const
std
::
string
&
name
=
input
.
name
();
// TODO: Get shape of input parameter
rtg
::
shape
s
=
parse_type
(
input
.
type
());
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
}
for
(
auto
&&
p
:
nodes
)
{
this
->
parse_node
(
p
.
second
.
name
());
}
}
void
parse_node
(
std
::
string
name
)
{
if
(
instructions
.
count
(
name
)
==
0
)
{
auto
&&
node
=
nodes
.
at
(
name
);
std
::
vector
<
rtg
::
instruction_ref
>
args
;
for
(
auto
&&
input
:
node
.
input
())
{
if
(
nodes
.
count
(
input
)
>
0
)
{
auto
&&
iname
=
nodes
.
at
(
input
).
name
();
this
->
parse_node
(
iname
);
args
.
push_back
(
instructions
.
at
(
iname
));
}
else
{
args
.
push_back
(
instructions
.
at
(
input
));
}
}
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
{
instructions
[
name
]
=
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
}
else
{
instructions
[
name
]
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
}
}
}
static
attribute_map
get_attributes
(
const
onnx
::
NodeProto
&
node
)
{
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
result
;
for
(
auto
&&
attr
:
node
.
attribute
())
{
result
[
attr
.
name
()]
=
attr
;
}
return
result
;
}
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
{
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
for
(
auto
&&
node
:
graph
.
node
())
{
result
[
node
.
name
()]
=
node
;
for
(
auto
&&
output
:
node
.
output
())
{
result
[
output
]
=
node
;
}
}
return
result
;
}
template
<
class
T
>
static
rtg
::
literal
from_repeated
(
rtg
::
shape
::
type_t
t
,
const
T
&
r
)
{
std
::
size_t
size
=
r
.
size
();
return
rtg
::
literal
{{
t
,
{
size
}},
r
.
begin
(),
r
.
end
()};
}
static
rtg
::
literal
parse_value
(
const
onnx
::
AttributeProto
&
attr
)
{
switch
(
attr
.
type
())
{
case
onnx
::
AttributeProto
::
UNDEFINED
:
return
{};
case
onnx
::
AttributeProto
::
FLOAT
:
return
rtg
::
literal
{
attr
.
f
()};
case
onnx
::
AttributeProto
::
INT
:
return
rtg
::
literal
{
attr
.
i
()};
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
FLOATS
:
return
from_repeated
(
rtg
::
shape
::
float_type
,
attr
.
floats
());
case
onnx
::
AttributeProto
::
INTS
:
return
from_repeated
(
rtg
::
shape
::
int64_type
,
attr
.
ints
());
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
}
RTG_THROW
(
"Invalid attribute type"
);
}
static
rtg
::
literal
parse_tensor
(
const
onnx
::
TensorProto
&
t
)
{
std
::
vector
<
std
::
size_t
>
dims
(
t
.
dims
().
begin
(),
t
.
dims
().
end
());
switch
(
t
.
data_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
rtg
::
literal
{
{
rtg
::
shape
::
float_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
UINT16
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT16
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT32
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT64
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int64_type
,
dims
},
t
.
int64_data
().
begin
(),
t
.
int64_data
().
end
()};
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
rtg
::
literal
{
{
rtg
::
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
DOUBLE
:
return
rtg
::
literal
{
{
rtg
::
shape
::
double_type
,
dims
},
t
.
double_data
().
begin
(),
t
.
double_data
().
end
()};
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
RTG_THROW
(
"Invalid tensor type"
);
}
static
rtg
::
shape
parse_type
(
const
onnx
::
TypeProto
&
t
)
{
rtg
::
shape
::
type_t
shape_type
{};
switch
(
t
.
tensor_type
().
elem_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
break
;
// throw std::runtime_error("Unsupported type UNDEFINED");
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
rtg
::
shape
::
float_type
;
break
;
case
onnx
::
TensorProto
::
UINT8
:
break
;
// throw std::runtime_error("Unsupported type UINT8");
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
rtg
::
shape
::
int8_type
;
break
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
rtg
::
shape
::
uint16_type
;
break
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
rtg
::
shape
::
int16_type
;
break
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
rtg
::
shape
::
int32_type
;
break
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
rtg
::
shape
::
int64_type
;
break
;
case
onnx
::
TensorProto
::
STRING
:
break
;
// throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
BOOL
:
break
;
// throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
FLOAT16
:
break
;
// throw std::runtime_error("Unsupported type FLOAT16");
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
rtg
::
shape
::
double_type
;
break
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
rtg
::
shape
::
uint32_type
;
break
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
rtg
::
shape
::
uint64_type
;
break
;
case
onnx
::
TensorProto
::
COMPLEX64
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX64");
case
onnx
::
TensorProto
::
COMPLEX128
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX128");
}
std
::
vector
<
std
::
size_t
>
dims
;
// TODO: USe std::transform
for
(
auto
&&
d
:
t
.
tensor_type
().
shape
().
dim
())
{
dims
.
push_back
(
d
.
dim_value
());
}
return
{
shape_type
,
dims
};
}
};
// TODO: Move this to a seperate header
// TODO: Move this to a seperate header
std
::
vector
<
float
>
get_tensor_data
(
rtg
::
shape
s
)
std
::
vector
<
float
>
get_tensor_data
(
rtg
::
shape
s
)
{
{
...
@@ -358,22 +25,12 @@ int main(int argc, char const* argv[])
...
@@ -358,22 +25,12 @@ int main(int argc, char const* argv[])
if
(
argc
>
1
)
if
(
argc
>
1
)
{
{
std
::
string
file
=
argv
[
1
];
std
::
string
file
=
argv
[
1
];
std
::
fstream
input
(
file
.
c_str
(),
std
::
ios
::
in
|
std
::
ios
::
binary
);
auto
prog
=
rtg
::
parse_onnx
(
file
);
onnx_parser
parser
;
prog
.
compile
(
rtg
::
cpu
::
cpu_target
{});
try
auto
s
=
prog
.
get_parameter_shape
(
"Input3"
);
{
auto
input3
=
get_tensor_argument
(
s
);
parser
.
parse_from
(
input
);
auto
out
=
prog
.
eval
({{
"Input3"
,
input3
}});
parser
.
prog
.
compile
(
rtg
::
cpu
::
cpu_target
{});
(
void
)
out
;
auto
s
=
parser
.
prog
.
get_parameter_shape
(
"Input3"
);
std
::
cout
<<
prog
<<
std
::
endl
;
auto
input3
=
get_tensor_argument
(
s
);
auto
out
=
parser
.
prog
.
eval
({{
"Input3"
,
input3
}});
(
void
)
out
;
}
catch
(...)
{
std
::
cout
<<
parser
.
prog
<<
std
::
endl
;
throw
;
}
std
::
cout
<<
parser
.
prog
<<
std
::
endl
;
}
}
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment