Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
fuhuangpei
flash-attention
Commits
dee03c7b
Commit
dee03c7b
authored
Jun 04, 2026
by
fuhuangpei
Browse files
return d_value
parent
484743c6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
16 deletions
+16
-16
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+16
-16
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
dee03c7b
...
...
@@ -8148,20 +8148,20 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0);
S_BARRIER;
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 0, params.v_row_stride,
128
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 1, params.v_row_stride,
128
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 0, params.v_row_stride,
params.d_value
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 1, params.v_row_stride,
params.d_value
, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1, 1);
S_BARRIER;
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 2, params.v_row_stride,
128
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 3, params.v_row_stride,
128
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 2, params.v_row_stride,
params.d_value
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 3, params.v_row_stride,
params.d_value
, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2, 2);
S_BARRIER;
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 0, params.v_row_stride,
128
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 1, params.v_row_stride,
128
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 0, params.v_row_stride,
params.d_value
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 1, params.v_row_stride,
params.d_value
, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3);
S_BARRIER;
...
...
@@ -8196,8 +8196,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
{
__builtin_amdgcn_sched_barrier(0);
// int token_id = n_block * kBlockN + ((tidx % 64) / 16) * 4;
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 2, params.v_row_stride,
128
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 3, params.v_row_stride,
128
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 2, params.v_row_stride,
params.d_value
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 3, params.v_row_stride,
params.d_value
, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(7) \n s_barrier");
__builtin_amdgcn_sched_barrier(0);
if (!Is_even_MN && Is_need_pad && masking_step == 0) {
...
...
@@ -8277,20 +8277,20 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0);
S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 0, params.v_row_stride,
128
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 1, params.v_row_stride,
128
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 0, params.v_row_stride,
params.d_value
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 1, params.v_row_stride,
params.d_value
, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1, 1);
S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 2, params.v_row_stride,
128
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 3, params.v_row_stride,
128
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 2, params.v_row_stride,
params.d_value
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 3, params.v_row_stride,
params.d_value
, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2, 2);
S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 0, params.v_row_stride,
128
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 1, params.v_row_stride,
128
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 0, params.v_row_stride,
params.d_value
, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 1, params.v_row_stride,
params.d_value
, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3);
S_BARRIER;
...
...
@@ -8320,8 +8320,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
{
__builtin_amdgcn_sched_barrier(0);
S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 2, params.v_row_stride,
128
);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 3, params.v_row_stride,
128
);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 2, params.v_row_stride,
params.d
);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 3, params.v_row_stride,
params.d
);
asm volatile("s_waitcnt vmcnt(7) \n s_barrier");
flash::gemm_k_rs(acc_o_ori, rP, tOrV, tSsV, tiled_mma_gemm1, smem_tiled_copy_V, smem_thr_copy_V, 0, 0);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment