Commit 484743c6 authored by fuhuangpei's avatar fuhuangpei
Browse files

del per compile params

parent c119c0f0
......@@ -7968,7 +7968,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_mla(const Params &pa
// }
}
 
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, bool Identity_block_table, bool Contiguous_qo_layout, bool Contiguous_kv_layout, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch(const Params &params, const int bidb, const int bidh, const int m_block) {
#if 1
using ElementAccum = typename Kernel_traits::ElementAccum;
......@@ -8014,66 +8014,30 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
// We move K and V to the last block.
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
const int *block_table = !Has_block_table? nullptr : params.block_table + bidb * params.block_table_batch_stride;
const int kv_row_stride = Contiguous_kv_layout ? 128 : params.k_row_stride;
const int v_row_stride = Contiguous_kv_layout ? 64 : params.v_row_stride;
const int kv_batch_stride = params.k_batch_stride;
const int v_batch_stride = params.v_batch_stride;
const index_t row_offset_k = !Has_block_table
? binfo.k_offset(kv_batch_stride, kv_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * kv_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
: (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = !Has_block_table
? binfo.k_offset(v_batch_stride, v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: (bidh / params.h_h_k_ratio) * params.v_head_stride;
 
Tensor mQ = [&] {
if constexpr (Contiguous_qo_layout) {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, Int<128>{}, _1{}));
} else {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, params.q_head_stride, _1{}));
}
}();
Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, params.q_head_stride, _1{}));
Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gK = [&] {
if constexpr (Contiguous_kv_layout) {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(Int<128>{}, _1{}));
} else {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
}
}();
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
Tensor gV = [&] {
if constexpr (Contiguous_kv_layout) {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<64>, Int<kBlockN>>{},
make_stride(Int<64>{}, _1{}));
} else {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<64>, Int<kBlockN>>{},
make_stride(params.v_row_stride, _1{}));
}
}();
Tensor gV_tail = [&] {
if constexpr (Contiguous_kv_layout) {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v + 64 * 64),
Shape<Int<64>, Int<kBlockN>>{},
make_stride(Int<64>{}, _1{}));
} else {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v + 64 * params.v_row_stride),
Shape<Int<64>, Int<kBlockN>>{},
make_stride(params.v_row_stride, _1{}));
}
}();
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<64>, Int<kBlockN>>{},
make_stride(params.v_row_stride, _1{}));
Tensor gV_tail = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v + 64 * params.v_row_stride),
Shape<Int<64>, Int<kBlockN>>{},
make_stride(params.v_row_stride, _1{}));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutKV{});
......@@ -8157,25 +8121,21 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
auto gV_tail_data = gV_tail.data();
{
int cur_block_table;
if constexpr (Identity_block_table) {
cur_block_table = n_block;
} else {
const int *cur_block_table_ptr = block_table + (n_block);
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
}
index_t offset_k = cur_block_table * kv_batch_stride;
index_t offset_v = cur_block_table * v_batch_stride;
const int *cur_block_table_ptr = block_table + (n_block);
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
}
__builtin_amdgcn_sched_barrier(0);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 0, kv_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 1, kv_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 2, kv_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 0, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 1, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 2, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
__builtin_amdgcn_sched_barrier(0);
const bool Is_need_pad = binfo.actual_seqlen_k % 4 != 0;
#if 1
......@@ -8184,24 +8144,24 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori);
{
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 3, kv_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 3, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0);
S_BARRIER;
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 0, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 1, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 0, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 1, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1, 1);
S_BARRIER;
 
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 2, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 3, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 2, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 3, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2, 2);
S_BARRIER;
 
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 0, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 1, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 0, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 1, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3);
S_BARRIER;
......@@ -8236,8 +8196,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
{
__builtin_amdgcn_sched_barrier(0);
// int token_id = n_block * kBlockN + ((tidx % 64) / 16) * 4;
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 2, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 3, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 2, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 3, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(7) \n s_barrier");
__builtin_amdgcn_sched_barrier(0);
if (!Is_even_MN && Is_need_pad && masking_step == 0) {
......@@ -8285,24 +8245,20 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
// gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
// gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride));
int cur_block_table;
if constexpr (Identity_block_table) {
cur_block_table = n_block - 1;
} else {
const int *cur_block_table_ptr = block_table + (n_block - 1);
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
}
index_t offset_k = cur_block_table * kv_batch_stride;
index_t offset_v = cur_block_table * v_batch_stride;
const int *cur_block_table_ptr = block_table + (n_block - 1);
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
 
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 0, kv_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 1, kv_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 2, kv_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 0, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 1, params.k_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 2, params.k_row_stride, params.d);
}
if (n_masking_steps > 1 && n_block <= 0) {
......@@ -8317,24 +8273,24 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
clear(acc_s_ori);
{
__builtin_amdgcn_sched_barrier(0);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 3, kv_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 3, params.k_row_stride, params.d);
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0);
S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 0, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 1, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 0, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 1, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1, 1);
S_BARRIER;
 
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 2, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 3, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 2, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 3, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2, 2);
S_BARRIER;
 
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 0, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 1, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 0, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 1, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3);
S_BARRIER;
......@@ -8364,8 +8320,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
{
__builtin_amdgcn_sched_barrier(0);
S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 2, v_row_stride, 128);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 3, v_row_stride, 128);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 2, params.v_row_stride, 128);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 3, params.v_row_stride, 128);
asm volatile("s_waitcnt vmcnt(7) \n s_barrier");
flash::gemm_k_rs(acc_o_ori, rP, tOrV, tSsV, tiled_mma_gemm1, smem_tiled_copy_V, smem_thr_copy_V, 0, 0);
......@@ -8389,24 +8345,20 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
 
if (n_block > 0) {
int cur_block_table;
if constexpr (Identity_block_table) {
cur_block_table = n_block - 1;
} else {
const int *cur_block_table_ptr = block_table + (n_block - 1);
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
}
index_t offset_k = cur_block_table * kv_batch_stride;
index_t offset_v = cur_block_table * v_batch_stride;
const int *cur_block_table_ptr = block_table + (n_block - 1);
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr),
"=s"(cur_block_table));
index_t offset_k = cur_block_table * params.k_batch_stride;
index_t offset_v = cur_block_table * params.v_batch_stride;
gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
#pragma unroll
for (int i = 0; i < kStages; ++i) {
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, i, kv_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, i, params.k_row_stride, params.d);
}
}
 
......@@ -8436,7 +8388,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
}
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ (Contiguous_qo_layout ? m_block * kBlockM * params.o_row_stride + bidh * 128 : m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride);
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ?
( bidb * params.h + bidh) * params.seqlen_q : bidh * params.seqlen_q + binfo.q_offset(params.seqlen_q, 1, bidb)
) + m_block * kBlockM;
......@@ -8500,7 +8452,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
#endif
}
 
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, bool Identity_block_table, bool Contiguous_qo_layout, bool Contiguous_kv_layout, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_fp8(const Params &params, const int bidb, const int bidh, const int m_block) {
#if 1
using ElementAccum = typename Kernel_traits::ElementAccum;
......@@ -16132,7 +16084,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) {
 
 
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, bool Identity_block_table=false, bool Contiguous_qo_layout=false, bool Contiguous_kv_layout=false, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, typename Params>
inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Params &params) {
const int m_block = Is_causal?gridDim.z - 1 - blockIdx.z:blockIdx.x;
const int bidb = Is_causal?blockIdx.y:blockIdx.z;
......@@ -16141,7 +16093,7 @@ inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Pa
if constexpr (Kernel_traits::kHeadDim == 64){
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_dim64<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block);
}else if constexpr (Kernel_traits::kHeadDim == 128) {
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table, Identity_block_table, Contiguous_qo_layout, Contiguous_kv_layout>(params, bidb, bidh, m_block);
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block);
}else if constexpr (Kernel_traits::kHeadDim == 192) {
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_dim192<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block);
}else if constexpr (Kernel_traits::kHeadDim == 256) {
......@@ -16150,7 +16102,7 @@ inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Pa
#endif
 
}
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, bool Identity_block_table=false, bool Contiguous_qo_layout=false, bool Contiguous_kv_layout=false, typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, typename Params>
inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch_kv_fp8(const Params &params) {
const int m_block = Is_causal?gridDim.z - 1 - blockIdx.z:blockIdx.x;
const int bidb = Is_causal?blockIdx.y:blockIdx.z;
......
......@@ -90,9 +90,9 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_lo
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, bool Identity_block_table=false, bool Contiguous_qo_layout=false, bool Contiguous_kv_layout=false) {
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table) {
#if defined(ARCH_SUPPORTS_FLASH)
flash::compute_attn_splitkv_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table, Identity_block_table, Contiguous_qo_layout, Contiguous_kv_layout>(params);
flash::compute_attn_splitkv_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
......@@ -355,22 +355,7 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params &params,
dim3 grid;
if constexpr(Is_causal)grid=dim3(params.h, params.b,num_m_block);
else grid=dim3(num_m_block,params.h, params.b);
const bool is_even_MN = ((params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr) || params.b == 1) &&
params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool has_paged_kv_layout = params.d == 128 && params.d_value == 128 && params.page_block_size == 64 &&
params.block_table != nullptr && params.cache_batch_idx == nullptr;
const bool is_identity_block_table = has_paged_kv_layout && params.b == 1 && params.h == 6 && params.h_k == 1 &&
params.seqlen_q == 10240 && params.seqlen_k == 51200 &&
params.block_table_batch_stride == cute::ceil_div(params.seqlen_k, params.page_block_size);
const bool is_contiguous_qo_layout = has_paged_kv_layout &&
params.q_head_stride == params.d && params.o_head_stride == params.d_value &&
params.q_row_stride == params.h * params.d && params.o_row_stride == params.h * params.d_value;
const bool is_contiguous_kv_layout = has_paged_kv_layout &&
params.k_row_stride == params.d && params.v_row_stride == params.page_block_size &&
params.k_head_stride == params.page_block_size * params.d &&
params.v_head_stride == params.d_value * params.page_block_size &&
params.k_batch_stride == params.h_k * params.page_block_size * params.d &&
params.v_batch_stride == params.h_k * params.d_value * params.page_block_size;
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = true;
// const bool is_even_K = params.d == Kernel_traits::kHeadDim;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
......@@ -385,19 +370,7 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params &params,
constexpr static bool Has_block_table = true;
constexpr static bool Append_KV = false;
constexpr static bool Split = false;
auto kernel = is_identity_block_table
? (is_contiguous_qo_layout && is_contiguous_kv_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, true, true, true>
: (is_contiguous_qo_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, true, true, false>
: &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, true, false, false>))
: (is_contiguous_qo_layout && is_contiguous_kv_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, false, true, true>
: (is_contiguous_qo_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, false, true, false>
: (is_contiguous_kv_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, false, false, true>
: &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, false, false, false>)));
auto kernel = &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment