Commit 0728420c authored by zhanghj2's avatar zhanghj2
Browse files

优化sparse prefill

parent 1cb8a563
...@@ -124,7 +124,7 @@ static void run(const SparseAttnFwdParams &params); ...@@ -124,7 +124,7 @@ static void run(const SparseAttnFwdParams &params);
}; };
template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048> template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048, bool USE_ATTN_SINK = false, bool CACHE_INDICES_IN_LDS = false>
class KernelTemplate_B_H_64 class KernelTemplate_B_H_64
{ {
public: public:
......
...@@ -10,8 +10,8 @@ namespace gfx93::fwd { ...@@ -10,8 +10,8 @@ namespace gfx93::fwd {
using namespace cute; using namespace cute;
template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048> template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048, bool USE_ATTN_SINK, bool CACHE_INDICES_IN_LDS>
__device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::devfunc(const SparseAttnFwdParams &params) { __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ATTN_SINK, CACHE_INDICES_IN_LDS>::devfunc(const SparseAttnFwdParams &params) {
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
static constexpr int kBlockM = B_H; static constexpr int kBlockM = B_H;
...@@ -96,7 +96,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -96,7 +96,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
#endif #endif
// int row_offset = row + warp_idx * 16 + block_idx * kBlockN; // int row_offset = row + warp_idx * 16 + block_idx * kBlockN;
if constexpr (IS_TOPK_2048) { if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset % 1024]; row_offset = sIndices[row_offset % 1024];
} else { } else {
row_offset = gIndices[row_offset]; row_offset = gIndices[row_offset];
...@@ -109,7 +109,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -109,7 +109,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
// int col = lane_idx % 4; // int col = lane_idx % 4;
int row_offset = row + i * 16 + block_idx * kBlockN;; int row_offset = row + i * 16 + block_idx * kBlockN;;
// int col_offset = col * 8 + warp_idx * 32; // int col_offset = col * 8 + warp_idx * 32;
if constexpr (IS_TOPK_2048) { if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
row_offset = sIndices[row_offset % 1024]; row_offset = sIndices[row_offset % 1024];
} else { } else {
row_offset = gIndices[row_offset]; row_offset = gIndices[row_offset];
...@@ -132,8 +132,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -132,8 +132,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
global_addr_q[2] = 64; global_addr_q[2] = 64;
global_addr_q[3] = 0x00020000; global_addr_q[3] = 0x00020000;
auto buffer_load_lds_indices = [&] (int n) { auto buffer_load_lds_indices = [&] (int n, int num_indices) {
if constexpr (IS_TOPK_2048) { if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
PtrWrapper glob_ptr_indices; PtrWrapper glob_ptr_indices;
*(uint64_t*)&glob_ptr_indices = reinterpret_cast<uint64_t>(gIndices); *(uint64_t*)&glob_ptr_indices = reinterpret_cast<uint64_t>(gIndices);
glob_ptr_indices.latter |= 0x40000000; glob_ptr_indices.latter |= 0x40000000;
...@@ -146,6 +146,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -146,6 +146,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
int ldsAddrPerWave = reinterpret_cast<size_t>(sIndices) + warp_idx * 64 * 4 * 4; int ldsAddrPerWave = reinterpret_cast<size_t>(sIndices) + warp_idx * 64 * 4 * 4;
const int offset_v = lane_idx * 4 * 4 + warp_idx * 64 * 4 * 4; const int offset_v = lane_idx * 4 * 4 + warp_idx * 64 * 4 * 4;
const int offset_s = n * 1024 * 4; const int offset_s = n * 1024 * 4;
const int first_index = warp_idx * 256 + lane_idx * 4;
if (first_index < num_indices) {
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
...@@ -155,9 +157,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -155,9 +157,10 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
:); :);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
}
}; };
if constexpr (IS_TOPK_2048) { if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
buffer_load_lds_indices(0); buffer_load_lds_indices(0, IS_TOPK_2048 ? 1024 : params.topk);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -323,71 +326,20 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -323,71 +326,20 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
auto [row_offset, col] = calc_row_and_col_k(block_idx); auto [row_offset, col] = calc_row_and_col_k(block_idx);
row_offset = row_offset == -1 ? params.s_kv : row_offset; row_offset = row_offset == -1 ? params.s_kv : row_offset;
#if 1 #if 1
if constexpr (D_QK == 512) {
#define LOAD_K_AND_QK_GEMM_512(k) \
{ \
constexpr int k_val = (k); \
buffer_load_lds_k(row_offset, col, k_val - 3); \
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
}
constexpr int k_val = 15;
buffer_load_lds_k(row_offset, col, k_val);
buffer_load_lds_k(row_offset, col, k_val - 1);
buffer_load_lds_k(row_offset, col, k_val - 2);
buffer_load_lds_k(row_offset, col, k_val - 3);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
LOAD_K_AND_QK_GEMM_512(14);
LOAD_K_AND_QK_GEMM_512(13);
LOAD_K_AND_QK_GEMM_512(12);
LOAD_K_AND_QK_GEMM_512(11);
LOAD_K_AND_QK_GEMM_512(10);
LOAD_K_AND_QK_GEMM_512(9);
LOAD_K_AND_QK_GEMM_512(8);
LOAD_K_AND_QK_GEMM_512(7);
LOAD_K_AND_QK_GEMM_512(6);
LOAD_K_AND_QK_GEMM_512(5);
LOAD_K_AND_QK_GEMM_512(4);
LOAD_K_AND_QK_GEMM_512(3);
flash::qk_gemm<Element, 2>(q_reg[2].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, 1>(q_reg[1].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, 0>(q_reg[0].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0);
#undef LOAD_K_AND_QK_GEMM_512
} else {
#define LOAD_K_AND_QK_GEMM(k) \ #define LOAD_K_AND_QK_GEMM(k) \
{ \ { \
constexpr int k_val = (k); \ constexpr int k_val = (k); \
if constexpr (k_val < kQkChunks - 1) { \
buffer_load_lds_k(row_offset, col, k_val - 3); \ buffer_load_lds_k(row_offset, col, k_val - 3); \
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \ flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \ __builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \ __builtin_amdgcn_sched_barrier(0); \
} \
} }
{ {
constexpr int k_val = (17); constexpr int k_val = kQkChunks - 1;
buffer_load_lds_k(row_offset, col, k_val); buffer_load_lds_k(row_offset, col, k_val);
buffer_load_lds_k(row_offset, col, k_val - 1); buffer_load_lds_k(row_offset, col, k_val - 1);
buffer_load_lds_k(row_offset, col, k_val - 2); buffer_load_lds_k(row_offset, col, k_val - 2);
...@@ -415,23 +367,22 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -415,23 +367,22 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
LOAD_K_AND_QK_GEMM(4); LOAD_K_AND_QK_GEMM(4);
LOAD_K_AND_QK_GEMM(3); LOAD_K_AND_QK_GEMM(3);
flash::qk_gemm<Element, k_val - 15>(q_reg[k_val - 15].data_128, k_lds_read_ptr, accs_f32); flash::qk_gemm<Element, 2>(q_reg[2].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val - 16>(q_reg[k_val - 16].data_128, k_lds_read_ptr, accs_f32); flash::qk_gemm<Element, 1>(q_reg[1].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
flash::qk_gemm<Element, k_val - 17>(q_reg[k_val - 17].data_128, k_lds_read_ptr, accs_f32); flash::qk_gemm<Element, 0>(q_reg[0].data_128, k_lds_read_ptr, accs_f32);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
asm volatile("s_barrier\n\t"); asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
#undef LOAD_K_AND_QK_GEMM #undef LOAD_K_AND_QK_GEMM
}
#else #else
#define LOAD_K_AND_QK_GEMM(k) \ #define LOAD_K_AND_QK_GEMM(k) \
{ \ { \
...@@ -495,7 +446,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -495,7 +446,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
const int n_idx = (lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16; const int n_idx = (lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16;
int offs = n_idx + block_idx * kBlockN; int offs = n_idx + block_idx * kBlockN;
int t; int t;
if constexpr (IS_TOPK_2048) { if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) {
t = sIndices[offs % 1024]; t = sIndices[offs % 1024];
} else { } else {
t = gIndices[offs]; t = gIndices[offs];
...@@ -724,7 +675,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -724,7 +675,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
{ {
process_one_block(block_idx, IsOtherBlock{}); process_one_block(block_idx, IsOtherBlock{});
} }
buffer_load_lds_indices(1); buffer_load_lds_indices(1, 1024);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -765,8 +716,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -765,8 +716,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
float* gMax_logits = reinterpret_cast<float *>(params.max_logits) + row_offset_lse; float* gMax_logits = reinterpret_cast<float *>(params.max_logits) + row_offset_lse;
float attn_sink_o_scale = 1.0f; float attn_sink_o_scale = 1.0f;
if constexpr (D_QK == 512 && HAVE_TOPK_LENGTH) { if constexpr (USE_ATTN_SINK) {
if (params.attn_sink != nullptr) {
float rAttn_sink = __ldg((float*)params.attn_sink + bidh * kBlockM + lane_idx % 16 + warp_idx * 16); float rAttn_sink = __ldg((float*)params.attn_sink + bidh * kBlockM + lane_idx % 16 + warp_idx * 16);
if (flash::is_positive_infinity(rAttn_sink)) { if (flash::is_positive_infinity(rAttn_sink)) {
attn_sink_o_scale = 0.0f; attn_sink_o_scale = 0.0f;
...@@ -776,7 +726,13 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -776,7 +726,13 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
attn_sink_o_scale = lse_exp2 / (lse_exp2 + rAttn_sink_exp2); attn_sink_o_scale = lse_exp2 / (lse_exp2 + rAttn_sink_exp2);
} }
} }
auto maybe_apply_attn_sink = [&] (float value) -> float {
if constexpr (USE_ATTN_SINK) {
return value * attn_sink_o_scale;
} else {
return value;
} }
};
{ {
// store O and gLSE // store O and gLSE
...@@ -792,13 +748,13 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -792,13 +748,13 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
#if defined(__gfx938__) #if defined(__gfx938__)
Bf16_storage res; Bf16_storage res;
col = (lane_idx / 16) * 8 + ni * 32 ; col = (lane_idx / 16) * 8 + ni * 32 ;
res.data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][0] * attn_sink_o_scale, 0, acco_f32[ni * 2 + 1][0] * attn_sink_o_scale, 0); res.data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, maybe_apply_attn_sink(acco_f32[ni * 2][0]), 0, maybe_apply_attn_sink(acco_f32[ni * 2 + 1][0]), 0);
res.data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][1] * attn_sink_o_scale, 0, acco_f32[ni * 2 + 1][1] * attn_sink_o_scale, 0); res.data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, maybe_apply_attn_sink(acco_f32[ni * 2][1]), 0, maybe_apply_attn_sink(acco_f32[ni * 2 + 1][1]), 0);
res.data_32[2] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][2] * attn_sink_o_scale, 0, acco_f32[ni * 2 + 1][2] * attn_sink_o_scale, 0); res.data_32[2] = __builtin_hcu_cvt_pk_bf16_f32(0, maybe_apply_attn_sink(acco_f32[ni * 2][2]), 0, maybe_apply_attn_sink(acco_f32[ni * 2 + 1][2]), 0);
res.data_32[3] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][3] * attn_sink_o_scale, 0, acco_f32[ni * 2 + 1][3] * attn_sink_o_scale, 0); res.data_32[3] = __builtin_hcu_cvt_pk_bf16_f32(0, maybe_apply_attn_sink(acco_f32[ni * 2][3]), 0, maybe_apply_attn_sink(acco_f32[ni * 2 + 1][3]), 0);
*(__fp16x8_t*)(&gO(row, col)) = res.data_128; *(__fp16x8_t*)(&gO(row, col)) = res.data_128;
...@@ -809,8 +765,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev ...@@ -809,8 +765,8 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
{ {
result_type res; result_type res;
Element e0, e1; Element e0, e1;
e0.storage = float2bf16(acco_f32[ni * 2][ei] * attn_sink_o_scale); e0.storage = float2bf16(maybe_apply_attn_sink(acco_f32[ni * 2][ei]));
e1.storage = float2bf16(acco_f32[ni * 2 + 1][ei] * attn_sink_o_scale); e1.storage = float2bf16(maybe_apply_attn_sink(acco_f32[ni * 2 + 1][ei]));
res[0] = e0; res[0] = e0;
res[1] = e1; res[1] = e1;
// gO(row, col) = res[0]; // gO(row, col) = res[0];
...@@ -1372,61 +1328,41 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &para ...@@ -1372,61 +1328,41 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &para
KU_CHECK_KERNEL_LAUNCH(); KU_CHECK_KERNEL_LAUNCH();
} }
template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048> template<int D_QK, bool HAVE_TOPK_LENGTH, bool IS_TOPK_2048, bool USE_ATTN_SINK, bool CACHE_INDICES_IN_LDS>
void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::run(const SparseAttnFwdParams &params) { void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ATTN_SINK, CACHE_INDICES_IN_LDS>::run(const SparseAttnFwdParams &params) {
KU_ASSERT(params.h_kv == 1); KU_ASSERT(params.h_kv == 1);
// KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings // KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings
KU_ASSERT(params.topk > 0); KU_ASSERT(params.topk > 0);
// KU_ASSERT(params.h_q % B_H == 0); // KU_ASSERT(params.h_q % B_H == 0);
auto kernel = &sparse_attn_fwd_kernel<KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>>; auto kernel = &sparse_attn_fwd_kernel<KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048, USE_ATTN_SINK, CACHE_INDICES_IN_LDS>>;
constexpr size_t smem_size = 16384 + 4096; // 做了lds复用 constexpr size_t smem_size = 16384 + 4096; // 做了lds复用
dim3 grid((params.h_q + B_H - 1) / B_H, params.s_q, 1); dim3 grid((params.h_q + B_H - 1) / B_H, params.s_q, 1);
kernel<<<grid, NUM_THREADS, smem_size, params.stream>>>(params); kernel<<<grid, NUM_THREADS, smem_size, params.stream>>>(params);
KU_CHECK_KERNEL_LAUNCH(); KU_CHECK_KERNEL_LAUNCH();
} }
class KernelTemplate_D512_H64_TopkLen_AttnSink { template<int D_QK, bool HAVE_TOPK_LENGTH, bool USE_ATTN_SINK>
public: static void run_h64_fast_path(const SparseAttnFwdParams& params) {
static constexpr int NUM_THREADS = KernelTemplate_B_H_64<512, true, false>::NUM_THREADS; if (params.topk == 2048) {
KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, true, USE_ATTN_SINK, false>::run(params);
static __device__ __forceinline__ void } else if (params.topk <= 1024) {
devfunc(const SparseAttnFwdParams &params) { KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, false, USE_ATTN_SINK, true>::run(params);
KernelTemplate_B_H_64<512, true, false>::devfunc(params); } else {
} KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, false, USE_ATTN_SINK, false>::run(params);
}
static void run(const SparseAttnFwdParams &params) {
KU_ASSERT(params.h_kv == 1);
KU_ASSERT(params.topk > 0);
auto kernel = &sparse_attn_fwd_kernel<KernelTemplate_D512_H64_TopkLen_AttnSink>;
constexpr size_t smem_size = 16384 + 4096;
dim3 grid((params.h_q + 64 - 1) / 64, params.s_q, 1);
kernel<<<grid, NUM_THREADS, smem_size, params.stream>>>(params);
KU_CHECK_KERNEL_LAUNCH();
} }
};
template<int D_QK, bool HAVE_TOPK_LENGTH> template<int D_QK, bool HAVE_TOPK_LENGTH>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) { void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) {
if (D_QK == 512 && HAVE_TOPK_LENGTH && params.h_q == 64 && params.attn_sink) if (params.h_q == 64) {
{ if (params.attn_sink) {
KernelTemplate_D512_H64_TopkLen_AttnSink::run(params); run_h64_fast_path<D_QK, HAVE_TOPK_LENGTH, true>(params);
} } else {
else if (params.h_q == 64 && !HAVE_TOPK_LENGTH && D_QK == 576 && !params.attn_sink) run_h64_fast_path<D_QK, HAVE_TOPK_LENGTH, false>(params);
{
if (params.topk == 2048)
{
KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, true>::run(params);
}
else
{
KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, false>::run(params);
} }
return;
} }
else
{
KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(params); KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(params);
}
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment