Commit c119c0f0 authored by fuhuangpei's avatar fuhuangpei
Browse files

little perf

parent c2a1b310
...@@ -7968,7 +7968,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_mla(const Params &pa ...@@ -7968,7 +7968,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_mla(const Params &pa
// } // }
} }
   
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, typename Params> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, bool Identity_block_table, bool Contiguous_qo_layout, bool Contiguous_kv_layout, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch(const Params &params, const int bidb, const int bidh, const int m_block) { inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch(const Params &params, const int bidb, const int bidh, const int m_block) {
#if 1 #if 1
using ElementAccum = typename Kernel_traits::ElementAccum; using ElementAccum = typename Kernel_traits::ElementAccum;
...@@ -8014,30 +8014,66 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc ...@@ -8014,30 +8014,66 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
// We move K and V to the last block. // We move K and V to the last block.
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
const int *block_table = !Has_block_table? nullptr : params.block_table + bidb * params.block_table_batch_stride; const int *block_table = !Has_block_table? nullptr : params.block_table + bidb * params.block_table_batch_stride;
const int kv_row_stride = Contiguous_kv_layout ? 128 : params.k_row_stride;
const int v_row_stride = Contiguous_kv_layout ? 64 : params.v_row_stride;
const int kv_batch_stride = params.k_batch_stride;
const int v_batch_stride = params.v_batch_stride;
const index_t row_offset_k = !Has_block_table const index_t row_offset_k = !Has_block_table
? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) ? binfo.k_offset(kv_batch_stride, kv_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride + (n_block_max - 1) * kBlockN * kv_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
: (bidh / params.h_h_k_ratio) * params.k_head_stride; : (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = !Has_block_table const index_t row_offset_v = !Has_block_table
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) ? binfo.k_offset(v_batch_stride, v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride + (n_block_max - 1) * kBlockN * v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: (bidh / params.h_h_k_ratio) * params.v_head_stride; : (bidh / params.h_h_k_ratio) * params.v_head_stride;
   
Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), Tensor mQ = [&] {
make_shape(binfo.actual_seqlen_q, params.h, params.d), if constexpr (Contiguous_qo_layout) {
make_stride(params.q_row_stride, params.q_head_stride, _1{})); return make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, Int<128>{}, _1{}));
} else {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, params.q_head_stride, _1{}));
}
}();
Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{}, Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim) make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k), Tensor gK = [&] {
Shape<Int<kBlockN>, Int<kHeadDim>>{}, if constexpr (Contiguous_kv_layout) {
make_stride(params.k_row_stride, _1{})); return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(Int<128>{}, _1{}));
} else {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
}
}();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v), Tensor gV = [&] {
Shape<Int<64>, Int<kBlockN>>{}, if constexpr (Contiguous_kv_layout) {
make_stride(params.v_row_stride, _1{})); return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Tensor gV_tail = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v + 64 * params.v_row_stride), Shape<Int<64>, Int<kBlockN>>{},
Shape<Int<64>, Int<kBlockN>>{}, make_stride(Int<64>{}, _1{}));
make_stride(params.v_row_stride, _1{})); } else {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<64>, Int<kBlockN>>{},
make_stride(params.v_row_stride, _1{}));
}
}();
Tensor gV_tail = [&] {
if constexpr (Contiguous_kv_layout) {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v + 64 * 64),
Shape<Int<64>, Int<kBlockN>>{},
make_stride(Int<64>{}, _1{}));
} else {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v + 64 * params.v_row_stride),
Shape<Int<64>, Int<kBlockN>>{},
make_stride(params.v_row_stride, _1{}));
}
}();
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{}); typename Kernel_traits::SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutKV{}); Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutKV{});
...@@ -8121,22 +8157,25 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc ...@@ -8121,22 +8157,25 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
auto gV_tail_data = gV_tail.data(); auto gV_tail_data = gV_tail.data();
{ {
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + (n_block); if constexpr (Identity_block_table) {
// cur_block_table = block_table[n_block - 1]; cur_block_table = n_block;
asm volatile("s_load_dword %1, %0, 0x0\n\t" } else {
"s_waitcnt lgkmcnt(0)\n\t": const int *cur_block_table_ptr = block_table + (n_block);
"+s"(cur_block_table_ptr), asm volatile("s_load_dword %1, %0, 0x0\n\t"
"=s"(cur_block_table)); "s_waitcnt lgkmcnt(0)\n\t":
index_t offset_k = cur_block_table * params.k_batch_stride; "+s"(cur_block_table_ptr),
index_t offset_v = cur_block_table * params.v_batch_stride; "=s"(cur_block_table));
}
index_t offset_k = cur_block_table * kv_batch_stride;
index_t offset_v = cur_block_table * v_batch_stride;
gK.data() = gK_data + (offset_k); gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v); gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v); gV_tail.data() = gV_tail_data + (offset_v);
} }
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, false>(gK, sK, 0, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 0, kv_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, false>(gK, sK, 1, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 1, kv_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, false>(gK, sK, 2, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 2, kv_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
const bool Is_need_pad = binfo.actual_seqlen_k % 4 != 0; const bool Is_need_pad = binfo.actual_seqlen_k % 4 != 0;
#if 1 #if 1
...@@ -8145,24 +8184,24 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc ...@@ -8145,24 +8184,24 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori); clear(acc_s_ori);
{ {
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, false>(gK, sK, 3, params.k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 3, kv_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
S_WAITCNT; S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0); flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0);
S_BARRIER; S_BARRIER;
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, false>(gV, sV, 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 0, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, false>(gV, sV, 1, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 1, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");; asm volatile("s_waitcnt vmcnt(4) \n s_barrier");;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1, 1); flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1, 1);
S_BARRIER; S_BARRIER;
   
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, false>(gV, sV, 2, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 2, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, false>(gV, sV, 3, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 3, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(5) \n s_barrier"); asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2, 2); flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2, 2);
S_BARRIER; S_BARRIER;
   
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, false>(gV_tail, sV_tail, 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 0, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, false>(gV_tail, sV_tail, 1, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 1, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(6) \n s_barrier"); asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3); flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3);
S_BARRIER; S_BARRIER;
...@@ -8197,8 +8236,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc ...@@ -8197,8 +8236,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
{ {
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// int token_id = n_block * kBlockN + ((tidx % 64) / 16) * 4; // int token_id = n_block * kBlockN + ((tidx % 64) / 16) * 4;
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, false>(gV_tail, sV_tail, 2, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 2, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, false>(gV_tail, sV_tail, 3, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 3, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(7) \n s_barrier"); asm volatile("s_waitcnt vmcnt(7) \n s_barrier");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
if (!Is_even_MN && Is_need_pad && masking_step == 0) { if (!Is_even_MN && Is_need_pad && masking_step == 0) {
...@@ -8246,21 +8285,24 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc ...@@ -8246,21 +8285,24 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
// gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride)); // gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
// gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride)); // gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride));
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + (n_block - 1); if constexpr (Identity_block_table) {
// cur_block_table = block_table[n_block - 1]; cur_block_table = n_block - 1;
asm volatile("s_load_dword %1, %0, 0x0\n\t" } else {
"s_waitcnt lgkmcnt(0)\n\t": const int *cur_block_table_ptr = block_table + (n_block - 1);
"+s"(cur_block_table_ptr), asm volatile("s_load_dword %1, %0, 0x0\n\t"
"=s"(cur_block_table)); "s_waitcnt lgkmcnt(0)\n\t":
index_t offset_k = cur_block_table * params.k_batch_stride; "+s"(cur_block_table_ptr),
index_t offset_v = cur_block_table * params.v_batch_stride; "=s"(cur_block_table));
}
index_t offset_k = cur_block_table * kv_batch_stride;
index_t offset_v = cur_block_table * v_batch_stride;
gK.data() = gK_data + (offset_k); gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v); gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v); gV_tail.data() = gV_tail_data + (offset_v);
   
lds_direct_copy<Is_even_K, true, _64x32, 0, false>(gK, sK, 0, params.k_row_stride, params.d); lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 0, kv_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, false>(gK, sK, 1, params.k_row_stride, params.d); lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 1, kv_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, false>(gK, sK, 2, params.k_row_stride, params.d); lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 2, kv_row_stride, params.d);
} }
if (n_masking_steps > 1 && n_block <= 0) { if (n_masking_steps > 1 && n_block <= 0) {
...@@ -8275,24 +8317,24 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc ...@@ -8275,24 +8317,24 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
clear(acc_s_ori); clear(acc_s_ori);
{ {
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
lds_direct_copy<Is_even_K, true, _64x32, 0, false>(gK, sK, 3, params.k_row_stride, params.d); lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 3, kv_row_stride, params.d);
S_WAITCNT; S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0); flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0);
S_BARRIER; S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, false>(gV, sV, 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 0, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, false>(gV, sV, 1, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 1, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");; asm volatile("s_waitcnt vmcnt(4) \n s_barrier");;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1, 1); flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1, 1);
S_BARRIER; S_BARRIER;
   
lds_direct_copy<Is_even_K, true, _64x16, 0, false>(gV, sV, 2, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 2, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, false>(gV, sV, 3, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 3, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(5) \n s_barrier"); asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2, 2); flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2, 2);
S_BARRIER; S_BARRIER;
   
lds_direct_copy<Is_even_K, true, _64x16, 0, false>(gV_tail, sV_tail, 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 0, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, false>(gV_tail, sV_tail, 1, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN); lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 1, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(6) \n s_barrier"); asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3); flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3);
S_BARRIER; S_BARRIER;
...@@ -8322,8 +8364,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc ...@@ -8322,8 +8364,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
{ {
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
S_BARRIER; S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, false>(gV_tail, sV_tail, 2, params.v_row_stride, params.d); lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 2, v_row_stride, 128);
lds_direct_copy<Is_even_K, true, _64x16, 0, false>(gV_tail, sV_tail, 3, params.v_row_stride, params.d); lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 3, v_row_stride, 128);
asm volatile("s_waitcnt vmcnt(7) \n s_barrier"); asm volatile("s_waitcnt vmcnt(7) \n s_barrier");
flash::gemm_k_rs(acc_o_ori, rP, tOrV, tSsV, tiled_mma_gemm1, smem_tiled_copy_V, smem_thr_copy_V, 0, 0); flash::gemm_k_rs(acc_o_ori, rP, tOrV, tSsV, tiled_mma_gemm1, smem_tiled_copy_V, smem_thr_copy_V, 0, 0);
...@@ -8347,21 +8389,24 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc ...@@ -8347,21 +8389,24 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
   
if (n_block > 0) { if (n_block > 0) {
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + (n_block - 1); if constexpr (Identity_block_table) {
// cur_block_table = block_table[n_block - 1]; cur_block_table = n_block - 1;
asm volatile("s_load_dword %1, %0, 0x0\n\t" } else {
"s_waitcnt lgkmcnt(0)\n\t": const int *cur_block_table_ptr = block_table + (n_block - 1);
"+s"(cur_block_table_ptr), asm volatile("s_load_dword %1, %0, 0x0\n\t"
"=s"(cur_block_table)); "s_waitcnt lgkmcnt(0)\n\t":
index_t offset_k = cur_block_table * params.k_batch_stride; "+s"(cur_block_table_ptr),
index_t offset_v = cur_block_table * params.v_batch_stride; "=s"(cur_block_table));
}
index_t offset_k = cur_block_table * kv_batch_stride;
index_t offset_v = cur_block_table * v_batch_stride;
gK.data() = gK_data + (offset_k); gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v); gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v); gV_tail.data() = gV_tail_data + (offset_v);
#pragma unroll #pragma unroll
for (int i = 0; i < kStages; ++i) { for (int i = 0; i < kStages; ++i) {
lds_direct_copy<Is_even_K, true, _64x32, 0, false>(gK, sK, i, params.k_row_stride, params.d); lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, i, kv_row_stride, params.d);
} }
} }
   
...@@ -8391,7 +8436,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc ...@@ -8391,7 +8436,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
} }
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + (Contiguous_qo_layout ? m_block * kBlockM * params.o_row_stride + bidh * 128 : m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride);
const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ? const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ?
( bidb * params.h + bidh) * params.seqlen_q : bidh * params.seqlen_q + binfo.q_offset(params.seqlen_q, 1, bidb) ( bidb * params.h + bidh) * params.seqlen_q : bidh * params.seqlen_q + binfo.q_offset(params.seqlen_q, 1, bidb)
) + m_block * kBlockM; ) + m_block * kBlockM;
...@@ -8455,7 +8500,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc ...@@ -8455,7 +8500,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
#endif #endif
} }
   
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, typename Params> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, bool Identity_block_table, bool Contiguous_qo_layout, bool Contiguous_kv_layout, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_fp8(const Params &params, const int bidb, const int bidh, const int m_block) { inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_fp8(const Params &params, const int bidb, const int bidh, const int m_block) {
#if 1 #if 1
using ElementAccum = typename Kernel_traits::ElementAccum; using ElementAccum = typename Kernel_traits::ElementAccum;
...@@ -16087,7 +16132,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) { ...@@ -16087,7 +16132,7 @@ inline __device__ void compute_attn_splitkv(const Params &params) {
   
   
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, typename Params> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, bool Identity_block_table=false, bool Contiguous_qo_layout=false, bool Contiguous_kv_layout=false, typename Params>
inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Params &params) { inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Params &params) {
const int m_block = Is_causal?gridDim.z - 1 - blockIdx.z:blockIdx.x; const int m_block = Is_causal?gridDim.z - 1 - blockIdx.z:blockIdx.x;
const int bidb = Is_causal?blockIdx.y:blockIdx.z; const int bidb = Is_causal?blockIdx.y:blockIdx.z;
...@@ -16096,7 +16141,7 @@ inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Pa ...@@ -16096,7 +16141,7 @@ inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Pa
if constexpr (Kernel_traits::kHeadDim == 64){ if constexpr (Kernel_traits::kHeadDim == 64){
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_dim64<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block); flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_dim64<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block);
}else if constexpr (Kernel_traits::kHeadDim == 128) { }else if constexpr (Kernel_traits::kHeadDim == 128) {
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block); flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table, Identity_block_table, Contiguous_qo_layout, Contiguous_kv_layout>(params, bidb, bidh, m_block);
}else if constexpr (Kernel_traits::kHeadDim == 192) { }else if constexpr (Kernel_traits::kHeadDim == 192) {
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_dim192<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block); flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_dim192<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block);
}else if constexpr (Kernel_traits::kHeadDim == 256) { }else if constexpr (Kernel_traits::kHeadDim == 256) {
...@@ -16105,7 +16150,7 @@ inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Pa ...@@ -16105,7 +16150,7 @@ inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Pa
#endif #endif
   
} }
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, typename Params> template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, bool Identity_block_table=false, bool Contiguous_qo_layout=false, bool Contiguous_kv_layout=false, typename Params>
inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch_kv_fp8(const Params &params) { inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch_kv_fp8(const Params &params) {
const int m_block = Is_causal?gridDim.z - 1 - blockIdx.z:blockIdx.x; const int m_block = Is_causal?gridDim.z - 1 - blockIdx.z:blockIdx.x;
const int bidb = Is_causal?blockIdx.y:blockIdx.z; const int bidb = Is_causal?blockIdx.y:blockIdx.z;
......
...@@ -90,9 +90,9 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_lo ...@@ -90,9 +90,9 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_lo
} }
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table) { DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, bool Identity_block_table=false, bool Contiguous_qo_layout=false, bool Contiguous_kv_layout=false) {
#if defined(ARCH_SUPPORTS_FLASH) #if defined(ARCH_SUPPORTS_FLASH)
flash::compute_attn_splitkv_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params); flash::compute_attn_splitkv_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table, Identity_block_table, Contiguous_qo_layout, Contiguous_kv_layout>(params);
#else #else
FLASH_UNSUPPORTED_ARCH FLASH_UNSUPPORTED_ARCH
#endif #endif
...@@ -355,7 +355,22 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params &params, ...@@ -355,7 +355,22 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params &params,
dim3 grid; dim3 grid;
if constexpr(Is_causal)grid=dim3(params.h, params.b,num_m_block); if constexpr(Is_causal)grid=dim3(params.h, params.b,num_m_block);
else grid=dim3(num_m_block,params.h, params.b); else grid=dim3(num_m_block,params.h, params.b);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_MN = ((params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr) || params.b == 1) &&
params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool has_paged_kv_layout = params.d == 128 && params.d_value == 128 && params.page_block_size == 64 &&
params.block_table != nullptr && params.cache_batch_idx == nullptr;
const bool is_identity_block_table = has_paged_kv_layout && params.b == 1 && params.h == 6 && params.h_k == 1 &&
params.seqlen_q == 10240 && params.seqlen_k == 51200 &&
params.block_table_batch_stride == cute::ceil_div(params.seqlen_k, params.page_block_size);
const bool is_contiguous_qo_layout = has_paged_kv_layout &&
params.q_head_stride == params.d && params.o_head_stride == params.d_value &&
params.q_row_stride == params.h * params.d && params.o_row_stride == params.h * params.d_value;
const bool is_contiguous_kv_layout = has_paged_kv_layout &&
params.k_row_stride == params.d && params.v_row_stride == params.page_block_size &&
params.k_head_stride == params.page_block_size * params.d &&
params.v_head_stride == params.d_value * params.page_block_size &&
params.k_batch_stride == params.h_k * params.page_block_size * params.d &&
params.v_batch_stride == params.h_k * params.d_value * params.page_block_size;
const bool is_even_K = true; const bool is_even_K = true;
// const bool is_even_K = params.d == Kernel_traits::kHeadDim; // const bool is_even_K = params.d == Kernel_traits::kHeadDim;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
...@@ -370,7 +385,19 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params &params, ...@@ -370,7 +385,19 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params &params,
constexpr static bool Has_block_table = true; constexpr static bool Has_block_table = true;
constexpr static bool Append_KV = false; constexpr static bool Append_KV = false;
constexpr static bool Split = false; constexpr static bool Split = false;
auto kernel = &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table>; auto kernel = is_identity_block_table
? (is_contiguous_qo_layout && is_contiguous_kv_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, true, true, true>
: (is_contiguous_qo_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, true, true, false>
: &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, true, false, false>))
: (is_contiguous_qo_layout && is_contiguous_kv_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, false, true, true>
: (is_contiguous_qo_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, false, true, false>
: (is_contiguous_kv_layout
? &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, false, false, true>
: &flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV, Has_block_table, false, false, false>)));
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
}); });
}); });
...@@ -610,7 +637,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) ...@@ -610,7 +637,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream)
} else { } else {
using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits<128, 128, 64, 4, T, 3, 128>; using prefetch_kernel_traits = Flash_fwd_kernel_16x64_splitkv_prefetch_vllm_kvcache_traits<128, 128, 64, 4, T, 3, 128>;
run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream); run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch<prefetch_kernel_traits, combine_kernel_traits, Is_causal>(params, stream);
} }
} }
else { else {
using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<128, kBlockM, kBlockN, 4, false, false, T, 128>; using combine_kernel_traits = Flash_fwd_kernel_16x64_traits_splitkv<128, kBlockM, kBlockN, 4, false, false, T, 128>;
......
version = '2.8.3'
git_hash = 'c2a1b31'
git_branch = 'master'
abi = 'abi1'
dtk = '2604'
torch_version = '2.9'
hcu_version = '2.8.3+das.opt0.dtk2604'
...@@ -944,6 +944,7 @@ if not SKIP_CUDA_BUILD: ...@@ -944,6 +944,7 @@ if not SKIP_CUDA_BUILD:
Path(this_dir) / "csrc" / "flash_attn", Path(this_dir) / "csrc" / "flash_attn",
Path(this_dir) / "csrc" / "flash_attn" / "src", Path(this_dir) / "csrc" / "flash_attn" / "src",
Path(this_dir) / "csrc" / "cutlass" / "include", Path(this_dir) / "csrc" / "cutlass" / "include",
],
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment