Commit a8393a04 authored by zhanghj2's avatar zhanghj2
Browse files

支持nhead<16

parent 945ced44
...@@ -137,7 +137,7 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface( ...@@ -137,7 +137,7 @@ static std::vector<at::Tensor> sparse_attn_prefill_interface(
}; };
std::vector<FwdFeatures> required_features; std::vector<FwdFeatures> required_features;
if (h_q == 16) { if (h_q <= 16) {
required_features.push_back(FwdFeatures::HEAD_16); required_features.push_back(FwdFeatures::HEAD_16);
} else if (h_q == 64) { } else if (h_q == 64) {
required_features.push_back(FwdFeatures::HEAD_64); required_features.push_back(FwdFeatures::HEAD_64);
......
...@@ -934,7 +934,7 @@ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::run(const SparseAttnDecodeParams &pa ...@@ -934,7 +934,7 @@ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::run(const SparseAttnDecodeParams &pa
KU_ASSERT(params.topk % TOPK_BLOCK_SIZE == 0); KU_ASSERT(params.topk % TOPK_BLOCK_SIZE == 0);
KU_ASSERT(params.d_qk == HEAD_DIM_K); KU_ASSERT(params.d_qk == HEAD_DIM_K);
KU_ASSERT(params.d_v == HEAD_DIM_V); KU_ASSERT(params.d_v == HEAD_DIM_V);
KU_ASSERT(params.h_q % BLOCK_M == 0); // KU_ASSERT(params.h_q % BLOCK_M == 0);
if constexpr (MODEL_TYPE == ModelType::MODEL1) { if constexpr (MODEL_TYPE == ModelType::MODEL1) {
constexpr int BYTES_PER_TOKEN = HEAD_DIM_NOPE + 2*HEAD_DIM_ROPE + 8; constexpr int BYTES_PER_TOKEN = HEAD_DIM_NOPE + 2*HEAD_DIM_ROPE + 8;
KU_ASSERT(params.stride_kv_row == BYTES_PER_TOKEN, "Each page block in KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1"); // Each block must be contiguous KU_ASSERT(params.stride_kv_row == BYTES_PER_TOKEN, "Each page block in KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1"); // Each block must be contiguous
......
...@@ -483,10 +483,10 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &para ...@@ -483,10 +483,10 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &para
KU_ASSERT(params.h_kv == 1); KU_ASSERT(params.h_kv == 1);
KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings
KU_ASSERT(params.topk > 0); KU_ASSERT(params.topk > 0);
KU_ASSERT(params.h_q % B_H == 0); // KU_ASSERT(params.h_q % B_H == 0);
auto kernel = &sparse_attn_fwd_kernel<KernelTemplate<D_QK, HAVE_TOPK_LENGTH>>; auto kernel = &sparse_attn_fwd_kernel<KernelTemplate<D_QK, HAVE_TOPK_LENGTH>>;
constexpr size_t smem_size = 16384 + 4096; // 做了lds复用 constexpr size_t smem_size = 16384 + 4096; // 做了lds复用
dim3 grid(params.s_q, params.h_q/B_H, 1); dim3 grid(params.s_q, (params.h_q + B_H - 1) / B_H, 1);
kernel<<<grid, NUM_THREADS, smem_size, params.stream>>>(params); kernel<<<grid, NUM_THREADS, smem_size, params.stream>>>(params);
KU_CHECK_KERNEL_LAUNCH(); KU_CHECK_KERNEL_LAUNCH();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment