Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
fuhuangpei
flash-attention
Commits
484743c6
Commit
484743c6
authored
Jun 04, 2026
by
fuhuangpei
Browse files
del per compile params
parent
c119c0f0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
147 deletions
+72
-147
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+68
-116
csrc/flash_attn/src/flash_fwd_launch_template.h
csrc/flash_attn/src/flash_fwd_launch_template.h
+4
-31
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
484743c6
...
@@ -7968,7 +7968,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_mla(const Params &pa
...
@@ -7968,7 +7968,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_mla(const Params &pa
// }
// }
}
}
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table,
bool Identity_block_table, bool Contiguous_qo_layout, bool Contiguous_kv_layout,
typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
#if 1
#if 1
using ElementAccum = typename Kernel_traits::ElementAccum;
using ElementAccum = typename Kernel_traits::ElementAccum;
...
@@ -8014,66 +8014,30 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
...
@@ -8014,66 +8014,30 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
// We move K and V to the last block.
// We move K and V to the last block.
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
const int *block_table = !Has_block_table? nullptr : params.block_table + bidb * params.block_table_batch_stride;
const int *block_table = !Has_block_table? nullptr : params.block_table + bidb * params.block_table_batch_stride;
const int kv_row_stride = Contiguous_kv_layout ? 128 : params.k_row_stride;
const int v_row_stride = Contiguous_kv_layout ? 64 : params.v_row_stride;
const int kv_batch_stride = params.k_batch_stride;
const int v_batch_stride = params.v_batch_stride;
const index_t row_offset_k = !Has_block_table
const index_t row_offset_k = !Has_block_table
? binfo.k_offset(k
v
_batch_stride, k
v
_row_stride, bidb_cache)
? binfo.k_offset(
params.
k_batch_stride,
params.
k_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * k
v
_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
+ (n_block_max - 1) * kBlockN *
params.
k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
: (bidh / params.h_h_k_ratio) * params.k_head_stride;
: (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = !Has_block_table
const index_t row_offset_v = !Has_block_table
? binfo.k_offset(v_batch_stride, v_row_stride, bidb_cache)
? binfo.k_offset(
params.
v_batch_stride,
params.
v_row_stride, bidb_cache)
+ (n_block_max - 1) * kBlockN * v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
+ (n_block_max - 1) * kBlockN *
params.
v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
: (bidh / params.h_h_k_ratio) * params.v_head_stride;
: (bidh / params.h_h_k_ratio) * params.v_head_stride;
Tensor mQ = [&] {
Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
if constexpr (Contiguous_qo_layout) {
make_shape(binfo.actual_seqlen_q, params.h, params.d),
return make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_stride(params.q_row_stride, params.q_head_stride, _1{}));
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, Int<128>{}, _1{}));
} else {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, params.q_head_stride, _1{}));
}
}();
Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
make_coord(m_block, 0)); // (kBlockM, kHeadDim)
Tensor gK = [&] {
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
if constexpr (Contiguous_kv_layout) {
Shape<Int<kBlockN>, Int<kHeadDim>>{},
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
make_stride(params.k_row_stride, _1{}));
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(Int<128>{}, _1{}));
} else {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
}
}();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
Tensor gV = [&] {
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
if constexpr (Contiguous_kv_layout) {
Shape<Int<64>, Int<kBlockN>>{},
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
make_stride(params.v_row_stride, _1{}));
Shape<Int<64>, Int<kBlockN>>{},
Tensor gV_tail = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v + 64 * params.v_row_stride),
make_stride(Int<64>{}, _1{}));
Shape<Int<64>, Int<kBlockN>>{},
} else {
make_stride(params.v_row_stride, _1{}));
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<64>, Int<kBlockN>>{},
make_stride(params.v_row_stride, _1{}));
}
}();
Tensor gV_tail = [&] {
if constexpr (Contiguous_kv_layout) {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v + 64 * 64),
Shape<Int<64>, Int<kBlockN>>{},
make_stride(Int<64>{}, _1{}));
} else {
return make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v + 64 * params.v_row_stride),
Shape<Int<64>, Int<kBlockN>>{},
make_stride(params.v_row_stride, _1{}));
}
}();
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
typename Kernel_traits::SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutKV{});
Tensor sK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutKV{});
...
@@ -8157,25 +8121,21 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
...
@@ -8157,25 +8121,21 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
auto gV_tail_data = gV_tail.data();
auto gV_tail_data = gV_tail.data();
{
{
int cur_block_table;
int cur_block_table;
if constexpr (Identity_block_table) {
const int *cur_block_table_ptr = block_table + (n_block);
cur_block_table = n_block;
asm volatile("s_load_dword %1, %0, 0x0\n\t"
} else {
"s_waitcnt lgkmcnt(0)\n\t":
const int *cur_block_table_ptr = block_table + (n_block);
"+s"(cur_block_table_ptr),
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"=s"(cur_block_table));
"s_waitcnt lgkmcnt(0)\n\t":
index_t offset_k = cur_block_table * params.k_batch_stride;
"+s"(cur_block_table_ptr),
index_t offset_v = cur_block_table * params.v_batch_stride;
"=s"(cur_block_table));
}
index_t offset_k = cur_block_table * kv_batch_stride;
index_t offset_v = cur_block_table * v_batch_stride;
gK.data() = gK_data + (offset_k);
gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v);
gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
}
}
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(0);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 0, k
v
_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 0,
params.
k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 1, k
v
_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 1,
params.
k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 2, k
v
_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 2,
params.
k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(0);
const bool Is_need_pad = binfo.actual_seqlen_k % 4 != 0;
const bool Is_need_pad = binfo.actual_seqlen_k % 4 != 0;
#if 1
#if 1
...
@@ -8184,24 +8144,24 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
...
@@ -8184,24 +8144,24 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor acc_s_ori = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s_ori);
clear(acc_s_ori);
{
{
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 3, k
v
_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x32, 0, true>(gK, sK, 3,
params.
k_row_stride, params.d, binfo.actual_seqlen_k - n_block * kBlockN);
S_WAITCNT;
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0);
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0);
S_BARRIER;
S_BARRIER;
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 0, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 0,
params.
v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 1, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 1,
params.
v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");;
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1, 1);
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1, 1);
S_BARRIER;
S_BARRIER;
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 2, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 2,
params.
v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 3, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 3,
params.
v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2, 2);
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2, 2);
S_BARRIER;
S_BARRIER;
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 0, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 0,
params.
v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 1, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 1,
params.
v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3);
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3);
S_BARRIER;
S_BARRIER;
...
@@ -8236,8 +8196,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
...
@@ -8236,8 +8196,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
{
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(0);
// int token_id = n_block * kBlockN + ((tidx % 64) / 16) * 4;
// int token_id = n_block * kBlockN + ((tidx % 64) / 16) * 4;
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 2, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 2,
params.
v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 3, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 3,
params.
v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(7) \n s_barrier");
asm volatile("s_waitcnt vmcnt(7) \n s_barrier");
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(0);
if (!Is_even_MN && Is_need_pad && masking_step == 0) {
if (!Is_even_MN && Is_need_pad && masking_step == 0) {
...
@@ -8285,24 +8245,20 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
...
@@ -8285,24 +8245,20 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
// gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
// gK.data() = gK.data() + (-int(kBlockN * params.k_row_stride));
// gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride));
// gV.data() = gV.data() + (-int(kBlockN * params.v_row_stride));
int cur_block_table;
int cur_block_table;
if constexpr (Identity_block_table) {
const int *cur_block_table_ptr = block_table + (n_block - 1);
cur_block_table = n_block - 1;
asm volatile("s_load_dword %1, %0, 0x0\n\t"
} else {
"s_waitcnt lgkmcnt(0)\n\t":
const int *cur_block_table_ptr = block_table + (n_block - 1);
"+s"(cur_block_table_ptr),
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"=s"(cur_block_table));
"s_waitcnt lgkmcnt(0)\n\t":
index_t offset_k = cur_block_table * params.k_batch_stride;
"+s"(cur_block_table_ptr),
index_t offset_v = cur_block_table * params.v_batch_stride;
"=s"(cur_block_table));
}
index_t offset_k = cur_block_table * kv_batch_stride;
index_t offset_v = cur_block_table * v_batch_stride;
gK.data() = gK_data + (offset_k);
gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v);
gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 0, k
v
_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 0,
params.
k_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 1, k
v
_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 1,
params.
k_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 2, k
v
_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 2,
params.
k_row_stride, params.d);
}
}
if (n_masking_steps > 1 && n_block <= 0) {
if (n_masking_steps > 1 && n_block <= 0) {
...
@@ -8317,24 +8273,24 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
...
@@ -8317,24 +8273,24 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
clear(acc_s_ori);
clear(acc_s_ori);
{
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(0);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 3, k
v
_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, 3,
params.
k_row_stride, params.d);
S_WAITCNT;
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0);
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0);
S_BARRIER;
S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 0, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 0,
params.
v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 1, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 1,
params.
v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");;
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1, 1);
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1, 1);
S_BARRIER;
S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 2, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 2,
params.
v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 3, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 3,
params.
v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2, 2);
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2, 2);
S_BARRIER;
S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 0, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 0,
params.
v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 1, v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 1,
params.
v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3);
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3);
S_BARRIER;
S_BARRIER;
...
@@ -8364,8 +8320,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
...
@@ -8364,8 +8320,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
{
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(0);
S_BARRIER;
S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 2, v_row_stride, 128);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 2,
params.
v_row_stride, 128);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 3, v_row_stride, 128);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 3,
params.
v_row_stride, 128);
asm volatile("s_waitcnt vmcnt(7) \n s_barrier");
asm volatile("s_waitcnt vmcnt(7) \n s_barrier");
flash::gemm_k_rs(acc_o_ori, rP, tOrV, tSsV, tiled_mma_gemm1, smem_tiled_copy_V, smem_thr_copy_V, 0, 0);
flash::gemm_k_rs(acc_o_ori, rP, tOrV, tSsV, tiled_mma_gemm1, smem_tiled_copy_V, smem_thr_copy_V, 0, 0);
...
@@ -8389,24 +8345,20 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
...
@@ -8389,24 +8345,20 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
if (n_block > 0) {
if (n_block > 0) {
int cur_block_table;
int cur_block_table;
if constexpr (Identity_block_table) {
const int *cur_block_table_ptr = block_table + (n_block - 1);
cur_block_table = n_block - 1;
asm volatile("s_load_dword %1, %0, 0x0\n\t"
} else {
"s_waitcnt lgkmcnt(0)\n\t":
const int *cur_block_table_ptr = block_table + (n_block - 1);
"+s"(cur_block_table_ptr),
asm volatile("s_load_dword %1, %0, 0x0\n\t"
"=s"(cur_block_table));
"s_waitcnt lgkmcnt(0)\n\t":
index_t offset_k = cur_block_table * params.k_batch_stride;
"+s"(cur_block_table_ptr),
index_t offset_v = cur_block_table * params.v_batch_stride;
"=s"(cur_block_table));
}
index_t offset_k = cur_block_table * kv_batch_stride;
index_t offset_v = cur_block_table * v_batch_stride;
gK.data() = gK_data + (offset_k);
gK.data() = gK_data + (offset_k);
gV.data() = gV_data + (offset_v);
gV.data() = gV_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
gV_tail.data() = gV_tail_data + (offset_v);
#pragma unroll
#pragma unroll
for (int i = 0; i < kStages; ++i) {
for (int i = 0; i < kStages; ++i) {
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, i, k
v
_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x32, 0, true>(gK, sK, i,
params.
k_row_stride, params.d);
}
}
}
}
...
@@ -8436,7 +8388,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
...
@@ -8436,7 +8388,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
}
}
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+
(Contiguous_qo_layout ? m_block * kBlockM * params.o_row_stride + bidh * 128 :
m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride
)
;
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ?
const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ?
( bidb * params.h + bidh) * params.seqlen_q : bidh * params.seqlen_q + binfo.q_offset(params.seqlen_q, 1, bidb)
( bidb * params.h + bidh) * params.seqlen_q : bidh * params.seqlen_q + binfo.q_offset(params.seqlen_q, 1, bidb)
) + m_block * kBlockM;
) + m_block * kBlockM;
...
@@ -8500,7 +8452,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
...
@@ -8500,7 +8452,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
#endif
#endif
}
}
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table,
bool Identity_block_table, bool Contiguous_qo_layout, bool Contiguous_kv_layout,
typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, typename Params>
inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_fp8(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_fp8(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
#if 1
#if 1
using ElementAccum = typename Kernel_traits::ElementAccum;
using ElementAccum = typename Kernel_traits::ElementAccum;
...
@@ -16132,7 +16084,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
...
@@ -16132,7 +16084,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table,
bool Identity_block_table=false, bool Contiguous_qo_layout=false, bool Contiguous_kv_layout=false,
typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, typename Params>
inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Params ¶ms) {
inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Params ¶ms) {
const int m_block = Is_causal?gridDim.z - 1 - blockIdx.z:blockIdx.x;
const int m_block = Is_causal?gridDim.z - 1 - blockIdx.z:blockIdx.x;
const int bidb = Is_causal?blockIdx.y:blockIdx.z;
const int bidb = Is_causal?blockIdx.y:blockIdx.z;
...
@@ -16141,7 +16093,7 @@ inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Pa
...
@@ -16141,7 +16093,7 @@ inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Pa
if constexpr (Kernel_traits::kHeadDim == 64){
if constexpr (Kernel_traits::kHeadDim == 64){
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_dim64<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block);
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_dim64<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block);
}else if constexpr (Kernel_traits::kHeadDim == 128) {
}else if constexpr (Kernel_traits::kHeadDim == 128) {
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table
, Identity_block_table, Contiguous_qo_layout, Contiguous_kv_layout
>(params, bidb, bidh, m_block);
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block);
}else if constexpr (Kernel_traits::kHeadDim == 192) {
}else if constexpr (Kernel_traits::kHeadDim == 192) {
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_dim192<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block);
flash::compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetch_dim192<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV, Has_block_table>(params, bidb, bidh, m_block);
}else if constexpr (Kernel_traits::kHeadDim == 256) {
}else if constexpr (Kernel_traits::kHeadDim == 256) {
...
@@ -16150,7 +16102,7 @@ inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Pa
...
@@ -16150,7 +16102,7 @@ inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch(const Pa
#endif
#endif
}
}
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table,
bool Identity_block_table=false, bool Contiguous_qo_layout=false, bool Contiguous_kv_layout=false,
typename Params>
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, bool Has_block_table, typename Params>
inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch_kv_fp8(const Params ¶ms) {
inline __device__ void compute_attn_splitkv_16x64_vllm_kvcache_prefetch_kv_fp8(const Params ¶ms) {
const int m_block = Is_causal?gridDim.z - 1 - blockIdx.z:blockIdx.x;
const int m_block = Is_causal?gridDim.z - 1 - blockIdx.z:blockIdx.x;
const int bidb = Is_causal?blockIdx.y:blockIdx.z;
const int bidb = Is_causal?blockIdx.y:blockIdx.z;
...
...
csrc/flash_attn/src/flash_fwd_launch_template.h
View file @
484743c6
...
@@ -90,9 +90,9 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_lo
...
@@ -90,9 +90,9 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_lo
}
}
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Split
,
bool
Append_KV
,
bool
Has_block_table
,
bool
Identity_block_table
=
false
,
bool
Contiguous_qo_layout
=
false
,
bool
Contiguous_kv_layout
=
false
)
{
DEFINE_FLASH_FORWARD_KERNEL
(
flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch
,
bool
Is_causal
,
bool
Is_local
,
bool
Has_alibi
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_softcap
,
bool
Split
,
bool
Append_KV
,
bool
Has_block_table
)
{
#if defined(ARCH_SUPPORTS_FLASH)
#if defined(ARCH_SUPPORTS_FLASH)
flash
::
compute_attn_splitkv_16x64_vllm_kvcache_prefetch
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Split
,
Append_KV
,
Has_block_table
,
Identity_block_table
,
Contiguous_qo_layout
,
Contiguous_kv_layout
>
(
params
);
flash
::
compute_attn_splitkv_16x64_vllm_kvcache_prefetch
<
Kernel_traits
,
Is_causal
,
Is_local
,
Has_alibi
,
Is_even_MN
,
Is_even_K
,
Is_softcap
,
Split
,
Append_KV
,
Has_block_table
>
(
params
);
#else
#else
FLASH_UNSUPPORTED_ARCH
FLASH_UNSUPPORTED_ARCH
#endif
#endif
...
@@ -355,22 +355,7 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params ¶ms,
...
@@ -355,22 +355,7 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params ¶ms,
dim3
grid
;
dim3
grid
;
if
constexpr
(
Is_causal
)
grid
=
dim3
(
params
.
h
,
params
.
b
,
num_m_block
);
if
constexpr
(
Is_causal
)
grid
=
dim3
(
params
.
h
,
params
.
b
,
num_m_block
);
else
grid
=
dim3
(
num_m_block
,
params
.
h
,
params
.
b
);
else
grid
=
dim3
(
num_m_block
,
params
.
h
,
params
.
b
);
const
bool
is_even_MN
=
((
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
)
||
params
.
b
==
1
)
&&
const
bool
is_even_MN
=
params
.
cu_seqlens_q
==
nullptr
&&
params
.
cu_seqlens_k
==
nullptr
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
has_paged_kv_layout
=
params
.
d
==
128
&&
params
.
d_value
==
128
&&
params
.
page_block_size
==
64
&&
params
.
block_table
!=
nullptr
&&
params
.
cache_batch_idx
==
nullptr
;
const
bool
is_identity_block_table
=
has_paged_kv_layout
&&
params
.
b
==
1
&&
params
.
h
==
6
&&
params
.
h_k
==
1
&&
params
.
seqlen_q
==
10240
&&
params
.
seqlen_k
==
51200
&&
params
.
block_table_batch_stride
==
cute
::
ceil_div
(
params
.
seqlen_k
,
params
.
page_block_size
);
const
bool
is_contiguous_qo_layout
=
has_paged_kv_layout
&&
params
.
q_head_stride
==
params
.
d
&&
params
.
o_head_stride
==
params
.
d_value
&&
params
.
q_row_stride
==
params
.
h
*
params
.
d
&&
params
.
o_row_stride
==
params
.
h
*
params
.
d_value
;
const
bool
is_contiguous_kv_layout
=
has_paged_kv_layout
&&
params
.
k_row_stride
==
params
.
d
&&
params
.
v_row_stride
==
params
.
page_block_size
&&
params
.
k_head_stride
==
params
.
page_block_size
*
params
.
d
&&
params
.
v_head_stride
==
params
.
d_value
*
params
.
page_block_size
&&
params
.
k_batch_stride
==
params
.
h_k
*
params
.
page_block_size
*
params
.
d
&&
params
.
v_batch_stride
==
params
.
h_k
*
params
.
d_value
*
params
.
page_block_size
;
const
bool
is_even_K
=
true
;
const
bool
is_even_K
=
true
;
// const bool is_even_K = params.d == Kernel_traits::kHeadDim;
// const bool is_even_K = params.d == Kernel_traits::kHeadDim;
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
...
@@ -385,19 +370,7 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params ¶ms,
...
@@ -385,19 +370,7 @@ void run_flash_splitkv_fwd_16x64_vllm_kvcache_prefetch(Flash_fwd_params ¶ms,
constexpr
static
bool
Has_block_table
=
true
;
constexpr
static
bool
Has_block_table
=
true
;
constexpr
static
bool
Append_KV
=
false
;
constexpr
static
bool
Append_KV
=
false
;
constexpr
static
bool
Split
=
false
;
constexpr
static
bool
Split
=
false
;
auto
kernel
=
is_identity_block_table
auto
kernel
=
&
flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
Split
,
Append_KV
,
Has_block_table
>
;
?
(
is_contiguous_qo_layout
&&
is_contiguous_kv_layout
?
&
flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
Split
,
Append_KV
,
Has_block_table
,
true
,
true
,
true
>
:
(
is_contiguous_qo_layout
?
&
flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
Split
,
Append_KV
,
Has_block_table
,
true
,
true
,
false
>
:
&
flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
Split
,
Append_KV
,
Has_block_table
,
true
,
false
,
false
>
))
:
(
is_contiguous_qo_layout
&&
is_contiguous_kv_layout
?
&
flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
Split
,
Append_KV
,
Has_block_table
,
false
,
true
,
true
>
:
(
is_contiguous_qo_layout
?
&
flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
Split
,
Append_KV
,
Has_block_table
,
false
,
true
,
false
>
:
(
is_contiguous_kv_layout
?
&
flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
Split
,
Append_KV
,
Has_block_table
,
false
,
false
,
true
>
:
&
flash_fwd_splitkv_kernel_16x64_vllm_kvcache_prefetch
<
Kernel_traits
,
Is_causal
,
Is_local
&&
!
Is_causal
,
Has_alibi
,
IsEvenMNConst
&&
!
Append_KV
&&
IsEvenKConst
&&
!
Is_local
&&
Kernel_traits
::
kHeadDim
<=
128
,
IsEvenKConst
,
Is_softcap
,
Split
,
Append_KV
,
Has_block_table
,
false
,
false
,
false
>
)));
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
});
});
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment