Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
83b9164b
Commit
83b9164b
authored
Oct 18, 2023
by
Manupa Karunaratne
Browse files
* fix pm + attention test
parent
a22ec139
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
14 deletions
+22
-14
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+22
-14
No files found.
src/targets/gpu/fuse_mlir.cpp
View file @
83b9164b
...
@@ -411,7 +411,7 @@ struct find_mlir_standalone_attention_op
...
@@ -411,7 +411,7 @@ struct find_mlir_standalone_attention_op
mm
->
set_bypass
();
mm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
auto
top_ins
=
r
.
instructions
[
"
top_dot
"
];
auto
top_ins
=
r
.
instructions
[
"
gemm0
"
];
auto
[
new_top_ins
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
top_ins
);
auto
[
new_top_ins
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
top_ins
);
inputs
.
insert
(
inputs
.
begin
(),
top_inputs
.
begin
(),
top_inputs
.
end
());
inputs
.
insert
(
inputs
.
begin
(),
top_inputs
.
begin
(),
top_inputs
.
end
());
ins_map
[
top_ins
]
=
new_top_ins
;
ins_map
[
top_ins
]
=
new_top_ins
;
...
@@ -424,8 +424,8 @@ struct find_mlir_standalone_attention_op
...
@@ -424,8 +424,8 @@ struct find_mlir_standalone_attention_op
[
&
](
auto
input
)
{
return
input
!=
top_ins
;
});
[
&
](
auto
input
)
{
return
input
!=
top_ins
;
});
}
}
auto
softmax
=
mm
->
add_instruction
(
r
.
instructions
[
"softmax"
]
->
get_operator
(),
new_top_ins
);
auto
softmax
=
mm
->
add_instruction
(
r
.
instructions
[
"softmax"
]
->
get_operator
(),
new_top_ins
);
std
::
transform
(
r
.
instructions
[
"
bottom_dot
"
]
->
inputs
().
begin
(),
std
::
transform
(
r
.
instructions
[
"
gemm1
"
]
->
inputs
().
begin
(),
r
.
instructions
[
"
bottom_dot
"
]
->
inputs
().
end
(),
r
.
instructions
[
"
gemm1
"
]
->
inputs
().
end
(),
std
::
inserter
(
ins_map
,
ins_map
.
end
()),
std
::
inserter
(
ins_map
,
ins_map
.
end
()),
[
&
](
auto
old_ins
)
{
[
&
](
auto
old_ins
)
{
if
(
old_ins
==
r
.
instructions
[
"softmax"
]){
if
(
old_ins
==
r
.
instructions
[
"softmax"
]){
...
@@ -433,22 +433,30 @@ struct find_mlir_standalone_attention_op
...
@@ -433,22 +433,30 @@ struct find_mlir_standalone_attention_op
}
}
inputs
.
push_back
(
old_ins
);
inputs
.
push_back
(
old_ins
);
return
std
::
make_pair
(
old_ins
,
return
std
::
make_pair
(
old_ins
,
mm
->
add_parameter
(
"
bdot_non_smax_in
"
,
old_ins
->
get_shape
()));
mm
->
add_parameter
(
"
v
"
,
old_ins
->
get_shape
()));
});
});
auto
bottom_dot_a
=
ins_map
[
r
.
instructions
[
"bottom_dot"
]
->
inputs
().
front
()];
auto
gemm1_a
=
ins_map
[
r
.
instructions
[
"gemm1"
]
->
inputs
().
front
()];
auto
bottom_dot_b
=
ins_map
[
r
.
instructions
[
"bottom_dot"
]
->
inputs
().
back
()];
auto
gemm1_b
=
ins_map
[
r
.
instructions
[
"gemm1"
]
->
inputs
().
back
()];
auto
new_bottom_dot
=
mm
->
add_instruction
(
make_op
(
"dot"
),
{
bottom_dot_a
,
bottom_dot_b
});
auto
new_gemm1
=
mm
->
add_instruction
(
make_op
(
"dot"
),
{
gemm1_a
,
gemm1_b
});
ins_map
[
r
.
instructions
[
"gemm1"
]]
=
new_gemm1
;
auto
ins_to_replace
=
new_gemm1
;
auto
ins_to_be_replaced
=
r
.
instructions
[
"gemm1"
];
if
(
r
.
instructions
.
find
(
"trailing_pm"
)
!=
r
.
instructions
.
end
()){
if
(
r
.
instructions
.
find
(
"trailing_pm"
)
!=
r
.
instructions
.
end
()){
new_bottom_dot
=
fold_pointwise_mod
(
r
.
instructions
[
"trailing_pm"
],
mm
,
ins_map
)[
0
];
ins_to_replace
=
fold_pointwise_mod
(
r
.
instructions
[
"trailing_pm"
],
mm
,
ins_map
)[
0
];
std
::
copy_if
(
r
.
instructions
[
"trailing_pm"
]
->
inputs
().
begin
(),
r
.
instructions
[
"trailing_pm"
]
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
r
.
instructions
[
"gemm1"
];
});
ins_to_be_replaced
=
r
.
instructions
[
"trailing_pm"
];
}
}
mm
->
add_return
({
new_bottom_dot
});
mm
->
add_return
({
ins_to_replace
});
mpm
.
get_module
().
replace_instruction
(
mpm
.
get_module
().
replace_instruction
(
r
.
ins
tructions
[
"bottom_dot"
]
,
mlir_op
{
new_
bottom_dot
->
get_operator
()},
inputs
,
{
mm
});
ins
_to_be_replaced
,
mlir_op
{
new_
gemm1
->
get_operator
()},
inputs
,
{
mm
});
}
}
auto
matcher
()
const
{
auto
matcher
()
const
{
auto
match_softmax_input
=
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"
top_dot
"
),
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"
top_dot
"
))).
bind
(
"scale"
));
auto
match_softmax_input
=
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"
gemm0
"
),
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"dot"
).
bind
(
"
gemm0
"
))).
bind
(
"scale"
));
auto
is_mlir_attention
=
match
::
name
(
"dot"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"softmax"
)(
match_softmax_input
).
bind
(
"softmax"
))).
bind
(
"
bottom_dot
"
);
auto
is_mlir_attention
=
match
::
name
(
"dot"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
name
(
"softmax"
)(
match_softmax_input
).
bind
(
"softmax"
))).
bind
(
"
gemm1
"
);
return
is_mlir_attention
;
return
is_mlir_attention
;
}
}
...
@@ -458,7 +466,7 @@ struct find_mlir_standalone_attention_op
...
@@ -458,7 +466,7 @@ struct find_mlir_standalone_attention_op
if
(
mode
!=
mlir_mode
::
all
){
if
(
mode
!=
mlir_mode
::
all
){
return
false
;
return
false
;
}
}
auto
top_dot
=
r
.
instructions
[
"
top_dot
"
];
auto
gemm0
=
r
.
instructions
[
"
gemm0
"
];
// Check the pointwise mod only contains a single mul
// Check the pointwise mod only contains a single mul
if
(
r
.
instructions
.
find
(
"scale"
)
!=
r
.
instructions
.
end
()){
if
(
r
.
instructions
.
find
(
"scale"
)
!=
r
.
instructions
.
end
()){
auto
scale_pm
=
r
.
instructions
[
"scale"
];
auto
scale_pm
=
r
.
instructions
[
"scale"
];
...
@@ -475,7 +483,7 @@ struct find_mlir_standalone_attention_op
...
@@ -475,7 +483,7 @@ struct find_mlir_standalone_attention_op
}
}
}
}
// enable only for fp32/fp16/i8 types
// enable only for fp32/fp16/i8 types
if
(
std
::
any_of
(
top_dot
->
inputs
().
begin
(),
top_dot
->
inputs
().
end
(),
[
&
](
auto
i
)
{
if
(
std
::
any_of
(
gemm0
->
inputs
().
begin
(),
gemm0
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
(
return
not
contains
(
{
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
},
{
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
},
i
->
get_shape
().
type
());
i
->
get_shape
().
type
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment