Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
98b8dff1
"docs/vscode:/vscode.git/clone" did not exist on "a03570a0200f4356079fdf23beae2c717810accc"
Commit
98b8dff1
authored
Oct 17, 2022
by
Khalique Ahmed
Browse files
workaround using two layout_nhwc passes
parent
7393cf1e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
55 additions
and
18 deletions
+55
-18
src/include/migraphx/layout_nhwc.hpp
src/include/migraphx/layout_nhwc.hpp
+1
-0
src/layout_nhwc.cpp
src/layout_nhwc.cpp
+26
-7
src/targets/gpu/convolution.cpp
src/targets/gpu/convolution.cpp
+27
-11
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-0
No files found.
src/include/migraphx/layout_nhwc.hpp
View file @
98b8dff1
...
@@ -15,6 +15,7 @@ struct module_pass_manager;
...
@@ -15,6 +15,7 @@ struct module_pass_manager;
*/
*/
struct
layout_nhwc
struct
layout_nhwc
{
{
bool
skip_elim_contiguous
=
false
;
std
::
string
name
()
const
{
return
"layout_nhwc"
;
}
std
::
string
name
()
const
{
return
"layout_nhwc"
;
}
void
apply
(
module_pass_manager
&
m
)
const
;
void
apply
(
module_pass_manager
&
m
)
const
;
};
};
...
...
src/layout_nhwc.cpp
View file @
98b8dff1
...
@@ -44,7 +44,7 @@ std::unordered_set<instruction_ref> preserve_output_layout(module& m)
...
@@ -44,7 +44,7 @@ std::unordered_set<instruction_ref> preserve_output_layout(module& m)
return
result
;
return
result
;
}
}
void
transform_convolutions
(
module
&
m
)
void
transform_convolutions
(
module
&
m
,
bool
skip_elim_contiguous
)
{
{
for
(
auto
ins
:
iterator_for
(
m
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
...
@@ -56,11 +56,29 @@ void transform_convolutions(module& m)
...
@@ -56,11 +56,29 @@ void transform_convolutions(module& m)
if
(
v
.
at
(
"group"
).
to
<
int
>
()
>
1
)
if
(
v
.
at
(
"group"
).
to
<
int
>
()
>
1
)
continue
;
continue
;
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
std
::
transform
(
args
.
begin
(),
args
.
end
(),
args
.
begin
(),
[
&
](
auto
&
i
)
{
if
(
skip_elim_contiguous
)
return
m
.
insert_instruction
(
ins
,
make_op
(
"layout"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
i
);
{
});
// std::cout << "HERE" << std::endl;
for
(
auto
i
=
0
;
i
<
args
.
size
();
i
++
)
{
// std::cout << args[i]->name() << std::endl;
if
(
args
[
i
]
->
name
()
!=
"layout"
and
args
[
i
]
->
get_shape
().
standard
())
{
// std::cout << "HERE2" << std::endl;
args
[
i
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"layout"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
args
[
i
]);
// m.debug_print(args);
}
}
}
else
std
::
transform
(
args
.
begin
(),
args
.
end
(),
args
.
begin
(),
[
&
](
auto
&
i
)
{
return
m
.
insert_instruction
(
ins
,
make_op
(
"layout"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
i
);
});
auto
conv
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
args
);
auto
conv
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
args
);
auto
c
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
conv
);
auto
c
=
conv
;
if
(
not
skip_elim_contiguous
)
c
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
conv
);
m
.
replace_instruction
(
ins
,
c
);
m
.
replace_instruction
(
ins
,
c
);
}
}
}
}
...
@@ -82,9 +100,10 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_
...
@@ -82,9 +100,10 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_
void
layout_nhwc
::
apply
(
module_pass_manager
&
mpm
)
const
void
layout_nhwc
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
std
::
unordered_set
<
instruction_ref
>
output_layouts
=
preserve_output_layout
(
mpm
.
get_module
());
std
::
unordered_set
<
instruction_ref
>
output_layouts
=
preserve_output_layout
(
mpm
.
get_module
());
transform_convolutions
(
mpm
.
get_module
());
transform_convolutions
(
mpm
.
get_module
()
,
this
->
skip_elim_contiguous
);
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
eliminate_contiguous
{
"contiguous"
});
if
(
not
this
->
skip_elim_contiguous
)
mpm
.
run_pass
(
eliminate_contiguous
{
"contiguous"
});
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
remove_layout
(
mpm
.
get_module
(),
output_layouts
);
remove_layout
(
mpm
.
get_module
(),
output_layouts
);
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
...
...
src/targets/gpu/convolution.cpp
View file @
98b8dff1
...
@@ -63,17 +63,33 @@ argument miopen_convolution::compute(context& ctx,
...
@@ -63,17 +63,33 @@ argument miopen_convolution::compute(context& ctx,
if
(
solution_id
==
0
)
if
(
solution_id
==
0
)
MIGRAPHX_THROW
(
"MIOpen Convolution: invalid solution ID"
);
MIGRAPHX_THROW
(
"MIOpen Convolution: invalid solution ID"
);
auto
status
=
miopenConvolutionForwardImmediate
(
ctx
.
get_stream
().
get_miopen
(),
// auto status = miopenConvolutionForwardImmediate(ctx.get_stream().get_miopen(),
w_desc
.
get
(),
// w_desc.get(),
args
[
1
].
implicit
(),
// args[1].implicit(),
x_desc
.
get
(),
// x_desc.get(),
args
[
0
].
implicit
(),
// args[0].implicit(),
cd
.
get
(),
// cd.get(),
y_desc
.
get
(),
// y_desc.get(),
args
[
3
].
implicit
(),
// args[3].implicit(),
args
[
2
].
implicit
(),
// args[2].implicit(),
args
[
2
].
get_shape
().
bytes
(),
// args[2].get_shape().bytes(),
solution_id
);
// solution_id);
float
alpha
=
1
;
float
beta
=
0
;
auto
status
=
miopenConvolutionForward
(
ctx
.
get_stream
().
get_miopen
(),
&
alpha
,
x_desc
.
get
(),
args
[
0
].
implicit
(),
w_desc
.
get
(),
args
[
1
].
implicit
(),
cd
.
get
(),
algo
,
&
beta
,
y_desc
.
get
(),
args
[
3
].
implicit
(),
args
[
2
].
implicit
(),
args
[
2
].
get_shape
().
bytes
());
if
(
status
!=
miopenStatusSuccess
)
if
(
status
!=
miopenStatusSuccess
)
MIGRAPHX_THROW
(
"MIOpen Convolution: running convolution failed"
);
MIGRAPHX_THROW
(
"MIOpen Convolution: running convolution failed"
);
...
...
src/targets/gpu/target.cpp
View file @
98b8dff1
...
@@ -132,6 +132,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -132,6 +132,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops
{},
prefuse_ops
{},
dead_code_elimination
{},
dead_code_elimination
{},
auto_contiguous
{},
auto_contiguous
{},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{
true
}),
simplify_reshapes
{},
simplify_reshapes
{},
propagate_constant
{},
propagate_constant
{},
dead_code_elimination
{},
dead_code_elimination
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment