Commit ec78c4c5 authored by liucong's avatar liucong
Browse files

增加dcu_alloc_extend_kernel实现

parent c9bcffd2
...@@ -28,7 +28,7 @@ import triton.language as tl ...@@ -28,7 +28,7 @@ import triton.language as tl
from sglang.srt.mem_cache.memory_pool import SWAKVPool from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2 from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
from sgl_kernel.kvcacheio import dcu_alloc_decode_kernel from sgl_kernel.kvcacheio import dcu_alloc_decode_kernel, dcu_alloc_extend_kernel
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache from sglang.srt.mem_cache.memory_pool import KVCache
...@@ -431,7 +431,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -431,7 +431,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
super().__init__(size, page_size, dtype, device, kvcache, need_sort) super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size self.num_pages = size // page_size
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL") self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.use_dcu_decode_kernel = get_bool_env_var("USE_DCU_DECODE_KERNEL") self.sglang_kvalloc_kernel = get_bool_env_var("SGLANG_KVALLOC_KERNEL")
self.seen_max_num_extend_tokens_next_power_of_2 = 1 self.seen_max_num_extend_tokens_next_power_of_2 = 1
self.clear() self.clear()
...@@ -486,6 +486,19 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -486,6 +486,19 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
out_indices = torch.empty( out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device (extend_num_tokens,), dtype=torch.int64, device=self.device
) )
if self.sglang_kvalloc_kernel:
dcu_alloc_extend_kernel(
pre_lens_ptr = prefix_lens,
seq_lens_ptr = seq_lens,
last_loc_ptr = last_loc,
free_page_ptr = self.free_pages,
out_indices = out_indices,
bs = bs,
bs_upper = next_power_of_2(bs),
page_size = self.page_size,
max_num_extend_tokens = self.seen_max_num_extend_tokens_next_power_of_2,
)
else:
alloc_extend_kernel[(bs,)]( alloc_extend_kernel[(bs,)](
prefix_lens, prefix_lens,
seq_lens, seq_lens,
...@@ -528,7 +541,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -528,7 +541,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device) out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
if self.use_dcu_decode_kernel: if self.sglang_kvalloc_kernel:
dcu_alloc_decode_kernel( dcu_alloc_decode_kernel(
seq_lens_ptr = seq_lens, seq_lens_ptr = seq_lens,
last_loc_ptr = last_loc, last_loc_ptr = last_loc,
......
...@@ -125,6 +125,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -125,6 +125,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/* /*
* From csrc/kvcacheio * From csrc/kvcacheio
*/ */
m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size, int max_num_extend_tokens) -> ()");
m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel);
m.def("dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size) -> ()"); m.def("dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size) -> ()");
m.impl("dcu_alloc_decode_kernel", torch::kCUDA, &dcu_alloc_decode_kernel); m.impl("dcu_alloc_decode_kernel", torch::kCUDA, &dcu_alloc_decode_kernel);
m.def( m.def(
......
...@@ -576,14 +576,17 @@ __device__ int64_t ceil_div(int64_t a, int64_t b) { ...@@ -576,14 +576,17 @@ __device__ int64_t ceil_div(int64_t a, int64_t b) {
return (a + b - 1) / b; return (a + b - 1) / b;
} }
__device__ int64_t safe_min(int64_t a, int64_t b) {
return a < b ? a : b;
}
__global__ void launch_alloc_decode_kernel( __global__ void launch_alloc_decode_kernel(
const int64_t* seq_lens_ptr, const int64_t* seq_lens_ptr,
const int32_t* last_loc_ptr, const int32_t* last_loc_ptr,
const int64_t* free_page_ptr, const int64_t* free_page_ptr,
int64_t* out_indices, int64_t* out_indices,
int64_t bs_upper, int64_t bs_upper,
int64_t page_size) int64_t page_size) {
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x; int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -595,7 +598,7 @@ __global__ void launch_alloc_decode_kernel( ...@@ -595,7 +598,7 @@ __global__ void launch_alloc_decode_kernel(
int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size); int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size);
int64_t sum_num_new_pages = 0; int64_t sum_num_new_pages = 0;
for (int64_t i = 0; i < pid; i++) { for (int64_t i = 0; i <= pid; i++) {
int64_t other_seq_len = seq_lens_ptr[i]; int64_t other_seq_len = seq_lens_ptr[i];
int64_t other_pre_len = (i <= pid) ? (other_seq_len - 1) : other_seq_len; int64_t other_pre_len = (i <= pid) ? (other_seq_len - 1) : other_seq_len;
...@@ -616,6 +619,82 @@ __global__ void launch_alloc_decode_kernel( ...@@ -616,6 +619,82 @@ __global__ void launch_alloc_decode_kernel(
} }
} }
__global__ void launch_alloc_extend_kernel(
const int64_t* pre_lens_ptr,
const int64_t* seq_lens_ptr,
const int64_t* last_loc_ptr,
const int64_t* free_page_ptr,
int64_t* out_indices,
int64_t bs_upper,
int64_t page_size,
int64_t max_num_extend_tokens)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs_upper) return;
int64_t seq_len = seq_lens_ptr[pid];
int64_t pre_len = pre_lens_ptr[pid];
int64_t extend_len = seq_len - pre_len;
int64_t sum_extend_lens = 0;
for (int64_t i = 0; i <= pid; i++) {
int64_t other_seq_len = seq_lens_ptr[i];
int64_t other_pre_len = pre_lens_ptr[i];
int64_t other_extend_len = other_seq_len - other_pre_len;
sum_extend_lens += other_extend_len;
}
int64_t output_start_loc = sum_extend_lens - extend_len;
int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size);
int64_t sum_num_new_pages = 0;
for (int64_t i = 0; i <= pid; i++) {
int64_t other_seq_len = seq_lens_ptr[i];
int64_t other_pre_len = pre_lens_ptr[i];
int64_t other_num_pages_after = ceil_div(other_seq_len, page_size);
int64_t other_num_pages_before = ceil_div(other_pre_len, page_size);
int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before;
sum_num_new_pages += other_num_new_pages;
}
int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self;
int64_t last_loc = last_loc_ptr[pid];
int64_t num_part1 = safe_min(seq_len, ceil_div(pre_len, page_size) * page_size) - pre_len;
for (int64_t offset = 0; offset < num_part1; offset++) {
int64_t output_idx = output_start_loc + offset;
out_indices[output_idx] = last_loc + 1 + offset;
}
if (pre_len + num_part1 == seq_len) {
return;
}
int64_t num_part2 = (seq_len / page_size) * page_size - ceil_div(pre_len, page_size) * page_size;
for (int64_t offset = 0; offset < num_part2; offset++) {
int64_t page_idx = new_page_start_loc + offset / page_size;
int64_t page_start = free_page_ptr[page_idx];
int64_t output_idx = output_start_loc + num_part1 + offset;
out_indices[output_idx] = page_start * page_size + offset % page_size;
}
if (pre_len + num_part1 + num_part2 == seq_len) {
return;
}
int64_t num_part3 = seq_len - (seq_len / page_size) * page_size;
int64_t last_page_idx = new_page_start_loc + num_page_start_loc_self - 1;
int64_t start_loc = free_page_ptr[last_page_idx];
for (int64_t offset = 0; offset < num_part3 && offset < page_size; offset++) {
int64_t output_idx = output_start_loc + num_part1 + num_part2 + offset;
out_indices[output_idx] = start_loc * page_size + offset;
}
}
void dcu_alloc_decode_kernel( void dcu_alloc_decode_kernel(
const at::Tensor seq_lens_ptr, const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr, const at::Tensor last_loc_ptr,
...@@ -636,3 +715,27 @@ void dcu_alloc_decode_kernel( ...@@ -636,3 +715,27 @@ void dcu_alloc_decode_kernel(
launch_alloc_decode_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size); launch_alloc_decode_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK(); C10_CUDA_KERNEL_LAUNCH_CHECK();
} }
void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr,
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t bs_upper,
int64_t page_size,
int64_t max_num_extend_tokens) {
const int64_t* pre_lens_ptr1 = static_cast<const int64_t*>(pre_lens_ptr.data_ptr());
const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
const int64_t* last_loc_ptr1 = static_cast<const int64_t*>(last_loc_ptr.data_ptr());
const int64_t* free_page_ptr1 = static_cast<const int64_t*>(free_page_ptr.data_ptr());
int64_t* out_indices1 = static_cast<int64_t*>(out_indices.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size, max_num_extend_tokens);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
\ No newline at end of file
...@@ -538,6 +538,17 @@ void segment_packbits( ...@@ -538,6 +538,17 @@ void segment_packbits(
/* /*
* From csrc/kvcacheio * From csrc/kvcacheio
*/ */
void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr,
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t bs_upper,
int64_t page_size,
int64_t max_num_extend_tokens);
void dcu_alloc_decode_kernel( void dcu_alloc_decode_kernel(
const at::Tensor seq_lens_ptr, const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr, const at::Tensor last_loc_ptr,
......
...@@ -10,6 +10,29 @@ def is_hip() -> bool: ...@@ -10,6 +10,29 @@ def is_hip() -> bool:
_is_hip = is_hip() _is_hip = is_hip()
def dcu_alloc_extend_kernel(
pre_lens_ptr: torch.Tensor,
seq_lens_ptr: torch.Tensor,
last_loc_ptr: torch.Tensor,
free_page_ptr: torch.Tensor,
out_indices: torch.Tensor,
bs: int,
bs_upper: int,
page_size: int,
max_num_extend_tokens: int,
):
torch.ops.sgl_kernel.dcu_alloc_extend_kernel(
pre_lens_ptr,
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
bs,
bs_upper,
page_size,
max_num_extend_tokens,
)
def dcu_alloc_decode_kernel( def dcu_alloc_decode_kernel(
seq_lens_ptr: torch.Tensor, seq_lens_ptr: torch.Tensor,
last_loc_ptr: torch.Tensor, last_loc_ptr: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment