Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d1dc3e35
Commit
d1dc3e35
authored
Nov 14, 2023
by
Paul
Browse files
Add fuse_concat pass
parent
8ba66ebc
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
392 additions
and
13 deletions
+392
-13
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/fuse_concat.cpp
src/fuse_concat.cpp
+122
-0
src/include/migraphx/fuse_concat.hpp
src/include/migraphx/fuse_concat.hpp
+20
-0
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+2
-0
src/include/migraphx/pass_manager.hpp
src/include/migraphx/pass_manager.hpp
+1
-0
src/include/migraphx/program.hpp
src/include/migraphx/program.hpp
+1
-0
src/include/migraphx/stringutils.hpp
src/include/migraphx/stringutils.hpp
+12
-0
src/module.cpp
src/module.cpp
+5
-12
src/pass_manager.cpp
src/pass_manager.cpp
+6
-0
src/program.cpp
src/program.cpp
+7
-0
src/targets/gpu/jit/concat.cpp
src/targets/gpu/jit/concat.cpp
+105
-0
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
+2
-1
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+3
-0
test/fuse_concat.cpp
test/fuse_concat.cpp
+105
-0
No files found.
src/CMakeLists.txt
View file @
d1dc3e35
...
@@ -53,6 +53,7 @@ add_library(migraphx
...
@@ -53,6 +53,7 @@ add_library(migraphx
eliminate_pad.cpp
eliminate_pad.cpp
env.cpp
env.cpp
file_buffer.cpp
file_buffer.cpp
fuse_concat.cpp
fuse_pointwise.cpp
fuse_pointwise.cpp
fuse_reduce.cpp
fuse_reduce.cpp
generate.cpp
generate.cpp
...
...
src/fuse_concat.cpp
0 → 100644
View file @
d1dc3e35
#include <migraphx/fuse_concat.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
fused_concat
{
int64_t
axis
=
0
;
std
::
string
name
()
const
{
return
"fused_concat"
;
}
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
if
((
inputs
.
size
()
+
1
)
==
mods
.
size
())
MIGRAPHX_THROW
(
"FUSED_CONCAT: Missing fused modules"
);
auto
input_iter
=
inputs
.
begin
();
std
::
vector
<
shape
>
concat_inputs
;
for
(
module_ref
mod
:
range
(
mods
.
begin
(),
mods
.
end
()
-
1
))
{
concat_inputs
.
push_back
(
*
input_iter
);
input_iter
+=
mod
->
get_parameter_names
().
size
();
}
module_ref
post_mod
=
mods
.
back
();
auto
type
=
std
::
prev
(
post_mod
->
end
())
->
get_shape
().
type
();
const
auto
&
first_shape_lens
=
concat_inputs
.
front
().
lens
();
if
(
not
std
::
all_of
(
concat_inputs
.
begin
()
+
1
,
concat_inputs
.
end
(),
[
&
](
auto
s
)
{
const
auto
&
lens
=
s
.
lens
();
return
std
::
equal
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
first_shape_lens
.
begin
(),
first_shape_lens
.
begin
()
+
axis
)
and
std
::
equal
(
lens
.
begin
()
+
axis
+
1
,
lens
.
end
(),
first_shape_lens
.
begin
()
+
axis
+
1
,
first_shape_lens
.
end
());
}))
MIGRAPHX_THROW
(
"FUSED_CONCAT: all input dimensions should match along non-axis: "
+
std
::
to_string
(
axis
));
std
::
size_t
new_dim_axis
=
transform_accumulate
(
concat_inputs
.
begin
(),
concat_inputs
.
end
(),
0
,
std
::
plus
<>
{},
[
&
](
const
auto
&
input
)
{
return
input
.
lens
()[
axis
];
});
auto
new_lens
=
concat_inputs
.
front
().
lens
();
new_lens
[
axis
]
=
new_dim_axis
;
return
shape
::
from_permutation
(
type
,
new_lens
,
find_permutation
(
inputs
));
}
};
MIGRAPHX_REGISTER_OP
(
fused_concat
);
namespace
{
static
unsigned
int
counter
=
0
;
struct
find_pointwise_concat_pointwise
{
auto
matcher
()
const
{
auto
concat
=
match
::
name
(
"concat"
)(
match
::
used_once
(),
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"pointwise"
)(
match
::
used_once
())));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
concat
.
bind
(
"concat"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
concat_ins
=
r
.
instructions
[
"concat"
];
auto
concat_arg
=
std
::
find
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
concat_ins
)
-
ins
->
inputs
().
begin
();
std
::
vector
<
instruction_ref
>
inputs
;
for
(
auto
input
:
concat_ins
->
inputs
())
inputs
.
insert
(
inputs
.
end
(),
input
->
inputs
().
begin
(),
input
->
inputs
().
end
());
std
::
copy_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
concat_ins
;
});
std
::
vector
<
module_ref
>
module_inputs
;
std
::
transform
(
concat_ins
->
inputs
().
begin
(),
concat_ins
->
inputs
().
end
(),
std
::
back_inserter
(
module_inputs
),
[
&
](
instruction_ref
input
)
{
if
(
input
->
name
()
==
"pointwise"
)
{
auto
*
pm
=
input
->
module_inputs
().
front
();
return
mpm
.
create_module
(
"concat:"
+
pm
->
name
(),
*
pm
);
}
auto
*
pm
=
mpm
.
create_module
(
"concat"
+
std
::
to_string
(
counter
++
));
auto
x
=
pm
->
add_parameter
(
"x"
,
shape
{
input
->
get_shape
().
type
()});
auto
id
=
pm
->
add_instruction
(
make_op
(
"identity"
),
x
);
pm
->
add_return
({
id
});
return
pm
;
});
auto
*
post_pm
=
ins
->
module_inputs
().
front
();
auto
*
rm
=
mpm
.
create_module
(
post_pm
->
name
()
+
":concat"
,
*
post_pm
);
std
::
vector
<
std
::
string
>
names
=
rm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
auto
concat_param_name
=
names
[
concat_arg
];
auto
concat_param
=
rm
->
get_parameter
(
concat_param_name
);
auto
param
=
rm
->
add_parameter
(
"!"
+
concat_param_name
,
concat_param
->
get_shape
());
rm
->
replace_instruction
(
concat_param
,
param
);
rm
->
remove_instruction
(
concat_param
);
module_inputs
.
push_back
(
rm
);
mpm
.
get_module
().
replace_instruction
(
ins
,
make_op
(
"fused_concat"
,
concat_ins
->
normalized_operator
().
to_value
()),
inputs
,
module_inputs
);
}
};
}
void
fuse_concat
::
apply
(
module_pass_manager
&
mpm
)
const
{
match
::
find_matches
(
mpm
,
find_pointwise_concat_pointwise
{});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/include/migraphx/fuse_concat.hpp
0 → 100644
View file @
d1dc3e35
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_CONCAT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_CONCAT_HPP
#include <migraphx/config.hpp>
#include <string>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module_pass_manager
;
struct
MIGRAPHX_EXPORT
fuse_concat
{
std
::
string
name
()
const
{
return
"fuse_concat"
;
}
void
apply
(
module_pass_manager
&
mpm
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_CONCAT_HPP
src/include/migraphx/module.hpp
View file @
d1dc3e35
...
@@ -239,7 +239,9 @@ struct MIGRAPHX_EXPORT module
...
@@ -239,7 +239,9 @@ struct MIGRAPHX_EXPORT module
MIGRAPHX_EXPORT
friend
bool
operator
==
(
const
module
&
x
,
const
module
&
y
);
MIGRAPHX_EXPORT
friend
bool
operator
==
(
const
module
&
x
,
const
module
&
y
);
friend
bool
operator
!=
(
const
module
&
x
,
const
module
&
y
)
{
return
not
(
x
==
y
);
}
friend
bool
operator
!=
(
const
module
&
x
,
const
module
&
y
)
{
return
not
(
x
==
y
);
}
friend
struct
program
;
private:
private:
void
set_name
(
const
std
::
string
&
name
);
void
assign
(
const
module
&
m
);
void
assign
(
const
module
&
m
);
void
calc_implicit_deps
(
const
module
&
smod
,
void
calc_implicit_deps
(
const
module
&
smod
,
const
module
&
pmod
,
const
module
&
pmod
,
...
...
src/include/migraphx/pass_manager.hpp
View file @
d1dc3e35
...
@@ -39,6 +39,7 @@ struct module_pass_manager
...
@@ -39,6 +39,7 @@ struct module_pass_manager
module_pass_manager
(
const
module_pass_manager
&
)
=
delete
;
module_pass_manager
(
const
module_pass_manager
&
)
=
delete
;
virtual
module
&
get_module
()
=
0
;
virtual
module
&
get_module
()
=
0
;
virtual
module
*
create_module
(
const
std
::
string
&
name
)
=
0
;
virtual
module
*
create_module
(
const
std
::
string
&
name
)
=
0
;
virtual
module
*
create_module
(
const
std
::
string
&
name
,
const
module
&
m
)
=
0
;
virtual
module
*
get_common_parent
()
=
0
;
virtual
module
*
get_common_parent
()
=
0
;
virtual
module
*
get_root_module
()
=
0
;
virtual
module
*
get_root_module
()
=
0
;
virtual
void
run_pass
(
const
pass
&
p
)
=
0
;
virtual
void
run_pass
(
const
pass
&
p
)
=
0
;
...
...
src/include/migraphx/program.hpp
View file @
d1dc3e35
...
@@ -136,6 +136,7 @@ struct MIGRAPHX_EXPORT program
...
@@ -136,6 +136,7 @@ struct MIGRAPHX_EXPORT program
// module related api
// module related api
module
*
create_module
(
const
std
::
string
&
name
);
module
*
create_module
(
const
std
::
string
&
name
);
module
*
create_module
(
const
std
::
string
&
name
,
module
m
);
module
*
get_module
(
const
std
::
string
&
name
);
module
*
get_module
(
const
std
::
string
&
name
);
const
module
*
get_module
(
const
std
::
string
&
name
)
const
;
const
module
*
get_module
(
const
std
::
string
&
name
)
const
;
...
...
src/include/migraphx/stringutils.hpp
View file @
d1dc3e35
...
@@ -177,6 +177,18 @@ inline std::string interpolate_string(const std::string& input,
...
@@ -177,6 +177,18 @@ inline std::string interpolate_string(const std::string& input,
std
::
move
(
end
));
std
::
move
(
end
));
}
}
inline
std
::
string
to_c_id
(
const
std
::
string
&
name
,
char
rep
=
'_'
)
{
std
::
string
id
=
transform_string
(
name
,
[
&
](
auto
c
)
{
if
(
with_char
(
::
isalnum
)(
c
)
or
c
==
'_'
)
return
c
;
return
rep
;
});
while
(
id
.
find
(
"__"
)
!=
std
::
string
::
npos
)
replace_string_inplace
(
id
,
"__"
,
"_"
);
return
id
;
}
template
<
class
Iterator
>
template
<
class
Iterator
>
inline
std
::
string
to_string_range
(
Iterator
start
,
Iterator
last
,
const
char
*
delim
=
", "
)
inline
std
::
string
to_string_range
(
Iterator
start
,
Iterator
last
,
const
char
*
delim
=
", "
)
{
{
...
...
src/module.cpp
View file @
d1dc3e35
...
@@ -134,6 +134,11 @@ module& module::operator=(module m)
...
@@ -134,6 +134,11 @@ module& module::operator=(module m)
std
::
string
module
::
name
()
const
{
return
impl
->
name
;
}
std
::
string
module
::
name
()
const
{
return
impl
->
name
;
}
void
module
::
set_name
(
const
std
::
string
&
name
)
{
impl
->
name
=
name
;
}
bool
module
::
bypass
()
const
{
return
impl
->
bypass
;
}
bool
module
::
bypass
()
const
{
return
impl
->
bypass
;
}
void
module
::
set_bypass
(
bool
b
)
{
impl
->
bypass
=
b
;
}
void
module
::
set_bypass
(
bool
b
)
{
impl
->
bypass
=
b
;
}
...
@@ -784,18 +789,6 @@ void module::print_graph(std::ostream& os, bool brief) const
...
@@ -784,18 +789,6 @@ void module::print_graph(std::ostream& os, bool brief) const
os
<<
"}"
<<
std
::
endl
;
os
<<
"}"
<<
std
::
endl
;
}
}
static
std
::
string
to_c_id
(
const
std
::
string
&
name
,
char
rep
=
'_'
)
{
std
::
string
id
=
transform_string
(
name
,
[
&
](
auto
c
)
{
if
(
with_char
(
::
isalnum
)(
c
)
or
c
==
'_'
)
return
c
;
return
rep
;
});
while
(
contains
(
id
,
"__"
))
replace_string_inplace
(
id
,
"__"
,
"_"
);
return
id
;
}
static
std
::
string
cpp_var_name
(
const
std
::
string
&
name
)
static
std
::
string
cpp_var_name
(
const
std
::
string
&
name
)
{
{
std
::
string
prefix
=
"x_"
;
std
::
string
prefix
=
"x_"
;
...
...
src/pass_manager.cpp
View file @
d1dc3e35
...
@@ -99,6 +99,12 @@ struct module_pm : module_pass_manager
...
@@ -99,6 +99,12 @@ struct module_pm : module_pass_manager
return
prog
->
create_module
(
name
);
return
prog
->
create_module
(
name
);
}
}
virtual
module
*
create_module
(
const
std
::
string
&
name
,
const
module
&
m
)
override
{
assert
(
prog
);
return
prog
->
create_module
(
name
,
m
);
}
virtual
module
*
get_common_parent
()
override
{
return
common_parent
;
}
virtual
module
*
get_common_parent
()
override
{
return
common_parent
;
}
virtual
module
*
get_root_module
()
override
virtual
module
*
get_root_module
()
override
...
...
src/program.cpp
View file @
d1dc3e35
...
@@ -1064,6 +1064,13 @@ module* program::create_module(const std::string& name)
...
@@ -1064,6 +1064,13 @@ module* program::create_module(const std::string& name)
auto
r
=
impl
->
modules
.
emplace
(
name
,
name
);
auto
r
=
impl
->
modules
.
emplace
(
name
,
name
);
return
&
(
r
.
first
->
second
);
return
&
(
r
.
first
->
second
);
}
}
module
*
program
::
create_module
(
const
std
::
string
&
name
,
module
m
)
{
assert
(
not
contains
(
impl
->
modules
,
name
));
m
.
set_name
(
name
);
auto
r
=
impl
->
modules
.
emplace
(
name
,
std
::
move
(
m
));
return
&
(
r
.
first
->
second
);
}
module
*
program
::
get_module
(
const
std
::
string
&
name
)
{
return
&
impl
->
modules
.
at
(
name
);
}
module
*
program
::
get_module
(
const
std
::
string
&
name
)
{
return
&
impl
->
modules
.
at
(
name
);
}
...
...
src/targets/gpu/jit/concat.cpp
View file @
d1dc3e35
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/algorithm.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -60,6 +61,7 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
...
@@ -60,6 +61,7 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
)__migraphx__"
;
)__migraphx__"
;
struct
concat_compiler
:
compiler
<
concat_compiler
>
struct
concat_compiler
:
compiler
<
concat_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"concat"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"concat"
};
}
...
@@ -112,6 +114,109 @@ struct concat_compiler : compiler<concat_compiler>
...
@@ -112,6 +114,109 @@ struct concat_compiler : compiler<concat_compiler>
}
}
};
};
// NOLINTNEXTLINE
static
const
char
*
const
fused_concat_kernel
=
R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat2<${axis}>(${concat_args})(${post}, y, xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
fused_concat_compiler
:
compiler
<
fused_concat_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"fused_concat"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
params
=
"-Wno-float-equal"
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
auto
axis
=
find_fast_axis
(
options
.
inputs
);
auto
op_names
=
v
.
at
(
"ops"
).
to_vector
<
std
::
string
>
();
auto
args
=
v
.
at
(
"args"
);
vectorize
vec
{};
if
(
axis
!=
v
.
at
(
"axis"
).
to
<
std
::
size_t
>
())
vec
=
vectorize
::
elements
(
ctx
,
axis
,
options
.
inputs
);
auto
nelements_per_op
=
options
.
inputs
.
back
().
elements
()
/
op_names
.
size
();
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements_per_op
/
vec
.
size
,
256
));
std
::
vector
<
std
::
string
>
concat_params
;
std
::
vector
<
std
::
string
>
concat_args
;
for
(
const
auto
&
name
:
op_names
)
{
auto
n
=
args
.
at
(
name
).
to
<
std
::
size_t
>
();
auto
prefix
=
name
+
"_concat_x"
;
transform
(
range
(
n
),
std
::
back_inserter
(
concat_params
),
[
&
](
auto
i
)
{
return
"auto "
+
prefix
+
std
::
to_string
(
i
);
});
std
::
vector
<
std
::
string
>
pack_args
=
{
"MIGRAPHX_LIFT("
+
name
+
")"
};
transform
(
range
(
n
),
std
::
back_inserter
(
pack_args
),
[
&
](
auto
i
)
{
return
prefix
+
std
::
to_string
(
i
);
});
concat_args
.
push_back
(
"pack("
+
join_strings
(
pack_args
,
", "
)
+
")"
);
}
auto
src
=
interpolate_string
(
fused_concat_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"concat_params"
,
join_strings
(
concat_params
,
", "
)},
{
"concat_args"
,
join_strings
(
concat_args
,
", "
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"axis"
,
v
.
at
(
"axis"
).
to
<
std
::
string
>
()}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
auto
v
=
op
.
to_value
();
std
::
unordered_map
<
std
::
string
,
std
::
string
>
mod_names_lookup
;
transform
(
range
(
ins
->
module_inputs
().
size
()),
std
::
inserter
(
mod_names_lookup
,
mod_names_lookup
.
end
()),
[
&
](
auto
i
)
{
return
std
::
make_pair
(
ins
->
module_inputs
()[
i
]
->
name
(),
"pointwise"
+
std
::
to_string
(
i
));
});
v
[
"preamble"
]
=
transform_accumulate
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
(),
std
::
string
{},
std
::
plus
<>
{},
[
&
](
module_ref
mod
)
{
return
generate_pointwise
(
*
mod
,
mod_names_lookup
.
at
(
mod
->
name
()))
+
"
\n
"
;
});
std
::
vector
<
std
::
string
>
mod_names
;
std
::
transform
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
(),
std
::
back_inserter
(
mod_names
),
[
&
](
module_ref
mod
)
{
return
mod_names_lookup
.
at
(
mod
->
name
());
});
v
[
"ops"
]
=
mod_names
;
std
::
unordered_map
<
std
::
string
,
std
::
size_t
>
mod_args
;
std
::
transform
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
(),
std
::
inserter
(
mod_args
,
mod_args
.
end
()),
[
&
](
module_ref
mod
)
{
const
auto
&
name
=
mod_names_lookup
.
at
(
mod
->
name
());
return
std
::
make_pair
(
name
,
mod
->
get_parameter_names
().
size
());
});
v
[
"args"
]
=
mod_args
;
v
[
"kernel"
]
=
transform_accumulate
(
ins
->
module_inputs
().
begin
(),
ins
->
module_inputs
().
end
()
-
1
,
std
::
string
{},
std
::
plus
<>
{},
[
&
](
module_ref
mod
)
{
return
generate_name_from_ops
(
*
mod
)
+
"_"
;
})
+
"concat_"
+
generate_name_from_ops
(
*
(
ins
->
module_inputs
().
back
()))
+
"_kernel"
;
return
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
);
}
};
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/targets/gpu/kernels/include/migraphx/kernels/concat.hpp
View file @
d1dc3e35
...
@@ -89,7 +89,8 @@ __device__ auto concat2(InputPacks... input_packs)
...
@@ -89,7 +89,8 @@ __device__ auto concat2(InputPacks... input_packs)
});
});
});
});
})(
_c
<
0
>
,
input_packs
...);
})(
_c
<
0
>
,
input_packs
...);
}
};
}
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_CONCAT_HPP
#endif // MIGRAPHX_GUARD_KERNELS_CONCAT_HPP
src/targets/gpu/target.cpp
View file @
d1dc3e35
...
@@ -31,6 +31,7 @@
...
@@ -31,6 +31,7 @@
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/fuse_concat.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/inline_module.hpp>
...
@@ -140,6 +141,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -140,6 +141,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
dead_code_elimination
{},
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_REDUCE_FUSION
{}),
fuse_reduce
{}),
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_REDUCE_FUSION
{}),
fuse_reduce
{}),
dead_code_elimination
{},
dead_code_elimination
{},
fuse_concat
{},
dead_code_elimination
{},
#ifndef _WIN32
#ifndef _WIN32
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_CK
{}),
fuse_ck
{}),
enable_pass
(
enabled
(
MIGRAPHX_ENABLE_CK
{}),
fuse_ck
{}),
#endif
#endif
...
...
test/fuse_concat.cpp
0 → 100644
View file @
d1dc3e35
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fuse_concat.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/functional.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void
run_pass
(
migraphx
::
program
&
p
)
{
migraphx
::
run_passes
(
p
,
{
migraphx
::
fuse_concat
{},
migraphx
::
dead_code_elimination
{}});
}
template
<
class
F
>
struct
concat_arg
{
std
::
string
name
;
std
::
vector
<
migraphx
::
instruction_ref
>
inputs
;
F
f
;
};
template
<
class
F
>
concat_arg
<
F
>
arg
(
std
::
string
name
,
std
::
vector
<
migraphx
::
instruction_ref
>
inputs
,
F
f
)
{
return
{
std
::
move
(
name
),
std
::
move
(
inputs
),
std
::
move
(
f
)};
}
template
<
class
Arg
,
class
...
Args
>
migraphx
::
instruction_ref
add_concat
(
migraphx
::
program
&
p
,
std
::
size_t
axis
,
Arg
post_arg
,
Args
...
args
)
{
std
::
vector
<
migraphx
::
module_ref
>
module_inputs
;
std
::
vector
<
migraphx
::
instruction_ref
>
ins_inputs
;
migraphx
::
each_args
([
&
](
auto
arg
)
{
module_inputs
.
push_back
(
create_pointwise_module
(
p
,
arg
.
name
,
arg
.
inputs
,
arg
.
f
));
ins_inputs
.
insert
(
ins_inputs
.
end
(),
arg
.
inputs
.
begin
(),
arg
.
inputs
.
end
());
},
args
...);
module_inputs
.
push_back
(
create_pointwise_module
(
p
,
post_arg
.
name
,
{},
[
&
](
auto
*
pm
,
auto
&&
)
{
std
::
vector
<
migraphx
::
instruction_ref
>
params
;
params
.
push_back
(
pm
->
add_parameter
(
"!x0"
,
migraphx
::
shape
{
ins_inputs
.
back
()
->
get_shape
().
type
()}));
std
::
transform
(
post_arg
.
inputs
.
begin
(),
post_arg
.
inputs
.
end
(),
std
::
back_inserter
(
params
),
[
&
](
auto
input
)
{
return
pm
->
add_parameter
(
"x"
+
std
::
to_string
(
params
.
size
()),
migraphx
::
shape
{
input
->
get_shape
().
type
()});
});
return
post_arg
.
f
(
pm
,
params
);
}));
auto
*
mm
=
p
.
get_main_module
();
return
mm
->
add_instruction
(
migraphx
::
make_op
(
"fused_concat"
,
{{
"axis"
,
axis
}}),
ins_inputs
,
module_inputs
);
}
TEST_CASE
(
simple_pointwise_concat
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
program
p1
;
{
auto
*
mm
=
p1
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
add
=
add_pointwise
(
p1
,
"main:pointwise0"
,
{
x
,
y
},
single_pointwise
(
"add"
));
auto
sub
=
add_pointwise
(
p1
,
"main:pointwise1"
,
{
x
,
y
},
single_pointwise
(
"sub"
));
auto
concat
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
add
,
sub
);
auto
relu
=
add_pointwise
(
p1
,
"main:pointwise2"
,
{
concat
},
single_pointwise
(
"relu"
));
mm
->
add_return
({
relu
});
}
run_pass
(
p1
);
migraphx
::
program
p2
;
{
auto
*
mm
=
p2
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
fused_concat
=
add_concat
(
p2
,
1
,
arg
(
"main:pointwise2:concat"
,
{},
single_pointwise
(
"relu"
)),
arg
(
"concat:main:pointwise0"
,
{
x
,
y
},
single_pointwise
(
"add"
)),
arg
(
"concat:main:pointwise1"
,
{
x
,
y
},
single_pointwise
(
"sub"
)));
mm
->
add_return
({
fused_concat
});
}
EXPECT
(
p1
==
p2
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment