Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
98b8dff1
Commit
98b8dff1
authored
Oct 17, 2022
by
Khalique Ahmed
Browse files
workaround using two layout_nhwc passes
parent
7393cf1e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
55 additions
and
18 deletions
+55
-18
src/include/migraphx/layout_nhwc.hpp
src/include/migraphx/layout_nhwc.hpp
+1
-0
src/layout_nhwc.cpp
src/layout_nhwc.cpp
+26
-7
src/targets/gpu/convolution.cpp
src/targets/gpu/convolution.cpp
+27
-11
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-0
No files found.
src/include/migraphx/layout_nhwc.hpp
View file @
98b8dff1
...
@@ -15,6 +15,7 @@ struct module_pass_manager;
...
@@ -15,6 +15,7 @@ struct module_pass_manager;
*/
*/
struct
layout_nhwc
struct
layout_nhwc
{
{
bool
skip_elim_contiguous
=
false
;
std
::
string
name
()
const
{
return
"layout_nhwc"
;
}
std
::
string
name
()
const
{
return
"layout_nhwc"
;
}
void
apply
(
module_pass_manager
&
m
)
const
;
void
apply
(
module_pass_manager
&
m
)
const
;
};
};
...
...
src/layout_nhwc.cpp
View file @
98b8dff1
...
@@ -44,7 +44,7 @@ std::unordered_set<instruction_ref> preserve_output_layout(module& m)
...
@@ -44,7 +44,7 @@ std::unordered_set<instruction_ref> preserve_output_layout(module& m)
return
result
;
return
result
;
}
}
void
transform_convolutions
(
module
&
m
)
void
transform_convolutions
(
module
&
m
,
bool
skip_elim_contiguous
)
{
{
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
...
@@ -56,11 +56,29 @@ void transform_convolutions(module& m)
...
@@ -56,11 +56,29 @@ void transform_convolutions(module& m)
if
(
v
.
at
(
"group"
).
to
<
int
>
()
>
1
)
if
(
v
.
at
(
"group"
).
to
<
int
>
()
>
1
)
continue
;
continue
;
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
std
::
transform
(
args
.
begin
(),
args
.
end
(),
args
.
begin
(),
[
&
](
auto
&
i
)
{
if
(
skip_elim_contiguous
)
return
m
.
insert_instruction
(
ins
,
make_op
(
"layout"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
i
);
{
});
// std::cout << "HERE" << std::endl;
for
(
auto
i
=
0
;
i
<
args
.
size
();
i
++
)
{
// std::cout << args[i]->name() << std::endl;
if
(
args
[
i
]
->
name
()
!=
"layout"
and
args
[
i
]
->
get_shape
().
standard
())
{
// std::cout << "HERE2" << std::endl;
args
[
i
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"layout"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
args
[
i
]);
// m.debug_print(args);
}
}
}
else
std
::
transform
(
args
.
begin
(),
args
.
end
(),
args
.
begin
(),
[
&
](
auto
&
i
)
{
return
m
.
insert_instruction
(
ins
,
make_op
(
"layout"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
i
);
});
auto
conv
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
args
);
auto
conv
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
args
);
auto
c
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
conv
);
auto
c
=
conv
;
if
(
not
skip_elim_contiguous
)
c
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
conv
);
m
.
replace_instruction
(
ins
,
c
);
m
.
replace_instruction
(
ins
,
c
);
}
}
}
}
...
@@ -82,9 +100,10 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_
...
@@ -82,9 +100,10 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_
void
layout_nhwc
::
apply
(
module_pass_manager
&
mpm
)
const
void
layout_nhwc
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
std
::
unordered_set
<
instruction_ref
>
output_layouts
=
preserve_output_layout
(
mpm
.
get_module
());
std
::
unordered_set
<
instruction_ref
>
output_layouts
=
preserve_output_layout
(
mpm
.
get_module
());
transform_convolutions
(
mpm
.
get_module
());
transform_convolutions
(
mpm
.
get_module
()
,
this
->
skip_elim_contiguous
);
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
eliminate_contiguous
{
"contiguous"
});
if
(
not
this
->
skip_elim_contiguous
)
mpm
.
run_pass
(
eliminate_contiguous
{
"contiguous"
});
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
remove_layout
(
mpm
.
get_module
(),
output_layouts
);
remove_layout
(
mpm
.
get_module
(),
output_layouts
);
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
...
...
src/targets/gpu/convolution.cpp
View file @
98b8dff1
...
@@ -63,17 +63,33 @@ argument miopen_convolution::compute(context& ctx,
...
@@ -63,17 +63,33 @@ argument miopen_convolution::compute(context& ctx,
if
(
solution_id
==
0
)
if
(
solution_id
==
0
)
MIGRAPHX_THROW
(
"MIOpen Convolution: invalid solution ID"
);
MIGRAPHX_THROW
(
"MIOpen Convolution: invalid solution ID"
);
auto
status
=
miopenConvolutionForwardImmediate
(
ctx
.
get_stream
().
get_miopen
(),
// auto status = miopenConvolutionForwardImmediate(ctx.get_stream().get_miopen(),
w_desc
.
get
(),
// w_desc.get(),
args
[
1
].
implicit
(),
// args[1].implicit(),
x_desc
.
get
(),
// x_desc.get(),
args
[
0
].
implicit
(),
// args[0].implicit(),
cd
.
get
(),
// cd.get(),
y_desc
.
get
(),
// y_desc.get(),
args
[
3
].
implicit
(),
// args[3].implicit(),
args
[
2
].
implicit
(),
// args[2].implicit(),
args
[
2
].
get_shape
().
bytes
(),
// args[2].get_shape().bytes(),
solution_id
);
// solution_id);
float
alpha
=
1
;
float
beta
=
0
;
auto
status
=
miopenConvolutionForward
(
ctx
.
get_stream
().
get_miopen
(),
&
alpha
,
x_desc
.
get
(),
args
[
0
].
implicit
(),
w_desc
.
get
(),
args
[
1
].
implicit
(),
cd
.
get
(),
algo
,
&
beta
,
y_desc
.
get
(),
args
[
3
].
implicit
(),
args
[
2
].
implicit
(),
args
[
2
].
get_shape
().
bytes
());
if
(
status
!=
miopenStatusSuccess
)
if
(
status
!=
miopenStatusSuccess
)
MIGRAPHX_THROW
(
"MIOpen Convolution: running convolution failed"
);
MIGRAPHX_THROW
(
"MIOpen Convolution: running convolution failed"
);
...
...
src/targets/gpu/target.cpp
View file @
98b8dff1
...
@@ -132,6 +132,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -132,6 +132,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops
{},
prefuse_ops
{},
dead_code_elimination
{},
dead_code_elimination
{},
auto_contiguous
{},
auto_contiguous
{},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{
true
}),
simplify_reshapes
{},
simplify_reshapes
{},
propagate_constant
{},
propagate_constant
{},
dead_code_elimination
{},
dead_code_elimination
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment