Commit 11886dc8 authored by liucong's avatar liucong
Browse files

限制dcu_alloc_extend_kernel的使用范围

parent ec78c4c5
......@@ -487,17 +487,29 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
if self.sglang_kvalloc_kernel:
dcu_alloc_extend_kernel(
pre_lens_ptr = prefix_lens,
seq_lens_ptr = seq_lens,
last_loc_ptr = last_loc,
free_page_ptr = self.free_pages,
out_indices = out_indices,
bs = bs,
bs_upper = next_power_of_2(bs),
page_size = self.page_size,
max_num_extend_tokens = self.seen_max_num_extend_tokens_next_power_of_2,
)
if bs < 3:
dcu_alloc_extend_kernel(
pre_lens_ptr = prefix_lens,
seq_lens_ptr = seq_lens,
last_loc_ptr = last_loc,
free_page_ptr = self.free_pages,
out_indices = out_indices,
bs = bs,
bs_upper = next_power_of_2(bs),
page_size = self.page_size,
max_num_extend_tokens = self.seen_max_num_extend_tokens_next_power_of_2,
)
else:
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
next_power_of_2(bs),
self.page_size,
self.seen_max_num_extend_tokens_next_power_of_2,
)
else:
alloc_extend_kernel[(bs,)](
prefix_lens,
......
......@@ -664,7 +664,7 @@ __global__ void launch_alloc_extend_kernel(
int64_t last_loc = last_loc_ptr[pid];
int64_t num_part1 = safe_min(seq_len, ceil_div(pre_len, page_size) * page_size) - pre_len;
for (int64_t offset = 0; offset < num_part1; offset++) {
for (int64_t offset = 0; offset < num_part1 && offset < page_size; offset++) {
int64_t output_idx = output_start_loc + offset;
out_indices[output_idx] = last_loc + 1 + offset;
}
......@@ -674,7 +674,7 @@ __global__ void launch_alloc_extend_kernel(
}
int64_t num_part2 = (seq_len / page_size) * page_size - ceil_div(pre_len, page_size) * page_size;
for (int64_t offset = 0; offset < num_part2; offset++) {
for (int64_t offset = 0; offset < num_part2 && offset < max_num_extend_tokens; offset++) {
int64_t page_idx = new_page_start_loc + offset / page_size;
int64_t page_start = free_page_ptr[page_idx];
int64_t output_idx = output_start_loc + num_part1 + offset;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment