Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c6078c1e
"vscode:/vscode.git/clone" did not exist on "204ef976cacee1b3452e8e9d38186933f601756e"
Commit
c6078c1e
authored
Jan 25, 2019
by
Shucai Xiao
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into gather_operator
parents
e344f80d
2d80965f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
27 deletions
+50
-27
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+50
-27
No files found.
src/onnx/onnx.cpp
View file @
c6078c1e
...
@@ -24,7 +24,8 @@ struct onnx_parser
...
@@ -24,7 +24,8 @@ struct onnx_parser
{
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
AttributeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
;
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
using
op_func
=
std
::
function
<
std
::
vector
<
instruction_ref
>
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
node_map
nodes
;
node_map
nodes
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
program
prog
=
program
();
program
prog
=
program
();
...
@@ -88,6 +89,15 @@ struct onnx_parser
...
@@ -88,6 +89,15 @@ struct onnx_parser
template
<
class
F
>
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
vector
<
instruction_ref
>
{
f
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...)};
});
}
// Multi output op
template
<
class
F
>
void
add_multi_op
(
std
::
string
name
,
F
f
)
{
{
ops
.
emplace
(
name
,
f
);
ops
.
emplace
(
name
,
f
);
}
}
...
@@ -95,7 +105,7 @@ struct onnx_parser
...
@@ -95,7 +105,7 @@ struct onnx_parser
template
<
class
F
>
template
<
class
F
>
void
add_mem_op
(
std
::
string
name
,
F
f
)
void
add_mem_op
(
std
::
string
name
,
F
f
)
{
{
ops
.
emplace
(
name
,
[
=
](
auto
&&
...
xs
)
{
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
});
}
}
...
@@ -103,7 +113,7 @@ struct onnx_parser
...
@@ -103,7 +113,7 @@ struct onnx_parser
template
<
class
T
>
template
<
class
T
>
void
add_binary_op
(
std
::
string
name
,
T
x
)
void
add_binary_op
(
std
::
string
name
,
T
x
)
{
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
if
(
args
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
if
(
contains
(
attributes
,
"broadcast"
)
and
contains
(
attributes
,
"axis"
))
if
(
contains
(
attributes
,
"broadcast"
)
and
contains
(
attributes
,
"axis"
))
...
@@ -172,7 +182,7 @@ struct onnx_parser
...
@@ -172,7 +182,7 @@ struct onnx_parser
template
<
class
T
>
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
void
add_generic_op
(
std
::
string
name
,
T
x
)
{
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
return
prog
.
add_instruction
(
x
,
args
);
});
});
}
}
...
@@ -180,7 +190,7 @@ struct onnx_parser
...
@@ -180,7 +190,7 @@ struct onnx_parser
template
<
class
T
>
template
<
class
T
>
void
add_variadic_op
(
std
::
string
name
,
T
x
)
void
add_variadic_op
(
std
::
string
name
,
T
x
)
{
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
std
::
accumulate
(
std
::
next
(
args
.
begin
()),
return
std
::
accumulate
(
std
::
next
(
args
.
begin
()),
args
.
end
(),
args
.
end
(),
args
.
front
(),
args
.
front
(),
...
@@ -643,7 +653,7 @@ struct onnx_parser
...
@@ -643,7 +653,7 @@ struct onnx_parser
}
}
else
else
{
{
throw
std
::
runtime_error
(
"Failed reading"
);
MIGRAPHX_THROW
(
"Failed reading
onnx file.
"
);
}
}
}
}
...
@@ -673,7 +683,7 @@ struct onnx_parser
...
@@ -673,7 +683,7 @@ struct onnx_parser
}
}
for
(
auto
&&
p
:
nodes
)
for
(
auto
&&
p
:
nodes
)
{
{
this
->
parse_node
(
get_name
(
p
.
second
)
);
this
->
parse_node
(
p
.
first
);
}
}
}
}
...
@@ -689,23 +699,37 @@ struct onnx_parser
...
@@ -689,23 +699,37 @@ struct onnx_parser
{
{
if
(
nodes
.
count
(
input
)
>
0
)
if
(
nodes
.
count
(
input
)
>
0
)
{
{
auto
&&
iname
=
get_name
(
nodes
.
at
(
input
));
assert
(
name
!=
input
);
assert
(
name
!=
iname
);
this
->
parse_node
(
input
);
this
->
parse_node
(
iname
);
args
.
push_back
(
instructions
.
at
(
input
));
args
.
push_back
(
instructions
.
at
(
iname
));
}
}
else
else
{
{
args
.
push_back
(
instructions
.
at
(
input
));
args
.
push_back
(
instructions
.
at
(
input
));
}
}
}
}
std
::
vector
<
instruction_ref
>
result
;
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
if
(
ops
.
count
(
node
.
op_type
())
==
0
)
{
{
instructions
[
name
]
=
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
);
result
.
push_back
(
prog
.
add_instruction
(
unknown
{
node
.
op_type
()},
args
));
}
else
{
result
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
}
// Even no output nodes produce output in migraphx
if
(
node
.
output
().
empty
()
and
result
.
size
()
==
1
)
{
instructions
[
name
]
=
result
.
front
();
}
}
else
else
{
{
instructions
[
name
]
=
ops
[
node
.
op_type
()](
get_attributes
(
node
),
args
);
assert
(
node
.
output
().
size
()
>=
result
.
size
());
std
::
transform
(
result
.
begin
(),
result
.
end
(),
node
.
output
().
begin
(),
std
::
inserter
(
instructions
,
instructions
.
end
()),
[](
auto
&&
x
,
auto
&&
y
)
{
return
std
::
make_pair
(
y
,
x
);
});
}
}
}
}
}
}
...
@@ -720,25 +744,24 @@ struct onnx_parser
...
@@ -720,25 +744,24 @@ struct onnx_parser
return
result
;
return
result
;
}
}
static
std
::
string
get_name
(
const
onnx
::
NodeProto
&
node
)
{
if
(
node
.
name
().
empty
())
{
std
::
string
generated
=
"migraphx_unnamed_node"
;
return
std
::
accumulate
(
node
.
output
().
begin
(),
node
.
output
().
end
(),
generated
,
[](
auto
x
,
auto
y
)
{
return
x
+
"_"
+
y
;
});
}
return
node
.
name
();
}
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
static
node_map
get_nodes
(
const
onnx
::
GraphProto
&
graph
)
{
{
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
std
::
unordered_map
<
std
::
string
,
onnx
::
NodeProto
>
result
;
std
::
size_t
n
=
0
;
for
(
auto
&&
node
:
graph
.
node
())
for
(
auto
&&
node
:
graph
.
node
())
{
{
result
[
get_name
(
node
)]
=
node
;
if
(
node
.
output
().
empty
())
{
if
(
node
.
name
().
empty
())
{
result
[
"migraphx_unamed_node_"
+
std
::
to_string
(
n
)]
=
node
;
n
++
;
}
else
{
result
[
node
.
name
()]
=
node
;
}
}
for
(
auto
&&
output
:
node
.
output
())
for
(
auto
&&
output
:
node
.
output
())
{
{
result
[
output
]
=
node
;
result
[
output
]
=
node
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment