Commit 93eb92f8 authored by maxiao1's avatar maxiao1
Browse files

Merge branch 'v0.5.4_dev_liucong' into 'v0.5.4_dev'

V0.5.4 dev liucong

See merge request OpenDAS/sglang!16
parents 698bc661 5f643074
......@@ -487,29 +487,15 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
if self.sglang_kvalloc_kernel:
if bs < 3:
dcu_alloc_extend_kernel(
pre_lens_ptr = prefix_lens,
seq_lens_ptr = seq_lens,
last_loc_ptr = last_loc,
free_page_ptr = self.free_pages,
out_indices = out_indices,
bs = bs,
bs_upper = next_power_of_2(bs),
page_size = self.page_size,
max_num_extend_tokens = self.seen_max_num_extend_tokens_next_power_of_2,
)
else:
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
self.seen_max_num_extend_tokens_next_power_of_2,
)
dcu_alloc_extend_kernel(
pre_lens_ptr = prefix_lens,
seq_lens_ptr = seq_lens,
last_loc_ptr = last_loc,
free_page_ptr = self.free_pages,
out_indices = out_indices,
bs = bs,
page_size = self.page_size,
)
else:
alloc_extend_kernel[(bs,)](
prefix_lens,
......@@ -560,7 +546,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
free_page_ptr = self.free_pages,
out_indices = out_indices,
bs = bs,
bs_upper = next_power_of_2(bs),
page_size = self.page_size,
)
else:
......
......@@ -131,9 +131,9 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/kvcacheio
*/
m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size, int max_num_extend_tokens) -> ()");
m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()");
m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel);
m.def("dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int bs_upper, int page_size) -> ()");
m.def("dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()");
m.impl("dcu_alloc_decode_kernel", torch::kCUDA, &dcu_alloc_decode_kernel);
m.def(
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
......
......@@ -585,12 +585,12 @@ __global__ void launch_alloc_decode_kernel(
const int32_t* last_loc_ptr,
const int64_t* free_page_ptr,
int64_t* out_indices,
int64_t bs_upper,
int64_t bs,
int64_t page_size) {
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs_upper) return;
if (pid >= bs) return;
int64_t seq_len = seq_lens_ptr[pid];
int64_t pre_len = seq_len - 1;
......@@ -625,13 +625,12 @@ __global__ void launch_alloc_extend_kernel(
const int64_t* last_loc_ptr,
const int64_t* free_page_ptr,
int64_t* out_indices,
int64_t bs_upper,
int64_t page_size,
int64_t max_num_extend_tokens)
int64_t bs,
int64_t page_size)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs_upper) return;
if (pid >= bs) return;
int64_t seq_len = seq_lens_ptr[pid];
int64_t pre_len = pre_lens_ptr[pid];
......@@ -674,7 +673,7 @@ __global__ void launch_alloc_extend_kernel(
}
int64_t num_part2 = (seq_len / page_size) * page_size - ceil_div(pre_len, page_size) * page_size;
for (int64_t offset = 0; offset < num_part2 && offset < max_num_extend_tokens; offset++) {
for (int64_t offset = 0; offset < num_part2; offset++) {
int64_t page_idx = new_page_start_loc + offset / page_size;
int64_t page_start = free_page_ptr[page_idx];
int64_t output_idx = output_start_loc + num_part1 + offset;
......@@ -701,7 +700,6 @@ void dcu_alloc_decode_kernel(
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t bs_upper,
int64_t page_size) {
const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
......@@ -712,7 +710,7 @@ void dcu_alloc_decode_kernel(
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_decode_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size);
launch_alloc_decode_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
......@@ -723,9 +721,7 @@ void dcu_alloc_extend_kernel(
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t bs_upper,
int64_t page_size,
int64_t max_num_extend_tokens) {
int64_t page_size) {
const int64_t* pre_lens_ptr1 = static_cast<const int64_t*>(pre_lens_ptr.data_ptr());
const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
......@@ -736,6 +732,6 @@ void dcu_alloc_extend_kernel(
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs_upper, page_size, max_num_extend_tokens);
launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
\ No newline at end of file
......@@ -545,9 +545,7 @@ void dcu_alloc_extend_kernel(
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t bs_upper,
int64_t page_size,
int64_t max_num_extend_tokens);
int64_t page_size);
void dcu_alloc_decode_kernel(
const at::Tensor seq_lens_ptr,
......@@ -555,7 +553,6 @@ void dcu_alloc_decode_kernel(
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t bs_upper,
int64_t page_size);
void transfer_kv_per_layer(
......
......@@ -17,9 +17,7 @@ def dcu_alloc_extend_kernel(
free_page_ptr: torch.Tensor,
out_indices: torch.Tensor,
bs: int,
bs_upper: int,
page_size: int,
max_num_extend_tokens: int,
):
torch.ops.sgl_kernel.dcu_alloc_extend_kernel(
pre_lens_ptr,
......@@ -28,9 +26,7 @@ def dcu_alloc_extend_kernel(
free_page_ptr,
out_indices,
bs,
bs_upper,
page_size,
max_num_extend_tokens,
)
def dcu_alloc_decode_kernel(
......@@ -39,7 +35,6 @@ def dcu_alloc_decode_kernel(
free_page_ptr: torch.Tensor ,
out_indices: torch.Tensor ,
bs: int,
bs_upper: int,
page_size: int,
):
torch.ops.sgl_kernel.dcu_alloc_decode_kernel(
......@@ -48,7 +43,6 @@ def dcu_alloc_decode_kernel(
free_page_ptr,
out_indices,
bs,
bs_upper,
page_size,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment