Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1f6b546e
Commit
1f6b546e
authored
Jan 01, 2025
by
Po Yen Chen
Browse files
Sync fmha fwd splitkv codegen logics
parent
7cd4e574
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
127 additions
and
68 deletions
+127
-68
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+127
-68
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
1f6b546e
...
@@ -12,9 +12,9 @@ from typing import List, Optional, Tuple, Union
...
@@ -12,9 +12,9 @@ from typing import List, Optional, Tuple, Union
from
codegen.cmake_config
import
*
from
codegen.cmake_config
import
*
from
codegen.cpp_symbol_map
import
*
from
codegen.cpp_symbol_map
import
*
import
codegen.ops.fmha_fwd
from
codegen.ops.fmha_fwd
import
(
from
codegen.ops.fmha_fwd
import
(
FmhaFwdTileSize
,
FmhaFwdTileSize
,
FmhaFwdApiTrait
,
FMHA_FWD_KERNEL_HEADER
,
FMHA_FWD_KERNEL_HEADER
,
FMHA_FWD_API_PER_DTYPE
,
FMHA_FWD_API_PER_DTYPE
,
FMHA_FWD_API_PER_HDIM_CASE
,
FMHA_FWD_API_PER_HDIM_CASE
,
...
@@ -48,7 +48,7 @@ using fmha_dtype_{F_idx} = {F_dtype};
...
@@ -48,7 +48,7 @@ using fmha_dtype_{F_idx} = {F_dtype};
using fmha_mask_{F_idx} = {F_mask};
using fmha_mask_{F_idx} = {F_mask};
namespace {{
namespace {{
template <bool k
HasUneven
Splits>
template <bool k
IsMultipleSplits, bool kHasUnevenSplits = kIsMultiple
Splits>
struct kernel_runner {{
struct kernel_runner {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
...
@@ -68,7 +68,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
...
@@ -68,7 +68,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_lse},
{F_lse},
{F_squant},
{F_squant},
{F_pagedkv},
{F_pagedkv},
kHasUnevenSplits,
kIsMultipleSplits &&
kHasUnevenSplits,
{F_occupancy}>;
{F_occupancy}>;
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
...
@@ -81,7 +81,11 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
...
@@ -81,7 +81,11 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
std::conditional_t<
kIsMultipleSplits,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType
>,
fmha_shape,
fmha_shape,
{F_mode},
{F_mode},
fmha_mask_{F_idx},
fmha_mask_{F_idx},
...
@@ -90,10 +94,17 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
...
@@ -90,10 +94,17 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
using fmha_pipeline = {F_pipeline}<
using fmha_pipeline = {F_pipeline}<
fmha_pipeline_problem>;
fmha_pipeline_problem>;
/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving
/// store_tile_raw() data corruption issue
using fmha_epilogue =
using fmha_epilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
{F_spad}, {F_dvpad}>>;
std::conditional_t<
kIsMultipleSplits,
typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType
>,
false, false>>;
using fmha_kernel =
using fmha_kernel =
ck_tile::FmhaFwdSplitKVKernel<fmha_pipeline, fmha_epilogue>;
ck_tile::FmhaFwdSplitKVKernel<fmha_pipeline, fmha_epilogue>;
...
@@ -118,25 +129,30 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
...
@@ -118,25 +129,30 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
template<>
template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
{{
if constexpr({F_mode} == false) {{ // batch mode
if (1 < a.num_splits) {{
// we don't check every seqlen_k values for kvcache
constexpr bool kIsMultipleSplits = true;
if (a.seqlen_k_ptr != nullptr) {{
if constexpr({F_mode} == false) {{ // batch mode
kernel_runner<true>::run(s, a);
// we don't check every seqlen_k values for kvcache
// make sure F_bn0 is divisible by F_bk1
if (a.seqlen_k_ptr != nullptr) {{
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
kernel_runner<false>::run(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else {{
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<true>::run(s, a);
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/false>::run(s, a);
}} else {{
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
}}
}} else {{ // group mode
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
}}
}}
}} else {{
}} else {{
kernel_runner<
tru
e>::run(s, a);
kernel_runner<
/*kIsMultipleSplits=*/fals
e>::run(s, a);
}}
}}
}}
}}
template<>
template<>
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
{{
{{
using k_ = kernel_runner<true>::fmha_kernel; /// FIXME: choose real kernel type
using k_ = kernel_runner<
true,
true>::fmha_kernel; /// FIXME: choose real kernel type
return k_::GetName();
return k_::GetName();
}}
}}
"""
"""
...
@@ -220,19 +236,32 @@ FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp"
...
@@ -220,19 +236,32 @@ FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp"
FMHA_FWD_SPLITKV_API
=
"""
FMHA_FWD_SPLITKV_API
=
"""
#include <iostream>
#include <iostream>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_
= void
>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
{{
if(s.log_level_ > 0)
// fmha_fwd_splitkv_combine_traits_=void, launch splitkv kernel only
std::cout
if constexpr (std::is_same_v<fmha_fwd_splitkv_combine_traits_, void>) {{
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
if(s.log_level_ > 0)
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
std::cout
<< std::flush;
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
<< std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }}
);
);
// launch both splitkv & combine kernels
}} else {{
if(s.log_level_ > 0)
std::cout
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
<< std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
);
}}
}}
}}
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{
...
@@ -244,28 +273,45 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
...
@@ -244,28 +273,45 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
if (1 < a.num_splits) {{
// get combine kernel tile sizes
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, /*F_bn1=*/32>::kM0;
// get combine kernel tile sizes
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
// make sure we can reuse the padding flags in combine kernels
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, /*F_bn1=*/32>::kM0;
static_assert({F_bm0} % kM0 == 0);
static_assert({F_bn1} % 32 == 0);
// make sure we can reuse the padding flags in combine kernels
static_assert({F_bm0} % kM0 == 0);
if (t.has_lse) {{
static_assert({F_bn1} % 32 == 0);
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
return -1;
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
return -1;
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}} else {{
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32,
tru
e, {F_squant}, {F_spad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32,
fals
e, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}}
}} else {{
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>;
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
return -1;
}} else {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_>(s, a);
}}
}} else {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, false, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
return fmha_fwd_splitkv_<traits_>(s, a);
}}
}}
}}
}}
}}
"""
"""
...
@@ -605,7 +651,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
...
@@ -605,7 +651,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
### '96' : FmhaFwdSplitKVCombineTileSize(32, -1),
### '96' : FmhaFwdSplitKVCombineTileSize(32, -1),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
}
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
return
{
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
-
1
),
...
@@ -629,26 +675,28 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -629,26 +675,28 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant
=
't'
if
dtype
==
'fp8'
else
'f'
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
if
dtype
in
[
'fp16'
,
'bf16'
]:
for
mask
,
bias
,
pagedkv
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
]):
for
mask
,
bias
,
lse
,
pagedkv
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
# TODO: use async pipeline when compiler is more stable
# TODO: use async pipeline when compiler is more stable
if
hdim
==
256
or
hdim
in
[
32
,
64
,
128
]:
### [32, 64, 96, 128]:
if
hdim
==
256
or
hdim
in
[
32
,
64
,
128
]:
### [32, 64, 96, 128]:
# if True:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
else
:
else
:
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
if
receipt
==
1
:
if
receipt
==
1
:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
't'
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
elif
dtype
in
[
'fp8'
,
'bf8'
]:
elif
dtype
in
[
'fp8'
,
'bf8'
]:
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
for
mask
,
bias
,
lse
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()
,
[
't'
,
'f'
]
):
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
't'
,
squant
,
'f'
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
'f'
,
mask
))
elif
dtype
in
[
'fp8fp16'
,
'fp8bf16'
]:
elif
dtype
in
[
'fp8fp16'
,
'fp8bf16'
]:
# TODO
# TODO
None
None
...
@@ -660,18 +708,27 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -660,18 +708,27 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
api_pool
=
FmhaFwdSplitKVApiPool
(
mask_impl
)
api_pool
=
FmhaFwdSplitKVApiPool
(
mask_impl
)
for
dtype
in
FWD_DTYPE_MAP
.
keys
():
for
dtype
in
FWD_DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
prefill_tiles
=
codegen
.
ops
.
fmha_fwd
.
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
decode_tiles
=
get_fmha_fwd_tile_dict_from_dtype
(
dtype
)
if
decode_tiles
==
None
:
continue
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for
hdim_str
,
mode
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
()):
# make sure if all the hdim str keys in decode_tiles are also available in prefill_tiles
tile
=
d
[
hdim_str
]
assert
all
(
tile
in
prefill_tiles
.
keys
()
for
tile
in
decode_tiles
.
keys
())
for
hdim_str
,
mode
in
itertools
.
product
(
decode_tiles
.
keys
(),
MODE_MAP
.
keys
()):
prefill_tile
=
prefill_tiles
[
hdim_str
]
decode_tile
=
decode_tiles
[
hdim_str
]
hdim
=
int
(
hdim_str
)
hdim
=
int
(
hdim_str
)
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
if
mode
==
"group"
:
if
mode
==
"group"
:
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
continue
is_prefill
=
(
dtype
in
[
'fp16'
,
'bf16'
]
and
mode
==
"group"
and
pipeline
.
F_pagedkv
==
't'
)
tile
=
prefill_tile
if
is_prefill
else
decode_tile
k
=
Kernel
(
F_idx
=
0
,
k
=
Kernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_dtype
=
dtype
,
...
@@ -683,8 +740,11 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -683,8 +740,11 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
continue
continue
if
receipt
==
2
:
if
receipt
==
2
:
is_chunked_prefill
=
(
mode
==
'group'
and
pipeline
.
F_pagedkv
==
't'
)
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
pipeline
.
F_vlayout
==
'row'
# use vlayout=row for chunked prefill
cond
=
cond
and
((
pipeline
.
F_vlayout
==
'row'
and
not
is_chunked_prefill
)
or
(
pipeline
.
F_vlayout
==
'col'
and
is_chunked_prefill
))
cond
&=
pipeline
.
F_bias
in
[
'no'
,
'alibi'
]
cond
&=
pipeline
.
F_bias
in
[
'no'
,
'alibi'
]
cond
&=
pipeline
.
F_squant
==
'f'
cond
&=
pipeline
.
F_squant
==
'f'
if
not
cond
:
if
not
cond
:
...
@@ -723,12 +783,11 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
...
@@ -723,12 +783,11 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
d
=
get_fmha_fwd_splitkv_combine_tile_dict_from_dtype
(
dtype
)
d
=
get_fmha_fwd_splitkv_combine_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
if
d
==
None
:
continue
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for
hdim_str
,
mode
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
()):
for
hdim_str
,
mode
in
itertools
.
product
(
d
.
keys
(),
MODE_MAP
.
keys
()):
tile
=
d
[
hdim_str
]
tile
=
d
[
hdim_str
]
hdim
=
int
(
hdim_str
)
hdim
=
int
(
hdim_str
)
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
if
mode
==
"
group
"
:
if
mode
==
'
group
'
:
if
pipeline
.
F_spad
!=
't'
:
if
pipeline
.
F_spad
!=
't'
:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
continue
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment