Unverified Commit a93f10a7 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[overlap-spec] support page size > 1 (#11772)

parent 585e1223
......@@ -14,6 +14,7 @@ from sglang.srt.managers.io_struct import (
BatchTokenIDOutput,
)
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
from sglang.srt.utils.common import ceil_div
if TYPE_CHECKING:
from sglang.srt.managers.scheduler import (
......@@ -258,22 +259,22 @@ class SchedulerOutputProcessorMixin:
if self.enable_overlap and req.finished():
indices_to_free = None
if self.page_size == 1:
if batch.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_info import EagleDraftInput
end_p = allocate_lens_list[i]
start_p = end_p - EagleDraftInput.ALLOC_LEN_PER_DECODE
indices_to_free = self.req_to_token_pool.req_to_token[
req.req_pool_idx
][start_p:end_p]
else:
if batch.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_info import EagleDraftInput
end_p = allocate_lens_list[i]
start_p = end_p - EagleDraftInput.ALLOC_LEN_PER_DECODE
if self.page_size > 1:
start_p = ceil_div(start_p, self.page_size) * self.page_size
indices_to_free = self.req_to_token_pool.req_to_token[
req.req_pool_idx
][start_p:end_p]
else:
if self.page_size == 1:
# Free the one extra delayed token
indices_to_free = batch.out_cache_loc[i : i + 1]
else:
if batch.spec_algorithm.is_eagle():
# TODO(spec-v2): support eagle with page_size > 1
raise NotImplementedError()
else:
if (
len(req.origin_input_ids) + len(req.output_ids) - 1
......@@ -299,6 +300,10 @@ class SchedulerOutputProcessorMixin:
# 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration.
start_p = batch.seq_lens_cpu[i] + accept_lens_list[i]
end_p = allocate_lens_list[i]
if self.page_size > 1:
start_p = ceil_div(start_p, self.page_size) * self.page_size
indices_to_free = self.req_to_token_pool.req_to_token[
req.req_pool_idx
][start_p:end_p]
......
......@@ -10,7 +10,11 @@ import triton.language as tl
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
from sglang.srt.mem_cache.common import alloc_token_slots
from sglang.srt.mem_cache.common import (
alloc_paged_token_slots_extend,
alloc_token_slots,
get_last_loc,
)
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
......@@ -82,9 +86,31 @@ class EagleDraftInputV2Mixin:
# Now seq_lens and allocate_lens are correct
batch.maybe_wait_verify_done()
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item()
out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens)
page_size = batch.token_to_kv_pool_allocator.page_size
if page_size == 1:
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item()
out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens)
else:
last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
self.allocate_lens,
)
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
new_allocate_lens_cpu = new_allocate_lens.cpu()
allocate_lens_cpu = self.allocate_lens.cpu()
extend_num_tokens = sum(new_allocate_lens_cpu - allocate_lens_cpu).item()
out_cache_loc = alloc_paged_token_slots_extend(
batch.tree_cache,
self.allocate_lens,
allocate_lens_cpu,
new_allocate_lens,
new_allocate_lens_cpu,
last_loc,
extend_num_tokens,
)
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment