Commit 38421051 authored by zhanghj2's avatar zhanghj2
Browse files

减少lds使用, 提高并行度

parent 6d68e3d1
...@@ -725,7 +725,7 @@ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::run(const SparseAttnDecodeParams &pa ...@@ -725,7 +725,7 @@ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::run(const SparseAttnDecodeParams &pa
KU_ASSERT(params.stride_kv_row == 656); // number of bytes per token (512 fp8 + 4 float32 + 64 bfloat16) KU_ASSERT(params.stride_kv_row == 656); // number of bytes per token (512 fp8 + 4 float32 + 64 bfloat16)
} }
auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel<KernelTemplate<MODEL_TYPE, NUM_HEADS>>; auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel<KernelTemplate<MODEL_TYPE, NUM_HEADS>>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan); constexpr size_t smem_size = 32768; // lds复用
// zhj debug // zhj debug
// printf("NUM_M_BLOCKS = %d smem_size = %d \n",NUM_M_BLOCKS, smem_size); // printf("NUM_M_BLOCKS = %d smem_size = %d \n",NUM_M_BLOCKS, smem_size);
mla_kernel<<<dim3(NUM_M_BLOCKS, params.s_q, params.num_sm_parts), NUM_THREADS, smem_size, params.stream>>>(params); mla_kernel<<<dim3(NUM_M_BLOCKS, params.s_q, params.num_sm_parts), NUM_THREADS, smem_size, params.stream>>>(params);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment