Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c6978f5d
Commit
c6978f5d
authored
Jun 21, 2018
by
Paul
Browse files
Add extra checks in onnx parser
parent
2265e0d8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
11 deletions
+25
-11
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+25
-11
No files found.
src/onnx/onnx.cpp
View file @
c6978f5d
...
@@ -24,7 +24,10 @@ struct unknown
...
@@ -24,7 +24,10 @@ struct unknown
else
else
return
input
.
front
();
return
input
.
front
();
}
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
argument
compute
(
shape
,
std
::
vector
<
argument
>
)
const
{
RTG_THROW
(
"not computable"
);
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
unknown
&
x
)
{
{
os
<<
x
.
name
();
os
<<
x
.
name
();
...
@@ -48,7 +51,8 @@ struct onnx_parser
...
@@ -48,7 +51,8 @@ struct onnx_parser
{
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
node_map
nodes
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
program
prog
=
program
();
program
prog
=
program
();
...
@@ -161,7 +165,7 @@ struct onnx_parser
...
@@ -161,7 +165,7 @@ struct onnx_parser
{
{
const
std
::
string
&
name
=
input
.
name
();
const
std
::
string
&
name
=
input
.
name
();
// TODO: Get shape of input parameter
// TODO: Get shape of input parameter
shape
s
=
parse_type
(
input
.
type
());
shape
s
=
parse_type
(
input
.
type
());
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
}
}
for
(
auto
&&
p
:
nodes
)
for
(
auto
&&
p
:
nodes
)
...
@@ -172,6 +176,7 @@ struct onnx_parser
...
@@ -172,6 +176,7 @@ struct onnx_parser
void
parse_node
(
std
::
string
name
)
void
parse_node
(
std
::
string
name
)
{
{
if
(
name
.
empty
())
RTG_THROW
(
"Onnx node must have a name"
);
if
(
instructions
.
count
(
name
)
==
0
)
if
(
instructions
.
count
(
name
)
==
0
)
{
{
auto
&&
node
=
nodes
.
at
(
name
);
auto
&&
node
=
nodes
.
at
(
name
);
...
@@ -181,6 +186,7 @@ struct onnx_parser
...
@@ -181,6 +186,7 @@ struct onnx_parser
if
(
nodes
.
count
(
input
)
>
0
)
if
(
nodes
.
count
(
input
)
>
0
)
{
{
auto
&&
iname
=
nodes
.
at
(
input
).
name
();
auto
&&
iname
=
nodes
.
at
(
input
).
name
();
assert
(
name
!=
iname
);
this
->
parse_node
(
iname
);
this
->
parse_node
(
iname
);
args
.
push_back
(
instructions
.
at
(
iname
));
args
.
push_back
(
instructions
.
at
(
iname
));
}
}
...
@@ -241,7 +247,8 @@ struct onnx_parser
...
@@ -241,7 +247,8 @@ struct onnx_parser
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
FLOATS
:
return
from_repeated
(
shape
::
float_type
,
attr
.
floats
());
case
onnx
::
AttributeProto
::
FLOATS
:
return
from_repeated
(
shape
::
float_type
,
attr
.
floats
());
case
onnx
::
AttributeProto
::
INTS
:
return
from_repeated
(
shape
::
int64_type
,
attr
.
ints
());
case
onnx
::
AttributeProto
::
INTS
:
return
from_repeated
(
shape
::
int64_type
,
attr
.
ints
());
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
...
@@ -257,21 +264,28 @@ struct onnx_parser
...
@@ -257,21 +264,28 @@ struct onnx_parser
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
case
onnx
::
TensorProto
::
FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
return
literal
{
{
shape
::
float_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
case
onnx
::
TensorProto
::
INT8
:
return
literal
{{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
UINT16
:
case
onnx
::
TensorProto
::
UINT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT16
:
case
onnx
::
TensorProto
::
INT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT32
:
case
onnx
::
TensorProto
::
INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
INT64
:
case
onnx
::
TensorProto
::
INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
t
.
int64_data
().
begin
(),
t
.
int64_data
().
end
()};
return
literal
{
{
shape
::
int64_type
,
dims
},
t
.
int64_data
().
begin
(),
t
.
int64_data
().
end
()};
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
case
onnx
::
TensorProto
::
BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
return
literal
{
{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
DOUBLE
:
case
onnx
::
TensorProto
::
DOUBLE
:
return
literal
{
return
literal
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment