Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4b3474e4
Commit
4b3474e4
authored
Dec 23, 2024
by
Po Yen Chen
Browse files
Set kHasUnevenSplits=false if num_splits = 1
parent
346ba760
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
7 deletions
+17
-7
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+17
-7
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
4b3474e4
...
@@ -47,7 +47,7 @@ using fmha_dtype_{F_idx} = {F_dtype};
...
@@ -47,7 +47,7 @@ using fmha_dtype_{F_idx} = {F_dtype};
using fmha_mask_{F_idx} = {F_mask};
using fmha_mask_{F_idx} = {F_mask};
namespace {{
namespace {{
template <bool kHasUnevenSplits
, bool
kIsMultipleSplits>
template <bool
kIsMultipleSplits, bool
kHasUnevenSplits
=
kIsMultipleSplits>
struct kernel_runner {{
struct kernel_runner {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
...
@@ -68,7 +68,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
...
@@ -68,7 +68,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_lse},
{F_lse},
{F_squant},
{F_squant},
{F_pagedkv},
{F_pagedkv},
kHasUnevenSplits,
kIsMultipleSplits &&
kHasUnevenSplits,
{F_occupancy}>;
{F_occupancy}>;
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
...
@@ -131,12 +131,23 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
...
@@ -131,12 +131,23 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
template<>
template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
{{
/// NOTICE: kHasUnevenSplits=false may be able to speed-up the batch mode kernel,
/// but we use kHasUnevenSplits=true here to reduce compilation time
if (1 < a.num_splits) {{
if (1 < a.num_splits) {{
kernel_runner</*kHasUnevenSplits=*/true, /*kIsMultipleSplits=*/true>::run(s, a);
constexpr bool kIsMultipleSplits = true;
if constexpr({F_mode} == false) {{ // batch mode
// we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/false>::run(s, a);
}} else {{
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
}}
}} else {{ // group mode
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
}}
}} else {{
}} else {{
kernel_runner<
/*kHasUnevenSplits=*/true,
/*kIsMultipleSplits=*/false>::run(s, a);
kernel_runner</*kIsMultipleSplits=*/false>::run(s, a);
}}
}}
}}
}}
...
@@ -659,7 +670,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -659,7 +670,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
for
mask
,
bias
,
lse
,
pagedkv
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
for
mask
,
bias
,
lse
,
pagedkv
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
# TODO: use async pipeline when compiler is more stable
# TODO: use async pipeline when compiler is more stable
if
hdim
==
256
or
hdim
in
[
32
,
64
,
128
]:
### [32, 64, 96, 128]:
if
hdim
==
256
or
hdim
in
[
32
,
64
,
128
]:
### [32, 64, 96, 128]:
# if True:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
't'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment