Commit dee03c7b authored by fuhuangpei's avatar fuhuangpei
Browse files

return d_value

parent 484743c6
......@@ -8148,20 +8148,20 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0);
S_BARRIER;
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 0, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 1, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 1, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1, 1);
S_BARRIER;
 
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 2, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 3, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 2, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV, sV, 3, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2, 2);
S_BARRIER;
 
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 0, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 1, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 1, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3);
S_BARRIER;
......@@ -8196,8 +8196,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
{
__builtin_amdgcn_sched_barrier(0);
// int token_id = n_block * kBlockN + ((tidx % 64) / 16) * 4;
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 2, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 3, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 2, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, Is_even_MN, _64x16, 0, true>(gV_tail, sV_tail, 3, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(7) \n s_barrier");
__builtin_amdgcn_sched_barrier(0);
if (!Is_even_MN && Is_need_pad && masking_step == 0) {
......@@ -8277,20 +8277,20 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
S_WAITCNT;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 0, 0);
S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 0, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 1, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 1, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier");;
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 1, 1);
S_BARRIER;
 
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 2, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 3, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 2, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV, sV, 3, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(5) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 2, 2);
S_BARRIER;
 
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 0, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 1, params.v_row_stride, 128, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 0, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 1, params.v_row_stride, params.d_value, binfo.actual_seqlen_k - n_block * kBlockN);
asm volatile("s_waitcnt vmcnt(6) \n s_barrier");
flash::gemm_k_rs(acc_s_ori, tGrQ, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, 3, 3);
S_BARRIER;
......@@ -8320,8 +8320,8 @@ inline __device__ void compute_attn_1rowblock_splitkv_16x64_vllm_kvcache_prefetc
{
__builtin_amdgcn_sched_barrier(0);
S_BARRIER;
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 2, params.v_row_stride, 128);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 3, params.v_row_stride, 128);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 2, params.v_row_stride, params.d);
lds_direct_copy<Is_even_K, true, _64x16, 0, true>(gV_tail, sV_tail, 3, params.v_row_stride, params.d);
asm volatile("s_waitcnt vmcnt(7) \n s_barrier");
flash::gemm_k_rs(acc_o_ori, rP, tOrV, tSsV, tiled_mma_gemm1, smem_tiled_copy_V, smem_thr_copy_V, 0, 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment