Commit 12b60933 authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_dev' into v0.5.4_dev_linhai

parents cd0b5891 d297cda2
...@@ -1618,7 +1618,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1618,7 +1618,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum() self.seq_lens_sum = self.seq_lens.sum().item()
self.output_ids = self.output_ids[keep_indices_device] self.output_ids = self.output_ids[keep_indices_device]
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob: if self.return_logprob:
......
...@@ -28,6 +28,7 @@ import triton.language as tl ...@@ -28,6 +28,7 @@ import triton.language as tl
from sglang.srt.mem_cache.memory_pool import SWAKVPool from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2 from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
from sgl_kernel.kvcacheio import dcu_alloc_decode_kernel, dcu_alloc_extend_kernel
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache from sglang.srt.mem_cache.memory_pool import KVCache
...@@ -430,6 +431,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -430,6 +431,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
super().__init__(size, page_size, dtype, device, kvcache, need_sort) super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size self.num_pages = size // page_size
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL") self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.sglang_kvalloc_kernel = get_bool_env_var("SGLANG_KVALLOC_KERNEL")
self.seen_max_num_extend_tokens_next_power_of_2 = 1 self.seen_max_num_extend_tokens_next_power_of_2 = 1
self.clear() self.clear()
...@@ -484,16 +486,41 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -484,16 +486,41 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
out_indices = torch.empty( out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device (extend_num_tokens,), dtype=torch.int64, device=self.device
) )
alloc_extend_kernel[(bs,)]( if self.sglang_kvalloc_kernel:
prefix_lens, if bs < 3:
seq_lens, dcu_alloc_extend_kernel(
last_loc, pre_lens_ptr = prefix_lens,
self.free_pages, seq_lens_ptr = seq_lens,
out_indices, last_loc_ptr = last_loc,
next_power_of_2(bs), free_page_ptr = self.free_pages,
self.page_size, out_indices = out_indices,
self.seen_max_num_extend_tokens_next_power_of_2, bs = bs,
) bs_upper = next_power_of_2(bs),
page_size = self.page_size,
max_num_extend_tokens = self.seen_max_num_extend_tokens_next_power_of_2,
)
else:
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
self.seen_max_num_extend_tokens_next_power_of_2,
)
else:
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
self.seen_max_num_extend_tokens_next_power_of_2,
)
if self.debug_mode: if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices) assert len(torch.unique(out_indices)) == len(out_indices)
...@@ -525,14 +552,26 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -525,14 +552,26 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self.merge_and_sort_free() self.merge_and_sort_free()
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
alloc_decode_kernel[(bs,)](
seq_lens, if self.sglang_kvalloc_kernel:
last_loc, dcu_alloc_decode_kernel(
self.free_pages, seq_lens_ptr = seq_lens,
out_indices, last_loc_ptr = last_loc,
next_power_of_2(bs), free_page_ptr = self.free_pages,
self.page_size, out_indices = out_indices,
) bs = bs,
bs_upper = next_power_of_2(bs),
page_size = self.page_size,
)
else:
alloc_decode_kernel[(bs,)](
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
)
if self.debug_mode: if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices) assert len(torch.unique(out_indices)) == len(out_indices)
......
...@@ -174,6 +174,7 @@ MLA_ATTENTION_BACKENDS = [ ...@@ -174,6 +174,7 @@ MLA_ATTENTION_BACKENDS = [
CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [ CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
"flashinfer", "flashinfer",
"fa3", "fa3",
"dcu_mla",
"fa4", "fa4",
"flashmla", "flashmla",
"cutlass_mla", "cutlass_mla",
...@@ -2239,7 +2240,6 @@ class ModelRunner: ...@@ -2239,7 +2240,6 @@ class ModelRunner:
and self.graph_runner and self.graph_runner
and self.graph_runner.can_run(forward_batch) and self.graph_runner.can_run(forward_batch)
) )
if can_run_graph: if can_run_graph:
ret = self.graph_runner.replay( ret = self.graph_runner.replay(
forward_batch, forward_batch,
......
...@@ -185,6 +185,7 @@ elif _is_hip: ...@@ -185,6 +185,7 @@ elif _is_hip:
from sglang.srt.layers.quantization.awq_triton import ( from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_triton as awq_dequantize, awq_dequantize_triton as awq_dequantize,
) )
from sgl_kernel import merge_state_v2
elif _is_npu: elif _is_npu:
import custom_ops # noqa: F401 import custom_ops # noqa: F401
import sgl_kernel_npu # noqa: F401 import sgl_kernel_npu # noqa: F401
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <algorithm> #include <algorithm>
#include <optional> #include <optional>
#include "pytorch_extension_utils.h" #include "pytorch_extension_utils_rocm.h"
// Helper functions to convert between different data types // Helper functions to convert between different data types
// (float, half, bfloat16) for the merge attention states kernel. // (float, half, bfloat16) for the merge attention states kernel.
...@@ -27,6 +27,19 @@ inline __device__ void from_float(__nv_bfloat16& d, float s) { ...@@ -27,6 +27,19 @@ inline __device__ void from_float(__nv_bfloat16& d, float s) {
d = __float2bfloat16(s); d = __float2bfloat16(s);
} }
inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")");
}
}
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
template <typename scalar_t, const uint NUM_THREADS> template <typename scalar_t, const uint NUM_THREADS>
__global__ void merge_attn_states_kernel( __global__ void merge_attn_states_kernel(
......
...@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def("gelu_quick(Tensor! out, Tensor input) -> ()"); m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
m.impl("gelu_quick", torch::kCUDA, &gelu_quick); m.impl("gelu_quick", torch::kCUDA, &gelu_quick);
/*
* From csrc/attention
*/
m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
/* /*
* From csrc/allreduce * From csrc/allreduce
*/ */
...@@ -125,6 +131,10 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -125,6 +131,10 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/* /*
* From csrc/kvcacheio * From csrc/kvcacheio
*/ */
m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size, int max_num_extend_tokens) -> ()");
m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel);
m.def("dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size) -> ()");
m.impl("dcu_alloc_decode_kernel", torch::kCUDA, &dcu_alloc_decode_kernel);
m.def( m.def(
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " "transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"); "dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
......
...@@ -571,3 +571,171 @@ void transfer_kv_all_layer_direct_lf_pf( ...@@ -571,3 +571,171 @@ void transfer_kv_all_layer_direct_lf_pf(
int64_t page_size) { int64_t page_size) {
transfer_kv_page_first_direct_impl<true>(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size); transfer_kv_page_first_direct_impl<true>(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size);
} }
__device__ int64_t ceil_div(int64_t a, int64_t b) {
return (a + b - 1) / b;
}
__device__ int64_t safe_min(int64_t a, int64_t b) {
return a < b ? a : b;
}
__global__ void launch_alloc_decode_kernel(
const int64_t* seq_lens_ptr,
const int32_t* last_loc_ptr,
const int64_t* free_page_ptr,
int64_t* out_indices,
int64_t bs_upper,
int64_t page_size) {
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs_upper) return;
int64_t seq_len = seq_lens_ptr[pid];
int64_t pre_len = seq_len - 1;
int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size);
int64_t sum_num_new_pages = 0;
for (int64_t i = 0; i <= pid; i++) {
int64_t other_seq_len = seq_lens_ptr[i];
int64_t other_pre_len = (i <= pid) ? (other_seq_len - 1) : other_seq_len;
int64_t other_num_pages_after = ceil_div(other_seq_len, page_size);
int64_t other_num_pages_before = ceil_div(other_pre_len, page_size);
int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before;
sum_num_new_pages += other_num_new_pages;
}
int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self;
if (num_page_start_loc_self == 0) {
int32_t last_loc = last_loc_ptr[pid];
out_indices[pid] = last_loc + 1;
} else {
int64_t page = free_page_ptr[new_page_start_loc];
out_indices[pid] = page * page_size;
}
}
__global__ void launch_alloc_extend_kernel(
const int64_t* pre_lens_ptr,
const int64_t* seq_lens_ptr,
const int64_t* last_loc_ptr,
const int64_t* free_page_ptr,
int64_t* out_indices,
int64_t bs_upper,
int64_t page_size,
int64_t max_num_extend_tokens)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs_upper) return;
int64_t seq_len = seq_lens_ptr[pid];
int64_t pre_len = pre_lens_ptr[pid];
int64_t extend_len = seq_len - pre_len;
int64_t sum_extend_lens = 0;
for (int64_t i = 0; i <= pid; i++) {
int64_t other_seq_len = seq_lens_ptr[i];
int64_t other_pre_len = pre_lens_ptr[i];
int64_t other_extend_len = other_seq_len - other_pre_len;
sum_extend_lens += other_extend_len;
}
int64_t output_start_loc = sum_extend_lens - extend_len;
int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size);
int64_t sum_num_new_pages = 0;
for (int64_t i = 0; i <= pid; i++) {
int64_t other_seq_len = seq_lens_ptr[i];
int64_t other_pre_len = pre_lens_ptr[i];
int64_t other_num_pages_after = ceil_div(other_seq_len, page_size);
int64_t other_num_pages_before = ceil_div(other_pre_len, page_size);
int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before;
sum_num_new_pages += other_num_new_pages;
}
int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self;
int64_t last_loc = last_loc_ptr[pid];
int64_t num_part1 = safe_min(seq_len, ceil_div(pre_len, page_size) * page_size) - pre_len;
for (int64_t offset = 0; offset < num_part1 && offset < page_size; offset++) {
int64_t output_idx = output_start_loc + offset;
out_indices[output_idx] = last_loc + 1 + offset;
}
if (pre_len + num_part1 == seq_len) {
return;
}
int64_t num_part2 = (seq_len / page_size) * page_size - ceil_div(pre_len, page_size) * page_size;
for (int64_t offset = 0; offset < num_part2 && offset < max_num_extend_tokens; offset++) {
int64_t page_idx = new_page_start_loc + offset / page_size;
int64_t page_start = free_page_ptr[page_idx];
int64_t output_idx = output_start_loc + num_part1 + offset;
out_indices[output_idx] = page_start * page_size + offset % page_size;
}
if (pre_len + num_part1 + num_part2 == seq_len) {
return;
}
int64_t num_part3 = seq_len - (seq_len / page_size) * page_size;
int64_t last_page_idx = new_page_start_loc + num_page_start_loc_self - 1;
int64_t start_loc = free_page_ptr[last_page_idx];
for (int64_t offset = 0; offset < num_part3 && offset < page_size; offset++) {
int64_t output_idx = output_start_loc + num_part1 + num_part2 + offset;
out_indices[output_idx] = start_loc * page_size + offset;
}
}
void dcu_alloc_decode_kernel(
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t bs_upper,
int64_t page_size) {
const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
const int32_t* last_loc_ptr1 = static_cast<const int32_t*>(last_loc_ptr.data_ptr());
const int64_t* free_page_ptr1 = static_cast<const int64_t*>(free_page_ptr.data_ptr());
int64_t* out_indices1 = static_cast<int64_t*>(out_indices.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_decode_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr,
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t bs_upper,
int64_t page_size,
int64_t max_num_extend_tokens) {
const int64_t* pre_lens_ptr1 = static_cast<const int64_t*>(pre_lens_ptr.data_ptr());
const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
const int64_t* last_loc_ptr1 = static_cast<const int64_t*>(last_loc_ptr.data_ptr());
const int64_t* free_page_ptr1 = static_cast<const int64_t*>(free_page_ptr.data_ptr());
int64_t* out_indices1 = static_cast<int64_t*>(out_indices.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size, max_num_extend_tokens);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
\ No newline at end of file
...@@ -538,6 +538,26 @@ void segment_packbits( ...@@ -538,6 +538,26 @@ void segment_packbits(
/* /*
* From csrc/kvcacheio * From csrc/kvcacheio
*/ */
void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr,
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t bs_upper,
int64_t page_size,
int64_t max_num_extend_tokens);
void dcu_alloc_decode_kernel(
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t bs_upper,
int64_t page_size);
void transfer_kv_per_layer( void transfer_kv_per_layer(
const at::Tensor src_k, const at::Tensor src_k,
at::Tensor dst_k, at::Tensor dst_k,
......
...@@ -10,6 +10,48 @@ def is_hip() -> bool: ...@@ -10,6 +10,48 @@ def is_hip() -> bool:
_is_hip = is_hip() _is_hip = is_hip()
def dcu_alloc_extend_kernel(
pre_lens_ptr: torch.Tensor,
seq_lens_ptr: torch.Tensor,
last_loc_ptr: torch.Tensor,
free_page_ptr: torch.Tensor,
out_indices: torch.Tensor,
bs: int,
bs_upper: int,
page_size: int,
max_num_extend_tokens: int,
):
torch.ops.sgl_kernel.dcu_alloc_extend_kernel(
pre_lens_ptr,
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
bs,
bs_upper,
page_size,
max_num_extend_tokens,
)
def dcu_alloc_decode_kernel(
seq_lens_ptr: torch.Tensor,
last_loc_ptr: torch.Tensor,
free_page_ptr: torch.Tensor ,
out_indices: torch.Tensor ,
bs: int,
bs_upper: int,
page_size: int,
):
torch.ops.sgl_kernel.dcu_alloc_decode_kernel(
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
bs,
bs_upper,
page_size,
)
def transfer_kv_per_layer( def transfer_kv_per_layer(
src_k: torch.Tensor, src_k: torch.Tensor,
dst_k: torch.Tensor, dst_k: torch.Tensor,
......
...@@ -50,6 +50,7 @@ sources = [ ...@@ -50,6 +50,7 @@ sources = [
"csrc/moe/moe_topk_softmax_kernels.cu", "csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/speculative/eagle_utils.cu", "csrc/speculative/eagle_utils.cu",
"csrc/kvcacheio/transfer.cu", "csrc/kvcacheio/transfer.cu",
"csrc/attention/merge_attn_states.cu",
] ]
cxx_flags = ["-O3"] cxx_flags = ["-O3"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment