Commit 4e0bdf6e authored by shenzhe's avatar shenzhe Committed by zhanghj2
Browse files

Support no-split sparse decode for large batch

parent 97ab7511
...@@ -289,6 +289,17 @@ sparse_attn_decode_interface( ...@@ -289,6 +289,17 @@ sparse_attn_decode_interface(
} }
DecodeImplMeta impl_meta = impl->get_meta(h_q, s_q); DecodeImplMeta impl_meta = impl->get_meta(h_q, s_q);
bool force_no_split_kv = false;
if (const char* val = std::getenv("FLASH_MLA_SPARSE_DECODE_DISABLE_SPLITKV")) {
force_no_split_kv = (std::string(val) == "1");
}
constexpr int max_sched_meta_smem_size = 64 * 1024;
bool sched_meta_smem_overflow = sizeof(int) * (static_cast<int64_t>(b) * 5 + 1) > max_sched_meta_smem_size;
bool use_no_split_kv = force_no_split_kv || sched_meta_smem_overflow;
if (use_no_split_kv) {
constexpr int max_grid_z = 65535;
impl_meta.num_sm_parts = std::min(b, max_grid_z);
}
SparseAttnDecodeParams params = { SparseAttnDecodeParams params = {
b, s_q, h_q, h_kv, d_qk, d_v, b, s_q, h_q, h_kv, d_qk, d_v,
...@@ -344,7 +355,11 @@ sparse_attn_decode_interface( ...@@ -344,7 +355,11 @@ sparse_attn_decode_interface(
impl_meta.num_sm_parts, impl_meta.num_sm_parts,
at::cuda::getCurrentCUDAStream().stream() at::cuda::getCurrentCUDAStream().stream()
}; };
gfx9::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params); if (use_no_split_kv) {
gfx9::decode::run_get_decoding_sched_meta_no_split_kernel(get_sched_meta_params);
} else {
gfx9::decode::run_get_decoding_sched_meta_kernel(get_sched_meta_params);
}
} }
// Stick the metadata pointers to `params` // Stick the metadata pointers to `params`
KU_CHECK_DEVICE(tile_scheduler_metadata); KU_CHECK_DEVICE(tile_scheduler_metadata);
...@@ -359,43 +374,55 @@ sparse_attn_decode_interface( ...@@ -359,43 +374,55 @@ sparse_attn_decode_interface(
params.num_splits_ptr = num_splits->data_ptr<int>(); params.num_splits_ptr = num_splits->data_ptr<int>();
params.num_sm_parts = impl_meta.num_sm_parts; params.num_sm_parts = impl_meta.num_sm_parts;
// Allocate intermediate buffers for split-KV if (!use_no_split_kv) {
const int total_num_splits = b + impl_meta.num_sm_parts; // Allocate intermediate buffers for split-KV
lse_accum = torch::empty({total_num_splits, s_q, h_q}, opts.dtype(at::kFloat)); const int total_num_splits = b + impl_meta.num_sm_parts;
o_accum = torch::empty({total_num_splits, s_q, h_q, d_v}, opts.dtype(at::kFloat)); lse_accum = torch::empty({total_num_splits, s_q, h_q}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(lse_accum); o_accum = torch::empty({total_num_splits, s_q, h_q, d_v}, opts.dtype(at::kFloat));
KU_CHECK_CONTIGUOUS(o_accum); KU_CHECK_CONTIGUOUS(lse_accum);
params.lse_accum = lse_accum.data_ptr<float>(); KU_CHECK_CONTIGUOUS(o_accum);
params.o_accum = o_accum.data_ptr<float>(); params.lse_accum = lse_accum.data_ptr<float>();
params.stride_lse_accum_split = int64_stride_to_int(lse_accum.stride(0)); params.o_accum = o_accum.data_ptr<float>();
params.stride_lse_accum_s_q = int64_stride_to_int(lse_accum.stride(1)); params.stride_lse_accum_split = int64_stride_to_int(lse_accum.stride(0));
params.stride_o_accum_split = int64_stride_to_int(o_accum.stride(0)); params.stride_lse_accum_s_q = int64_stride_to_int(lse_accum.stride(1));
params.stride_o_accum_s_q = int64_stride_to_int(o_accum.stride(1)); params.stride_o_accum_split = int64_stride_to_int(o_accum.stride(0));
params.stride_o_accum_h_q = int64_stride_to_int(o_accum.stride(2)); params.stride_o_accum_s_q = int64_stride_to_int(o_accum.stride(1));
params.stride_o_accum_h_q = int64_stride_to_int(o_accum.stride(2));
} else {
params.lse_accum = nullptr;
params.o_accum = nullptr;
params.stride_lse_accum_split = 0;
params.stride_lse_accum_s_q = 0;
params.stride_o_accum_split = 0;
params.stride_o_accum_s_q = 0;
params.stride_o_accum_h_q = 0;
}
impl->run(params, features); impl->run(params, features);
CombineParams combine_params = { if (!use_no_split_kv) {
b, s_q, h_q, d_v, CombineParams combine_params = {
b, s_q, h_q, d_v,
params.lse, params.lse,
params.out, params.out,
params.stride_lse_b, params.stride_lse_s_q, params.stride_lse_b, params.stride_lse_s_q,
params.stride_o_b, params.stride_o_s_q, params.stride_o_h_q, params.stride_o_b, params.stride_o_s_q, params.stride_o_h_q,
params.lse_accum, params.lse_accum,
params.o_accum, params.o_accum,
params.stride_lse_accum_split, params.stride_lse_accum_s_q, params.stride_lse_accum_split, params.stride_lse_accum_s_q,
params.stride_o_accum_split, params.stride_o_accum_s_q, params.stride_o_accum_h_q, params.stride_o_accum_split, params.stride_o_accum_s_q, params.stride_o_accum_h_q,
params.tile_scheduler_metadata_ptr, params.tile_scheduler_metadata_ptr,
params.num_splits_ptr, params.num_splits_ptr,
params.num_sm_parts, params.num_sm_parts,
ku::get_optional_tensor_ptr<float>(attn_sink), ku::get_optional_tensor_ptr<float>(attn_sink),
at::cuda::getCurrentCUDAStream().stream() at::cuda::getCurrentCUDAStream().stream()
}; };
gfx9::decode::run_flash_mla_combine_kernel<bf16>(combine_params); gfx9::decode::run_flash_mla_combine_kernel<bf16>(combine_params);
}
delete impl; delete impl;
......
#include "get_decoding_sched_meta.h" #include "get_decoding_sched_meta.h"
#include <algorithm>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cutlass/fast_math.h> #include <cutlass/fast_math.h>
#include <kerutils/kerutils.cuh> #include <kerutils/kerutils.cuh>
...@@ -105,10 +106,64 @@ get_mla_metadata_kernel(const GetDecodeSchedMetaParams params) { ...@@ -105,10 +106,64 @@ get_mla_metadata_kernel(const GetDecodeSchedMetaParams params) {
} }
} }
__global__ void __launch_bounds__(256, 1)
get_mla_metadata_no_split_kernel(const GetDecodeSchedMetaParams params) {
DecodingSchedMeta *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
int *num_splits_ptr = params.num_splits_ptr;
int batch_size = params.b;
int block_size_n = params.block_size_n;
int num_sm_parts = params.num_sm_parts;
for (int part_idx = blockIdx.x * blockDim.x + threadIdx.x;
part_idx < num_sm_parts;
part_idx += blockDim.x * gridDim.x) {
int begin_req_idx = (static_cast<int64_t>(part_idx) * batch_size) / num_sm_parts;
int end_req_idx_exclusive = (static_cast<int64_t>(part_idx + 1) * batch_size) / num_sm_parts;
DecodingSchedMeta cur_meta;
cur_meta.begin_req_idx = begin_req_idx;
cur_meta.end_req_idx = end_req_idx_exclusive - 1;
cur_meta.begin_block_idx = 0;
cur_meta.begin_split_idx = 0;
cur_meta.is_first_req_splitted = 0;
cur_meta.is_last_req_splitted = 0;
cur_meta._pad[0] = 0;
int cur_s_k = 0;
if (begin_req_idx < end_req_idx_exclusive) {
if (params.topk == -1) {
cur_s_k = __ldg(params.seqlens_k_ptr + cur_meta.end_req_idx);
} else {
cur_s_k = params.topk_length ? __ldg(params.topk_length + cur_meta.end_req_idx) : params.topk;
if (cur_s_k == 0) cur_s_k = 1;
if (params.extra_topk) {
cur_s_k = ku::ceil(cur_s_k, block_size_n);
cur_s_k += params.extra_topk_length ? __ldg(params.extra_topk_length + cur_meta.end_req_idx) : params.extra_topk;
}
}
}
cur_meta.end_block_idx = cutlass::ceil_div(cur_s_k, block_size_n);
tile_scheduler_metadata_ptr[part_idx] = cur_meta;
}
for (int i = blockIdx.x * blockDim.x + threadIdx.x;
i <= batch_size;
i += blockDim.x * gridDim.x) {
num_splits_ptr[i] = i;
}
}
void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params) { void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params) {
int smem_size = sizeof(int) * (params.b*5+1); int smem_size = sizeof(int) * (static_cast<int64_t>(params.b) * 5 + 1);
get_mla_metadata_kernel<<<1, 64, smem_size, params.stream>>>(params); get_mla_metadata_kernel<<<1, 64, smem_size, params.stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH(); CHECK_CUDA_KERNEL_LAUNCH();
} }
void run_get_decoding_sched_meta_no_split_kernel(GetDecodeSchedMetaParams &params) {
int grid = cutlass::ceil_div(std::max(params.num_sm_parts, params.b + 1), 256);
grid = std::min(grid, 1024);
get_mla_metadata_no_split_kernel<<<grid, 256, 0, params.stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
}
} }
...@@ -5,5 +5,6 @@ ...@@ -5,5 +5,6 @@
namespace gfx9::decode { namespace gfx9::decode {
void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params); void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params);
void run_get_decoding_sched_meta_no_split_kernel(GetDecodeSchedMetaParams &params);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment