Commit 11886dc8 authored by liucong's avatar liucong
Browse files

限制dcu_alloc_extend_kernel的使用范围

parent ec78c4c5
...@@ -487,6 +487,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -487,6 +487,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(extend_num_tokens,), dtype=torch.int64, device=self.device (extend_num_tokens,), dtype=torch.int64, device=self.device
) )
if self.sglang_kvalloc_kernel: if self.sglang_kvalloc_kernel:
if bs < 3:
dcu_alloc_extend_kernel( dcu_alloc_extend_kernel(
pre_lens_ptr = prefix_lens, pre_lens_ptr = prefix_lens,
seq_lens_ptr = seq_lens, seq_lens_ptr = seq_lens,
...@@ -509,6 +510,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -509,6 +510,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self.page_size, self.page_size,
self.seen_max_num_extend_tokens_next_power_of_2, self.seen_max_num_extend_tokens_next_power_of_2,
) )
else:
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
self.seen_max_num_extend_tokens_next_power_of_2,
)
if self.debug_mode: if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices) assert len(torch.unique(out_indices)) == len(out_indices)
......
...@@ -664,7 +664,7 @@ __global__ void launch_alloc_extend_kernel( ...@@ -664,7 +664,7 @@ __global__ void launch_alloc_extend_kernel(
int64_t last_loc = last_loc_ptr[pid]; int64_t last_loc = last_loc_ptr[pid];
int64_t num_part1 = safe_min(seq_len, ceil_div(pre_len, page_size) * page_size) - pre_len; int64_t num_part1 = safe_min(seq_len, ceil_div(pre_len, page_size) * page_size) - pre_len;
for (int64_t offset = 0; offset < num_part1; offset++) { for (int64_t offset = 0; offset < num_part1 && offset < page_size; offset++) {
int64_t output_idx = output_start_loc + offset; int64_t output_idx = output_start_loc + offset;
out_indices[output_idx] = last_loc + 1 + offset; out_indices[output_idx] = last_loc + 1 + offset;
} }
...@@ -674,7 +674,7 @@ __global__ void launch_alloc_extend_kernel( ...@@ -674,7 +674,7 @@ __global__ void launch_alloc_extend_kernel(
} }
int64_t num_part2 = (seq_len / page_size) * page_size - ceil_div(pre_len, page_size) * page_size; int64_t num_part2 = (seq_len / page_size) * page_size - ceil_div(pre_len, page_size) * page_size;
for (int64_t offset = 0; offset < num_part2; offset++) { for (int64_t offset = 0; offset < num_part2 && offset < max_num_extend_tokens; offset++) {
int64_t page_idx = new_page_start_loc + offset / page_size; int64_t page_idx = new_page_start_loc + offset / page_size;
int64_t page_start = free_page_ptr[page_idx]; int64_t page_start = free_page_ptr[page_idx];
int64_t output_idx = output_start_loc + num_part1 + offset; int64_t output_idx = output_start_loc + num_part1 + offset;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment