Commit c119c0f0 authored by fuhuangpei's avatar fuhuangpei
Browse files

little perf

parent c2a1b310
This diff is collapsed.
......@@ -90,9 +90,9 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_lo
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table) {
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, bool Identity_block_table=false, bool Contiguous_qo_layout=false, bool Contiguous_kv_layout=false) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_attn_splitkv_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params);
flash::compute_attn_splitkv_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table, Identity_block_table, Contiguous_qo_layout, Contiguous_kv_layout>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
......@@ -355,7 +355,22 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params &params,
dim3 grid;
if constexpr(Is_causal)grid=dim3(params.h, params.b,num_m_block);
else grid=dim3(num_m_block,params.h, params.b);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_MN = ((params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr) || params.b == 1) &&
params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool has_paged_kv_layout = params.d == 128 && params.d_value == 128 && params.page_block_size == 64 &&
params.block_table != nullptr && params.cache_batch_idx == nullptr;
const bool is_identity_block_table = has_paged_kv_layout && params.b == 1 && params.h == 6 && params.h_k == 1 &&
params.seqlen_q == 10240 && params.seqlen_k == 51200 &&
params.block_table_batch_stride == cute::ceil_div(params.seqlen_k, params.page_block_size);
const bool is_contiguous_qo_layout = has_paged_kv_layout &&
params.q_head_stride == params.d && params.o_head_stride == params.d_value &&
params.q_row_stride == params.h * params.d && params.o_row_stride == params.h * params.d_value;
const bool is_contiguous_kv_layout = has_paged_kv_layout &&
params.k_row_stride == params.d && params.v_row_stride == params.page_block_size &&
params.k_head_stride == params.page_block_size * params.d &&
params.v_head_stride == params.d_value * params.page_block_size &&
params.k_batch_stride == params.h_k * params.page_block_size * params.d &&
params.v_batch_stride == params.h_k * params.d_value * params.page_block_size;
const bool is_even_K = true;
// const bool is_even_K = params.d == Kernel_traits::kHeadDim;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
......@@ -370,7 +385,19 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params &params,
constexpr static bool Has_block_table = true;
constexpr static bool Append_KV = false;
constexpr static bool Split = false;
auto kernel = &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table>;
auto kernel = is_identity_block_table
? (is_contiguous_qo_layout && is_contiguous_kv_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, true, true, true>
: (is_contiguous_qo_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, true, true, false>
: &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, true, false, false>))
: (is_contiguous_qo_layout && is_contiguous_kv_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, false, true, true>
: (is_contiguous_qo_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, false, true, false>
: (is_contiguous_kv_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, false, false, true>
: &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, false, false, false>)));
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
});
......
version = '2.8.3'
git_hash = 'c2a1b31'
git_branch = 'master'
abi = 'abi1'
dtk = '2604'
torch_version = '2.9'
hcu_version = '2.8.3+das.opt0.dtk2604'
......@@ -944,6 +944,7 @@ if not SKIP_CUDA_BUILD:
Path(this_dir) / "csrc" / "flash_attn",
Path(this_dir) / "csrc" / "flash_attn" / "src",
Path(this_dir) / "csrc" / "cutlass" / "include",
],
)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment