Commit 142846b5 authored by zhanghj2's avatar zhanghj2
Browse files

fix精度问题

parent a9e4de8d
...@@ -1141,10 +1141,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx93 ...@@ -1141,10 +1141,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx93
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block; const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k); gK.data() = gK.data() + (offset_k);
...@@ -1165,81 +1167,111 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx93 ...@@ -1165,81 +1167,111 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx93
int k_idx = 0; int k_idx = 0;
// k_idx++; // k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
...@@ -1248,13 +1280,19 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx93 ...@@ -1248,13 +1280,19 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx93
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_to_tensor(buffer[0], tSrK_smem, 15); buffer_to_tensor(buffer[0], tSrK_smem, 15);
cute::gemm(tiled_mma, tSrQ(_, _, 15), tSrK_smem(_, _, 15), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 15), tSrK_smem(_, _, 15), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_to_tensor(buffer[1], tSrK_smem, 16); buffer_to_tensor(buffer[1], tSrK_smem, 16);
cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_to_tensor(buffer[2], tSrK_smem, 17); buffer_to_tensor(buffer[2], tSrK_smem, 17);
cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
...@@ -1297,7 +1335,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx93 ...@@ -1297,7 +1335,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx93
lds_direct_copy<false, true>(gK, sK, 15, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy<false, true>(gK, sK, 15, params.k_row_stride, seqlen_k - n_block * kBlockN);
// asm_ds_write(buffer[0], tVsV, 15); // asm_ds_write(buffer[0], tVsV, 15);
// asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); // asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
gK.data() = gK.data() + (-offset_k); gK.data() = gK.data() + (-offset_k);
...@@ -1321,10 +1361,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx93 ...@@ -1321,10 +1361,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx93
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block; const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k); gK.data() = gK.data() + (offset_k);
...@@ -1344,85 +1386,117 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx93 ...@@ -1344,85 +1386,117 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx93
int k_idx = 0; int k_idx = 0;
// k_idx++; // k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0 + 2) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0 + 2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -1435,14 +1509,20 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx93 ...@@ -1435,14 +1509,20 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_nope_pe_gfx93
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_to_tensor(buffer[0], tSrK_smem, 16); buffer_to_tensor(buffer[0], tSrK_smem, 16);
cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_to_tensor(buffer[1], tSrK_smem, 17); buffer_to_tensor(buffer[1], tSrK_smem, 17);
cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
gK.data() = gK.data() + (-offset_k); gK.data() = gK.data() + (-offset_k);
// We have key_padding_mask so we'll need to Check_inf // We have key_padding_mask so we'll need to Check_inf
...@@ -1734,10 +1814,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g ...@@ -1734,10 +1814,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block; const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k); gK.data() = gK_data + (offset_k);
...@@ -1791,7 +1873,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g ...@@ -1791,7 +1873,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g
#else #else
gemm_rs_fp8<Element, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale); gemm_rs_fp8<Element, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
gemm_k_rs_fp8<Element, 8, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale); gemm_k_rs_fp8<Element, 8, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale);
#endif #endif
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
...@@ -1839,10 +1923,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g ...@@ -1839,10 +1923,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block - 1; const int *cur_block_table_ptr = block_table + n_block - 1;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k); gK.data() = gK_data + (offset_k);
...@@ -1884,7 +1970,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g ...@@ -1884,7 +1970,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g
gemm_rs_fp8<Element, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale); gemm_rs_fp8<Element, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale);
#endif #endif
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
gemm_k_rs_fp8<Element, 8, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale); gemm_k_rs_fp8<Element, 8, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
...@@ -1910,10 +1998,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g ...@@ -1910,10 +1998,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block - 1; const int *cur_block_table_ptr = block_table + n_block - 1;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k); gK.data() = gK_data + (offset_k);
...@@ -2107,10 +2197,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936(co ...@@ -2107,10 +2197,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936(co
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block; const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k); gK.data() = gK_data + (offset_k);
...@@ -2164,7 +2256,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936(co ...@@ -2164,7 +2256,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936(co
#else #else
gemm_rs_fp8<Element, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale); gemm_rs_fp8<Element, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
gemm_k_rs_fp8<Element, 8, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale); gemm_k_rs_fp8<Element, 8, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale);
#endif #endif
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
...@@ -2212,10 +2306,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936(co ...@@ -2212,10 +2306,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936(co
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block - 1; const int *cur_block_table_ptr = block_table + n_block - 1;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k); gK.data() = gK_data + (offset_k);
...@@ -2257,7 +2353,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936(co ...@@ -2257,7 +2353,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936(co
gemm_rs_fp8<Element, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale); gemm_rs_fp8<Element, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK_int8, tSrK, tSsK, tiled_mma, smem_tiled_copy_K, smem_thr_copy_K, k_scale);
#endif #endif
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
gemm_k_rs_fp8<Element, 8, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale); gemm_k_rs_fp8<Element, 8, is_scale_equal_one, KV_DTYPE>(acc_s, tSrQ, tSrK, tiled_mma, buffer[0], k_scale);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
...@@ -2283,10 +2381,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936(co ...@@ -2283,10 +2381,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936(co
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block - 1; const int *cur_block_table_ptr = block_table + n_block - 1;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k); gK.data() = gK_data + (offset_k);
...@@ -2465,10 +2565,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936_tp ...@@ -2465,10 +2565,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936_tp
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block; const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k); gK.data() = gK_data + (offset_k);
...@@ -2696,7 +2798,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936_tp ...@@ -2696,7 +2798,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx936_tp
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier \n\t"); asm volatile("s_barrier \n\t");
__builtin_amdgcn_sched_barrier(0);
} }
...@@ -2801,10 +2905,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(co ...@@ -2801,10 +2905,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(co
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block; const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK_data + (offset_k); gK.data() = gK_data + (offset_k);
...@@ -2829,7 +2935,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(co ...@@ -2829,7 +2935,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(co
|| lane_id == 15 || lane_id == 31 || lane_id == 47 || lane_id == 63 || lane_id == 15 || lane_id == 31 || lane_id == 47 || lane_id == 63
) * (- 8 * 64) ) + 4 * 8; ) * (- 8 * 64) ) + 4 * 8;
// for (int k = 0; k < 1; k+=2) // for (int k = 0; k < 1; k+=2)
__builtin_amdgcn_sched_barrier(0);
asm volatile(" s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); asm volatile(" s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
for (int k = 0; k < 4; k+=2) for (int k = 0; k < 4; k+=2)
{ {
bf16_storage data; bf16_storage data;
...@@ -2864,7 +2972,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(co ...@@ -2864,7 +2972,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(co
k_lds_ptr += 32 * 64; k_lds_ptr += 32 * 64;
k_lds_ptr1 += 32 * 64; k_lds_ptr1 += 32 * 64;
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile(" s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); asm volatile(" s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
for (int k = 4; k < 8; k+=2) for (int k = 4; k < 8; k+=2)
{ {
bf16_storage data; bf16_storage data;
...@@ -2899,7 +3009,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(co ...@@ -2899,7 +3009,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(co
k_lds_ptr += 32 * 64; k_lds_ptr += 32 * 64;
k_lds_ptr1 += 32 * 64; k_lds_ptr1 += 32 * 64;
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile(" s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); asm volatile(" s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
for (int k = 8; k < 12; k+=2) for (int k = 8; k < 12; k+=2)
{ {
bf16_storage data; bf16_storage data;
...@@ -2934,7 +3046,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(co ...@@ -2934,7 +3046,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(co
k_lds_ptr += 32 * 64; k_lds_ptr += 32 * 64;
k_lds_ptr1 += 32 * 64; k_lds_ptr1 += 32 * 64;
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile(" s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); asm volatile(" s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
for (int k = 12; k < 16; k+=2) for (int k = 12; k < 16; k+=2)
{ {
bf16_storage data; bf16_storage data;
...@@ -2969,7 +3083,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(co ...@@ -2969,7 +3083,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(co
k_lds_ptr += 32 * 64; k_lds_ptr += 32 * 64;
k_lds_ptr1 += 32 * 64; k_lds_ptr1 += 32 * 64;
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile(" s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile(" s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
for (int k = 16; k < 18; k+=2) for (int k = 16; k < 18; k+=2)
{ {
bf16_storage data; bf16_storage data;
...@@ -3078,7 +3194,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(co ...@@ -3078,7 +3194,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936_tp1(co
#endif #endif
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile(" s_barrier\n\t"); asm volatile(" s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
} }
} }
...@@ -3316,10 +3434,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const ...@@ -3316,10 +3434,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block; const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k); gK.data() = gK.data() + (offset_k);
...@@ -3340,81 +3460,111 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const ...@@ -3340,81 +3460,111 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const
int k_idx = 0; int k_idx = 0;
// k_idx++; // k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
...@@ -3423,13 +3573,19 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const ...@@ -3423,13 +3573,19 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_to_tensor(buffer[0], tSrK_smem, 15); buffer_to_tensor(buffer[0], tSrK_smem, 15);
cute::gemm(tiled_mma, tSrQ(_, _, 15), tSrK_smem(_, _, 15), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 15), tSrK_smem(_, _, 15), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_to_tensor(buffer[1], tSrK_smem, 16); buffer_to_tensor(buffer[1], tSrK_smem, 16);
cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_to_tensor(buffer[2], tSrK_smem, 17); buffer_to_tensor(buffer[2], tSrK_smem, 17);
cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
...@@ -3472,7 +3628,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const ...@@ -3472,7 +3628,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const
lds_direct_copy<false, true>(gK, sK, 15, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy<false, true>(gK, sK, 15, params.k_row_stride, seqlen_k - n_block * kBlockN);
// asm_ds_write(buffer[0], tVsV, 15); // asm_ds_write(buffer[0], tVsV, 15);
// asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); // asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
gK.data() = gK.data() + (-offset_k); gK.data() = gK.data() + (-offset_k);
...@@ -3496,10 +3654,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const ...@@ -3496,10 +3654,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block; const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k); gK.data() = gK.data() + (offset_k);
...@@ -3519,85 +3679,117 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const ...@@ -3519,85 +3679,117 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const
int k_idx = 0; int k_idx = 0;
// k_idx++; // k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0 + 2) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0 + 2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -3610,14 +3802,20 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const ...@@ -3610,14 +3802,20 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_to_tensor(buffer[0], tSrK_smem, 16); buffer_to_tensor(buffer[0], tSrK_smem, 16);
cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
buffer_to_tensor(buffer[1], tSrK_smem, 17); buffer_to_tensor(buffer[1], tSrK_smem, 17);
cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
gK.data() = gK.data() + (-offset_k); gK.data() = gK.data() + (-offset_k);
// We have key_padding_mask so we'll need to Check_inf // We have key_padding_mask so we'll need to Check_inf
...@@ -3814,7 +4012,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx928(const ...@@ -3814,7 +4012,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx928(const
#if 1 #if 1
#pragma unroll #pragma unroll
for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t"); asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s); clear(acc_s);
...@@ -3826,7 +4026,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx928(const ...@@ -3826,7 +4026,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx928(const
for (int i = 0; i < k0_lds_loops - BUFFER_SIZE + 1; i++) { for (int i = 0; i < k0_lds_loops - BUFFER_SIZE + 1; i++) {
// asm volatile("s_waitcnt vmcnt(3) \n\t \n\t"); // asm volatile("s_waitcnt vmcnt(3) \n\t \n\t");
asm_ds_write(buffer[i % BUFFER_SIZE], tKsK, i); asm_ds_write(buffer[i % BUFFER_SIZE], tKsK, i);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i)); cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i));
buffer_load_copy<false, true, false>(gK, buffer[(i + BUFFER_SIZE - 1) % BUFFER_SIZE], i + BUFFER_SIZE - 1, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); buffer_load_copy<false, true, false>(gK, buffer[(i + BUFFER_SIZE - 1) % BUFFER_SIZE], i + BUFFER_SIZE - 1, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s);
...@@ -3842,17 +4044,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx928(const ...@@ -3842,17 +4044,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx928(const
// 计算 13-15 // 计算 13-15
const int k_idx = k0_lds_loops - BUFFER_SIZE + 1; const int k_idx = k0_lds_loops - BUFFER_SIZE + 1;
asm_ds_write(buffer[k_idx % BUFFER_SIZE], tKsK, k_idx); asm_ds_write(buffer[k_idx % BUFFER_SIZE], tKsK, k_idx);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
asm_ds_write(buffer[(k_idx + 1) % BUFFER_SIZE], tKsK, k_idx + 1); asm_ds_write(buffer[(k_idx + 1) % BUFFER_SIZE], tKsK, k_idx + 1);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx + 1), tSrK_copy_view(_, _, k_idx + 1)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx + 1), tSrK_copy_view(_, _, k_idx + 1));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx + 1), tSrK(_, _, k_idx + 1), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx + 1), tSrK(_, _, k_idx + 1), acc_s);
asm_ds_write(buffer[(k_idx + 2) % BUFFER_SIZE], tKsK, k_idx + 2); asm_ds_write(buffer[(k_idx + 2) % BUFFER_SIZE], tKsK, k_idx + 2);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx + 2), tSrK_copy_view(_, _, k_idx + 2)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx + 2), tSrK_copy_view(_, _, k_idx + 2));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx + 2), tSrK(_, _, k_idx + 2), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx + 2), tSrK(_, _, k_idx + 2), acc_s);
...@@ -3869,7 +4077,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx928(const ...@@ -3869,7 +4077,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx928(const
buffer_to_tensor(buffer[2], tSrK_smem, 17); buffer_to_tensor(buffer[2], tSrK_smem, 17);
cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t"); asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
#endif #endif
...@@ -3904,7 +4114,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx928(const ...@@ -3904,7 +4114,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx928(const
#if 1 #if 1
// 第15块已经读取到了buffer[3]中 // 第15块已经读取到了buffer[3]中
asm_ds_write(buffer[3], tVsV, 15); asm_ds_write(buffer[3], tVsV, 15);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
#endif #endif
gK.data() = gK.data() + (-offset_k); gK.data() = gK.data() + (-offset_k);
...@@ -3923,7 +4135,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx928(const ...@@ -3923,7 +4135,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_gfx928(const
cute::copy(smem_tiled_copy_V, tOsVt(_, _, i), tOrVt_copy_view(_, _, i)); cute::copy(smem_tiled_copy_V, tOsVt(_, _, i), tOrVt_copy_view(_, _, i));
cute::gemm(tiled_mma_o, tOrP(_, _, i), tOrVt(_, _, i), acc_o); cute::gemm(tiled_mma_o, tOrP(_, _, i), tOrVt(_, _, i), acc_o);
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile(" s_barrier\n\t"); asm volatile(" s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
} }
#endif #endif
......
...@@ -642,17 +642,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -642,17 +642,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
lds_direct_copy_qkvfp8<false, true, true>(gQ, sQ, 0, params.q_row_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8<false, true, true>(gQ, sQ, 0, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8<false, true, true>(gQ, sQ, 1, params.q_row_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8<false, true, true>(gQ, sQ, 1, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8<false, false, true>(gQ, sQ, 2, params.q_row_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8<false, false, true>(gQ, sQ, 2, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8));
__syncthreads(); __syncthreads();
} }
...@@ -708,12 +714,14 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -708,12 +714,14 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
const int offset_s = 0; const int offset_s = 0;
int ldsAddrPerWave = reinterpret_cast<size_t>(s_q) + warp_idx * bytes_per_warp + k_idx * bytes_per_block int ldsAddrPerWave = reinterpret_cast<size_t>(s_q) + warp_idx * bytes_per_warp + k_idx * bytes_per_block
+ offset_k * 3 * bytes_per_block; + offset_k * 3 * bytes_per_block;
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
}; };
lds_direct_copy_q(0, 0); lds_direct_copy_q(0, 0);
...@@ -723,7 +731,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -723,7 +731,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
lds_direct_copy_q(2, 0); lds_direct_copy_q(2, 0);
ElementQ* s_q_read_ptr = s_q + lane_idx * 8; ElementQ* s_q_read_ptr = s_q + lane_idx * 8;
Fp8_storage bf16_data; Fp8_storage bf16_data;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(4) \n s_barrier"); asm volatile("s_waitcnt vmcnt(4) \n s_barrier");
__builtin_amdgcn_sched_barrier(0);
float fp32[8]; float fp32[8];
union Fp8_temp{ union Fp8_temp{
int32_t data; int32_t data;
...@@ -747,7 +757,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -747,7 +757,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
} }
s_q_read_ptr += 16 * 32; s_q_read_ptr += 16 * 32;
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n s_barrier"); asm volatile("s_waitcnt vmcnt(3) \n s_barrier");
__builtin_amdgcn_sched_barrier(0);
for (int k = 4; k < 8; k++) { for (int k = 4; k < 8; k++) {
bf16_data.data = *reinterpret_cast<intx4_t*>(s_q_read_ptr); bf16_data.data = *reinterpret_cast<intx4_t*>(s_q_read_ptr);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
...@@ -766,7 +778,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -766,7 +778,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
} }
s_q_read_ptr += 16 * 32; s_q_read_ptr += 16 * 32;
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n s_barrier"); asm volatile("s_waitcnt vmcnt(2) \n s_barrier");
__builtin_amdgcn_sched_barrier(0);
s_q_read_ptr = s_q + lane_idx * 8 + 3 * 4 * 16 * 4 * 8; s_q_read_ptr = s_q + lane_idx * 8 + 3 * 4 * 16 * 4 * 8;
for (int k = 0; k < 4; k++) { for (int k = 0; k < 4; k++) {
bf16_data.data = *reinterpret_cast<intx4_t*>(s_q_read_ptr); bf16_data.data = *reinterpret_cast<intx4_t*>(s_q_read_ptr);
...@@ -786,7 +800,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -786,7 +800,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
} }
s_q_read_ptr += 16 * 32; s_q_read_ptr += 16 * 32;
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n s_barrier"); asm volatile("s_waitcnt vmcnt(1) \n s_barrier");
__builtin_amdgcn_sched_barrier(0);
for (int k = 4; k < 8; k++) { for (int k = 4; k < 8; k++) {
bf16_data.data = *reinterpret_cast<intx4_t*>(s_q_read_ptr); bf16_data.data = *reinterpret_cast<intx4_t*>(s_q_read_ptr);
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
...@@ -805,7 +821,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -805,7 +821,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
} }
s_q_read_ptr += 16 * 32; s_q_read_ptr += 16 * 32;
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n s_barrier"); asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
__builtin_amdgcn_sched_barrier(0);
s_q_read_ptr = s_q + lane_idx * 8 + 2 * 4 * 16 * 4 * 8; s_q_read_ptr = s_q + lane_idx * 8 + 2 * 4 * 16 * 4 * 8;
for (int k = 8; k < 9; k++) { for (int k = 8; k < 9; k++) {
bf16_data.data = *reinterpret_cast<intx4_t*>(s_q_read_ptr); bf16_data.data = *reinterpret_cast<intx4_t*>(s_q_read_ptr);
...@@ -848,17 +866,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -848,17 +866,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
lds_direct_copy_qkvfp8<false, true, true>(gQ_nope, sQ, 0, params.q_nope_head_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8<false, true, true>(gQ_nope, sQ, 0, params.q_nope_head_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8<false, true, true>(gQ_nope, sQ, 1, params.q_nope_head_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8<false, true, true>(gQ_nope, sQ, 1, params.q_nope_head_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_pe<false, false, true>(gQ_pe, sQ, 2, params.q_pe_head_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8_pe<false, false, true>(gQ_pe, sQ, 2, params.q_pe_head_stride, params.seqlen_q - m_block * kBlockM);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8));
__syncthreads(); __syncthreads();
} }
...@@ -968,7 +992,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -968,7 +992,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s);
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
// if (thread0()) { // if (thread0()) {
// printf(" %.2f %.2f %.2f %.2f \n", acc_s(0), acc_s(1), acc_s(2), acc_s(3)); // printf(" %.2f %.2f %.2f %.2f \n", acc_s(0), acc_s(1), acc_s(2), acc_s(3));
// } // }
...@@ -1252,26 +1278,36 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP ...@@ -1252,26 +1278,36 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
} }
auto q_lds_read_ptr = sQ.data().get() + (warp_id % 4) * 16 * 64; auto q_lds_read_ptr = sQ.data().get() + (warp_id % 4) * 16 * 64;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
q_r[0].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 0, 3, 1, 0); q_r[0].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 0, 3, 1, 0);
// q_lds_read_ptr += 64 * 64; // q_lds_read_ptr += 64 * 64;
q_r[1].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 64*64, 3, 1, 0); q_r[1].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 64*64, 3, 1, 0);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
// q_lds_read_ptr += 64 * 64; // q_lds_read_ptr += 64 * 64;
q_r[2].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 2*64*64, 3, 1, 0); q_r[2].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 2*64*64, 3, 1, 0);
// q_lds_read_ptr += 64 * 64; // q_lds_read_ptr += 64 * 64;
q_r[3].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 3*64*64, 3, 1, 0); q_r[3].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 3*64*64, 3, 1, 0);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
// q_lds_read_ptr += 64 * 64; // q_lds_read_ptr += 64 * 64;
q_r[4].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 4*64*64, 3, 1, 0); q_r[4].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 4*64*64, 3, 1, 0);
// q_lds_read_ptr += 64 * 64; // q_lds_read_ptr += 64 * 64;
q_r[5].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 5*64*64, 3, 1, 0); q_r[5].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 5*64*64, 3, 1, 0);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
// q_lds_read_ptr += 64 * 64; // q_lds_read_ptr += 64 * 64;
q_r[6].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 6*64*64, 3, 1, 0); q_r[6].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 6*64*64, 3, 1, 0);
// q_lds_read_ptr += 64 * 64; // q_lds_read_ptr += 64 * 64;
q_r[7].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 7*64*64, 3, 1, 0); q_r[7].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 7*64*64, 3, 1, 0);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
// q_lds_read_ptr += 64 * 64; // q_lds_read_ptr += 64 * 64;
q_r[8].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 8*64*64, 3, 1, 0); q_r[8].data = __builtin_hcu_ds_read_matrix_trans_format_u8((__attribute__((address_space(3))) int*)(q_lds_read_ptr), 8*64*64, 3, 1, 0);
__syncthreads(); __syncthreads();
...@@ -1939,7 +1975,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP ...@@ -1939,7 +1975,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
lds_direct_copy_qkvfp8_q_tp4<false, true>(gQ, sQ, 2, params.q_row_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8_q_tp4<false, true>(gQ, sQ, 2, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_q_tp4<false, true>(gQ, sQ, 3, params.q_row_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8_q_tp4<false, true>(gQ, sQ, 3, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
lds_direct_copy_qkvfp8_q_tp4<false, false>(gQ, sQ, 4, params.q_row_stride, params.seqlen_q - m_block * kBlockM); lds_direct_copy_qkvfp8_q_tp4<false, false>(gQ, sQ, 4, params.q_row_stride, params.seqlen_q - m_block * kBlockM);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
uint8_t* q_lds_read_ptr = reinterpret_cast<uint8_t*>(sQ.data().get()) + (tidx % 64) * 16 + (warp_id % 2) * (16 * 64); uint8_t* q_lds_read_ptr = reinterpret_cast<uint8_t*>(sQ.data().get()) + (tidx % 64) * 16 + (warp_id % 2) * (16 * 64);
{ {
int k = 0; int k = 0;
...@@ -1961,7 +1999,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP ...@@ -1961,7 +1999,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1))); // tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr); // *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
{ {
q_lds_read_ptr += 32*64; q_lds_read_ptr += 32*64;
int k = 2; int k = 2;
...@@ -1984,7 +2024,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP ...@@ -1984,7 +2024,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1))); // tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr); // *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
{ {
q_lds_read_ptr += 32*64; q_lds_read_ptr += 32*64;
int k = 4; int k = 4;
...@@ -2007,7 +2049,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP ...@@ -2007,7 +2049,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1))); // tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr); // *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
{ {
q_lds_read_ptr += 32*64; q_lds_read_ptr += 32*64;
int k = 6; int k = 6;
...@@ -2030,7 +2074,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP ...@@ -2030,7 +2074,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1))); // tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr); // *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
{ {
q_lds_read_ptr += 32*64; q_lds_read_ptr += 32*64;
int k = 8; int k = 8;
...@@ -2092,10 +2138,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP ...@@ -2092,10 +2138,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block; const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k); gK.data() = gK.data() + (offset_k);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 0, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8<false, true>(gK, sK, 0, params.k_row_stride, seqlen_k - n_block * kBlockN);
...@@ -2107,31 +2155,49 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP ...@@ -2107,31 +2155,49 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
lds_direct_copy_qkvfp8<false, true>(gK, sK, 6, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8<false, true>(gK, sK, 6, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 7, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8<false, true>(gK, sK, 7, params.k_row_stride, seqlen_k - n_block * kBlockN);
lds_direct_copy_qkvfp8<false, true>(gK, sK, 8, params.k_row_stride, seqlen_k - n_block * kBlockN); lds_direct_copy_qkvfp8<false, true>(gK, sK, 8, params.k_row_stride, seqlen_k - n_block * kBlockN);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(8) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(8) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0)); cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0));
cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 0), tSrK(_, _, 0), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(7) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(7) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tSrK_copy_view(_, _, 1)); cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tSrK_copy_view(_, _, 1));
cute::gemm(tiled_mma, tSrQ(_, _, 1), tSrK(_, _, 1), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 1), tSrK(_, _, 1), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(6) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(6) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 2), tSrK_copy_view(_, _, 2)); cute::copy(smem_tiled_copy_K, tSsK(_, _, 2), tSrK_copy_view(_, _, 2));
cute::gemm(tiled_mma, tSrQ(_, _, 2), tSrK(_, _, 2), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 2), tSrK(_, _, 2), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(5) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(5) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 3), tSrK_copy_view(_, _, 3)); cute::copy(smem_tiled_copy_K, tSsK(_, _, 3), tSrK_copy_view(_, _, 3));
cute::gemm(tiled_mma, tSrQ(_, _, 3), tSrK(_, _, 3), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 3), tSrK(_, _, 3), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 4), tSrK_copy_view(_, _, 4)); cute::copy(smem_tiled_copy_K, tSsK(_, _, 4), tSrK_copy_view(_, _, 4));
cute::gemm(tiled_mma, tSrQ(_, _, 4), tSrK(_, _, 4), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 4), tSrK(_, _, 4), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 5), tSrK_copy_view(_, _, 5)); cute::copy(smem_tiled_copy_K, tSsK(_, _, 5), tSrK_copy_view(_, _, 5));
cute::gemm(tiled_mma, tSrQ(_, _, 5), tSrK(_, _, 5), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 5), tSrK(_, _, 5), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 6), tSrK_copy_view(_, _, 6)); cute::copy(smem_tiled_copy_K, tSsK(_, _, 6), tSrK_copy_view(_, _, 6));
cute::gemm(tiled_mma, tSrQ(_, _, 6), tSrK(_, _, 6), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 6), tSrK(_, _, 6), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 7), tSrK_copy_view(_, _, 7)); cute::copy(smem_tiled_copy_K, tSsK(_, _, 7), tSrK_copy_view(_, _, 7));
cute::gemm(tiled_mma, tSrQ(_, _, 7), tSrK(_, _, 7), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 7), tSrK(_, _, 7), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 8), tSrK_copy_view(_, _, 8)); cute::copy(smem_tiled_copy_K, tSsK(_, _, 8), tSrK_copy_view(_, _, 8));
cute::gemm(tiled_mma, tSrQ(_, _, 8), tSrK(_, _, 8), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 8), tSrK(_, _, 8), acc_s);
gK.data() = gK.data() + (-offset_k); gK.data() = gK.data() + (-offset_k);
......
...@@ -300,9 +300,11 @@ __forceinline__ __device__ void copy_k_idx(TiledCopy tiled_copy, Tensor<Engine0, ...@@ -300,9 +300,11 @@ __forceinline__ __device__ void copy_k_idx(TiledCopy tiled_copy, Tensor<Engine0,
template <int N> template <int N>
CUTE_HOST_DEVICE CUTE_HOST_DEVICE
void wait_vmcnt() { void wait_vmcnt() {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(%0) ;\n\t" asm volatile("s_waitcnt vmcnt(%0) ;\n\t"
"s_barrier; \n\t" "s_barrier; \n\t"
:: "n"(N)); :: "n"(N));
__builtin_amdgcn_sched_barrier(0);
} }
template< template<
...@@ -377,11 +379,13 @@ buffer_load_copy( ...@@ -377,11 +379,13 @@ buffer_load_copy(
if constexpr(use_asm) { if constexpr(use_asm) {
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst), " \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr) "+v"(offset_v), "+s"(global_addr)
); );
__builtin_amdgcn_sched_barrier(0);
} }
else { else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
...@@ -408,11 +412,13 @@ buffer_load_copy( ...@@ -408,11 +412,13 @@ buffer_load_copy(
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if constexpr(use_asm) { if constexpr(use_asm) {
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst), " \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr) "+v"(offset_v), "+s"(global_addr)
); );
__builtin_amdgcn_sched_barrier(0);
} }
else { else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
...@@ -473,11 +479,13 @@ buffer_load_copy_fp8( ...@@ -473,11 +479,13 @@ buffer_load_copy_fp8(
if constexpr(use_asm) { if constexpr(use_asm) {
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst), " \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr) "+v"(offset_v), "+s"(global_addr)
); );
__builtin_amdgcn_sched_barrier(0);
} }
else { else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
...@@ -542,11 +550,13 @@ buffer_load_copy_fp8x2( ...@@ -542,11 +550,13 @@ buffer_load_copy_fp8x2(
if constexpr(use_asm) { if constexpr(use_asm) {
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"buffer_load_dwordx2 %0, %1, %2 ,0 offen offset:0 \n" "buffer_load_dwordx2 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst), " \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr) "+v"(offset_v), "+s"(global_addr)
); );
__builtin_amdgcn_sched_barrier(0);
} }
else { else {
auto res = __builtin_amdgcn_buffer_load_dwordx2(global_addr, 0, offset_v, false, false); auto res = __builtin_amdgcn_buffer_load_dwordx2(global_addr, 0, offset_v, false, false);
...@@ -711,12 +721,14 @@ lds_direct_copy_qkvfp8_pe( ...@@ -711,12 +721,14 @@ lds_direct_copy_qkvfp8_pe(
if (!Is_even_K && col_offset >= 576) offset_v = -1; if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
...@@ -778,12 +790,14 @@ lds_direct_copy_qkvfp8( ...@@ -778,12 +790,14 @@ lds_direct_copy_qkvfp8(
if (!Is_even_K && col_offset >= 576) offset_v = -1; if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
...@@ -845,12 +859,14 @@ lds_direct_copy_qkvfp8( ...@@ -845,12 +859,14 @@ lds_direct_copy_qkvfp8(
#if defined(__gfx938__) #if defined(__gfx938__)
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
#endif #endif
} }
} }
...@@ -933,12 +949,14 @@ lds_direct_copy_fp8( ...@@ -933,12 +949,14 @@ lds_direct_copy_fp8(
// } // }
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
#endif #endif
} }
} }
...@@ -1005,12 +1023,14 @@ lds_direct_copy_tp1( ...@@ -1005,12 +1023,14 @@ lds_direct_copy_tp1(
// { // {
// printf(" %x \n", ldsAddrPerWave); // printf(" %x \n", ldsAddrPerWave);
// } // }
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
} }
...@@ -1186,12 +1206,14 @@ lds_direct_copy_sparse_k( ...@@ -1186,12 +1206,14 @@ lds_direct_copy_sparse_k(
if (!Is_even_K && col_offset >= 576) offset_v = -1; if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + 0 * mma_k * element_size; int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + 0 * mma_k * element_size;
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
...@@ -1255,12 +1277,14 @@ lds_direct_copy( ...@@ -1255,12 +1277,14 @@ lds_direct_copy(
if (!Is_even_K && col_offset >= 576) offset_v = -1; if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
...@@ -1319,12 +1343,14 @@ lds_direct_copy( ...@@ -1319,12 +1343,14 @@ lds_direct_copy(
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
} }
} }
...@@ -1409,12 +1435,14 @@ lds_direct_copy_for_prefill_sparse_mla( ...@@ -1409,12 +1435,14 @@ lds_direct_copy_for_prefill_sparse_mla(
uint32x2_t index_offset = {0}; uint32x2_t index_offset = {0};
index_offset[0] = row_offset == -1 ? max_MN : row_offset; index_offset[0] = row_offset == -1 ? max_MN : row_offset;
index_offset[1] = offset_v; index_offset[1] = offset_v;
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset), "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
} }
...@@ -1474,11 +1502,13 @@ buffer_load_copy_sparse_fp8( ...@@ -1474,11 +1502,13 @@ buffer_load_copy_sparse_fp8(
uint32x2_t index_offset = {0}; uint32x2_t index_offset = {0};
index_offset[0] = (row_offset + 64 ) % 64; index_offset[0] = (row_offset + 64 ) % 64;
index_offset[1] = offset_v; index_offset[1] = offset_v;
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"buffer_load_dwordx2 %0, %1, %2 ,0 offen offset:0 \n" "buffer_load_dwordx2 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst), " \n\t" :"=v"(dst),
"+v"(index_offset), "+s"(global_addr) "+v"(index_offset), "+s"(global_addr)
); );
__builtin_amdgcn_sched_barrier(0);
} }
else { else {
// auto res = __builtin_amdgcn_buffer_load_dwordx2(global_addr, (row_offset + 64 ) % 64 , offset_v, false, false); // auto res = __builtin_amdgcn_buffer_load_dwordx2(global_addr, (row_offset + 64 ) % 64 , offset_v, false, false);
...@@ -1555,11 +1585,13 @@ buffer_load_copy_sparse_decoding( ...@@ -1555,11 +1585,13 @@ buffer_load_copy_sparse_decoding(
uint32x2_t index_offset = {0}; uint32x2_t index_offset = {0};
index_offset[0] = (row_offset + 64 ) % 64; index_offset[0] = (row_offset + 64 ) % 64;
index_offset[1] = offset_v; index_offset[1] = offset_v;
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst), " \n\t" :"=v"(dst),
"+v"(index_offset), "+s"(global_addr) "+v"(index_offset), "+s"(global_addr)
); );
__builtin_amdgcn_sched_barrier(0);
} }
else { else {
// auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, (row_offset + 64 ) % 64 + (batch_stride / row_stride) * block_idx , offset_v, false, false); // auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, (row_offset + 64 ) % 64 + (batch_stride / row_stride) * block_idx , offset_v, false, false);
...@@ -2303,12 +2335,14 @@ lds_direct_copy_qkvfp8_q_tp1( ...@@ -2303,12 +2335,14 @@ lds_direct_copy_qkvfp8_q_tp1(
#if defined(__gfx938__) #if defined(__gfx938__)
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
#endif #endif
} }
...@@ -2362,12 +2396,14 @@ lds_direct_copy_qkvfp8_q_tp4( ...@@ -2362,12 +2396,14 @@ lds_direct_copy_qkvfp8_q_tp4(
#if defined(__gfx938__) #if defined(__gfx938__)
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
#endif #endif
} }
...@@ -2444,12 +2480,14 @@ lds_direct_copy_qkvfp8_tp1( ...@@ -2444,12 +2480,14 @@ lds_direct_copy_qkvfp8_tp1(
#if defined(__gfx938__) #if defined(__gfx938__)
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
#endif #endif
} }
...@@ -2792,12 +2830,14 @@ lds_direct_copy_qkvfp8_zero_lds( ...@@ -2792,12 +2830,14 @@ lds_direct_copy_qkvfp8_zero_lds(
#if defined(__gfx938__) #if defined(__gfx938__)
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
#endif #endif
} }
......
...@@ -744,10 +744,13 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params, ...@@ -744,10 +744,13 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block; const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k); gK.data() = gK.data() + (offset_k);
...@@ -768,81 +771,119 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params, ...@@ -768,81 +771,119 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
int k_idx = 0; int k_idx = 0;
// k_idx++; // k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
...@@ -850,14 +891,20 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params, ...@@ -850,14 +891,20 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
flash::buffer_to_tensor(buffer[0], tSrK_smem, 15); flash::buffer_to_tensor(buffer[0], tSrK_smem, 15);
cute::gemm(tiled_mma, tSrQ(_, _, 15), tSrK_smem(_, _, 15), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 15), tSrK_smem(_, _, 15), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
flash::buffer_to_tensor(buffer[1], tSrK_smem, 16); flash::buffer_to_tensor(buffer[1], tSrK_smem, 16);
cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
flash::buffer_to_tensor(buffer[2], tSrK_smem, 17); flash::buffer_to_tensor(buffer[2], tSrK_smem, 17);
cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
...@@ -903,7 +950,9 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params, ...@@ -903,7 +950,9 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
flash::lds_direct_copy<false, true>(gK, sK, 15, params.k_row_stride, seqlen_k - n_block * kBlockN); flash::lds_direct_copy<false, true>(gK, sK, 15, params.k_row_stride, seqlen_k - n_block * kBlockN);
// asm_ds_write(buffer[0], tVsV, 15); // asm_ds_write(buffer[0], tVsV, 15);
// asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); // asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
gK.data() = gK.data() + (-offset_k); gK.data() = gK.data() + (-offset_k);
...@@ -925,10 +974,12 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params, ...@@ -925,10 +974,12 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
int cur_block_table; int cur_block_table;
const int *cur_block_table_ptr = block_table + n_block; const int *cur_block_table_ptr = block_table + n_block;
// cur_block_table = block_table[n_block - 1]; // cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_load_dword %1, %0, 0x0\n\t" asm volatile("s_load_dword %1, %0, 0x0\n\t"
"s_waitcnt lgkmcnt(0)\n\t": "s_waitcnt lgkmcnt(0)\n\t":
"+s"(cur_block_table_ptr), "+s"(cur_block_table_ptr),
"=s"(cur_block_table)); "=s"(cur_block_table));
__builtin_amdgcn_sched_barrier(0);
index_t offset_k = cur_block_table * params.k_batch_stride; index_t offset_k = cur_block_table * params.k_batch_stride;
gK.data() = gK.data() + (offset_k); gK.data() = gK.data() + (offset_k);
...@@ -948,85 +999,117 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params, ...@@ -948,85 +999,117 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
int k_idx = 0; int k_idx = 0;
// k_idx++; // k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0 + 2) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0 + 2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
k_idx++; k_idx++;
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -1039,14 +1122,20 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params, ...@@ -1039,14 +1122,20 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
flash::buffer_to_tensor(buffer[0], tSrK_smem, 16); flash::buffer_to_tensor(buffer[0], tSrK_smem, 16);
cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t");
__builtin_amdgcn_sched_barrier(0);
flash::buffer_to_tensor(buffer[1], tSrK_smem, 17); flash::buffer_to_tensor(buffer[1], tSrK_smem, 17);
cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s);
// asm volatile("s_barrier\n\t"); // asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
gK.data() = gK.data() + (-offset_k); gK.data() = gK.data() + (-offset_k);
// We have key_padding_mask so we'll need to Check_inf // We have key_padding_mask so we'll need to Check_inf
...@@ -1325,7 +1414,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params, ...@@ -1325,7 +1414,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
#if 1 #if 1
#pragma unroll #pragma unroll
for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t"); asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
clear(acc_s); clear(acc_s);
...@@ -1337,7 +1428,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params, ...@@ -1337,7 +1428,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
for (int i = 0; i < k0_lds_loops - BUFFER_SIZE + 1; i++) { for (int i = 0; i < k0_lds_loops - BUFFER_SIZE + 1; i++) {
// asm volatile("s_waitcnt vmcnt(3) \n\t \n\t"); // asm volatile("s_waitcnt vmcnt(3) \n\t \n\t");
flash::asm_ds_write(buffer[i % BUFFER_SIZE], tKsK, i); flash::asm_ds_write(buffer[i % BUFFER_SIZE], tKsK, i);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i)); cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i));
flash::buffer_load_copy<false, true, false>(gK, buffer[(i + BUFFER_SIZE - 1) % BUFFER_SIZE], i + BUFFER_SIZE - 1, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); flash::buffer_load_copy<false, true, false>(gK, buffer[(i + BUFFER_SIZE - 1) % BUFFER_SIZE], i + BUFFER_SIZE - 1, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s);
...@@ -1353,17 +1446,23 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params, ...@@ -1353,17 +1446,23 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
// 计算 13-15 // 计算 13-15
const int k_idx = k0_lds_loops - BUFFER_SIZE + 1; const int k_idx = k0_lds_loops - BUFFER_SIZE + 1;
flash::asm_ds_write(buffer[k_idx % BUFFER_SIZE], tKsK, k_idx); flash::asm_ds_write(buffer[k_idx % BUFFER_SIZE], tKsK, k_idx);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
flash::asm_ds_write(buffer[(k_idx + 1) % BUFFER_SIZE], tKsK, k_idx + 1); flash::asm_ds_write(buffer[(k_idx + 1) % BUFFER_SIZE], tKsK, k_idx + 1);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx + 1), tSrK_copy_view(_, _, k_idx + 1)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx + 1), tSrK_copy_view(_, _, k_idx + 1));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx + 1), tSrK(_, _, k_idx + 1), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx + 1), tSrK(_, _, k_idx + 1), acc_s);
flash::asm_ds_write(buffer[(k_idx + 2) % BUFFER_SIZE], tKsK, k_idx + 2); flash::asm_ds_write(buffer[(k_idx + 2) % BUFFER_SIZE], tKsK, k_idx + 2);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx + 2), tSrK_copy_view(_, _, k_idx + 2)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx + 2), tSrK_copy_view(_, _, k_idx + 2));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx + 2), tSrK(_, _, k_idx + 2), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx + 2), tSrK(_, _, k_idx + 2), acc_s);
...@@ -1380,7 +1479,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params, ...@@ -1380,7 +1479,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
flash::buffer_to_tensor(buffer[2], tSrK_smem, 17); flash::buffer_to_tensor(buffer[2], tSrK_smem, 17);
cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t"); asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
#endif #endif
...@@ -1415,7 +1516,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params, ...@@ -1415,7 +1516,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
#if 1 #if 1
// 第15块已经读取到了buffer[3]中 // 第15块已经读取到了buffer[3]中
flash::asm_ds_write(buffer[3], tVsV, 15); flash::asm_ds_write(buffer[3], tVsV, 15);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
#endif #endif
gK.data() = gK.data() + (-offset_k); gK.data() = gK.data() + (-offset_k);
...@@ -1434,7 +1537,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params, ...@@ -1434,7 +1537,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
cute::copy(smem_tiled_copy_V, tOsVt(_, _, i), tOrVt_copy_view(_, _, i)); cute::copy(smem_tiled_copy_V, tOsVt(_, _, i), tOrVt_copy_view(_, _, i));
cute::gemm(tiled_mma_o, tOrP(_, _, i), tOrVt(_, _, i), acc_o); cute::gemm(tiled_mma_o, tOrP(_, _, i), tOrVt(_, _, i), acc_o);
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile(" s_barrier\n\t"); asm volatile(" s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
} }
#endif #endif
......
...@@ -800,48 +800,66 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn ...@@ -800,48 +800,66 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
// asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); // asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
if constexpr (D_QK == 576) if constexpr (D_QK == 576)
{ {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 9), tSrQ_copy_view(_, _, 9)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 9), tSrQ_copy_view(_, _, 9));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 10), tSrQ_copy_view(_, _, 10)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 10), tSrQ_copy_view(_, _, 10));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 11), tSrQ_copy_view(_, _, 11)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 11), tSrQ_copy_view(_, _, 11));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 12), tSrQ_copy_view(_, _, 12)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 12), tSrQ_copy_view(_, _, 12));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 13), tSrQ_copy_view(_, _, 13)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 13), tSrQ_copy_view(_, _, 13));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 14), tSrQ_copy_view(_, _, 14)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 14), tSrQ_copy_view(_, _, 14));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 15), tSrQ_copy_view(_, _, 15)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 15), tSrQ_copy_view(_, _, 15));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 16), tSrQ_copy_view(_, _, 16)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 16), tSrQ_copy_view(_, _, 16));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 17), tSrQ_copy_view(_, _, 17)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 17), tSrQ_copy_view(_, _, 17));
} }
else else
{ {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 9), tSrQ_copy_view(_, _, 9)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 9), tSrQ_copy_view(_, _, 9));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 10), tSrQ_copy_view(_, _, 10)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 10), tSrQ_copy_view(_, _, 10));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 11), tSrQ_copy_view(_, _, 11)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 11), tSrQ_copy_view(_, _, 11));
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 12), tSrQ_copy_view(_, _, 12)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 12), tSrQ_copy_view(_, _, 12));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 13), tSrQ_copy_view(_, _, 13)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 13), tSrQ_copy_view(_, _, 13));
cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 14), tSrQ_copy_view(_, _, 14)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 14), tSrQ_copy_view(_, _, 14));
...@@ -898,10 +916,14 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn ...@@ -898,10 +916,14 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, i, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, i, params.stride_kv_s_kv, params.s_kv);
} }
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n s_barrier"); asm volatile("s_waitcnt vmcnt(1) \n s_barrier");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0)); cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0));
cute::gemm(tiled_mma, tSrQ(_, _, 0 + 16), tSrK(_, _, 0), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 0 + 16), tSrK(_, _, 0), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n s_barrier"); asm volatile("s_waitcnt vmcnt(0) \n s_barrier");
__builtin_amdgcn_sched_barrier(0);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 0, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 0, params.stride_kv_s_kv, params.s_kv);
cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tSrK_copy_view(_, _, 1)); cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tSrK_copy_view(_, _, 1));
cute::gemm(tiled_mma, tSrQ(_, _, 1 + 16), tSrK(_, _, 1), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, 1 + 16), tSrK(_, _, 1), acc_s);
...@@ -917,58 +939,79 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn ...@@ -917,58 +939,79 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
} }
int k_idx = 0; int k_idx = 0;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 0>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 0>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 0>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 0>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 0>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 0>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 4, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 4, params.stride_kv_s_kv, params.s_kv);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 5, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 5, params.stride_kv_s_kv, params.s_kv);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 6, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 6, params.stride_kv_s_kv, params.s_kv);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 7, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 7, params.stride_kv_s_kv, params.s_kv);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 1>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 1>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 1>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 1>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 1>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 1>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 8, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 8, params.stride_kv_s_kv, params.s_kv);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 9, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 9, params.stride_kv_s_kv, params.s_kv);
...@@ -976,29 +1019,39 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn ...@@ -976,29 +1019,39 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 11, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 11, params.stride_kv_s_kv, params.s_kv);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 2>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 2>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 2>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 2>(tOsVt, tOrVt_copy_view);
flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 2>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 2>(tOsVt, tOrVt_copy_view);
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 12, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 12, params.stride_kv_s_kv, params.s_kv);
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 13, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 13, params.stride_kv_s_kv, params.s_kv);
...@@ -1006,25 +1059,39 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn ...@@ -1006,25 +1059,39 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 15, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla<true, false, false>(gK, sK, row_offset, col, 15, params.stride_kv_s_kv, params.s_kv);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
k_idx++; k_idx++;
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx));
cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t"); asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
// if (block0()) // if (block0())
// { // {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment