Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
db5e6340
Commit
db5e6340
authored
Aug 15, 2023
by
Khalique Ahmed
Browse files
file cleanup
parent
5315d9bb
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
84 additions
and
178 deletions
+84
-178
src/eliminate_layout.cpp
src/eliminate_layout.cpp
+41
-71
src/include/migraphx/layout_nhwc.hpp
src/include/migraphx/layout_nhwc.hpp
+1
-2
src/layout_nhwc.cpp
src/layout_nhwc.cpp
+36
-98
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+6
-5
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+0
-2
No files found.
src/eliminate_layout.cpp
View file @
db5e6340
...
...
@@ -40,75 +40,51 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Predicate
>
std
::
vector
<
instruction_ref
>
find_lasts
(
const
module
&
m
,
Predicate
pred
)
{
std
::
vector
<
instruction_ref
>
result
;
fix
([
&
](
auto
self
,
auto
ins
)
{
if
(
pred
(
ins
))
{
result
.
push_back
(
ins
);
return
;
}
for
(
auto
input
:
ins
->
inputs
())
self
(
input
);
})(
std
::
prev
(
m
.
end
()));
return
result
;
}
std
::
unordered_set
<
instruction_ref
>
preserve_output_layout
(
module
&
m
)
{
std
::
unordered_set
<
instruction_ref
>
result
;
std
::
vector
<
instruction_ref
>
outputs
=
find_lasts
(
m
,
[](
auto
ins
)
{
return
ins
->
get_shape
().
lens
().
size
()
==
4
;
});
for
(
auto
output
:
outputs
)
{
auto
permutation
=
find_permutation
(
output
->
get_shape
());
// template <class Predicate>
// std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
// {
// std::vector<instruction_ref> result;
// fix([&](auto self, auto ins) {
// if(pred(ins))
// {
// result.push_back(ins);
// return;
// }
// for(auto input : ins->inputs())
// self(input);
// })(std::prev(m.end()));
// return result;
// }
auto
layout_ins
=
m
.
insert_instruction
(
std
::
next
(
output
),
make_op
(
"layout"
,
{{
"permutation"
,
permutation
}}),
output
);
// std::unordered_set<instruction_ref> preserve_output_layout(module& m)
// {
// std::unordered_set<instruction_ref> result;
// std::vector<instruction_ref> outputs =
// find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; });
// for(auto output : outputs)
// {
// auto permutation = find_permutation(output->get_shape());
auto
output1
=
m
.
insert_instruction
(
layout_ins
,
make_op
(
"allocate"
,
{{
"shape"
,
to_value
(
layout_ins
->
get_shape
())}}));
std
::
vector
<
instruction_ref
>
refs
=
layout_ins
->
inputs
();
refs
.
push_back
(
output1
);
// auto layout_ins = m.insert_instruction(
// std::next(output), make_op("layout", {{"permutation", permutation}}), output);
auto
layout
=
m
.
replace_instruction
(
layout_ins
,
make_op
(
"gpu::precompile_op"
,
{{
"op"
,
to_value
(
layout_ins
->
get_operator
())}}),
refs
,
layout_ins
->
module_inputs
());
// auto output1 = m.insert_instruction(
// layout_ins, make_op("allocate", {{"shape", to_value(layout_ins->get_shape())}}));
// std::vector<instruction_ref> refs = layout_ins->inputs();
// refs.push_back(output1);
result
.
insert
(
layout
);
// m.debug_print(
layout
);
}
return
result
;
}
//
auto layout = m.replace_instruction(
//
layout
_ins,
//
make_op("gpu::precompile_op", {{"op", to_value(layout_ins->get_operator())}}),
//
refs,
// layout_ins->module_inputs());
void
remove_layout
(
module
&
m
)
{
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
"layout"
)
continue
;
auto
in_shape
=
ins
->
inputs
().
front
()
->
get_shape
();
if
(
in_shape
==
ins
->
get_shape
())
m
.
replace_instruction
(
ins
,
ins
->
inputs
().
front
());
}
}
// std::vector<instruction_ref> find_convs(const module& m)
// {
// std::vector<instruction_ref> convs;
// for(auto ins : iterator_for(m))
// {
// if(ins->name() == "gpu::miopen_op")
// convs.push_back(ins);
// result.insert(layout);
// }
// return
convs
;
// return
result
;
// }
void
remove_layout
(
module
&
m
,
const
std
::
unordered_set
<
instruction_ref
>&
output_layouts
)
void
remove_layout
(
module
&
m
)
{
for
(
auto
ins
:
iterator_for
(
m
))
{
...
...
@@ -120,18 +96,14 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_
if
(
val
[
"op"
].
at
(
"name"
).
to
<
std
::
string
>
()
!=
"layout"
)
{
// std::cout << val["op"].at("name").to<std::string>() << std::endl;
continue
;
}
// m.debug_print(ins);
if
(
ins
->
get_shape
()
!=
ins
->
inputs
().
front
()
->
get_shape
())
{
// std::cout << ins->get_shape() << " " << ins->inputs().front()->get_shape() <<
// std::endl;
continue
;
}
if
(
contains
(
output_layouts
,
ins
))
continue
;
//
if(contains(output_layouts, ins))
//
continue;
m
.
replace_instruction
(
ins
,
ins
->
inputs
().
front
());
}
...
...
@@ -139,10 +111,8 @@ void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_
void
eliminate_layout
::
apply
(
module_pass_manager
&
mpm
)
const
{
std
::
unordered_set
<
instruction_ref
>
output_layouts
=
preserve_output_layout
(
mpm
.
get_module
());
remove_layout
(
mpm
.
get_module
(),
output_layouts
);
// find_convs(mpm.get_module()));
// remove_layout(mpm.get_module());
// std::unordered_set<instruction_ref> output_layouts = preserve_output_layout(mpm.get_module());
remove_layout
(
mpm
.
get_module
());
mpm
.
run_pass
(
dead_code_elimination
{});
}
...
...
src/include/migraphx/layout_nhwc.hpp
100755 → 100644
View file @
db5e6340
...
...
@@ -38,9 +38,8 @@ struct module_pass_manager;
*/
struct
MIGRAPHX_EXPORT
layout_nhwc
{
bool
skip_elim_contiguous
=
false
;
std
::
string
name
()
const
{
return
"layout_nhwc"
;
}
void
apply
(
module_pass_manager
&
m
)
const
;
void
apply
(
module_pass_manager
&
mp
m
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/layout_nhwc.cpp
View file @
db5e6340
...
...
@@ -30,47 +30,43 @@
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/pass_manager.hpp>
#include <stdexcept>
#include <system_error>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
//
template <class Predicate>
//
std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
//
{
//
std::vector<instruction_ref> result;
//
fix([&](auto self, auto ins) {
//
if(pred(ins))
//
{
//
result.push_back(ins);
//
return;
//
}
//
for(auto input : ins->inputs())
//
self(input);
//
})(std::prev(m.end()));
//
return result;
//
}
template
<
class
Predicate
>
std
::
vector
<
instruction_ref
>
find_lasts
(
const
module
&
m
,
Predicate
pred
)
{
std
::
vector
<
instruction_ref
>
result
;
fix
([
&
](
auto
self
,
auto
ins
)
{
if
(
pred
(
ins
))
{
result
.
push_back
(
ins
);
return
;
}
for
(
auto
input
:
ins
->
inputs
())
self
(
input
);
})(
std
::
prev
(
m
.
end
()));
return
result
;
}
// std::unordered_set<instruction_ref> preserve_output_layout(module& m)
// {
// std::unordered_set<instruction_ref> result;
// std::vector<instruction_ref> outputs =
// find_lasts(m, [](auto ins) { return ins->get_shape().lens().size() == 4; });
// for(auto output : outputs)
// {
// auto permutation = find_permutation(output->get_shape());
// auto layout = m.insert_instruction(
// std::next(output), make_op("layout", {{"permutation", permutation}}), output);
// result.insert(m.replace_instruction(output, layout));
// }
// return result;
// }
void
preserve_output_layout
(
module
&
m
)
{
std
::
vector
<
instruction_ref
>
outputs
=
find_lasts
(
m
,
[](
auto
ins
)
{
return
ins
->
name
()
==
"convolution"
and
ins
->
get_shape
().
lens
().
size
()
==
4
;
});
for
(
auto
output
:
outputs
)
{
auto
permutation
=
find_permutation
(
output
->
get_shape
());
auto
layout
=
m
.
insert_instruction
(
std
::
next
(
output
),
make_op
(
"layout"
,
{{
"permutation"
,
permutation
}}),
output
);
m
.
replace_instruction
(
output
,
layout
);
}
}
void
transform_convolutions
(
module
&
m
,
bool
skip_elim_contiguous
)
void
transform_convolutions
(
module
&
m
)
{
for
(
auto
ins
:
iterator_for
(
m
))
{
...
...
@@ -82,79 +78,21 @@ void transform_convolutions(module& m, bool skip_elim_contiguous)
if
(
v
.
at
(
"group"
).
to
<
int
>
()
>
1
)
continue
;
auto
args
=
ins
->
inputs
();
if
(
skip_elim_contiguous
)
{
for
(
auto
i
=
0
;
i
<
args
.
size
();
i
++
)
{
if
(
args
[
i
]
->
name
()
!=
"layout"
and
args
[
i
]
->
get_shape
().
standard
())
{
args
[
i
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"layout"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
args
[
i
]);
}
}
}
else
std
::
transform
(
args
.
begin
(),
args
.
end
(),
args
.
begin
(),
[
&
](
auto
&
i
)
{
return
m
.
insert_instruction
(
ins
,
make_op
(
"layout"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
i
);
});
std
::
transform
(
args
.
begin
(),
args
.
end
(),
args
.
begin
(),
[
&
](
const
auto
&
i
)
{
return
m
.
insert_instruction
(
ins
,
make_op
(
"layout"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
i
);
});
auto
conv
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
args
);
// m.debug_print(conv);
// auto c = conv;
// auto nchw = m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 2, 3}}}),
// conv); m.debug_print(); if(not skip_elim_contiguous)
auto
c
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
conv
);
auto
c
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
conv
);
m
.
replace_instruction
(
ins
,
c
);
}
}
void
insert_contiguous
(
module
&
m
)
{
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
"reshape"
and
ins
->
name
()
!=
"pooling"
)
continue
;
auto
c
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
ins
->
inputs
().
front
());
auto
reshape
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
c
);
m
.
replace_instruction
(
ins
,
reshape
);
}
// m.debug_print();
}
// void remove_layout(module& m, const std::unordered_set<instruction_ref>& output_layouts)
// {
// for(auto ins : iterator_for(m))
// {
// if(ins->name() != "layout")
// continue;
// if(ins->get_shape() != ins->inputs().front()->get_shape())
// continue;
// if(contains(output_layouts, ins))
// continue;
// m.replace_instruction(ins, ins->inputs().front());
// }
// }
void
layout_nhwc
::
apply
(
module_pass_manager
&
mpm
)
const
{
// std::unordered_set<instruction_ref> output_layouts =
// preserve_output_layout(mpm.get_module());
// insert_contiguous(mpm.get_module());
mpm
.
run_pass
(
dead_code_elimination
{});
// mpm.get_module().debug_print();
transform_convolutions
(
mpm
.
get_module
(),
this
->
skip_elim_contiguous
);
module
&
m
=
mpm
.
get_module
();
// preserve_output_layout(m);
transform_convolutions
(
m
);
mpm
.
run_pass
(
dead_code_elimination
{});
// std::cout << "after layout" << std::endl;
// mpm.get_module().debug_print();
// if(not this->skip_elim_contiguous)
// mpm.run_pass(eliminate_contiguous{"contiguous"});
// mpm.run_pass(dead_code_elimination{});
// mpm.run_pass(auto_contiguous{});
// mpm.run_pass(dead_code_elimination{});
// remove_layout(mpm.get_module(), output_layouts);
// mpm.run_pass(dead_code_elimination{});
}
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/fuse_ops.cpp
View file @
db5e6340
...
...
@@ -782,6 +782,8 @@ struct find_contiguous_pointwise
auto
args
=
pw
->
inputs
();
args
.
back
()
=
alloc
;
if
(
ins
->
get_shape
()
!=
pw
->
get_shape
())
return
;
m
.
replace_instruction
(
ins
,
pw
->
get_operator
(),
args
,
pw
->
module_inputs
());
}
};
...
...
@@ -835,24 +837,23 @@ struct find_concat_pointwise
auto
op
=
concat
->
get_operator
();
op
.
from_value
({{
"additional_args"
,
ins
->
inputs
().
size
()
-
1
},
{
"ignore_modules"
,
true
}});
m
.
replace_instruction
(
ins
,
op
,
inputs
,
{
pm
});
}
};
void
fuse_ops
::
apply
(
module
&
m
)
const
{
//
match::find_matches(m, find_contiguous_pointwise{});
match
::
find_matches
(
m
,
find_contiguous_pointwise
{});
run_passes
(
m
,
{
dead_code_elimination
{}});
match
::
find_matches
(
m
,
find_conv_pointwise
{
ctx
},
find_conv_bias_relu
{
ctx
},
find_conv_bias
{
ctx
});
run_passes
(
m
,
{
dead_code_elimination
{}});
match
::
find_matches
(
m
,
find_layernorm_pointwise
{},
find_concat_pointwise
{},
//
find_concat_pointwise{},
find_gemm_pointwise
{},
find_contiguous_tranpose_gemm
{},
find_commutative_broadcast
{});
//
match::find_matches(m, find_contiguous{});
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
}
}
// namespace gpu
...
...
src/targets/gpu/target.cpp
View file @
db5e6340
...
...
@@ -131,7 +131,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
optimize_module
{},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
layout_nhwc
{}),
dead_code_elimination
{},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
eliminate_layout
{}),
prefuse_ops
{},
dead_code_elimination
{},
optimize_module
{},
...
...
@@ -150,7 +149,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
eliminate_contiguous
{
"gpu::contiguous"
},
dead_code_elimination
{},
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_NHWC
{}),
eliminate_layout
{}),
// dead_code_elimination{},
eliminate_concat
{
concat_gpu_optimization
{}},
dead_code_elimination
{},
compile_miopen
{
&
gctx
},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment