Unverified Commit fb1ccfa9 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

[CK_TILE] Add paged-kvcache support in group mode fmha fwd splitkv kernels (#1678)

* Generate group mode paged-attn kernel

* Enable paged-kvcache + group mode support

* Add missing header: fused_moe.hpp

* Add comment to explain kernel arg usage

* Make error message more clear

* Add comment for confusing data member names

* Add more comment for confusing variable names

* Fix typo in option description
parent 6916d8cc
...@@ -655,9 +655,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -655,9 +655,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if pipeline.F_spad != 't' or pipeline.F_skpad != 't': if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue continue
if pipeline.F_pagedkv == 't':
# we only use batch mode kernels to handle (paged-) kvcache problems
continue
k = Kernel(F_idx=0, k = Kernel(F_idx=0,
F_hdim=hdim, F_hdim=hdim,
F_dtype=dtype, F_dtype=dtype,
......
...@@ -62,7 +62,7 @@ auto create_args(int argc, char* argv[]) ...@@ -62,7 +62,7 @@ auto create_args(int argc, char* argv[])
"-1 to choose s_knew in [1, s] randomly.") "-1 to choose s_knew in [1, s] randomly.")
.insert("s_kpad", .insert("s_kpad",
"-1", "-1",
"seqlen_k stride between 2 tokens, currently used in group-mode only\n" "seqlen_k stride between 2 batches, currently used in group-mode only\n"
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
"along seqlen, instead of packed. same as xformer kv_padding") "along seqlen, instead of packed. same as xformer kv_padding")
.insert("d", "128", "head dim for q, k") .insert("d", "128", "head dim for q, k")
...@@ -294,7 +294,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -294,7 +294,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
#if !CK_TILE_FMHA_FWD_APPENDKV_API #if !CK_TILE_FMHA_FWD_APPENDKV_API
if(seqlen_knew != 0) if(seqlen_knew != 0)
{ {
std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl; std::cerr << "fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option"
<< std::endl;
seqlen_knew = 0; seqlen_knew = 0;
} }
#endif #endif
...@@ -321,6 +322,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -321,6 +322,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
rotary_dim = 0; rotary_dim = 0;
} }
#endif #endif
// to use fmha_fwd_appendkv(), make sure it's in batch mode
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
if(need_append_kvcache && mode == mode_enum::group)
{
std::cerr << "fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option" << std::endl;
mode = mode_enum::batch;
}
if(!(rotary_dim <= hdim_q)) if(!(rotary_dim <= hdim_q))
{ {
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
...@@ -356,22 +364,26 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -356,22 +364,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< std::endl; << std::endl;
use_cache_batch_idx = false; use_cache_batch_idx = false;
} }
#endif #else
if(0 < page_block_size && use_cache_batch_idx) if(use_cache_batch_idx)
{ {
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " if(0 < page_block_size)
"'cache_batch_idx' option" {
<< std::endl; std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
use_cache_batch_idx = false; "'cache_batch_idx' option"
<< std::endl;
use_cache_batch_idx = false;
}
else if(mode == mode_enum::group)
{
std::cerr << "group mode will not use cache_batch_idx. ignoring the "
"'cache_batch_idx' option"
<< std::endl;
use_cache_batch_idx = false;
}
} }
// the input tensor layout for kvcache is same as batch mode #endif
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
if(use_kvcache && mode != mode_enum::batch)
{
std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl;
mode = mode_enum::batch;
}
auto [seqlen_qs, seqlen_ks, seqlen_kpads] = auto [seqlen_qs, seqlen_ks, seqlen_kpads] =
decode_seqlen(mode, decode_seqlen(mode,
...@@ -380,7 +392,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -380,7 +392,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
arg_parser.get_str("s_k"), arg_parser.get_str("s_k"),
arg_parser.get_str("s_kpad"), arg_parser.get_str("s_kpad"),
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
use_kvcache); need_append_kvcache);
// compute kvcache seqlen_k (before appending knew/vnew) // compute kvcache seqlen_k (before appending knew/vnew)
auto cache_seqlen_ks = seqlen_ks; auto cache_seqlen_ks = seqlen_ks;
std::transform(cache_seqlen_ks.begin(), std::transform(cache_seqlen_ks.begin(),
...@@ -741,8 +753,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -741,8 +753,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_buf( ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0); 0 <= seqlen_kpads[0]
? seqlen_ks.size() * sizeof(int32_t)
: 0);
ck_tile::DeviceMem cache_seqlen_k_buf( ck_tile::DeviceMem cache_seqlen_k_buf(
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
...@@ -763,7 +777,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -763,7 +777,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
: seqstart_k_with_padding_host.data()); : seqstart_k_with_padding_host.data());
seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
? seqlen_ks.data()
: nullptr);
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
rotary_cos_buf.ToDevice(rotary_cos_host.data()); rotary_cos_buf.ToDevice(rotary_cos_host.data());
rotary_sin_buf.ToDevice(rotary_sin_host.data()); rotary_sin_buf.ToDevice(rotary_sin_host.data());
...@@ -976,8 +992,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -976,8 +992,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr = args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr = args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
(use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr); ? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled)
args.max_seqlen_q = max_seqlen_q; args.max_seqlen_q = max_seqlen_q;
......
...@@ -173,8 +173,11 @@ struct fmha_fwd_splitkv_args ...@@ -173,8 +173,11 @@ struct fmha_fwd_splitkv_args
// seqlen_k = kargs.seqlen_k // seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// kvcache mode (use same kernel as batch mode): // batch mode (kvcache):
// seqlen_q = kargs.seqlen_q // seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k_ptr[b]
// group mode (kvcache):
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
const void* seqstart_k_ptr; const void* seqstart_k_ptr;
...@@ -251,7 +254,7 @@ struct fmha_fwd_appendkv_args ...@@ -251,7 +254,7 @@ struct fmha_fwd_appendkv_args
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
const void* cache_batch_idx; const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
ck_tile::index_t stride_q; ck_tile::index_t stride_q;
ck_tile::index_t stride_k; ck_tile::index_t stride_k;
...@@ -389,6 +392,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) ...@@ -389,6 +392,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.nhead_q, args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.num_splits, args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.scale_s, args.scale_s,
args.scale_p, args.scale_p,
args.stride_q, args.stride_q,
......
...@@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode, ...@@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode,
std::string k_val, std::string k_val,
std::string k_pad_val, std::string k_pad_val,
ck_tile::index_t seqlen_k_min = 0, ck_tile::index_t seqlen_k_min = 0,
bool use_kvcache = false, bool need_append_kvcache = false,
std::optional<unsigned> seed = std::nullopt) std::optional<unsigned> seed = std::nullopt)
{ {
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str())) #define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
...@@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode, ...@@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode,
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k); const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max); std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
if(1 < batch && use_kvcache) if(1 < batch && need_append_kvcache)
{ {
// to keep the original s_k value, we always use seqlen_k_max in first batch // to keep the original s_k value, we always use seqlen_k_max in first batch
randints(std::next(seqlen_ks.begin()), randints(std::next(seqlen_ks.begin()),
......
...@@ -46,8 +46,7 @@ struct FmhaFwdSplitKVKernel ...@@ -46,8 +46,7 @@ struct FmhaFwdSplitKVKernel
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static_assert(!kIsGroupMode || (kIsGroupMode && !kIsPagedKV),
"paged-kvcache only supported by batch mode kernels");
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>; using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kHasMask = FmhaMask::IsMasking;
...@@ -198,8 +197,10 @@ struct FmhaFwdSplitKVKernel ...@@ -198,8 +197,10 @@ struct FmhaFwdSplitKVKernel
const int32_t* seqlen_k_ptr; const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
ck_tile::index_t batch_stride_v; // single kcache page-block
ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
// single vcache page-block
ck_tile::index_t batch_stride_lse_acc; ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o_acc;
}; };
...@@ -212,14 +213,17 @@ struct FmhaFwdSplitKVKernel ...@@ -212,14 +213,17 @@ struct FmhaFwdSplitKVKernel
AlibiKargs, AlibiKargs,
EmptyKargs<0>>>, EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>, std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>> std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
std::conditional_t<kIsPagedKV, PageBlockTableKargs, EmptyKargs<3>>
{ {
const int32_t* seqstart_q_ptr; const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr; const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr; const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k; // only used for paged-kvcache ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
ck_tile::index_t batch_stride_v; // only used for paged-kvcache // for single kcache page-block
ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
// for single vcache page-block
}; };
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>; using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
...@@ -363,6 +367,9 @@ struct FmhaFwdSplitKVKernel ...@@ -363,6 +367,9 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t num_head_q, ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk, ck_tile::index_t nhead_ratio_qk,
ck_tile::index_t num_splits, ck_tile::index_t num_splits,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
float scale_s, float scale_s,
float scale_p, float scale_p,
ck_tile::index_t stride_q, ck_tile::index_t stride_q,
...@@ -416,6 +423,7 @@ struct FmhaFwdSplitKVKernel ...@@ -416,6 +423,7 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for bias {}, // placeholder for bias
{}, // placeholder for mask {}, // placeholder for mask
{}, // placeholder for fp8_static_quant args {}, // placeholder for fp8_static_quant args
{}, // placeholder for paged-block table
reinterpret_cast<const int32_t*>(seqstart_q_ptr), reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr), reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr), reinterpret_cast<const int32_t*>(seqlen_k_ptr),
...@@ -443,6 +451,12 @@ struct FmhaFwdSplitKVKernel ...@@ -443,6 +451,12 @@ struct FmhaFwdSplitKVKernel
{ {
kargs.scale_p = scale_p; kargs.scale_p = scale_p;
} }
if constexpr(kIsPagedKV)
{
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
kargs.batch_stride_block_table = batch_stride_block_table;
kargs.page_block_size = page_block_size;
}
return kargs; return kargs;
} }
...@@ -489,15 +503,22 @@ struct FmhaFwdSplitKVKernel ...@@ -489,15 +503,22 @@ struct FmhaFwdSplitKVKernel
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q; batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k; if constexpr(kIsPagedKV)
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
batch_offset_v = key_start * kargs.stride_v; batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
} }
else else
{ {
batch_offset_v = key_start; batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
}
else
{
batch_offset_v = key_start;
}
} }
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -685,7 +706,7 @@ struct FmhaFwdSplitKVKernel ...@@ -685,7 +706,7 @@ struct FmhaFwdSplitKVKernel
return make_page_block_navigator<const KDataType, 0>( return make_page_block_navigator<const KDataType, 0>(
kargs.k_ptr, kargs.k_ptr,
kargs.batch_stride_k, kargs.batch_stride_k, // kcache page-block stride/size
fixed_offset, fixed_offset,
block_indices, block_indices,
num_blocks, num_blocks,
...@@ -715,7 +736,7 @@ struct FmhaFwdSplitKVKernel ...@@ -715,7 +736,7 @@ struct FmhaFwdSplitKVKernel
return make_page_block_navigator<const VDataType, 1>( return make_page_block_navigator<const VDataType, 1>(
kargs.v_ptr, kargs.v_ptr,
kargs.batch_stride_v, kargs.batch_stride_v, // vcache page-block stride/size
fixed_offset, fixed_offset,
block_indices, block_indices,
num_blocks, num_blocks,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment