Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1110ef29
"vscode:/vscode.git/clone" did not exist on "e59e3058190cb7d7a590a3ed8cb6ca189f198799"
Commit
1110ef29
authored
Aug 10, 2018
by
Scott Thornton
Browse files
Added flatten to ONNX
parent
39151d27
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
0 deletions
+33
-0
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+21
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+12
-0
No files found.
src/include/migraph/operators.hpp
View file @
1110ef29
...
...
@@ -422,7 +422,28 @@ struct neg : unary
struct
flatten
{
uint64_t
axis
=
0
;
std
::
string
name
()
const
{
return
"flatten"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
}.
has
(
1
);
if
(
axis
==
0
)
{
return
{
inputs
.
at
(
0
).
type
(),
{
1
,
inputs
.
at
(
0
).
elements
()}};
}
if
(
axis
==
1
)
{
return
{
inputs
.
at
(
0
).
type
(),
{
inputs
.
at
(
0
).
elements
(),
1
}};
}
else
{
MIGRAPH_THROW
(
"axis can only be either 0 or 1"
);
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
output_shape
,
std
::
move
(
args
.
front
().
data
)};
}
};
struct
broadcast
...
...
src/onnx/onnx.cpp
View file @
1110ef29
...
...
@@ -62,6 +62,7 @@ struct onnx_parser
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_pooling
);
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
}
...
...
@@ -161,6 +162,17 @@ struct onnx_parser
return
prog
.
add_instruction
(
op
,
args
[
0
]);
}
instruction_ref
parse_flatten
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
uint64_t
axis
=
0
;
if
(
contains
(
attributes
,
"axis"
))
{
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
}
return
prog
.
add_instruction
(
flatten
{
axis
},
args
[
0
]);
}
instruction_ref
parse_constant
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment