Commit 4e0bdf6e authored by shenzhe's avatar shenzhe Committed by zhanghj2
Browse files

Support no-split sparse decode for large batch

parent 97ab7511
......@@ -289,6 +289,17 @@ sparse_attn_decode_interface(
}
DecodeImplMeta impl_meta = impl->get_meta(h_q, s_q);
bool force_no_split_kv = false;
if (const char* val = std::getenv("FLASH_MLA_SPARSE_DECODE_DISABLE_SPLITKV")) {
force_no_split_kv = (std::string(val) == "1");
}
constexpr int max_sched_meta_smem_size = 64 * 1024;
bool sched_meta_smem_overflow = sizeof(int) * (static_cast<int64_t>(b) * 5 + 1) > max_sched_meta_smem_size;
bool use_no_split_kv = force_no_split_kv || sched_meta_smem_overflow;
if (use_no_split_kv) {
constexpr int max_grid_z = 65535;
impl_meta.num_sm_parts = std::min(b, max_grid_z);
}
SparseAttnDecodeParams params = {
b, s_q, h_q, h_kv, d_qk, d_v,
......@@ -344,8 +355,12 @@ sparse_attn_decode_interface(
impl_meta.num_sm_parts,
at::cuda::getCurrentCUDAStream().stream()
};
if (use_no_split_kv) {
gfx9::decode::run_get_decoding_sched_meta_no_split_kernel(get_sched_meta_params);
} else {
gfx9::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
}
}
// Stick the metadata pointers to `params`
KU_CHECK_DEVICE(tile_scheduler_metadata);
KU_CHECK_DEVICE(num_splits);
......@@ -359,6 +374,7 @@ sparse_attn_decode_interface(
params.num_splits_ptr = num_splits->data_ptr<int>();
params.num_sm_parts = impl_meta.num_sm_parts;
if (!use_no_split_kv) {
// Allocate intermediate buffers for split-KV
const int total_num_splits = b + impl_meta.num_sm_parts;
lse_accum = torch::empty({total_num_splits, s_q, h_q}, opts.dtype(at::kFloat));
......@@ -372,9 +388,19 @@ sparse_attn_decode_interface(
params.stride_o_accum_split = int64_stride_to_int(o_accum.stride(0));
params.stride_o_accum_s_q = int64_stride_to_int(o_accum.stride(1));
params.stride_o_accum_h_q = int64_stride_to_int(o_accum.stride(2));
} else {
params.lse_accum = nullptr;
params.o_accum = nullptr;
params.stride_lse_accum_split = 0;
params.stride_lse_accum_s_q = 0;
params.stride_o_accum_split = 0;
params.stride_o_accum_s_q = 0;
params.stride_o_accum_h_q = 0;
}
impl->run(params, features);
if (!use_no_split_kv) {
CombineParams combine_params = {
b, s_q, h_q, d_v,
......@@ -396,6 +422,7 @@ sparse_attn_decode_interface(
at::cuda::getCurrentCUDAStream().stream()
};
gfx9::decode::run_flash_mla_combine_kernel<bf16>(combine_params);
}
delete impl;
......
#include "get_decoding_sched_meta.h"
#include <algorithm>
#include <cuda_runtime_api.h>
#include <cutlass/fast_math.h>
#include <kerutils/kerutils.cuh>
......@@ -105,10 +106,64 @@ get_mla_metadata_kernel(const GetDecodeSchedMetaParams params) {
}
}
__global__ void __launch_bounds__(256, 1)
get_mla_metadata_no_split_kernel(const GetDecodeSchedMetaParams params) {
DecodingSchedMeta *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
int *num_splits_ptr = params.num_splits_ptr;
int batch_size = params.b;
int block_size_n = params.block_size_n;
int num_sm_parts = params.num_sm_parts;
for (int part_idx = blockIdx.x * blockDim.x + threadIdx.x;
part_idx < num_sm_parts;
part_idx += blockDim.x * gridDim.x) {
int begin_req_idx = (static_cast<int64_t>(part_idx) * batch_size) / num_sm_parts;
int end_req_idx_exclusive = (static_cast<int64_t>(part_idx + 1) * batch_size) / num_sm_parts;
DecodingSchedMeta cur_meta;
cur_meta.begin_req_idx = begin_req_idx;
cur_meta.end_req_idx = end_req_idx_exclusive - 1;
cur_meta.begin_block_idx = 0;
cur_meta.begin_split_idx = 0;
cur_meta.is_first_req_splitted = 0;
cur_meta.is_last_req_splitted = 0;
cur_meta._pad[0] = 0;
int cur_s_k = 0;
if (begin_req_idx < end_req_idx_exclusive) {
if (params.topk == -1) {
cur_s_k = __ldg(params.seqlens_k_ptr + cur_meta.end_req_idx);
} else {
cur_s_k = params.topk_length ? __ldg(params.topk_length + cur_meta.end_req_idx) : params.topk;
if (cur_s_k == 0) cur_s_k = 1;
if (params.extra_topk) {
cur_s_k = ku::ceil(cur_s_k, block_size_n);
cur_s_k += params.extra_topk_length ? __ldg(params.extra_topk_length + cur_meta.end_req_idx) : params.extra_topk;
}
}
}
cur_meta.end_block_idx = cutlass::ceil_div(cur_s_k, block_size_n);
tile_scheduler_metadata_ptr[part_idx] = cur_meta;
}
for (int i = blockIdx.x * blockDim.x + threadIdx.x;
i <= batch_size;
i += blockDim.x * gridDim.x) {
num_splits_ptr[i] = i;
}
}
void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params) {
int smem_size = sizeof(int) * (params.b*5+1);
int smem_size = sizeof(int) * (static_cast<int64_t>(params.b) * 5 + 1);
get_mla_metadata_kernel<<<1, 64, smem_size, params.stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
}
void run_get_decoding_sched_meta_no_split_kernel(GetDecodeSchedMetaParams &params) {
int grid = cutlass::ceil_div(std::max(params.num_sm_parts, params.b + 1), 256);
grid = std::min(grid, 1024);
get_mla_metadata_no_split_kernel<<<grid, 256, 0, params.stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
}
}
......@@ -5,5 +5,6 @@
namespace gfx9::decode {
void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params);
void run_get_decoding_sched_meta_no_split_kernel(GetDecodeSchedMetaParams &params);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment