Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6c7a3bf4
Commit
6c7a3bf4
authored
Dec 17, 2024
by
Po Yen Chen
Browse files
Only launch splitkv kernel if num_splits == 1
parent
fa34e87c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
114 additions
and
69 deletions
+114
-69
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+77
-48
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+28
-12
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+9
-9
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
6c7a3bf4
...
@@ -47,7 +47,7 @@ using fmha_dtype_{F_idx} = {F_dtype};
...
@@ -47,7 +47,7 @@ using fmha_dtype_{F_idx} = {F_dtype};
using fmha_mask_{F_idx} = {F_mask};
using fmha_mask_{F_idx} = {F_mask};
namespace {{
namespace {{
template <bool kHasUnevenSplits>
template <bool kHasUnevenSplits
, bool kIsMultipleSplits
>
struct kernel_runner {{
struct kernel_runner {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
...
@@ -81,7 +81,11 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
...
@@ -81,7 +81,11 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
std::conditional_t<
kIsMultipleSplits,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType
>,
fmha_shape,
fmha_shape,
{F_mode},
{F_mode},
fmha_mask_{F_idx},
fmha_mask_{F_idx},
...
@@ -93,9 +97,14 @@ using fmha_pipeline = {F_pipeline}<
...
@@ -93,9 +97,14 @@ using fmha_pipeline = {F_pipeline}<
/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving
/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving
/// store_tile_raw() data corruption issue
/// store_tile_raw() data corruption issue
using fmha_epilogue =
using fmha_epilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
false, false>>;
std::conditional_t<
kIsMultipleSplits,
typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType
>,
false, false>>;
using fmha_kernel =
using fmha_kernel =
ck_tile::FmhaFwdSplitKVKernel<ck_tile::FmhaFwdSplitKVTilePartitioner<fmha_shape>,
ck_tile::FmhaFwdSplitKVKernel<ck_tile::FmhaFwdSplitKVTilePartitioner<fmha_shape>,
...
@@ -122,25 +131,19 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
...
@@ -122,25 +131,19 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
template<>
template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
{{
if constexpr({F_mode} == false) {{ // batch mode
/// NOTICE: kHasUnevenSplits=false may be able to speed-up the batch mode kernel,
// we don't check every seqlen_k values for kvcache
/// but we use kHasUnevenSplits=true here to reduce compilation time
if (a.seqlen_k_ptr != nullptr) {{
if (1 < a.num_splits) {{
kernel_runner<true>::run(s, a);
kernel_runner</*kHasUnevenSplits=*/true, /*kIsMultipleSplits=*/true>::run(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<false>::run(s, a);
}} else {{
kernel_runner<true>::run(s, a);
}}
}} else {{
}} else {{
kernel_runner<
tru
e>::run(s, a);
kernel_runner<
/*kHasUnevenSplits=*/true, /*kIsMultipleSplits=*/fals
e>::run(s, a);
}}
}}
}}
}}
template<>
template<>
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
{{
{{
using k_ = kernel_runner<true>::fmha_kernel; /// FIXME: choose real kernel type
using k_ = kernel_runner<
true,
true>::fmha_kernel; /// FIXME: choose real kernel type
return k_::GetName();
return k_::GetName();
}}
}}
"""
"""
...
@@ -227,19 +230,32 @@ FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp"
...
@@ -227,19 +230,32 @@ FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp"
FMHA_FWD_SPLITKV_API
=
"""
FMHA_FWD_SPLITKV_API
=
"""
#include <iostream>
#include <iostream>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_
= void
>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
{{
if(s.log_level_ > 0)
// fmha_fwd_splitkv_combine_traits_=void, launch splitkv kernel only
std::cout
if constexpr (std::is_same_v<fmha_fwd_splitkv_combine_traits_, void>) {{
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
if(s.log_level_ > 0)
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
std::cout
<< std::flush;
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
<< std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }}
);
);
// launch both splitkv & combine kernels
}} else {{
if(s.log_level_ > 0)
std::cout
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
<< std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
);
}}
}}
}}
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{
...
@@ -251,19 +267,32 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
...
@@ -251,19 +267,32 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
if (t.has_lse) {{
if (1 < a.num_splits) {{
if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return -1;
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, ck_tile::fp8_t>) {{
return -1;
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, true, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}} else {{
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2,
tru
e, {F_squant}, {F_spad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2,
fals
e, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}}
}} else {{
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}/2, false, {F_squant}, {F_spad}, {F_dvpad}>;
if (t.has_lse) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_>(s, a);
}} else {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
return fmha_fwd_splitkv_<traits_>(s, a);
}}
}}
}}
}}
}}
"""
"""
...
@@ -626,26 +655,26 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -626,26 +655,26 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant
=
't'
if
dtype
==
'fp8'
else
'f'
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
if
dtype
in
[
'fp16'
,
'bf16'
]:
for
mask
,
bias
,
pagedkv
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
]):
for
mask
,
bias
,
lse
,
pagedkv
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
# TODO: use async pipeline when compiler is more stable
# TODO: use async pipeline when compiler is more stable
if
hdim
==
256
or
hdim
in
[
32
,
64
,
128
]:
### [32, 64, 96, 128]:
if
hdim
==
256
or
hdim
in
[
32
,
64
,
128
]:
### [32, 64, 96, 128]:
# if True:
# if True:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
else
:
else
:
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
if
receipt
==
1
:
if
receipt
==
1
:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
elif
dtype
in
[
'fp8'
,
'bf8'
]:
elif
dtype
in
[
'fp8'
,
'bf8'
]:
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
'f'
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
'f'
,
mask
))
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
6c7a3bf4
...
@@ -632,12 +632,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -632,12 +632,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto
[
rotary_cos_host
,
rotary_sin_host
]
=
generate_rotary_cos_sin
<
KDataType
>
(
auto
[
rotary_cos_host
,
rotary_sin_host
]
=
generate_rotary_cos_sin
<
KDataType
>
(
std
::
max
(
shape_seqlen_q
,
shape_seqlen_k
),
rotary_dim
,
seed
);
std
::
max
(
shape_seqlen_q
,
shape_seqlen_k
),
rotary_dim
,
seed
);
// lse_acc_host & o_acc_host are only used when 1 < num_spilts
ck_tile
::
HostTensor
<
LSEDataType
>
lse_acc_host
(
ck_tile
::
HostTensor
<
LSEDataType
>
lse_acc_host
(
1
<
num_splits
||
use_kvcache
1
<
num_splits
?
std
::
array
<
ck_tile
::
index_t
,
4
>
{
shape_batch
,
nhead
,
num_splits
,
shape_seqlen_q
}
?
std
::
array
<
ck_tile
::
index_t
,
4
>
{
shape_batch
,
nhead
,
num_splits
,
shape_seqlen_q
}
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
1
<
num_splits
||
use_kvcache
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
shape_batch
,
1
<
num_splits
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
shape_batch
,
nhead
,
nhead
,
num_splits
,
num_splits
,
shape_seqlen_q
,
shape_seqlen_q
,
...
@@ -1043,9 +1044,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -1043,9 +1044,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
else
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_splitkv_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
else
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_splitkv_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
{
{
args
.
lse_acc_ptr
=
lse_acc_buf
.
GetDeviceBuffer
();
// lse_acc_buf & o_acc_buf are only used when 1 < num_spilts
args
.
o_acc_ptr
=
o_acc_buf
.
GetDeviceBuffer
();
args
.
block_table_ptr
=
args
.
block_table_ptr
=
(
0
<
page_block_size
?
block_table_buf
.
GetDeviceBuffer
()
:
nullptr
);
(
0
<
page_block_size
?
block_table_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
batch_stride_block_table
=
batch_stride_block_table
;
args
.
batch_stride_block_table
=
batch_stride_block_table
;
...
@@ -1057,13 +1056,30 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -1057,13 +1056,30 @@ bool run(const ck_tile::ArgParser& arg_parser)
args
.
num_splits
=
num_splits
;
args
.
num_splits
=
num_splits
;
args
.
stride_o_acc
=
stride_o_acc
;
if
(
1
<
num_splits
)
{
args
.
nhead_stride_lse_acc
=
nhead_stride_lse_acc
;
args
.
lse_acc_ptr
=
lse_acc_buf
.
GetDeviceBuffer
();
args
.
nhead_stride_o_acc
=
nhead_stride_o_acc
;
args
.
o_acc_ptr
=
o_acc_buf
.
GetDeviceBuffer
();
args
.
batch_stride_lse_acc
=
batch_stride_lse_acc
;
args
.
batch_stride_o_acc
=
batch_stride_o_acc
;
args
.
stride_o_acc
=
stride_o_acc
;
args
.
split_stride_lse_acc
=
split_stride_lse_acc
;
args
.
nhead_stride_lse_acc
=
nhead_stride_lse_acc
;
args
.
split_stride_o_acc
=
split_stride_o_acc
;
args
.
nhead_stride_o_acc
=
nhead_stride_o_acc
;
args
.
batch_stride_lse_acc
=
batch_stride_lse_acc
;
args
.
batch_stride_o_acc
=
batch_stride_o_acc
;
args
.
split_stride_lse_acc
=
split_stride_lse_acc
;
args
.
split_stride_o_acc
=
split_stride_o_acc
;
}
else
{
// following attribues are ignored by fmha_fwd_splitkv()
args
.
lse_acc_ptr
=
nullptr
;
args
.
o_acc_ptr
=
nullptr
;
args
.
stride_o_acc
=
0
;
args
.
nhead_stride_lse_acc
=
0
;
args
.
nhead_stride_o_acc
=
0
;
args
.
batch_stride_lse_acc
=
0
;
args
.
batch_stride_o_acc
=
0
;
args
.
split_stride_lse_acc
=
0
;
args
.
split_stride_o_acc
=
0
;
}
}
}
}
}
};
};
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
6c7a3bf4
...
@@ -458,8 +458,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
...
@@ -458,8 +458,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args
.
k_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
bias_ptr
,
args
.
lse_acc_ptr
,
(
1
<
args
.
num_splits
?
args
.
lse_acc_ptr
:
args
.
lse_ptr
)
,
args
.
o_acc_ptr
,
(
1
<
args
.
num_splits
?
args
.
o_acc_ptr
:
args
.
o_ptr
)
,
args
.
batch
,
args
.
batch
,
args
.
seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
seqlen_k
,
...
@@ -479,21 +479,21 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
...
@@ -479,21 +479,21 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args
.
stride_k
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_o_acc
,
(
1
<
args
.
num_splits
?
args
.
stride_o_acc
:
args
.
stride_o
)
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_lse_acc
,
(
1
<
args
.
num_splits
?
args
.
nhead_stride_lse_acc
:
args
.
nhead_stride_lse
)
,
args
.
nhead_stride_o_acc
,
(
1
<
args
.
num_splits
?
args
.
nhead_stride_o_acc
:
args
.
nhead_stride_o
)
,
args
.
batch_stride_q
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_v
,
args
.
batch_stride_bias
,
args
.
batch_stride_bias
,
args
.
batch_stride_lse_acc
,
(
1
<
args
.
num_splits
?
args
.
batch_stride_lse_acc
:
args
.
batch_stride_lse
)
,
args
.
batch_stride_o_acc
,
(
1
<
args
.
num_splits
?
args
.
batch_stride_o_acc
:
args
.
batch_stride_o
)
,
args
.
split_stride_lse_acc
,
(
1
<
args
.
num_splits
?
args
.
split_stride_lse_acc
:
0
)
,
args
.
split_stride_o_acc
,
(
1
<
args
.
num_splits
?
args
.
split_stride_o_acc
:
0
)
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
);
args
.
mask_type
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment