Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b717b473
Commit
b717b473
authored
Oct 13, 2023
by
Manupa Karunaratne
Browse files
* standalone attention sane offloads to mlir but mlir is broken
parent
25d6b2e2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
4 deletions
+17
-4
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+17
-4
No files found.
src/targets/gpu/fuse_mlir.cpp
View file @
b717b473
...
@@ -81,6 +81,7 @@ struct mlir_op
...
@@ -81,6 +81,7 @@ struct mlir_op
MIGRAPHX_THROW
(
"should have at least two inputs."
);
MIGRAPHX_THROW
(
"should have at least two inputs."
);
module_ref
mod
=
mods
[
0
];
module_ref
mod
=
mods
[
0
];
std
::
cerr
<<
"mod:"
<<
*
mod
<<
std
::
endl
;
auto
type
=
mod
->
get_output_shapes
().
front
().
type
();
auto
type
=
mod
->
get_output_shapes
().
front
().
type
();
std
::
unordered_map
<
instruction_ref
,
shape
>
ins_shapes
;
std
::
unordered_map
<
instruction_ref
,
shape
>
ins_shapes
;
size_t
param_cnt
=
0
;
size_t
param_cnt
=
0
;
...
@@ -448,19 +449,31 @@ struct find_mlir_standalone_attention_op : find_mlir_standalone_op
...
@@ -448,19 +449,31 @@ struct find_mlir_standalone_attention_op : find_mlir_standalone_op
[
&
](
auto
input
)
{
return
input
!=
top_ins
;
});
[
&
](
auto
input
)
{
return
input
!=
top_ins
;
});
}
}
auto
softmax
=
mm
->
add_instruction
(
r
.
instructions
[
"softmax"
]
->
get_operator
(),
new_top_ins
);
auto
softmax
=
mm
->
add_instruction
(
r
.
instructions
[
"softmax"
]
->
get_operator
(),
new_top_ins
);
insert_to_map
(
ins_map
,
r
.
instructions
[
"softmax"
],
softmax
);
std
::
transform
(
r
.
instructions
[
"bottom_dot"
]
->
inputs
().
begin
(),
r
.
instructions
[
"bottom_dot"
]
->
inputs
().
end
(),
std
::
inserter
(
ins_map
,
ins_map
.
end
()),
[
&
](
auto
old_ins
)
{
if
(
old_ins
==
r
.
instructions
[
"softmax"
]){
return
std
::
make_pair
(
old_ins
,
softmax
);
}
inputs
.
push_back
(
old_ins
);
return
std
::
make_pair
(
old_ins
,
mm
->
add_parameter
(
"bdot_non_smax_in"
,
old_ins
->
get_shape
()));
});
auto
bottom_dot_a
=
get_from_map
(
ins_map
,
r
.
instructions
[
"bottom_dot"
]
->
inputs
().
front
());
auto
bottom_dot_a
=
get_from_map
(
ins_map
,
r
.
instructions
[
"bottom_dot"
]
->
inputs
().
front
());
auto
bottom_dot_b
=
get_from_map
(
ins_map
,
r
.
instructions
[
"bottom_dot"
]
->
inputs
().
back
());
auto
bottom_dot_b
=
get_from_map
(
ins_map
,
r
.
instructions
[
"bottom_dot"
]
->
inputs
().
back
());
auto
new_bottom_dot
=
mm
->
add_instruction
(
make_op
(
"dot"
),
{
bottom_dot_a
,
bottom_dot_b
});
auto
new_bottom_dot
=
mm
->
add_instruction
(
make_op
(
"dot"
),
{
bottom_dot_a
,
bottom_dot_b
});
mm
->
add_return
({
new_bottom_dot
});
mm
->
add_return
({
new_bottom_dot
});
inputs
.
insert
(
inputs
.
end
(),
top_inputs
.
begin
(),
top_inputs
.
end
());
inputs
.
insert
(
inputs
.
end
(),
top_inputs
.
begin
(),
top_inputs
.
end
());
mpm
.
get_module
().
replace_instruction
(
mpm
.
get_module
().
replace_instruction
(
top_ins
,
mlir_op
{
new_bottom_dot
->
get_operator
()},
inputs
,
{
mm
});
r
.
instructions
[
"bottom_dot"
]
,
mlir_op
{
new_bottom_dot
->
get_operator
()},
inputs
,
{
mm
});
}
}
auto
matcher
()
const
{
auto
matcher
()
const
{
auto
match_softmax_input
=
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"top_dot"
),
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"top_dot"
))).
bind
(
"scale"
));
auto
match_softmax_input
=
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"top_dot"
),
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"top_dot"
))).
bind
(
"scale"
));
auto
is_mlir_attention
=
match
::
name
(
"dot"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"softmax"
).
bind
(
"softmax"
))).
bind
(
"bottom_dot"
);
auto
is_mlir_attention
=
match
::
name
(
"dot"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"softmax"
)
(
match_softmax_input
)
.
bind
(
"softmax"
))).
bind
(
"bottom_dot"
);
return
is_mlir_attention
;
return
is_mlir_attention
;
}
}
...
@@ -554,12 +567,12 @@ bool is_enabled(std::string_view op_name, context* ctx)
...
@@ -554,12 +567,12 @@ bool is_enabled(std::string_view op_name, context* ctx)
void
fuse_mlir
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_mlir
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
#ifdef MIGRAPHX_MLIR
#ifdef MIGRAPHX_MLIR
match
::
find_matches
(
mpm
,
find_mlir_standalone_attention_op
{});
if
(
is_enabled
(
"fused"
,
this
->
ctx
))
if
(
is_enabled
(
"fused"
,
this
->
ctx
))
{
{
match
::
find_matches
(
mpm
,
find_mlir_attention_fused_ops
{});
match
::
find_matches
(
mpm
,
find_mlir_attention_fused_ops
{});
match
::
find_matches
(
mpm
,
find_mlir_fused_ops
{});
match
::
find_matches
(
mpm
,
find_mlir_fused_ops
{});
}
}
match
::
find_matches
(
mpm
,
find_mlir_standalone_attention_op
{});
if
(
is_enabled
(
"convolution"
,
this
->
ctx
))
if
(
is_enabled
(
"convolution"
,
this
->
ctx
))
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment