Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c5e8e14f
"...composable_kernel_rocm.git" did not exist on "39acaea36d7f9af41f3587bad9fba3dd5d370426"
Commit
c5e8e14f
authored
Dec 31, 2024
by
Po Yen, Chen
Browse files
Add kStoreLSE=true fp8 fmha fwd splitkv instances
parent
ca1a816d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
5 deletions
+9
-5
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+9
-5
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
c5e8e14f
...
@@ -294,9 +294,13 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F
...
@@ -294,9 +294,13 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F
}}
}}
}} else {{
}} else {{
if (t.has_lse) {{
if (t.has_lse) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
return -1;
}} else {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_>(s, a);
return fmha_fwd_splitkv_<traits_>(s, a);
}}
}} else {{
}} else {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, false, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, false, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
...
@@ -685,8 +689,8 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -685,8 +689,8 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
elif
dtype
in
[
'fp8'
,
'bf8'
]:
elif
dtype
in
[
'fp8'
,
'bf8'
]:
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
for
mask
,
bias
,
lse
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()
,
[
't'
,
'f'
]
):
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
'f'
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
'f'
,
mask
))
elif
dtype
in
[
'fp8fp16'
,
'fp8bf16'
]:
elif
dtype
in
[
'fp8fp16'
,
'fp8bf16'
]:
# TODO
# TODO
None
None
...
@@ -717,7 +721,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -717,7 +721,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
continue
is_prefill
=
(
mode
==
"group"
and
pipeline
.
F_pagedkv
==
't'
)
is_prefill
=
(
dtype
in
[
'fp16'
,
'bf16'
]
and
mode
==
"group"
and
pipeline
.
F_pagedkv
==
't'
)
tile
=
prefill_tile
if
is_prefill
else
decode_tile
tile
=
prefill_tile
if
is_prefill
else
decode_tile
k
=
Kernel
(
F_idx
=
0
,
k
=
Kernel
(
F_idx
=
0
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment