Commit c5083c0f authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Merge branch 'feature/add-splitkv-instance' into...

Merge branch 'feature/add-splitkv-instance' into feature/support-vllm-kcache-layout-add-splitkv-instance
parents 0739bc5a 3f29f232
......@@ -47,7 +47,7 @@ using fmha_dtype_{F_idx} = {F_dtype};
using fmha_mask_{F_idx} = {F_mask};
namespace {{
template <bool kHasUnevenSplits, bool kIsMultipleSplits>
template <bool kIsMultipleSplits, bool kHasUnevenSplits = kIsMultipleSplits>
struct kernel_runner {{
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
......@@ -68,7 +68,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_lse},
{F_squant},
{F_pagedkv},
kHasUnevenSplits,
kIsMultipleSplits && kHasUnevenSplits,
{F_occupancy}>;
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
......@@ -131,12 +131,23 @@ using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F
template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
/// NOTICE: kHasUnevenSplits=false may be able to speed-up the batch mode kernel,
/// but we use kHasUnevenSplits=true here to reduce compilation time
if (1 < a.num_splits) {{
kernel_runner</*kHasUnevenSplits=*/true, /*kIsMultipleSplits=*/true>::run(s, a);
constexpr bool kIsMultipleSplits = true;
if constexpr({F_mode} == false) {{ // batch mode
// we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/false>::run(s, a);
}} else {{
kernel_runner</*kHasUnevenSplits=*/true, /*kIsMultipleSplits=*/false>::run(s, a);
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
}}
}} else {{ // group mode
kernel_runner<kIsMultipleSplits, /*kHasUnevenSplits=*/true>::run(s, a);
}}
}} else {{
kernel_runner</*kIsMultipleSplits=*/false>::run(s, a);
}}
}}
......@@ -658,10 +669,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if dtype in ['fp16', 'bf16']:
for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
# TODO: use async pipeline when compiler is more stable
if hdim == 256:
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
else:
......
......@@ -402,21 +402,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in first/last iteration without increasing code size
// only check in first/last iterations
if constexpr(kHasUnevenSplits)
{
if(1 < num_splits && (i_total_loops == 0 || i_total_loops == num_total_loop - 1))
{
const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
set_tile_if(
s_acc,
set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&,
physical_seqlen_k_start_ = physical_seqlen_k_start,
physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) {
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
if constexpr(kIsPagedKV)
{
return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col;
return col < physical_seqlen_k_start_ ||
physical_seqlen_k_end_ <= col;
}
else
{
......@@ -424,6 +427,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
});
}
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
......
......@@ -102,7 +102,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
static constexpr bool kHasUnevenSplits = Traits::kHasUnevenSplits;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment