Commit 1cb8a563 authored by zhanghj2's avatar zhanghj2
Browse files

Optimize sparse prefill d512 h64 topk sink path

parent 79408d6d
...@@ -132,39 +132,36 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -132,39 +132,36 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
global_addr_q[2] = 64; global_addr_q[2] = 64;
global_addr_q[3] = 0x00020000; global_addr_q[3] = 0x00020000;
PtrWrapper glob_ptr_indices;
*(uint64_t*)&glob_ptr_indices = reinterpret_cast<uint64_t>(gIndices);
// glob_ptr_indices.latter |= ((params.stride_indices_s_q * 4) << 16);
// *(uint64_t*)&glob_ptr_indices = reinterpret_cast<uint64_t>(params.indices);
// glob_ptr_indices.latter |= ((params.stride_indices_s_q * 4) << 16);
glob_ptr_indices.latter |= 0x40000000;
uint32x4_t global_addr_indices = {0};
global_addr_indices[0] = (glob_ptr_indices.former);
global_addr_indices[1] = (glob_ptr_indices.latter);
global_addr_indices[2] = 0x80000000;
global_addr_indices[3] = 0x00020000;
auto buffer_load_lds_indices = [&] (int n) { auto buffer_load_lds_indices = [&] (int n) {
constexpr int element_size = 4; if constexpr (IS_TOPK_2048) {
int ldsAddrPerWave = reinterpret_cast<size_t>(sIndices) + warp_idx * 64 * 4 * 4; PtrWrapper glob_ptr_indices;
typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2))); *(uint64_t*)&glob_ptr_indices = reinterpret_cast<uint64_t>(gIndices);
// uint32x2_t index_offset = {0}; glob_ptr_indices.latter |= 0x40000000;
// index_offset[0] = s_q_idx; uint32x4_t global_addr_indices = {0};
// index_offset[1] = lane_idx * 4 * 4 + warp_idx * 64 * 4 * 4; global_addr_indices[0] = (glob_ptr_indices.former);
const int offset_v = lane_idx * 4 * 4 + warp_idx * 64 * 4 * 4; global_addr_indices[1] = (glob_ptr_indices.latter);
const int offset_s = n * 1024 * 4; global_addr_indices[2] = 0x80000000;
global_addr_indices[3] = 0x00020000;
int ldsAddrPerWave = reinterpret_cast<size_t>(sIndices) + warp_idx * 64 * 4 * 4;
const int offset_v = lane_idx * 4 * 4 + warp_idx * 64 * 4 * 4;
const int offset_s = n * 1024 * 4;
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr_indices), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0);
}
};
if constexpr (IS_TOPK_2048) {
buffer_load_lds_indices(0);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
"s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr_indices), "s"(offset_s)
:);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
}; }
buffer_load_lds_indices(0);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
PtrWrapper glob_ptr_k; PtrWrapper glob_ptr_k;
*(uint64_t*)&glob_ptr_k = reinterpret_cast<uint64_t>(gK.data().get()); *(uint64_t*)&glob_ptr_k = reinterpret_cast<uint64_t>(gK.data().get());
...@@ -276,7 +273,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -276,7 +273,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
Element* q_lds_read_ptr = (q_lds + warp_idx * 16 * 32 + lane_idx * 8); Element* q_lds_read_ptr = (q_lds + warp_idx * 16 * 32 + lane_idx * 8);
Element* k_lds_read_ptr = (k_lds + k_lds_read_offset()); Element* k_lds_read_ptr = (k_lds + k_lds_read_offset());
Bf16_storage q_reg[18]; Bf16_storage q_reg[18];
for (int i = 0; i < 18; i++) static constexpr int kQkChunks = D_QK / 32;
for (int i = 0; i < kQkChunks; i++)
{ {
constexpr int elements_per_thread = 8; constexpr int elements_per_thread = 8;
int row = lane_idx % 16; int row = lane_idx % 16;
...@@ -325,6 +323,59 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -325,6 +323,59 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
auto [row_offset, col] = calc_row_and_col_k(block_idx); auto [row_offset, col] = calc_row_and_col_k(block_idx);
row_offset = row_offset == -1 ? params.s_kv : row_offset; row_offset = row_offset == -1 ? params.s_kv : row_offset;
#if 1 #if 1
if constexpr (D_QK == 512) {
#define LOAD_K_AND_QK_GEMM_512(k) \
{ \
constexpr int k_val = (k); \
buffer_load_lds_k(row_offset, col, k_val - 3); \
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
}
constexpr int k_val = 15;
buffer_load_lds_k(row_offset, col, k_val);
buffer_load_lds_k(row_offset, col, k_val - 1);
buffer_load_lds_k(row_offset, col, k_val - 2);
buffer_load_lds_k(row_offset, col, k_val - 3);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
LOAD_K_AND_QK_GEMM_512(14);
LOAD_K_AND_QK_GEMM_512(13);
LOAD_K_AND_QK_GEMM_512(12);
LOAD_K_AND_QK_GEMM_512(11);
LOAD_K_AND_QK_GEMM_512(10);
LOAD_K_AND_QK_GEMM_512(9);
LOAD_K_AND_QK_GEMM_512(8);
LOAD_K_AND_QK_GEMM_512(7);
LOAD_K_AND_QK_GEMM_512(6);
LOAD_K_AND_QK_GEMM_512(5);
LOAD_K_AND_QK_GEMM_512(4);
LOAD_K_AND_QK_GEMM_512(3);
flash::qk_gemm<Element, 2>(q_reg[2].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, 1>(q_reg[1].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, 0>(q_reg[0].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
#undef LOAD_K_AND_QK_GEMM_512
} else {
#define LOAD_K_AND_QK_GEMM(k) \ #define LOAD_K_AND_QK_GEMM(k) \
{ \ { \
constexpr int k_val = (k); \ constexpr int k_val = (k); \
...@@ -379,6 +430,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -379,6 +430,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
asm volatile("s_barrier\n\t"); asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
#undef LOAD_K_AND_QK_GEMM
}
#else #else
#define LOAD_K_AND_QK_GEMM(k) \ #define LOAD_K_AND_QK_GEMM(k) \
{ \ { \
...@@ -710,6 +763,20 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -710,6 +763,20 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
float* gLSE = reinterpret_cast<float *>(params.lse) + row_offset_lse; float* gLSE = reinterpret_cast<float *>(params.lse) + row_offset_lse;
// const index_t row_offset_lse = m_block * params.h_q; // const index_t row_offset_lse = m_block * params.h_q;
float* gMax_logits = reinterpret_cast<float *>(params.max_logits) + row_offset_lse; float* gMax_logits = reinterpret_cast<float *>(params.max_logits) + row_offset_lse;
float attn_sink_o_scale = 1.0f;
if constexpr (D_QK == 512 && HAVE_TOPK_LENGTH) {
if (params.attn_sink != nullptr) {
float rAttn_sink = __ldg((float*)params.attn_sink + bidh * kBlockM + lane_idx % 16 + warp_idx * 16);
if (flash::is_positive_infinity(rAttn_sink)) {
attn_sink_o_scale = 0.0f;
} else if (!flash::is_positive_infinity(lse(0))) {
float lse_exp2 = __builtin_amdgcn_exp2f(lse[0] * CUDART_L2E_F);
float rAttn_sink_exp2 = __builtin_amdgcn_exp2f(rAttn_sink * CUDART_L2E_F);
attn_sink_o_scale = lse_exp2 / (lse_exp2 + rAttn_sink_exp2);
}
}
}
{ {
// store O and gLSE // store O and gLSE
...@@ -725,13 +792,13 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -725,13 +792,13 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
#if defined(__gfx938__) #if defined(__gfx938__)
Bf16_storage res; Bf16_storage res;
col = (lane_idx / 16) * 8 + ni * 32 ; col = (lane_idx / 16) * 8 + ni * 32 ;
res.data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][0], 0, acco_f32[ni * 2 + 1][0], 0); res.data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][0] * attn_sink_o_scale, 0, acco_f32[ni * 2 + 1][0] * attn_sink_o_scale, 0);
res.data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][1], 0, acco_f32[ni * 2 + 1][1], 0); res.data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][1] * attn_sink_o_scale, 0, acco_f32[ni * 2 + 1][1] * attn_sink_o_scale, 0);
res.data_32[2] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][2], 0, acco_f32[ni * 2 + 1][2], 0); res.data_32[2] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][2] * attn_sink_o_scale, 0, acco_f32[ni * 2 + 1][2] * attn_sink_o_scale, 0);
res.data_32[3] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][3], 0, acco_f32[ni * 2 + 1][3], 0); res.data_32[3] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][3] * attn_sink_o_scale, 0, acco_f32[ni * 2 + 1][3] * attn_sink_o_scale, 0);
*(__fp16x8_t*)(&gO(row, col)) = res.data_128; *(__fp16x8_t*)(&gO(row, col)) = res.data_128;
...@@ -742,8 +809,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -742,8 +809,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
{ {
result_type res; result_type res;
Element e0, e1; Element e0, e1;
e0.storage = float2bf16(acco_f32[ni * 2][ei]); e0.storage = float2bf16(acco_f32[ni * 2][ei] * attn_sink_o_scale);
e1.storage = float2bf16(acco_f32[ni * 2 + 1][ei]); e1.storage = float2bf16(acco_f32[ni * 2 + 1][ei] * attn_sink_o_scale);
res[0] = e0; res[0] = e0;
res[1] = e1; res[1] = e1;
// gO(row, col) = res[0]; // gO(row, col) = res[0];
...@@ -1318,9 +1385,34 @@ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::run(const Spar ...@@ -1318,9 +1385,34 @@ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::run(const Spar
KU_CHECK_KERNEL_LAUNCH(); KU_CHECK_KERNEL_LAUNCH();
} }
class KernelTemplate_D512_H64_TopkLen_AttnSink {
public:
static constexpr int NUM_THREADS = KernelTemplate_B_H_64<512, true, false>::NUM_THREADS;
static __device__ __forceinline__ void
devfunc(const SparseAttnFwdParams &params) {
KernelTemplate_B_H_64<512, true, false>::devfunc(params);
}
static void run(const SparseAttnFwdParams &params) {
KU_ASSERT(params.h_kv == 1);
KU_ASSERT(params.topk > 0);
auto kernel = &sparse_attn_fwd_kernel<KernelTemplate_D512_H64_TopkLen_AttnSink>;
constexpr size_t smem_size = 16384 + 4096;
dim3 grid((params.h_q + 64 - 1) / 64, params.s_q, 1);
kernel<<<grid, NUM_THREADS, smem_size, params.stream>>>(params);
KU_CHECK_KERNEL_LAUNCH();
}
};
template<int D_QK, bool HAVE_TOPK_LENGTH> template<int D_QK, bool HAVE_TOPK_LENGTH>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) { void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) {
if (params.h_q == 64 && !HAVE_TOPK_LENGTH && D_QK == 576 && !params.attn_sink) if (D_QK == 512 && HAVE_TOPK_LENGTH && params.h_q == 64 && params.attn_sink)
{
KernelTemplate_D512_H64_TopkLen_AttnSink::run(params);
}
else if (params.h_q == 64 && !HAVE_TOPK_LENGTH && D_QK == 576 && !params.attn_sink)
{ {
if (params.topk == 2048) if (params.topk == 2048)
{ {
......
2026-05-22 sparse prefill phase1 optimization
Scope
- Target kernel: csrc/gfx93/prefill/sparse/phase1.cuh
- Target shape: D_QK=512, h_q=64, topk=512
- Optimized dispatch: HAVE_TOPK_LENGTH=true and attn_sink enabled.
- Non-target D512/H64 combinations remain on the generic KernelTemplate path to avoid the measured slowdown risk from routing all D512/H64 cases through the H64 fast path.
Implementation
- Extended KernelTemplate_B_H_64 so the QK pipeline supports D_QK=512 with 16 q/k chunks instead of the existing D_QK=576-only 18 chunk schedule.
- Avoided index LDS prefetch overhead when IS_TOPK_2048 is false.
- Added attn_sink output scaling for the D512/H64 topk_length path.
- Added KernelTemplate_D512_H64_TopkLen_AttnSink wrapper and dispatch for D_QK=512 && HAVE_TOPK_LENGTH && h_q=64 && attn_sink.
Build
- Command:
source /parastor/home/public_user/zhanghj/dtk-26.04-DCC2602-0317/env.sh
touch csrc/gfx93/prefill/sparse/instantiations/phase1_k512.hip csrc/gfx93/prefill/sparse/instantiations/phase1_k512_topklen.hip
FLASH_MLA_OPT=phase1_d512_h64 python setup.py build_ext --inplace
- The touch is needed because phase1.cuh is included by generated instantiation .hip sources.
Benchmark
- Command:
PYTHONPATH=/parastor/home/public_user/zhanghj/flashmla/tests:$PYTHONPATH HIP_VISIBLE_DEVICES=1 python /parastor/home/public_user/zhanghj/hygon_tmp/bench_sparse_prefill_target.py --runs 20 --correctness --topk-length
- Device reported by PyTorch in this run: gfx936:sramecc+:xnack-
- hy-smi was run immediately before the measurement; all HCUs were idle at the start.
Target results: D_QK=512, h_q=64, topk=512, s_q=4096, HAVE_TOPK_LENGTH=true, attn_sink=true
| s_kv | baseline us | optimized us | latency reduction |
| ---: | ----------: | -----------: | ----------------: |
| 8192 | 3727.012 | 1955.701 | 47.53% |
| 32768 | 7798.721 | 2955.996 | 62.10% |
| 49152 | 8790.056 | 3162.484 | 64.02% |
| 65536 | 8959.296 | 3212.299 | 64.15% |
Average latency reduction: 59.45%.
The baseline values are from the same target benchmark before enabling this fast path. The optimized run also checked correctness for every measured target row.
Full regression
- Command:
HIP_VISIBLE_DEVICES=1 python tests/test_flash_mla_sparse_prefill.py
- Result:
All 958 cases passed.
Notes
- A broader D512/H64 dispatch was tested for HAVE_TOPK_LENGTH=false and attn_sink true/false. It gave only small, noisy gains in some rows and could regress the existing performance path, so the committed dispatch is limited to the target slow path that clears the 30% improvement requirement.
- The test environment prints NumPy 2.x compatibility warnings from PyTorch import. They did not prevent correctness or benchmark execution.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment