Commit 6d68e3d1 authored by zhanghj2's avatar zhanghj2
Browse files

减少lds用量

parent 5d62c0d7
...@@ -485,7 +485,7 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &para ...@@ -485,7 +485,7 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &para
KU_ASSERT(params.topk > 0); KU_ASSERT(params.topk > 0);
KU_ASSERT(params.h_q % B_H == 0); KU_ASSERT(params.h_q % B_H == 0);
auto kernel = &sparse_attn_fwd_kernel<KernelTemplate<D_QK, HAVE_TOPK_LENGTH>>; auto kernel = &sparse_attn_fwd_kernel<KernelTemplate<D_QK, HAVE_TOPK_LENGTH>>;
constexpr size_t smem_size = 65536; constexpr size_t smem_size = 16384 + 4096; // 做了lds复用
dim3 grid(params.s_q, params.h_q/B_H, 1); dim3 grid(params.s_q, params.h_q/B_H, 1);
kernel<<<grid, NUM_THREADS, smem_size, params.stream>>>(params); kernel<<<grid, NUM_THREADS, smem_size, params.stream>>>(params);
KU_CHECK_KERNEL_LAUNCH(); KU_CHECK_KERNEL_LAUNCH();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment