Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9bbea3ad
Commit
9bbea3ad
authored
Nov 28, 2022
by
jungpark-mlir
Browse files
add mlir gemm-pointwise fusion
parent
4627ea59
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
2 deletions
+72
-2
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+70
-0
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+2
-2
No files found.
src/targets/gpu/fuse_mlir.cpp
View file @
9bbea3ad
...
@@ -136,6 +136,75 @@ struct find_conv_pointwise
...
@@ -136,6 +136,75 @@ struct find_conv_pointwise
ins
,
mlir_conv
{
conv_ins
->
get_operator
()},
inputs
,
{
mm
});
ins
,
mlir_conv
{
conv_ins
->
get_operator
()},
inputs
,
{
mm
});
}
}
};
};
MIGRAPHX_PRED_MATCHER
(
is_mlir_gemm
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
return
true
;
}
struct
find_gemm_pointwise
{
// Find a convolution followed by a pointwise operation.
auto
matcher
()
const
{
auto
gemm
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
is_mlir_gemm
().
bind
(
"dot"
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm
.
bind
(
"x"
)));
// turn match::any_of[match::inputs()](gemm.bind("x"));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm_ins
=
r
.
instructions
[
"dot"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
// input after contiguous
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
// Whitelist pointwise operators
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[](
const
auto
&
i
)
{
return
not
contains
({
"@literal"
,
"@param"
,
"@return"
,
"dot"
,
"add"
,
"relu"
},
i
.
name
());
}))
return
;
// Only fuse with fp32/fp16
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
({
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
},
i
->
get_shape
().
type
());
}))
return
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
pm
->
name
());
mm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
;
auto
x
=
mm
->
add_parameter
(
"x"
+
std
::
to_string
(
names
.
size
()),
gemm_ins
->
inputs
().
at
(
0
)
->
get_shape
());
auto
w
=
mm
->
add_parameter
(
"x"
+
std
::
to_string
(
names
.
size
()
+
1
),
gemm_ins
->
inputs
().
at
(
1
)
->
get_shape
());
auto
gemm
=
mm
->
add_instruction
(
gemm_ins
->
get_operator
(),
{
x
,
w
});
std
::
transform
(
names
.
begin
(),
names
.
end
(),
ins
->
inputs
().
begin
(),
std
::
inserter
(
param_map
,
param_map
.
end
()),
[
&
](
auto
name
,
auto
input
)
{
if
(
input
==
x_ins
)
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
gemm
);
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
mm
->
add_parameter
(
name
,
input
->
get_shape
()));
});
mm
->
add_return
(
mm
->
insert_instructions
(
mm
->
end
(),
pm
,
param_map
));
std
::
vector
<
instruction_ref
>
inputs
;
std
::
copy_if
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
gemm_ins
;
});
inputs
.
insert
(
inputs
.
end
(),
gemm_ins
->
inputs
().
begin
(),
gemm_ins
->
inputs
().
end
());
mpm
.
get_module
().
replace_instruction
(
ins
,
mlir_conv
{
gemm_ins
->
get_operator
()},
inputs
,
{
mm
});
}
};
}
// namespace
}
// namespace
#endif
#endif
...
@@ -144,6 +213,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
...
@@ -144,6 +213,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
{
{
#ifdef MIGRAPHX_MLIR
#ifdef MIGRAPHX_MLIR
match
::
find_matches
(
mpm
,
find_conv_pointwise
{});
match
::
find_matches
(
mpm
,
find_conv_pointwise
{});
match
::
find_matches
(
mpm
,
find_gemm_pointwise
{});
#else
#else
(
void
)
mpm
;
(
void
)
mpm
;
#endif
#endif
...
...
src/targets/gpu/mlir.cpp
View file @
9bbea3ad
...
@@ -455,7 +455,7 @@ struct mlir_program
...
@@ -455,7 +455,7 @@ struct mlir_program
auto
ops
=
create_operation_state
(
"func.func"
);
auto
ops
=
create_operation_state
(
"func.func"
);
ops
.
add_attributes
({{
"function_type"
,
make_function_type
(
inputs
,
outputs
)},
ops
.
add_attributes
({{
"function_type"
,
make_function_type
(
inputs
,
outputs
)},
{
"sym_name"
,
std
::
string
(
"main"
)},
{
"sym_name"
,
std
::
string
(
"
mlir_
main"
)},
{
"kernel"
,
std
::
string
(
"mixr"
)},
{
"kernel"
,
std
::
string
(
"mixr"
)},
{
"arch"
,
target_arch
}});
{
"arch"
,
target_arch
}});
ops
.
add_region
(
std
::
move
(
region
));
ops
.
add_region
(
std
::
move
(
region
));
...
@@ -550,7 +550,7 @@ struct mlir_program
...
@@ -550,7 +550,7 @@ struct mlir_program
mlirPassManagerRun
(
pm
.
get
(),
mmodule
.
get
());
mlirPassManagerRun
(
pm
.
get
(),
mmodule
.
get
());
code_object_op
op
{};
code_object_op
op
{};
op
.
symbol_name
=
"main"
;
op
.
symbol_name
=
"
mlir_
main"
;
op
.
code_object
=
get_binary
();
op
.
code_object
=
get_binary
();
std
::
tie
(
op
.
global
,
op
.
local
)
=
get_launch_params
();
std
::
tie
(
op
.
global
,
op
.
local
)
=
get_launch_params
();
return
op
;
return
op
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment