Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8bc67132
Commit
8bc67132
authored
Jan 20, 2023
by
Paul
Browse files
Add fuse_reduce pass
parent
a5c87ec5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
203 additions
and
1 deletion
+203
-1
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/fuse_reduce.cpp
src/fuse_reduce.cpp
+158
-0
src/include/migraphx/fuse_reduce.hpp
src/include/migraphx/fuse_reduce.hpp
+43
-0
src/include/migraphx/op/reduce_op.hpp
src/include/migraphx/op/reduce_op.hpp
+1
-1
No files found.
src/CMakeLists.txt
View file @
8bc67132
...
@@ -50,6 +50,7 @@ add_library(migraphx
...
@@ -50,6 +50,7 @@ add_library(migraphx
env.cpp
env.cpp
file_buffer.cpp
file_buffer.cpp
fuse_pointwise.cpp
fuse_pointwise.cpp
fuse_reduce.cpp
generate.cpp
generate.cpp
inline_module.cpp
inline_module.cpp
insert_pad.cpp
insert_pad.cpp
...
...
src/fuse_reduce.cpp
0 → 100644
View file @
8bc67132
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/matcher.hpp>
#include <iterator>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
fused_reduce
{
std
::
vector
<
std
::
int64_t
>
axes
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
,
std
::
vector
<
module_ref
>
mods
)
const
{
if
(
mods
.
size
()
!=
1
)
{
MIGRAPHX_THROW
(
"should have one submodule."
);
}
auto
*
sm
=
mods
.
front
();
check_shapes
{
inputs
,
*
this
}.
has
(
sm
->
get_parameter_shapes
().
size
()).
same_dims
();
auto
s
=
inputs
.
at
(
0
);
auto
lens
=
s
.
lens
();
for
(
const
auto
&
axis
:
axes
)
{
lens
[
axis
]
=
1
;
}
if
(
sm
->
get_output_shapes
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"Only one output supported"
);
return
inputs
[
0
].
with_lens
(
sm
->
get_output_shapes
().
front
().
type
(),
lens
);
}
std
::
string
name
()
const
{
return
"fused_reduce"
;
}
};
static
void
create_reduce_modules
(
module_pass_manager
&
mpm
)
{
std
::
size_t
n
=
0
;
for
(
auto
ins
:
iterator_for
(
mpm
.
get_module
()))
{
if
(
not
ins
->
get_operator
().
attributes
().
get
(
"reduce"
,
false
))
continue
;
if
(
ins
->
inputs
().
size
()
!=
1
)
continue
;
auto
*
rm
=
mpm
.
create_module
(
mpm
.
get_module
().
name
()
+
":"
+
ins
->
name
()
+
std
::
to_string
(
n
++
));
rm
->
set_bypass
();
// TODO: Ensure standard shape
auto
x0
=
rm
->
add_parameter
(
"x0"
,
ins
->
inputs
().
front
()
->
get_shape
());
auto
r
=
rm
->
add_instruction
(
ins
->
get_operator
(),
x0
);
rm
->
add_return
({
r
});
// TODO: Set axes
mpm
.
get_module
().
replace_instruction
(
ins
,
make_op
(
"fused_reduce"
),
ins
->
inputs
(),
{
rm
});
}
}
static
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
get_param_map
(
const
std
::
vector
<
instruction_ref
>&
inputs
,
const_module_ref
sm
)
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
result
;
auto
names
=
sm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
assert
(
names
.
size
()
==
inputs
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
inputs
.
begin
(),
std
::
inserter
(
result
,
result
.
end
()),
[
&
](
const
auto
&
name
,
auto
input
)
{
return
std
::
make_pair
(
input
,
sm
->
get_parameter
(
name
));
});
return
result
;
}
static
std
::
vector
<
instruction_ref
>
get_returns
(
module
&
m
)
{
auto
last
=
std
::
prev
(
m
.
end
());
if
(
last
->
name
()
==
"@return"
)
return
last
->
inputs
();
return
{
last
};
}
struct
find_reduce_pointwise
{
auto
matcher
()
const
{
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"fused_reduce"
)(
match
::
used_once
()).
bind
(
"reduce"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
reduce
=
r
.
instructions
[
"reduce"
];
auto
*
old_rm
=
reduce
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
old_rm
->
name
()
+
":pointwise"
);
// Copy module
*
rm
=
*
old_rm
;
auto
map_ins
=
get_param_map
(
reduce
->
inputs
(),
rm
);
auto
new_inputs
=
reduce
->
inputs
();
for
(
auto
input
:
ins
->
inputs
())
{
if
(
contains
(
map_ins
,
input
))
continue
;
if
(
input
==
reduce
)
{
map_ins
[
input
]
=
rm
->
}
map_ins
[
input
]
=
rm
->
add_parameter
(
"x"
+
std
::
to_string
(
new_inputs
.
size
()),
input
->
get_shape
());
new_inputs
.
push_back
(
input
);
}
auto
out
=
rm
->
insert_instructions
(
std
::
prev
(
rm
->
end
()),
{
ins
},
map_ins
);
rm
->
replace_return
(
out
);
mpm
.
get_module
().
replace_instruction
(
ins
,
reduce
->
get_operator
(),
new_inputs
,
{
rm
});
}
};
void
fuse_reduce
::
apply
(
module_pass_manager
&
mpm
)
const
{
create_reduce_modules
(
mpm
);
mpm
.
run_pass
(
dead_code_elimination
{});
match
::
find_matches
(
mpm
,
find_reduce_pointwise
{});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/include/migraphx/fuse_reduce.hpp
0 → 100644
View file @
8bc67132
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_REDUCE_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_REDUCE_HPP
#include <migraphx/config.hpp>
#include <string>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module_pass_manager
;
struct
fuse_reduce
{
std
::
string
name
()
const
{
return
"fuse_reduce"
;
}
void
apply
(
module_pass_manager
&
mpm
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
src/include/migraphx/op/reduce_op.hpp
View file @
8bc67132
...
@@ -91,7 +91,7 @@ struct reduce_op : op_name<Derived>
...
@@ -91,7 +91,7 @@ struct reduce_op : op_name<Derived>
{
{
value
normalize
;
value
normalize
;
normalize
[
"axes"
]
=
value
::
array
{
normalize_attribute
::
include_min
};
normalize
[
"axes"
]
=
value
::
array
{
normalize_attribute
::
include_min
};
return
{{
"normalize_axes"
,
normalize
}};
return
{{
"normalize_axes"
,
normalize
}
,
{
"reduce"
,
true
}
};
}
}
std
::
vector
<
int64_t
>
tune_axes
(
std
::
size_t
n_dim
)
const
std
::
vector
<
int64_t
>
tune_axes
(
std
::
size_t
n_dim
)
const
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment