Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
632f819e
Commit
632f819e
authored
Aug 12, 2021
by
Paul
Browse files
Preserve outputs
parent
8004868b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
0 deletions
+33
-0
src/layout_nhwc.cpp
src/layout_nhwc.cpp
+33
-0
No files found.
src/layout_nhwc.cpp
View file @
632f819e
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
#include <migraphx/module.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_contiguous.hpp>
...
@@ -10,6 +12,36 @@
...
@@ -10,6 +12,36 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Predicate
>
std
::
vector
<
instruction_ref
>
find_lasts
(
const
module
&
m
,
Predicate
pred
)
{
std
::
vector
<
instruction_ref
>
result
;
fix
([
&
](
auto
self
,
auto
ins
)
{
if
(
pred
(
ins
))
{
result
.
push_back
(
ins
);
return
;
}
for
(
auto
input
:
ins
->
inputs
())
self
(
input
);
})(
std
::
prev
(
m
.
end
()));
return
result
;
}
void
preserve_output_layout
(
module
&
m
)
{
std
::
vector
<
instruction_ref
>
outputs
=
find_lasts
(
m
,
[](
auto
ins
)
{
return
ins
->
get_shape
().
lens
().
size
()
==
4
;
});
for
(
auto
output
:
outputs
)
{
auto
permutation
=
find_permutation
(
output
->
get_shape
());
auto
layout
=
m
.
insert_instruction
(
std
::
next
(
output
),
make_op
(
"layout"
,
{{
"permutation"
,
permutation
}}),
output
);
m
.
replace_instruction
(
output
,
layout
);
}
}
void
transform_convolutions
(
module
&
m
)
void
transform_convolutions
(
module
&
m
)
{
{
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
...
@@ -30,6 +62,7 @@ void transform_convolutions(module& m)
...
@@ -30,6 +62,7 @@ void transform_convolutions(module& m)
void
layout_nhwc
::
apply
(
module
&
m
)
const
void
layout_nhwc
::
apply
(
module
&
m
)
const
{
{
preserve_output_layout
(
m
);
transform_convolutions
(
m
);
transform_convolutions
(
m
);
dead_code_elimination
{}.
apply
(
m
);
dead_code_elimination
{}.
apply
(
m
);
eliminate_contiguous
{
"contiguous"
}.
apply
(
m
);
eliminate_contiguous
{
"contiguous"
}.
apply
(
m
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment