Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7ed60279
Commit
7ed60279
authored
Mar 29, 2023
by
Khalique Ahmed
Browse files
testing
parent
30c49503
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
46 deletions
+49
-46
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/layout_nhwc.cpp
src/layout_nhwc.cpp
+45
-45
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+3
-1
No files found.
src/CMakeLists.txt
View file @
7ed60279
...
@@ -46,6 +46,7 @@ add_library(migraphx
...
@@ -46,6 +46,7 @@ add_library(migraphx
eliminate_contiguous.cpp
eliminate_contiguous.cpp
eliminate_data_type.cpp
eliminate_data_type.cpp
eliminate_identity.cpp
eliminate_identity.cpp
eliminate_layout.cpp
eliminate_pad.cpp
eliminate_pad.cpp
env.cpp
env.cpp
file_buffer.cpp
file_buffer.cpp
...
...
src/layout_nhwc.cpp
View file @
7ed60279
...
@@ -36,36 +36,36 @@
...
@@ -36,36 +36,36 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Predicate
>
//
template <class Predicate>
std
::
vector
<
instruction_ref
>
find_lasts
(
const
module
&
m
,
Predicate
pred
)
//
std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
{
//
{
std
::
vector
<
instruction_ref
>
result
;
//
std::vector<instruction_ref> result;
fix
([
&
](
auto
self
,
auto
ins
)
{
//
fix([&](auto self, auto ins) {
if
(
pred
(
ins
))
//
if(pred(ins))
{
//
{
result
.
push_back
(
ins
);
//
result.push_back(ins);
return
;
//
return;
}
//
}
for
(
auto
input
:
ins
->
inputs
())
//
for(auto input : ins->inputs())
self
(
input
);
//
self(input);
})(
std
::
prev
(
m
.
end
()));
//
})(std::prev(m.end()));
return
result
;
//
return result;
}
//
}
std
::
unordered_set
<
instruction_ref
>
preserve_output_layout
(
module
&
m
)
//
std::unordered_set<instruction_ref> preserve_output_layout(module& m)
{
//
{
std
::
unordered_set
<
instruction_ref
>
result
;
//
std::unordered_set<instruction_ref> result;
std
::
vector
<
instruction_ref
>
outputs
=
//
std::vector<instruction_ref> outputs =
find_lasts
(
m
,
[](
auto
ins
)
{
return
ins
->
get_shape
().
lens
().
size
()
==
4
;
});
//
find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; });
for
(
auto
output
:
outputs
)
//
for(auto output : outputs)
{
//
{
auto
permutation
=
find_permutation
(
output
->
get_shape
());
//
auto permutation = find_permutation(output->get_shape());
auto
layout
=
m
.
insert_instruction
(
//
auto layout = m.insert_instruction(
std
::
next
(
output
),
make_op
(
"layout"
,
{{
"permutation"
,
permutation
}}),
output
);
//
std::next(output), make_op("layout", {{"permutation", permutation}}), output);
result
.
insert
(
m
.
replace_instruction
(
output
,
layout
));
//
result.insert(m.replace_instruction(output, layout));
}
//
}
return
result
;
//
return result;
}
//
}
void
transform_convolutions
(
module
&
m
,
bool
skip_elim_contiguous
)
void
transform_convolutions
(
module
&
m
,
bool
skip_elim_contiguous
)
{
{
...
@@ -108,30 +108,30 @@ void transform_convolutions(module& m, bool skip_elim_contiguous)
...
@@ -108,30 +108,30 @@ void transform_convolutions(module& m, bool skip_elim_contiguous)
}
}
}
}
void
remove_layout
(
module
&
m
,
const
std
::
unordered_set
<
instruction_ref
>&
output_layouts
)
//
void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
{
//
{
for
(
auto
ins
:
iterator_for
(
m
))
//
for(auto ins : iterator_for(m))
{
//
{
if
(
ins
->
name
()
!=
"layout"
)
//
if(ins->name() != "layout")
continue
;
//
continue;
if
(
ins
->
get_shape
()
!=
ins
->
inputs
().
front
()
->
get_shape
())
//
if(ins->get_shape() != ins->inputs().front()->get_shape())
continue
;
//
continue;
if
(
contains
(
output_layouts
,
ins
))
//
if(contains(output_layouts, ins))
continue
;
//
continue;
m
.
replace_instruction
(
ins
,
ins
->
inputs
().
front
());
//
m.replace_instruction(ins, ins->inputs().front());
}
//
}
}
//
}
void
layout_nhwc
::
apply
(
module_pass_manager
&
mpm
)
const
void
layout_nhwc
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
std
::
unordered_set
<
instruction_ref
>
output_layouts
=
preserve_output_layout
(
mpm
.
get_module
());
//
std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
transform_convolutions
(
mpm
.
get_module
(),
this
->
skip_elim_contiguous
);
transform_convolutions
(
mpm
.
get_module
(),
this
->
skip_elim_contiguous
);
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
if
(
not
this
->
skip_elim_contiguous
)
if
(
not
this
->
skip_elim_contiguous
)
mpm
.
run_pass
(
eliminate_contiguous
{
"contiguous"
});
mpm
.
run_pass
(
eliminate_contiguous
{
"contiguous"
});
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
remove_layout
(
mpm
.
get_module
(),
output_layouts
);
//
remove_layout(mpm.get_module(), output_layouts);
mpm
.
run_pass
(
dead_code_elimination
{});
//
mpm.run_pass(dead_code_elimination{});
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/target.cpp
View file @
7ed60279
...
@@ -31,6 +31,7 @@
...
@@ -31,6 +31,7 @@
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_layout.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/inline_module.hpp>
...
@@ -125,7 +126,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -125,7 +126,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops
{},
prefuse_ops
{},
dead_code_elimination
{},
dead_code_elimination
{},
auto_contiguous
{},
auto_contiguous
{},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{
true
}),
//
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{true}),
simplify_reshapes
{},
simplify_reshapes
{},
propagate_constant
{},
propagate_constant
{},
dead_code_elimination
{},
dead_code_elimination
{},
...
@@ -134,6 +135,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -134,6 +135,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
fuse_mlir
{
&
ctx
},
fuse_mlir
{
&
ctx
},
dead_code_elimination
{},
dead_code_elimination
{},
lowering
{
&
ctx
,
options
.
offload_copy
},
lowering
{
&
ctx
,
options
.
offload_copy
},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
eliminate_layout
{}),
eliminate_contiguous
{
"gpu::contiguous"
},
eliminate_contiguous
{
"gpu::contiguous"
},
dead_code_elimination
{},
dead_code_elimination
{},
eliminate_concat
{
concat_gpu_optimization
{}},
eliminate_concat
{
concat_gpu_optimization
{}},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment