Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2d80965f
Unverified
Commit
2d80965f
authored
Jan 25, 2019
by
Paul Fultz II
Committed by
GitHub
Jan 25, 2019
Browse files
Merge pull request #157 from ROCmSoftwarePlatform/onnx-multi-out
Add support for multiple outputs in onnx parser
parents
3e9358c2
066b62e0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
27 deletions
+50
-27
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+50
-27
No files found.
src/onnx/onnx.cpp
View file @
2d80965f
...
...
@@ -24,7 +24,8 @@ struct onnx_parser
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
using
op_func
=
std
::
function
<
std
::
vector
<
instruction_ref
>
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
program
prog
=
program
();
...
...
@@ -88,6 +89,15 @@ struct onnx_parser
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
vector
<
instruction_ref
>
{
f
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...)};
});
}
// Multi output op
template
<
class
F
>
void
add_multi_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
...
...
@@ -95,7 +105,7 @@ struct onnx_parser
template
<
class
F
>
void
add_mem_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
[
=
](
auto
&&
...
xs
)
{
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
}
...
...
@@ -103,7 +113,7 @@ struct onnx_parser
template
<
class
T
>
void
add_binary_op
(
std
::
string
name
,
T
x
)
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
if
(
contains
(
attributes
,
"broadcast"
)
and
contains
(
attributes
,
"axis"
))
...
...
@@ -172,7 +182,7 @@ struct onnx_parser
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
});
}
...
...
@@ -180,7 +190,7 @@ struct onnx_parser
template
<
class
T
>
void
add_variadic_op
(
std
::
string
name
,
T
x
)
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
std
::
accumulate
(
std
::
next
(
args
.
begin
()),
args
.
end
(),
args
.
front
(),
...
...
@@ -643,7 +653,7 @@ struct onnx_parser
}
else
{
throw
std
::
runtime_error
(
"Failed reading"
);
MIGRAPHX_THROW
(
"Failed reading
onnx file.
"
);
}
}
...
...
@@ -673,7 +683,7 @@ struct onnx_parser
}
for
(
auto
&&
p
:
nodes
)
{
this
->
parse_node
(
get_name
(
p
.
second
)
);
this
->
parse_node
(
p
.
first
);
}
}
...
...
@@ -689,23 +699,37 @@ struct onnx_parser
{
if
(
nodes
.
count
(
input
)
>
0
)
{
auto
&&
iname
=
get_name
(
nodes
.
at
(
input
));
assert
(
name
!=
iname
);
this
->
parse_node
(
iname
);
args
.
push_back
(
instructions
.
at
(
iname
));
assert
(
name
!=
input
);
this
->
parse_node
(
input
);
args
.
push_back
(
instructions
.
at
(
input
));
}
else
{
args
.
push_back
(
instructions
.
at
(
input
));
}
}
std
::
vector
<
instruction_ref
>
result
;
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
{
instructions
[
name
]
=
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
result
.
push_back
(
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
)
)
;
}
else
{
instructions
[
name
]
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
result
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
}
// Even no output nodes produce output in migraphx
if
(
node
.
output
().
empty
()
and
result
.
size
()
==
1
)
{
instructions
[
name
]
=
result
.
front
();
}
else
{
assert
(
node
.
output
().
size
()
>=
result
.
size
());
std
::
transform
(
result
.
begin
(),
result
.
end
(),
node
.
output
().
begin
(),
std
::
inserter
(
instructions
,
instructions
.
end
()),
[](
auto
&&
x
,
auto
&&
y
)
{
return
std
::
make_pair
(
y
,
x
);
});
}
}
}
...
...
@@ -720,25 +744,24 @@ struct onnx_parser
return
result
;
}
static
std
::
string
get_name
(
const
onnx
::
NodeProto
&
node
)
{
if
(
node
.
name
().
empty
())
{
std
::
string
generated
=
"migraphx_unnamed_node"
;
return
std
::
accumulate
(
node
.
output
().
begin
(),
node
.
output
().
end
(),
generated
,
[](
auto
x
,
auto
y
)
{
return
x
+
"_"
+
y
;
});
}
return
node
.
name
();
}
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
{
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
std
::
size_t
n
=
0
;
for
(
auto
&&
node
:
graph
.
node
())
{
result
[
get_name
(
node
)]
=
node
;
if
(
node
.
output
().
empty
())
{
if
(
node
.
name
().
empty
())
{
result
[
"migraphx_unamed_node_"
+
std
::
to_string
(
n
)]
=
node
;
n
++
;
}
else
{
result
[
node
.
name
()]
=
node
;
}
}
for
(
auto
&&
output
:
node
.
output
())
{
result
[
output
]
=
node
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment