Commit 4b3474e4 authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Set kHasUnevenSplits=false if num_splits = 1

parent 346ba760
...@@ -47,7 +47,7 @@ using fmha_dtype_{F_idx} = {F_dtype}; ...@@ -47,7 +47,7 @@ using fmha_dtype_{F_idx} = {F_dtype};
using fmha_mask_{F_idx} = {F_mask}; using fmha_mask_{F_idx} = {F_mask};
namespace {{ namespace {{
template <bool kHasUnevenSplits, bool kIsMultipleSplits> template <bool kIsMultipleSplits, bool kHasUnevenSplits = kIsMultipleSplits>
struct kernel_runner {{ struct kernel_runner {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
...@@ -68,7 +68,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, ...@@ -68,7 +68,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_lse}, {F_lse},
{F_squant}, {F_squant},
{F_pagedkv}, {F_pagedkv},
kHasUnevenSplits, kIsMultipleSplits && kHasUnevenSplits,
{F_occupancy}>; {F_occupancy}>;
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
...@@ -131,12 +131,23 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F ...@@ -131,12 +131,23 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
template<> template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
/// NOTICE: kHasUnevenSplits=false may be able to speed-up the batch mode kernel,
/// but we use kHasUnevenSplits=true here to reduce compilation time
if (1 < a.num_splits) {{ if (1 < a.num_splits) {{
kernel_runner</*kHasUnevenSplits=*/true, /*kIsMultipleSplits=*/true>::run(s, a); constexpr bool kIsMultipleSplits = true;
if constexpr({F_mode} == false) {{ // batch mode
// we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/false>::run(s, a);
}} else {{ }} else {{
kernel_runner</*kHasUnevenSplits=*/true, /*kIsMultipleSplits=*/false>::run(s, a); kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
}}
}} else {{ // group mode
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
}}
}} else {{
kernel_runner</*kIsMultipleSplits=*/false>::run(s, a);
}} }}
}} }}
...@@ -659,7 +670,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -659,7 +670,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
# TODO: use async pipeline when compiler is more stable # TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
# if True:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment