Commit c5083c0f authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Merge branch 'feature/add-splitkv-instance' into...

Merge branch 'feature/add-splitkv-instance' into feature/support-vllm-kcache-layout-add-splitkv-instance
parents 0739bc5a 3f29f232
...@@ -47,7 +47,7 @@ using fmha_dtype_{F_idx} = {F_dtype}; ...@@ -47,7 +47,7 @@ using fmha_dtype_{F_idx} = {F_dtype};
using fmha_mask_{F_idx} = {F_mask}; using fmha_mask_{F_idx} = {F_mask};
namespace {{ namespace {{
template <bool kHasUnevenSplits, bool kIsMultipleSplits> template <bool kIsMultipleSplits, bool kHasUnevenSplits = kIsMultipleSplits>
struct kernel_runner {{ struct kernel_runner {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
...@@ -68,7 +68,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, ...@@ -68,7 +68,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_lse}, {F_lse},
{F_squant}, {F_squant},
{F_pagedkv}, {F_pagedkv},
kHasUnevenSplits, kIsMultipleSplits && kHasUnevenSplits,
{F_occupancy}>; {F_occupancy}>;
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
...@@ -131,12 +131,23 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F ...@@ -131,12 +131,23 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
template<> template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
/// NOTICE: kHasUnevenSplits=false may be able to speed-up the batch mode kernel,
/// but we use kHasUnevenSplits=true here to reduce compilation time
if (1 < a.num_splits) {{ if (1 < a.num_splits) {{
kernel_runner</*kHasUnevenSplits=*/true, /*kIsMultipleSplits=*/true>::run(s, a); constexpr bool kIsMultipleSplits = true;
if constexpr({F_mode} == false) {{ // batch mode
// we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/false>::run(s, a);
}} else {{
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
}}
}} else {{ // group mode
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
}}
}} else {{ }} else {{
kernel_runner</*kHasUnevenSplits=*/true, /*kIsMultipleSplits=*/false>::run(s, a); kernel_runner</*kIsMultipleSplits=*/false>::run(s, a);
}} }}
}} }}
...@@ -658,10 +669,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -658,10 +669,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
# TODO: use async pipeline when compiler is more stable # TODO: use async pipeline when compiler is more stable
if hdim == 256: if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
else: else:
......
...@@ -402,27 +402,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -402,27 +402,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
move_tile_window(bias_dram_window, {0, kN0}); move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in first/last iteration without increasing code size // only check in first/last iterations
if constexpr(kHasUnevenSplits) if constexpr(kHasUnevenSplits)
{ {
const auto k_origin = k_page_block_navigator.to_global_window_origin( if(1 < num_splits && (i_total_loops == 0 || i_total_loops == num_total_loop - 1))
i_page_block_k, k_dram_block_window.get_window_origin()); {
set_tile_if( const auto k_origin = k_page_block_navigator.to_global_window_origin(
s_acc, i_page_block_k, k_dram_block_window.get_window_origin());
-numeric<SMPLComputeDataType>::infinity(), set_tile_if(s_acc,
[&, -numeric<SMPLComputeDataType>::infinity(),
physical_seqlen_k_start_ = physical_seqlen_k_start, [&,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { physical_seqlen_k_start_ = physical_seqlen_k_start,
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
if constexpr(kIsPagedKV) const auto col =
{ k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col; if constexpr(kIsPagedKV)
} {
else return col < physical_seqlen_k_start_ ||
{ physical_seqlen_k_end_ <= col;
return physical_seqlen_k_end_ <= col; }
} else
}); {
return physical_seqlen_k_end_ <= col;
}
});
}
} }
if constexpr(kPadSeqLenK || FmhaMask::IsMasking) if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
......
...@@ -102,7 +102,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem ...@@ -102,7 +102,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = Traits::kIsPagedKV; static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; static constexpr bool kHasUnevenSplits = Traits::kHasUnevenSplits;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment