"src/vscode:/vscode.git/clone" did not exist on "cd6e1f1171530c76be872877a395dc90b90cfb36"
Commit 5073dd76 authored by liucong8560's avatar liucong8560 Committed by maxiao1
Browse files

V0.5.4 dev liucong

parent a5156371
......@@ -28,6 +28,7 @@ import triton.language as tl
from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
from sgl_kernel.kvcacheio import dcu_alloc_decode_kernel, dcu_alloc_extend_kernel
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache
......@@ -430,6 +431,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.sglang_kvalloc_kernel = get_bool_env_var("SGLANG_KVALLOC_KERNEL")
self.seen_max_num_extend_tokens_next_power_of_2 = 1
self.clear()
......@@ -484,16 +486,41 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
self.seen_max_num_extend_tokens_next_power_of_2,
)
if self.sglang_kvalloc_kernel:
if bs < 3:
dcu_alloc_extend_kernel(
pre_lens_ptr = prefix_lens,
seq_lens_ptr = seq_lens,
last_loc_ptr = last_loc,
free_page_ptr = self.free_pages,
out_indices = out_indices,
bs = bs,
bs_upper = next_power_of_2(bs),
page_size = self.page_size,
max_num_extend_tokens = self.seen_max_num_extend_tokens_next_power_of_2,
)
else:
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
self.seen_max_num_extend_tokens_next_power_of_2,
)
else:
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
self.seen_max_num_extend_tokens_next_power_of_2,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
......@@ -525,14 +552,26 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self.merge_and_sort_free()
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
alloc_decode_kernel[(bs,)](
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
)
if self.sglang_kvalloc_kernel:
dcu_alloc_decode_kernel(
seq_lens_ptr = seq_lens,
last_loc_ptr = last_loc,
free_page_ptr = self.free_pages,
out_indices = out_indices,
bs = bs,
bs_upper = next_power_of_2(bs),
page_size = self.page_size,
)
else:
alloc_decode_kernel[(bs,)](
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
......
......@@ -125,6 +125,10 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/kvcacheio
*/
m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size, int max_num_extend_tokens) -> ()");
m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel);
m.def("dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size) -> ()");
m.impl("dcu_alloc_decode_kernel", torch::kCUDA, &dcu_alloc_decode_kernel);
m.def(
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
......
......@@ -571,3 +571,171 @@ void transfer_kv_all_layer_direct_lf_pf(
int64_t page_size) {
transfer_kv_page_first_direct_impl<true>(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size);
}
__device__ int64_t ceil_div(int64_t a, int64_t b) {
return (a + b - 1) / b;
}
__device__ int64_t safe_min(int64_t a, int64_t b) {
return a < b ? a : b;
}
__global__ void launch_alloc_decode_kernel(
const int64_t* seq_lens_ptr,
const int32_t* last_loc_ptr,
const int64_t* free_page_ptr,
int64_t* out_indices,
int64_t bs_upper,
int64_t page_size) {
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs_upper) return;
int64_t seq_len = seq_lens_ptr[pid];
int64_t pre_len = seq_len - 1;
int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size);
int64_t sum_num_new_pages = 0;
for (int64_t i = 0; i <= pid; i++) {
int64_t other_seq_len = seq_lens_ptr[i];
int64_t other_pre_len = (i <= pid) ? (other_seq_len - 1) : other_seq_len;
int64_t other_num_pages_after = ceil_div(other_seq_len, page_size);
int64_t other_num_pages_before = ceil_div(other_pre_len, page_size);
int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before;
sum_num_new_pages += other_num_new_pages;
}
int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self;
if (num_page_start_loc_self == 0) {
int32_t last_loc = last_loc_ptr[pid];
out_indices[pid] = last_loc + 1;
} else {
int64_t page = free_page_ptr[new_page_start_loc];
out_indices[pid] = page * page_size;
}
}
__global__ void launch_alloc_extend_kernel(
const int64_t* pre_lens_ptr,
const int64_t* seq_lens_ptr,
const int64_t* last_loc_ptr,
const int64_t* free_page_ptr,
int64_t* out_indices,
int64_t bs_upper,
int64_t page_size,
int64_t max_num_extend_tokens)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs_upper) return;
int64_t seq_len = seq_lens_ptr[pid];
int64_t pre_len = pre_lens_ptr[pid];
int64_t extend_len = seq_len - pre_len;
int64_t sum_extend_lens = 0;
for (int64_t i = 0; i <= pid; i++) {
int64_t other_seq_len = seq_lens_ptr[i];
int64_t other_pre_len = pre_lens_ptr[i];
int64_t other_extend_len = other_seq_len - other_pre_len;
sum_extend_lens += other_extend_len;
}
int64_t output_start_loc = sum_extend_lens - extend_len;
int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size);
int64_t sum_num_new_pages = 0;
for (int64_t i = 0; i <= pid; i++) {
int64_t other_seq_len = seq_lens_ptr[i];
int64_t other_pre_len = pre_lens_ptr[i];
int64_t other_num_pages_after = ceil_div(other_seq_len, page_size);
int64_t other_num_pages_before = ceil_div(other_pre_len, page_size);
int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before;
sum_num_new_pages += other_num_new_pages;
}
int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self;
int64_t last_loc = last_loc_ptr[pid];
int64_t num_part1 = safe_min(seq_len, ceil_div(pre_len, page_size) * page_size) - pre_len;
for (int64_t offset = 0; offset < num_part1 && offset < page_size; offset++) {
int64_t output_idx = output_start_loc + offset;
out_indices[output_idx] = last_loc + 1 + offset;
}
if (pre_len + num_part1 == seq_len) {
return;
}
int64_t num_part2 = (seq_len / page_size) * page_size - ceil_div(pre_len, page_size) * page_size;
for (int64_t offset = 0; offset < num_part2 && offset < max_num_extend_tokens; offset++) {
int64_t page_idx = new_page_start_loc + offset / page_size;
int64_t page_start = free_page_ptr[page_idx];
int64_t output_idx = output_start_loc + num_part1 + offset;
out_indices[output_idx] = page_start * page_size + offset % page_size;
}
if (pre_len + num_part1 + num_part2 == seq_len) {
return;
}
int64_t num_part3 = seq_len - (seq_len / page_size) * page_size;
int64_t last_page_idx = new_page_start_loc + num_page_start_loc_self - 1;
int64_t start_loc = free_page_ptr[last_page_idx];
for (int64_t offset = 0; offset < num_part3 && offset < page_size; offset++) {
int64_t output_idx = output_start_loc + num_part1 + num_part2 + offset;
out_indices[output_idx] = start_loc * page_size + offset;
}
}
void dcu_alloc_decode_kernel(
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t bs_upper,
int64_t page_size) {
const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
const int32_t* last_loc_ptr1 = static_cast<const int32_t*>(last_loc_ptr.data_ptr());
const int64_t* free_page_ptr1 = static_cast<const int64_t*>(free_page_ptr.data_ptr());
int64_t* out_indices1 = static_cast<int64_t*>(out_indices.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_decode_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr,
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t bs_upper,
int64_t page_size,
int64_t max_num_extend_tokens) {
const int64_t* pre_lens_ptr1 = static_cast<const int64_t*>(pre_lens_ptr.data_ptr());
const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
const int64_t* last_loc_ptr1 = static_cast<const int64_t*>(last_loc_ptr.data_ptr());
const int64_t* free_page_ptr1 = static_cast<const int64_t*>(free_page_ptr.data_ptr());
int64_t* out_indices1 = static_cast<int64_t*>(out_indices.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size, max_num_extend_tokens);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
\ No newline at end of file
......@@ -538,6 +538,26 @@ void segment_packbits(
/*
* From csrc/kvcacheio
*/
void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr,
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t bs_upper,
int64_t page_size,
int64_t max_num_extend_tokens);
void dcu_alloc_decode_kernel(
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t bs_upper,
int64_t page_size);
void transfer_kv_per_layer(
const at::Tensor src_k,
at::Tensor dst_k,
......
......@@ -10,6 +10,48 @@ def is_hip() -> bool:
_is_hip = is_hip()
def dcu_alloc_extend_kernel(
pre_lens_ptr: torch.Tensor,
seq_lens_ptr: torch.Tensor,
last_loc_ptr: torch.Tensor,
free_page_ptr: torch.Tensor,
out_indices: torch.Tensor,
bs: int,
bs_upper: int,
page_size: int,
max_num_extend_tokens: int,
):
torch.ops.sgl_kernel.dcu_alloc_extend_kernel(
pre_lens_ptr,
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
bs,
bs_upper,
page_size,
max_num_extend_tokens,
)
def dcu_alloc_decode_kernel(
seq_lens_ptr: torch.Tensor,
last_loc_ptr: torch.Tensor,
free_page_ptr: torch.Tensor ,
out_indices: torch.Tensor ,
bs: int,
bs_upper: int,
page_size: int,
):
torch.ops.sgl_kernel.dcu_alloc_decode_kernel(
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
bs,
bs_upper,
page_size,
)
def transfer_kv_per_layer(
src_k: torch.Tensor,
dst_k: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment