Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ea166055
Commit
ea166055
authored
Jan 23, 2019
by
Paul
Browse files
Add initial support for multi output
parent
ecbb4de5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
12 deletions
+25
-12
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+25
-12
No files found.
src/onnx/onnx.cpp
View file @
ea166055
...
...
@@ -24,7 +24,7 @@ struct onnx_parser
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
using
op_func
=
std
::
function
<
std
::
vector
<
instruction_ref
>
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
program
prog
=
program
();
...
...
@@ -88,6 +88,15 @@ struct onnx_parser
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
vector
<
instruction_ref
>
{
f
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...)};
});
}
// Multi output op
template
<
class
F
>
void
add_multi_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
...
...
@@ -95,7 +104,7 @@ struct onnx_parser
template
<
class
F
>
void
add_mem_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
[
=
](
auto
&&
...
xs
)
{
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
}
...
...
@@ -103,7 +112,7 @@ struct onnx_parser
template
<
class
T
>
void
add_binary_op
(
std
::
string
name
,
T
x
)
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
if
(
contains
(
attributes
,
"broadcast"
))
...
...
@@ -174,7 +183,7 @@ struct onnx_parser
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
});
}
...
...
@@ -182,7 +191,7 @@ struct onnx_parser
template
<
class
T
>
void
add_variadic_op
(
std
::
string
name
,
T
x
)
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
std
::
accumulate
(
std
::
next
(
args
.
begin
()),
args
.
end
(),
args
.
front
(),
...
...
@@ -645,7 +654,7 @@ struct onnx_parser
}
else
{
throw
std
::
runtime_error
(
"Failed reading"
);
MIGRAPHX_THROW
(
"Failed reading
onnx file.
"
);
}
}
...
...
@@ -691,24 +700,28 @@ struct onnx_parser
{
if
(
nodes
.
count
(
input
)
>
0
)
{
auto
&&
iname
=
get_name
(
nodes
.
at
(
input
));
assert
(
name
!=
in
ame
);
this
->
parse_node
(
in
ame
);
args
.
push_back
(
instructions
.
at
(
in
ame
));
//
auto&& iname = get_name(nodes.at(input));
assert
(
name
!=
in
put
);
this
->
parse_node
(
in
put
);
args
.
push_back
(
instructions
.
at
(
in
put
));
}
else
{
args
.
push_back
(
instructions
.
at
(
input
));
}
}
std
::
vector
<
instruction_ref
>
result
;
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
{
instructions
[
name
]
=
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
result
.
push_back
(
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
)
)
;
}
else
{
instructions
[
name
]
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
result
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
}
std
::
transform
(
node
.
output
().
begin
(),
node
.
output
().
end
(),
result
.
begin
(),
std
::
inserter
(
instructions
,
instructions
.
end
()),
[](
auto
&&
onnx_out
,
auto
&&
parse_out
)
{
return
std
::
make_pair
(
onnx_out
,
parse_out
);
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment