Commit 0728420c authored by zhanghj2's avatar zhanghj2
Browse files

优化sparse prefill

parent 1cb8a563
......@@ -124,7 +124,7 @@ static void run(const SparseAttnFwdParams &params);
};
template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048>
template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048, bool USE_ATTN_SINK = false, bool CACHE_INDICES_IN_LDS = false>
class KernelTemplate_B_H_64
{
public:
......
......@@ -10,8 +10,8 @@ namespace gfx93::fwd {
using namespace cute;
template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048>
__device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::devfunc(const SparseAttnFwdParams &params) {
template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048, bool USE_ATTN_SINK, bool CACHE_INDICES_IN_LDS>
__device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ATTN_SINK, CACHE_INDICES_IN_LDS>::devfunc(const SparseAttnFwdParams &params) {
const int tidx = threadIdx.x;
static constexpr int kBlockM = B_H;
......@@ -96,7 +96,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
#endif
// int row_offset = row + warp_idx * 16 + block_idx * kBlockN;
if constexpr (IS_TOPK_2048) {
if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset % 1024];
} else {
row_offset = gIndices[row_offset];
......@@ -109,7 +109,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
// int col = lane_idx % 4;
int row_offset = row + i * 16 + block_idx * kBlockN;;
// int col_offset = col * 8 + warp_idx * 32;
if constexpr (IS_TOPK_2048) {
if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset % 1024];
} else {
row_offset = gIndices[row_offset];
......@@ -132,8 +132,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
global_addr_q[2] = 64;
global_addr_q[3] = 0x00020000;
auto buffer_load_lds_indices = [&] (int n) {
if constexpr (IS_TOPK_2048) {
auto buffer_load_lds_indices = [&] (int n, int num_indices) {
if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
PtrWrapper glob_ptr_indices;
*(uint64_t*)&glob_ptr_indices = reinterpret_cast<uint64_t>(gIndices);
glob_ptr_indices.latter |= 0x40000000;
......@@ -146,6 +146,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
int ldsAddrPerWave = reinterpret_cast<size_t>(sIndices) + warp_idx * 64 * 4 * 4;
const int offset_v = lane_idx * 4 * 4 + warp_idx * 64 * 4 * 4;
const int offset_s = n * 1024 * 4;
const int first_index = warp_idx * 256 + lane_idx * 4;
if (first_index < num_indices) {
__builtin_amdgcn_sched_barrier(0);
asm volatile(
"s_mov_b32 m0, %1 \n\t"
......@@ -155,9 +157,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
:);
__builtin_amdgcn_sched_barrier(0);
}
}
};
if constexpr (IS_TOPK_2048) {
buffer_load_lds_indices(0);
if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
buffer_load_lds_indices(0, IS_TOPK_2048 ? 1024 : params.topk);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
......@@ -323,71 +326,20 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
auto [row_offset, col] = calc_row_and_col_k(block_idx);
row_offset = row_offset == -1 ? params.s_kv : row_offset;
#if 1
if constexpr (D_QK == 512) {
#define LOAD_K_AND_QK_GEMM_512(k) \
{ \
constexpr int k_val = (k); \
buffer_load_lds_k(row_offset, col, k_val - 3); \
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
}
constexpr int k_val = 15;
buffer_load_lds_k(row_offset, col, k_val);
buffer_load_lds_k(row_offset, col, k_val - 1);
buffer_load_lds_k(row_offset, col, k_val - 2);
buffer_load_lds_k(row_offset, col, k_val - 3);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
LOAD_K_AND_QK_GEMM_512(14);
LOAD_K_AND_QK_GEMM_512(13);
LOAD_K_AND_QK_GEMM_512(12);
LOAD_K_AND_QK_GEMM_512(11);
LOAD_K_AND_QK_GEMM_512(10);
LOAD_K_AND_QK_GEMM_512(9);
LOAD_K_AND_QK_GEMM_512(8);
LOAD_K_AND_QK_GEMM_512(7);
LOAD_K_AND_QK_GEMM_512(6);
LOAD_K_AND_QK_GEMM_512(5);
LOAD_K_AND_QK_GEMM_512(4);
LOAD_K_AND_QK_GEMM_512(3);
flash::qk_gemm<Element, 2>(q_reg[2].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, 1>(q_reg[1].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, 0>(q_reg[0].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
#undef LOAD_K_AND_QK_GEMM_512
} else {
#define LOAD_K_AND_QK_GEMM(k) \
{ \
constexpr int k_val = (k); \
if constexpr (k_val < kQkChunks - 1) { \
buffer_load_lds_k(row_offset, col, k_val - 3); \
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
} \
}
{
constexpr int k_val = (17);
constexpr int k_val = kQkChunks - 1;
buffer_load_lds_k(row_offset, col, k_val);
buffer_load_lds_k(row_offset, col, k_val - 1);
buffer_load_lds_k(row_offset, col, k_val - 2);
......@@ -415,23 +367,22 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
LOAD_K_AND_QK_GEMM(4);
LOAD_K_AND_QK_GEMM(3);
flash::qk_gemm<Element, k_val - 15>(q_reg[k_val - 15].data_128, k_lds_read_ptr, accs_f32);
flash::qk_gemm<Element, 2>(q_reg[2].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val - 16>(q_reg[k_val - 16].data_128, k_lds_read_ptr, accs_f32);
flash::qk_gemm<Element, 1>(q_reg[1].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val - 17>(q_reg[k_val - 17].data_128, k_lds_read_ptr, accs_f32);
flash::qk_gemm<Element, 0>(q_reg[0].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
}
#undef LOAD_K_AND_QK_GEMM
}
#else
#define LOAD_K_AND_QK_GEMM(k) \
{ \
......@@ -495,7 +446,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
const int n_idx = (lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16;
int offs = n_idx + block_idx * kBlockN;
int t;
if constexpr (IS_TOPK_2048) {
if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
t = sIndices[offs % 1024];
} else {
t = gIndices[offs];
......@@ -724,7 +675,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
{
process_one_block(block_idx, IsOtherBlock{});
}
buffer_load_lds_indices(1);
buffer_load_lds_indices(1, 1024);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
......@@ -765,8 +716,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
float* gMax_logits = reinterpret_cast<float *>(params.max_logits) + row_offset_lse;
float attn_sink_o_scale = 1.0f;
if constexpr (D_QK == 512 && HAVE_TOPK_LENGTH) {
if (params.attn_sink != nullptr) {
if constexpr (USE_ATTN_SINK) {
float rAttn_sink = __ldg((float*)params.attn_sink + bidh * kBlockM + lane_idx % 16 + warp_idx * 16);
if (flash::is_positive_infinity(rAttn_sink)) {
attn_sink_o_scale = 0.0f;
......@@ -776,7 +726,13 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
attn_sink_o_scale = lse_exp2 / (lse_exp2 + rAttn_sink_exp2);
}
}
auto maybe_apply_attn_sink = [&] (float value) -> float {
if constexpr (USE_ATTN_SINK) {
return value * attn_sink_o_scale;
} else {
return value;
}
};
{
// store O and gLSE
......@@ -792,13 +748,13 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
#if defined(__gfx938__)
Bf16_storage res;
col = (lane_idx / 16) * 8 + ni * 32 ;
res.data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][0] * attn_sink_o_scale, 0, acco_f32[ni * 2 + 1][0] * attn_sink_o_scale, 0);
res.data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, maybe_apply_attn_sink(acco_f32[ni * 2][0]), 0, maybe_apply_attn_sink(acco_f32[ni * 2 + 1][0]), 0);
res.data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][1] * attn_sink_o_scale, 0, acco_f32[ni * 2 + 1][1] * attn_sink_o_scale, 0);
res.data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, maybe_apply_attn_sink(acco_f32[ni * 2][1]), 0, maybe_apply_attn_sink(acco_f32[ni * 2 + 1][1]), 0);
res.data_32[2] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][2] * attn_sink_o_scale, 0, acco_f32[ni * 2 + 1][2] * attn_sink_o_scale, 0);
res.data_32[2] = __builtin_hcu_cvt_pk_bf16_f32(0, maybe_apply_attn_sink(acco_f32[ni * 2][2]), 0, maybe_apply_attn_sink(acco_f32[ni * 2 + 1][2]), 0);
res.data_32[3] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][3] * attn_sink_o_scale, 0, acco_f32[ni * 2 + 1][3] * attn_sink_o_scale, 0);
res.data_32[3] = __builtin_hcu_cvt_pk_bf16_f32(0, maybe_apply_attn_sink(acco_f32[ni * 2][3]), 0, maybe_apply_attn_sink(acco_f32[ni * 2 + 1][3]), 0);
*(__fp16x8_t*)(&gO(row, col)) = res.data_128;
......@@ -809,8 +765,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
{
result_type res;
Element e0, e1;
e0.storage = float2bf16(acco_f32[ni * 2][ei] * attn_sink_o_scale);
e1.storage = float2bf16(acco_f32[ni * 2 + 1][ei] * attn_sink_o_scale);
e0.storage = float2bf16(maybe_apply_attn_sink(acco_f32[ni * 2][ei]));
e1.storage = float2bf16(maybe_apply_attn_sink(acco_f32[ni * 2 + 1][ei]));
res[0] = e0;
res[1] = e1;
// gO(row, col) = res[0];
......@@ -1372,61 +1328,41 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &para
KU_CHECK_KERNEL_LAUNCH();
}
template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048>
void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::run(const SparseAttnFwdParams &params) {
template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048, bool USE_ATTN_SINK, bool CACHE_INDICES_IN_LDS>
void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ATTN_SINK, CACHE_INDICES_IN_LDS>::run(const SparseAttnFwdParams &params) {
KU_ASSERT(params.h_kv == 1);
// KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings
KU_ASSERT(params.topk > 0);
// KU_ASSERT(params.h_q % B_H == 0);
auto kernel = &sparse_attn_fwd_kernel<KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>>;
auto kernel = &sparse_attn_fwd_kernel<KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ATTN_SINK, CACHE_INDICES_IN_LDS>>;
constexpr size_t smem_size = 16384 + 4096; // 做了lds复用
dim3 grid((params.h_q + B_H - 1) / B_H, params.s_q, 1);
kernel<<<grid, NUM_THREADS, smem_size, params.stream>>>(params);
KU_CHECK_KERNEL_LAUNCH();
}
class KernelTemplate_D512_H64_TopkLen_AttnSink {
public:
static constexpr int NUM_THREADS = KernelTemplate_B_H_64<512, true, false>::NUM_THREADS;
static __device__ __forceinline__ void
devfunc(const SparseAttnFwdParams &params) {
KernelTemplate_B_H_64<512, true, false>::devfunc(params);
}
static void run(const SparseAttnFwdParams &params) {
KU_ASSERT(params.h_kv == 1);
KU_ASSERT(params.topk > 0);
auto kernel = &sparse_attn_fwd_kernel<KernelTemplate_D512_H64_TopkLen_AttnSink>;
constexpr size_t smem_size = 16384 + 4096;
dim3 grid((params.h_q + 64 - 1) / 64, params.s_q, 1);
kernel<<<grid, NUM_THREADS, smem_size, params.stream>>>(params);
KU_CHECK_KERNEL_LAUNCH();
template<int D_QK, bool HAVE_TOPK_LENGTH, bool USE_ATTN_SINK>
static void run_h64_fast_path(const SparseAttnFwdParams& params) {
if (params.topk == 2048) {
KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, true, USE_ATTN_SINK, false>::run(params);
} else if (params.topk <= 1024) {
KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, false, USE_ATTN_SINK, true>::run(params);
} else {
KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, false, USE_ATTN_SINK, false>::run(params);
}
}
};
template<int D_QK, bool HAVE_TOPK_LENGTH>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) {
if (D_QK == 512 && HAVE_TOPK_LENGTH && params.h_q == 64 && params.attn_sink)
{
KernelTemplate_D512_H64_TopkLen_AttnSink::run(params);
}
else if (params.h_q == 64 && !HAVE_TOPK_LENGTH && D_QK == 576 && !params.attn_sink)
{
if (params.topk == 2048)
{
KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, true>::run(params);
}
else
{
KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, false>::run(params);
if (params.h_q == 64) {
if (params.attn_sink) {
run_h64_fast_path<D_QK, HAVE_TOPK_LENGTH, true>(params);
} else {
run_h64_fast_path<D_QK, HAVE_TOPK_LENGTH, false>(params);
}
return;
}
else
{
KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(params);
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment