"...text-generation-inference.git" did not exist on "8f99f165ce1a261c89ea2edef437ef23c03a0716"
Unverified Commit d3d4d767 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Eagle] Refactor eagle speculative decoding (#3986)


Co-authored-by: default avatarKe Bao <ISPObaoke@163.com>
parent 5be8f1ed
...@@ -230,7 +230,7 @@ def extend(reqs, model_runner): ...@@ -230,7 +230,7 @@ def extend(reqs, model_runner):
batch = ScheduleBatch.init_new( batch = ScheduleBatch.init_new(
reqs=reqs, reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool, req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool, token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
tree_cache=None, tree_cache=None,
model_config=model_runner.model_config, model_config=model_runner.model_config,
enable_overlap=False, enable_overlap=False,
...@@ -326,7 +326,7 @@ def latency_test_run_once( ...@@ -326,7 +326,7 @@ def latency_test_run_once(
# Clear the pools. # Clear the pools.
model_runner.req_to_token_pool.clear() model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool.clear() model_runner.token_to_kv_pool_allocator.clear()
measurement_results = { measurement_results = {
"run_name": run_name, "run_name": run_name,
......
...@@ -20,14 +20,15 @@ import triton.language as tl ...@@ -20,14 +20,15 @@ import triton.language as tl
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer import ( from flashinfer import (
...@@ -36,6 +37,7 @@ if is_flashinfer_available(): ...@@ -36,6 +37,7 @@ if is_flashinfer_available():
BatchPrefillWithRaggedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
) )
from flashinfer.cascade import merge_state from flashinfer.cascade import merge_state
from flashinfer.decode import PosEncodingMode
class WrapperDispatch(Enum): class WrapperDispatch(Enum):
...@@ -113,6 +115,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -113,6 +115,7 @@ class FlashInferAttnBackend(AttentionBackend):
device=model_runner.device, device=model_runner.device,
) )
self.workspace_buffer = global_workspace_buffer self.workspace_buffer = global_workspace_buffer
max_bs = model_runner.req_to_token_pool.size max_bs = model_runner.req_to_token_pool.size
if kv_indptr_buf is None: if kv_indptr_buf is None:
self.kv_indptr = [ self.kv_indptr = [
...@@ -133,10 +136,13 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -133,10 +136,13 @@ class FlashInferAttnBackend(AttentionBackend):
assert self.num_wrappers == 1 assert self.num_wrappers == 1
self.kv_last_page_len = kv_last_page_len_buf self.kv_last_page_len = kv_last_page_len_buf
self.qo_indptr = [ if not self.skip_prefill:
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) self.qo_indptr = [
for _ in range(self.num_wrappers) torch.zeros(
] (max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
for _ in range(self.num_wrappers)
]
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
self.workspace_buffer, "NHD" self.workspace_buffer, "NHD"
...@@ -276,7 +282,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -276,7 +282,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
decode_wrappers = [] decode_wrappers = []
...@@ -346,7 +352,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -346,7 +352,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode, forward_mode: ForwardMode,
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
self.indices_updater_decode.update( self.indices_updater_decode.update(
...@@ -526,7 +532,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -526,7 +532,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
...@@ -538,7 +544,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -538,7 +544,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
decode_wrappers = decode_wrappers or self.decode_wrappers decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward( self.call_begin_forward(
...@@ -558,7 +564,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -558,7 +564,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -592,7 +598,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -592,7 +598,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -623,7 +629,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -623,7 +629,7 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor, kv_start_idx: torch.Tensor,
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
if spec_info is None: if spec_info is None:
bs = len(req_pool_indices) bs = len(req_pool_indices)
...@@ -642,9 +648,9 @@ class FlashInferIndicesUpdaterDecode: ...@@ -642,9 +648,9 @@ class FlashInferIndicesUpdaterDecode:
self.req_to_token.shape[1], self.req_to_token.shape[1],
) )
else: else:
assert isinstance(spec_info, EagleDraftInput)
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1 bs = kv_indptr.shape[0] - 1
wrapper.begin_forward( wrapper.begin_forward(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
...@@ -699,7 +705,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -699,7 +705,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
...@@ -713,7 +719,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -713,7 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
if use_ragged: if use_ragged:
paged_kernel_lens = prefix_lens paged_kernel_lens = prefix_lens
...@@ -746,7 +752,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -746,7 +752,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -787,7 +793,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -787,7 +793,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -829,10 +835,11 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -829,10 +835,11 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor, qo_indptr: torch.Tensor,
use_ragged: bool, use_ragged: bool,
spec_info: Optional[SpecInfo], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
bs = len(req_pool_indices) bs = len(seq_lens)
if spec_info is None: if spec_info is None:
assert len(seq_lens) == len(req_pool_indices)
# Normal extend # Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
...@@ -855,10 +862,14 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -855,10 +862,14 @@ class FlashInferIndicesUpdaterPrefill:
qo_indptr = qo_indptr[: bs + 1] qo_indptr = qo_indptr[: bs + 1]
custom_mask = None custom_mask = None
else: else:
assert isinstance(spec_info, EagleDraftInput) or isinstance(
spec_info, EagleVerifyInput
)
kv_indices, kv_indptr, qo_indptr, custom_mask = ( kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill( spec_info.generate_attn_arg_prefill(
req_pool_indices, req_pool_indices,
paged_kernel_lens, paged_kernel_lens,
paged_kernel_lens_sum,
self.req_to_token, self.req_to_token,
) )
) )
...@@ -890,6 +901,11 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -890,6 +901,11 @@ class FlashInferIndicesUpdaterPrefill:
) )
# Use as a fast path to override the indptr in flashinfer's plan function
# This is used to remove some host-to-device copy overhead.
global global_override_indptr_cpu
class FlashInferMultiStepDraftBackend: class FlashInferMultiStepDraftBackend:
""" """
Wrap multiple flashinfer attention backends as one for multiple consecutive Wrap multiple flashinfer attention backends as one for multiple consecutive
...@@ -907,6 +923,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -907,6 +923,7 @@ class FlashInferMultiStepDraftBackend:
self.topk = topk self.topk = topk
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
max_bs = model_runner.req_to_token_pool.size * self.topk max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros( self.kv_indptr = torch.zeros(
( (
...@@ -929,7 +946,9 @@ class FlashInferMultiStepDraftBackend: ...@@ -929,7 +946,9 @@ class FlashInferMultiStepDraftBackend:
kv_last_page_len_buf=self.kv_last_page_len, kv_last_page_len_buf=self.kv_last_page_len,
) )
) )
self.max_context_len = self.attn_backends[0].max_context_len self.max_context_len = self.attn_backends[0].max_context_len
# Cached variables for generate_draft_decode_kv_indices # Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
...@@ -959,13 +978,23 @@ class FlashInferMultiStepDraftBackend: ...@@ -959,13 +978,23 @@ class FlashInferMultiStepDraftBackend:
triton.next_power_of_2(bs), triton.next_power_of_2(bs),
) )
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
global global_override_indptr_cpu
for i in range(self.speculative_num_steps - 1): for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1) : seq_lens_sum * self.topk + bs * (i + 1)
] ]
global_override_indptr_cpu = indptr_cpu_whole[i]
call_fn(i, forward_batch) call_fn(i, forward_batch)
global_override_indptr_cpu = None
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros( kv_indices = torch.zeros(
( (
...@@ -977,6 +1006,8 @@ class FlashInferMultiStepDraftBackend: ...@@ -977,6 +1006,8 @@ class FlashInferMultiStepDraftBackend:
) )
def call_fn(i, forward_batch): def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
forward_batch.spec_info.kv_indptr = ( forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone() forward_batch.spec_info.kv_indptr.clone()
) )
...@@ -993,6 +1024,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -993,6 +1024,7 @@ class FlashInferMultiStepDraftBackend:
dtype=torch.int32, dtype=torch.int32,
device="cuda", device="cuda",
) )
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state( self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
...@@ -1031,43 +1063,6 @@ class FlashInferMultiStepDraftBackend: ...@@ -1031,43 +1063,6 @@ class FlashInferMultiStepDraftBackend:
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_indptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid)
kv_start = 0
kv_end = 0
if kv_start_idx:
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < kv_end - kv_start
data = tl.load(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ kv_start
+ offset,
mask=mask,
)
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
def should_use_tensor_core( def should_use_tensor_core(
kv_cache_dtype: torch.dtype, kv_cache_dtype: torch.dtype,
num_attention_heads: int, num_attention_heads: int,
...@@ -1089,6 +1084,21 @@ def should_use_tensor_core( ...@@ -1089,6 +1084,21 @@ def should_use_tensor_core(
if env_override is not None: if env_override is not None:
return env_override.lower() == "true" return env_override.lower() == "true"
# Try to use _grouped_size_compiled_for_decode_kernels if available
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
try:
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels(
num_attention_heads,
num_kv_heads,
):
return True
else:
return False
except (ImportError, AttributeError):
pass
# Calculate GQA group size # Calculate GQA group size
gqa_group_size = num_attention_heads // num_kv_heads gqa_group_size = num_attention_heads // num_kv_heads
...@@ -1118,12 +1128,18 @@ def fast_decode_plan( ...@@ -1118,12 +1128,18 @@ def fast_decode_plan(
sm_scale: Optional[float] = None, sm_scale: Optional[float] = None,
rope_scale: Optional[float] = None, rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None, rope_theta: Optional[float] = None,
**kwargs, non_blocking: bool = True,
) -> None: ) -> None:
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.""" """
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
Modifications:
- Remove unnecessary device-to-device copy for the cuda graph buffers.
- Remove unnecessary host-to-device copy for the metadata buffers.
"""
batch_size = len(last_page_len) batch_size = len(last_page_len)
if logits_soft_cap is None: if logits_soft_cap is None:
logits_soft_cap = 0.0 logits_soft_cap = 0.0
if self.is_cuda_graph_enabled: if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size: if batch_size != self._fixed_batch_size:
raise ValueError( raise ValueError(
...@@ -1136,13 +1152,19 @@ def fast_decode_plan( ...@@ -1136,13 +1152,19 @@ def fast_decode_plan(
raise ValueError( raise ValueError(
"The size of indices should be less than or equal to the allocated buffer" "The size of indices should be less than or equal to the allocated buffer"
) )
# Skip these copies
# self._paged_kv_indptr_buf.copy_(indptr)
# self._paged_kv_indices_buf[: len(indices)] = indices
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
else: else:
self._paged_kv_indptr_buf = indptr self._paged_kv_indptr_buf = indptr
self._paged_kv_indices_buf = indices self._paged_kv_indices_buf = indices
self._paged_kv_last_page_len_buf = last_page_len self._paged_kv_last_page_len_buf = last_page_len
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
if not q_data_type: if not q_data_type:
q_data_type = data_type q_data_type = data_type
if not hasattr(self, "empty_q_data"): if not hasattr(self, "empty_q_data"):
self.empty_q_data = torch.empty( self.empty_q_data = torch.empty(
0, 0,
...@@ -1159,6 +1181,7 @@ def fast_decode_plan( ...@@ -1159,6 +1181,7 @@ def fast_decode_plan(
), ),
) )
self.last_page_len = torch.ones(32768, dtype=torch.int32) self.last_page_len = torch.ones(32768, dtype=torch.int32)
empty_q_data = self.empty_q_data empty_q_data = self.empty_q_data
empty_kv_cache = self.empty_kv_cache empty_kv_cache = self.empty_kv_cache
stream = torch.cuda.current_stream() stream = torch.cuda.current_stream()
......
...@@ -156,6 +156,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -156,6 +156,7 @@ class TritonAttnBackend(AttentionBackend):
spec_info.generate_attn_arg_prefill( spec_info.generate_attn_arg_prefill(
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
None,
self.req_to_token, self.req_to_token,
) )
) )
......
...@@ -22,7 +22,7 @@ from typing import List, Optional ...@@ -22,7 +22,7 @@ from typing import List, Optional
import torch import torch
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -128,7 +128,7 @@ class HiCacheController: ...@@ -128,7 +128,7 @@ class HiCacheController:
def __init__( def __init__(
self, self,
mem_pool_device: MHATokenToKVPool, mem_pool_device: MHATokenToKVPool,
mem_pool_host: MLATokenToKVPoolHost, mem_pool_host: MHATokenToKVPoolHost,
write_policy: str = "write_through_selective", write_policy: str = "write_through_selective",
): ):
......
...@@ -44,18 +44,16 @@ from sglang.srt.configs.model_config import ModelConfig ...@@ -44,18 +44,16 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access # Put some global args for easy access
...@@ -523,7 +521,7 @@ class ScheduleBatch: ...@@ -523,7 +521,7 @@ class ScheduleBatch:
# Request, memory pool, and cache # Request, memory pool, and cache
reqs: List[Req] reqs: List[Req]
req_to_token_pool: ReqToTokenPool = None req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool = None token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
tree_cache: BasePrefixCache = None tree_cache: BasePrefixCache = None
# Batch configs # Batch configs
...@@ -596,7 +594,7 @@ class ScheduleBatch: ...@@ -596,7 +594,7 @@ class ScheduleBatch:
cls, cls,
reqs: List[Req], reqs: List[Req],
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: ReqToTokenPool, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
model_config: ModelConfig, model_config: ModelConfig,
enable_overlap: bool, enable_overlap: bool,
...@@ -606,7 +604,7 @@ class ScheduleBatch: ...@@ -606,7 +604,7 @@ class ScheduleBatch:
return cls( return cls(
reqs=reqs, reqs=reqs,
req_to_token_pool=req_to_token_pool, req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool, token_to_kv_pool_allocator=token_to_kv_pool_allocator,
tree_cache=tree_cache, tree_cache=tree_cache,
model_config=model_config, model_config=model_config,
enable_overlap=enable_overlap, enable_overlap=enable_overlap,
...@@ -637,19 +635,19 @@ class ScheduleBatch: ...@@ -637,19 +635,19 @@ class ScheduleBatch:
return req_pool_indices return req_pool_indices
def alloc_token_slots(self, num_tokens: int): def alloc_token_slots(self, num_tokens: int):
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens) out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
if out_cache_loc is None: if out_cache_loc is None:
if self.tree_cache is not None: if self.tree_cache is not None:
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free) self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens) out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
if out_cache_loc is None: if out_cache_loc is None:
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode" phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
logger.error( logger.error(
f"{phase_str} out of memory. Try to lower your batch size.\n" f"{phase_str} out of memory. Try to lower your batch size.\n"
f"Try to allocate {num_tokens} tokens.\n" f"Try to allocate {num_tokens} tokens.\n"
f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n" f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
) )
if self.tree_cache is not None: if self.tree_cache is not None:
self.tree_cache.pretty_print() self.tree_cache.pretty_print()
...@@ -917,12 +915,12 @@ class ScheduleBatch: ...@@ -917,12 +915,12 @@ class ScheduleBatch:
def check_decode_mem(self, buf_multiplier=1): def check_decode_mem(self, buf_multiplier=1):
bs = len(self.reqs) * buf_multiplier bs = len(self.reqs) * buf_multiplier
if self.token_to_kv_pool.available_size() >= bs: if self.token_to_kv_pool_allocator.available_size() >= bs:
return True return True
self.tree_cache.evict(bs, self.token_to_kv_pool.free) self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
if self.token_to_kv_pool.available_size() >= bs: if self.token_to_kv_pool_allocator.available_size() >= bs:
return True return True
return False return False
...@@ -945,6 +943,10 @@ class ScheduleBatch: ...@@ -945,6 +943,10 @@ class ScheduleBatch:
reverse=True, reverse=True,
) )
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
def get_required_tokens(num_reqs: int): def get_required_tokens(num_reqs: int):
headroom_for_spec_decode = 0 headroom_for_spec_decode = 0
if server_args.speculative_algorithm: if server_args.speculative_algorithm:
...@@ -958,18 +960,15 @@ class ScheduleBatch: ...@@ -958,18 +960,15 @@ class ScheduleBatch:
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
) )
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
while ( while (
self.token_to_kv_pool.available_size() self.token_to_kv_pool_allocator.available_size()
< get_required_tokens(len(sorted_indices)) < get_required_tokens(len(sorted_indices))
or first_iter or first_iter
): ):
if len(sorted_indices) == 1: if len(sorted_indices) == 1:
# Corner case: only one request left # Corner case: only one request left
assert ( assert (
self.token_to_kv_pool.available_size() > 0 self.token_to_kv_pool_allocator.available_size() > 0
), "No space left for only one request" ), "No space left for only one request"
break break
...@@ -983,7 +982,7 @@ class ScheduleBatch: ...@@ -983,7 +982,7 @@ class ScheduleBatch:
token_indices = self.req_to_token_pool.req_to_token[ token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : seq_lens_cpu[idx] req.req_pool_idx, : seq_lens_cpu[idx]
] ]
self.token_to_kv_pool.free(token_indices) self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
del self.tree_cache.entries[req.rid] del self.tree_cache.entries[req.rid]
else: else:
...@@ -992,7 +991,7 @@ class ScheduleBatch: ...@@ -992,7 +991,7 @@ class ScheduleBatch:
token_indices = self.req_to_token_pool.req_to_token[ token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
] ]
self.token_to_kv_pool.free(token_indices) self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
# release the last node # release the last node
...@@ -1001,10 +1000,13 @@ class ScheduleBatch: ...@@ -1001,10 +1000,13 @@ class ScheduleBatch:
# NOTE(lsyin): we should use the newly evictable memory instantly. # NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size = ( residual_size = (
len(sorted_indices) * global_config.retract_decode_steps len(sorted_indices) * global_config.retract_decode_steps
- self.token_to_kv_pool.available_size() - self.token_to_kv_pool_allocator.available_size()
) )
residual_size = max(0, residual_size) residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free) self.tree_cache.evict(
residual_size, self.token_to_kv_pool_allocator.free
)
req.reset_for_retract() req.reset_for_retract()
self.filter_batch(keep_indices=sorted_indices) self.filter_batch(keep_indices=sorted_indices)
...@@ -1183,7 +1185,7 @@ class ScheduleBatch: ...@@ -1183,7 +1185,7 @@ class ScheduleBatch:
if self.spec_info: if self.spec_info:
self.spec_info.merge_batch(other.spec_info) self.spec_info.merge_batch(other.spec_info)
def get_model_worker_batch(self): def get_model_worker_batch(self) -> ModelWorkerBatch:
if self.forward_mode.is_decode_or_idle(): if self.forward_mode.is_decode_or_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else: else:
...@@ -1273,7 +1275,7 @@ class ModelWorkerBatch: ...@@ -1273,7 +1275,7 @@ class ModelWorkerBatch:
req_pool_indices: torch.Tensor req_pool_indices: torch.Tensor
# The sequence length # The sequence length
seq_lens: torch.Tensor seq_lens: torch.Tensor
# The indices of output tokens in the token_to_kv_pool # The indices of output tokens in the token_to_kv_pool_allocator
out_cache_loc: torch.Tensor out_cache_loc: torch.Tensor
# The sum of all sequence lengths # The sum of all sequence lengths
......
...@@ -22,9 +22,13 @@ from typing import Dict, List, Optional, Set, Union ...@@ -22,9 +22,13 @@ from typing import Dict, List, Optional, Set, Union
import torch import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import (
Req,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large. # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
...@@ -75,7 +79,7 @@ class SchedulePolicy: ...@@ -75,7 +79,7 @@ class SchedulePolicy:
# It is used to find the matching prefix for in-batch prefix caching. # It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache( self.waiting_queue_radix_tree = RadixCache(
req_to_token_pool=None, token_to_kv_pool=None, disable=False req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False
) )
def calc_priority(self, waiting_queue: List[Req]) -> bool: def calc_priority(self, waiting_queue: List[Req]) -> bool:
...@@ -251,7 +255,7 @@ class PrefillAdder: ...@@ -251,7 +255,7 @@ class PrefillAdder:
def __init__( def __init__(
self, self,
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
token_to_kv_pool: BaseTokenToKVPool, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
running_batch: ScheduleBatch, running_batch: ScheduleBatch,
new_token_ratio: float, new_token_ratio: float,
rem_input_tokens: int, rem_input_tokens: int,
...@@ -259,7 +263,7 @@ class PrefillAdder: ...@@ -259,7 +263,7 @@ class PrefillAdder:
mixed_with_decode_tokens: int = 0, mixed_with_decode_tokens: int = 0,
): ):
self.tree_cache = tree_cache self.tree_cache = tree_cache
self.token_to_kv_pool = token_to_kv_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.running_batch = running_batch self.running_batch = running_batch
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
...@@ -291,7 +295,7 @@ class PrefillAdder: ...@@ -291,7 +295,7 @@ class PrefillAdder:
@property @property
def rem_total_tokens(self): def rem_total_tokens(self):
return ( return (
self.token_to_kv_pool.available_size() self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size() + self.tree_cache.evictable_size()
- self.rem_total_token_offset - self.rem_total_token_offset
) )
...@@ -299,7 +303,7 @@ class PrefillAdder: ...@@ -299,7 +303,7 @@ class PrefillAdder:
@property @property
def cur_rem_tokens(self): def cur_rem_tokens(self):
return ( return (
self.token_to_kv_pool.available_size() self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size() + self.tree_cache.evictable_size()
- self.cur_rem_token_offset - self.cur_rem_token_offset
) )
...@@ -332,7 +336,6 @@ class PrefillAdder: ...@@ -332,7 +336,6 @@ class PrefillAdder:
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
self.can_run_list.append(req) self.can_run_list.append(req)
self._prefill_one_req( self._prefill_one_req(
0, 0,
req.extend_input_len, req.extend_input_len,
...@@ -400,8 +403,8 @@ class PrefillAdder: ...@@ -400,8 +403,8 @@ class PrefillAdder:
tokens_freed += tokens_occupied tokens_freed += tokens_occupied
if ( if (
self.rem_chunk_tokens is None self.rem_chunk_tokens is None # chunked prefill is disabled
or req.extend_input_len <= self.rem_chunk_tokens or req.extend_input_len <= self.rem_chunk_tokens # it is the last chunk
): ):
# Non-chunked prefill # Non-chunked prefill
self.can_run_list.append(req) self.can_run_list.append(req)
...@@ -411,10 +414,11 @@ class PrefillAdder: ...@@ -411,10 +414,11 @@ class PrefillAdder:
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION), min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
) )
else: else:
if self.rem_chunk_tokens == 0:
return AddReqResult.OTHER
# Chunked prefill # Chunked prefill
trunc_len = self.rem_chunk_tokens trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return AddReqResult.OTHER
req.extend_input_len = trunc_len req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[:trunc_len] req.fill_ids = req.fill_ids[:trunc_len]
...@@ -457,10 +461,11 @@ class PrefillAdder: ...@@ -457,10 +461,11 @@ class PrefillAdder:
), ),
) )
else: else:
if self.rem_chunk_tokens == 0:
return AddReqResult.OTHER
# Chunked prefill # Chunked prefill
trunc_len = self.rem_chunk_tokens trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return AddReqResult.OTHER
req.extend_input_len = trunc_len req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
......
...@@ -164,7 +164,7 @@ class Scheduler: ...@@ -164,7 +164,7 @@ class Scheduler:
self.server_args.speculative_num_draft_tokens self.server_args.speculative_num_draft_tokens
+ ( + (
self.server_args.speculative_eagle_topk self.server_args.speculative_eagle_topk
* self.server_args.speculative_num_steps * self.server_args.speculative_num_draft_tokens
) )
) )
if not self.spec_algorithm.is_none() if not self.spec_algorithm.is_none()
...@@ -309,7 +309,9 @@ class Scheduler: ...@@ -309,7 +309,9 @@ class Scheduler:
) )
# Init memory pool and cache # Init memory pool and cache
self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool() self.req_to_token_pool, self.token_to_kv_pool_allocator = (
self.tp_worker.get_memory_pool()
)
if ( if (
server_args.chunked_prefill_size is not None server_args.chunked_prefill_size is not None
...@@ -317,18 +319,18 @@ class Scheduler: ...@@ -317,18 +319,18 @@ class Scheduler:
): ):
self.tree_cache = ChunkCache( self.tree_cache = ChunkCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
) )
else: else:
if self.enable_hierarchical_cache: if self.enable_hierarchical_cache:
self.tree_cache = HiRadixCache( self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
) )
else: else:
self.tree_cache = RadixCache( self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
disable=server_args.disable_radix_cache, disable=server_args.disable_radix_cache,
) )
...@@ -458,7 +460,6 @@ class Scheduler: ...@@ -458,7 +460,6 @@ class Scheduler:
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
(ProfileReq, self.profile), (ProfileReq, self.profile),
(GetInternalStateReq, self.get_internal_state), (GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state),
] ]
) )
...@@ -809,7 +810,8 @@ class Scheduler: ...@@ -809,7 +810,8 @@ class Scheduler:
running_bs: int, running_bs: int,
): ):
num_used = self.max_total_num_tokens - ( num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
) )
self._largest_prefill_len = max( self._largest_prefill_len = max(
self._largest_prefill_len, adder.log_input_tokens self._largest_prefill_len, adder.log_input_tokens
...@@ -844,7 +846,8 @@ class Scheduler: ...@@ -844,7 +846,8 @@ class Scheduler:
self.num_generated_tokens = 0 self.num_generated_tokens = 0
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
num_used = self.max_total_num_tokens - ( num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
) )
if RECORD_STEP_TIME: if RECORD_STEP_TIME:
...@@ -894,7 +897,8 @@ class Scheduler: ...@@ -894,7 +897,8 @@ class Scheduler:
def check_memory(self): def check_memory(self):
available_size = ( available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
) )
protected_size = self.tree_cache.protected_size() protected_size = self.tree_cache.protected_size()
memory_leak = available_size != ( memory_leak = available_size != (
...@@ -999,7 +1003,7 @@ class Scheduler: ...@@ -999,7 +1003,7 @@ class Scheduler:
# Prefill policy # Prefill policy
adder = PrefillAdder( adder = PrefillAdder(
self.tree_cache, self.tree_cache,
self.token_to_kv_pool, self.token_to_kv_pool_allocator,
self.running_batch, self.running_batch,
self.new_token_ratio, self.new_token_ratio,
self.max_prefill_tokens, self.max_prefill_tokens,
...@@ -1099,7 +1103,7 @@ class Scheduler: ...@@ -1099,7 +1103,7 @@ class Scheduler:
new_batch = ScheduleBatch.init_new( new_batch = ScheduleBatch.init_new(
can_run_list, can_run_list,
self.req_to_token_pool, self.req_to_token_pool,
self.token_to_kv_pool, self.token_to_kv_pool_allocator,
self.tree_cache, self.tree_cache,
self.model_config, self.model_config,
self.enable_overlap, self.enable_overlap,
...@@ -1143,8 +1147,6 @@ class Scheduler: ...@@ -1143,8 +1147,6 @@ class Scheduler:
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args) retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
if self.draft_worker:
self.draft_worker.finish_request(retracted_reqs)
logger.info( logger.info(
"Decode out of memory happened. " "Decode out of memory happened. "
...@@ -1184,11 +1186,12 @@ class Scheduler: ...@@ -1184,11 +1186,12 @@ class Scheduler:
logits_output, next_token_ids = self.tp_worker.forward_batch_generation( logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
bid = model_worker_batch.bid
else: else:
( (
logits_output, logits_output,
next_token_ids, next_token_ids,
model_worker_batch, bid,
num_accepted_tokens, num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch) ) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += ( self.spec_num_total_accepted_tokens += (
...@@ -1214,7 +1217,7 @@ class Scheduler: ...@@ -1214,7 +1217,7 @@ class Scheduler:
next_token_ids=next_token_ids, next_token_ids=next_token_ids,
extend_input_len_per_req=extend_input_len_per_req, extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
bid=model_worker_batch.bid, bid=bid,
) )
else: # embedding or reward model else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
...@@ -1230,6 +1233,7 @@ class Scheduler: ...@@ -1230,6 +1233,7 @@ class Scheduler:
result: Union[GenerationBatchResult, EmbeddingBatchResult], result: Union[GenerationBatchResult, EmbeddingBatchResult],
): ):
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
assert isinstance(result, GenerationBatchResult)
self.process_batch_result_decode(batch, result) self.process_batch_result_decode(batch, result)
if batch.is_empty(): if batch.is_empty():
self.running_batch = None self.running_batch = None
...@@ -1302,7 +1306,7 @@ class Scheduler: ...@@ -1302,7 +1306,7 @@ class Scheduler:
if self.is_mixed_chunk and self.enable_overlap and req.finished(): if self.is_mixed_chunk and self.enable_overlap and req.finished():
# Free the one delayed token for the mixed decode batch # Free the one delayed token for the mixed decode batch
j = len(batch.out_cache_loc) - len(batch.reqs) + i j = len(batch.out_cache_loc) - len(batch.reqs) + i
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1]) self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
continue continue
if req.is_chunked <= 0: if req.is_chunked <= 0:
...@@ -1420,23 +1424,27 @@ class Scheduler: ...@@ -1420,23 +1424,27 @@ class Scheduler:
self.num_generated_tokens += len(batch.reqs) self.num_generated_tokens += len(batch.reqs)
if self.enable_overlap: if self.enable_overlap:
assert batch.spec_algorithm.is_none()
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
next_token_logprobs = logits_output.next_token_logprobs next_token_logprobs = logits_output.next_token_logprobs
else: elif batch.spec_algorithm.is_none():
# spec decoding handles output logprobs inside verify process.
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
if batch.return_logprob: if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs.tolist() next_token_logprobs = logits_output.next_token_logprobs.tolist()
self.token_to_kv_pool.free_group_begin() self.token_to_kv_pool_allocator.free_group_begin()
# Check finish condition # Check finish condition
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
# We should ignore using next_token_ids for spec decoding cases.
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_retracted: if req.is_retracted:
continue continue
if self.enable_overlap and req.finished(): if self.enable_overlap and req.finished():
# Free the one delayed token # Free the one delayed token
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
continue continue
if batch.spec_algorithm.is_none(): if batch.spec_algorithm.is_none():
...@@ -1479,7 +1487,7 @@ class Scheduler: ...@@ -1479,7 +1487,7 @@ class Scheduler:
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs, batch.return_logprob) self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool.free_group_end() self.token_to_kv_pool_allocator.free_group_end()
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if ( if (
...@@ -1718,9 +1726,6 @@ class Scheduler: ...@@ -1718,9 +1726,6 @@ class Scheduler:
and not self.model_config.is_multimodal_gen and not self.model_config.is_multimodal_gen
) )
): ):
if self.draft_worker and req.finished():
self.draft_worker.finish_request(req)
rids.append(req.rid) rids.append(req.rid)
finished_reasons.append( finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None req.finished_reason.to_json() if req.finished_reason else None
...@@ -1860,7 +1865,7 @@ class Scheduler: ...@@ -1860,7 +1865,7 @@ class Scheduler:
idle_batch = ScheduleBatch.init_new( idle_batch = ScheduleBatch.init_new(
[], [],
self.req_to_token_pool, self.req_to_token_pool,
self.token_to_kv_pool, self.token_to_kv_pool_allocator,
self.tree_cache, self.tree_cache,
self.model_config, self.model_config,
self.enable_overlap, self.enable_overlap,
...@@ -1916,11 +1921,11 @@ class Scheduler: ...@@ -1916,11 +1921,11 @@ class Scheduler:
if self.grammar_backend: if self.grammar_backend:
self.grammar_backend.reset() self.grammar_backend.reset()
self.req_to_token_pool.clear() self.req_to_token_pool.clear()
self.token_to_kv_pool.clear() self.token_to_kv_pool_allocator.clear()
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
self.draft_worker.model_runner.req_to_token_pool.clear() self.draft_worker.model_runner.req_to_token_pool.clear()
self.draft_worker.model_runner.token_to_kv_pool.clear() self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.forward_ct_decode = 0 self.forward_ct_decode = 0
......
...@@ -82,8 +82,6 @@ from sglang.srt.managers.io_struct import ( ...@@ -82,8 +82,6 @@ from sglang.srt.managers.io_struct import (
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput, ResumeMemoryOccupationReqOutput,
SessionParams, SessionParams,
SetInternalStateReq,
SetInternalStateReqOutput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
...@@ -257,9 +255,6 @@ class TokenizerManager: ...@@ -257,9 +255,6 @@ class TokenizerManager:
self.get_internal_state_communicator = _Communicator( self.get_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.set_internal_state_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self._result_dispatcher = TypeBasedDispatcher( self._result_dispatcher = TypeBasedDispatcher(
[ [
...@@ -309,10 +304,6 @@ class TokenizerManager: ...@@ -309,10 +304,6 @@ class TokenizerManager:
GetInternalStateReqOutput, GetInternalStateReqOutput,
self.get_internal_state_communicator.handle_recv, self.get_internal_state_communicator.handle_recv,
), ),
(
SetInternalStateReqOutput,
self.set_internal_state_communicator.handle_recv,
),
(HealthCheckOutput, lambda x: None), (HealthCheckOutput, lambda x: None),
] ]
) )
...@@ -774,14 +765,6 @@ class TokenizerManager: ...@@ -774,14 +765,6 @@ class TokenizerManager:
) )
return res[0].internal_state return res[0].internal_state
async def set_internal_state(
self, obj: SetInternalStateReq
) -> SetInternalStateReqOutput:
res: List[SetInternalStateReqOutput] = (
await self.set_internal_state_communicator(obj)
)
return res[0]
def get_log_request_metadata(self): def get_log_request_metadata(self):
max_length = None max_length = None
skip_names = None skip_names = None
......
...@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -49,6 +50,8 @@ class TpModelWorker: ...@@ -49,6 +50,8 @@ class TpModelWorker:
dp_rank: Optional[int], dp_rank: Optional[int],
nccl_port: int, nccl_port: int,
is_draft_worker: bool = False, is_draft_worker: bool = False,
req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
): ):
# Parse args # Parse args
self.tp_rank = tp_rank self.tp_rank = tp_rank
...@@ -77,6 +80,8 @@ class TpModelWorker: ...@@ -77,6 +80,8 @@ class TpModelWorker:
nccl_port=nccl_port, nccl_port=nccl_port,
server_args=server_args, server_args=server_args,
is_draft_worker=is_draft_worker, is_draft_worker=is_draft_worker,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
) )
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None self.tokenizer = self.processor = None
...@@ -154,7 +159,7 @@ class TpModelWorker: ...@@ -154,7 +159,7 @@ class TpModelWorker:
def get_memory_pool(self): def get_memory_pool(self):
return ( return (
self.model_runner.req_to_token_pool, self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool, self.model_runner.token_to_kv_pool_allocator,
) )
def forward_batch_generation( def forward_batch_generation(
......
...@@ -100,7 +100,7 @@ class TpModelWorkerClient: ...@@ -100,7 +100,7 @@ class TpModelWorkerClient:
def get_memory_pool(self): def get_memory_pool(self):
return ( return (
self.worker.model_runner.req_to_token_pool, self.worker.model_runner.req_to_token_pool,
self.worker.model_runner.token_to_kv_pool, self.worker.model_runner.token_to_kv_pool_allocator,
) )
def forward_thread_func(self): def forward_thread_func(self):
......
from __future__ import annotations from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled.""" """Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
import torch import torch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
...@@ -21,11 +20,13 @@ class ChunkCacheEntry: ...@@ -21,11 +20,13 @@ class ChunkCacheEntry:
class ChunkCache(BasePrefixCache): class ChunkCache(BasePrefixCache):
def __init__( def __init__(
self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
): ):
self.disable = True self.disable = True
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.entries: Dict[str, ChunkCacheEntry] = {} self.entries: Dict[str, ChunkCacheEntry] = {}
self.reset() self.reset()
...@@ -51,7 +52,7 @@ class ChunkCache(BasePrefixCache): ...@@ -51,7 +52,7 @@ class ChunkCache(BasePrefixCache):
req.req_pool_idx, :token_id_len req.req_pool_idx, :token_id_len
] ]
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool.free(kv_indices) self.token_to_kv_pool_allocator.free(kv_indices)
if req.rid in self.entries: if req.rid in self.entries:
del self.entries[req.rid] del self.entries[req.rid]
...@@ -91,3 +92,6 @@ class ChunkCache(BasePrefixCache): ...@@ -91,3 +92,6 @@ class ChunkCache(BasePrefixCache):
def protected_size(self): def protected_size(self):
return 0 return 0
def pretty_print(self):
return ""
...@@ -7,8 +7,8 @@ import torch ...@@ -7,8 +7,8 @@ import torch
from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
BaseTokenToKVPool, MHATokenToKVPool,
MLATokenToKVPoolHost, MHATokenToKVPoolHost,
ReqToTokenPool, ReqToTokenPool,
) )
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
...@@ -21,9 +21,9 @@ class HiRadixCache(RadixCache): ...@@ -21,9 +21,9 @@ class HiRadixCache(RadixCache):
def __init__( def __init__(
self, self,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: BaseTokenToKVPool, token_to_kv_pool: MHATokenToKVPool,
): ):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(token_to_kv_pool) self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
self.cache_controller = HiCacheController( self.cache_controller = HiCacheController(
token_to_kv_pool, self.token_to_kv_pool_host token_to_kv_pool, self.token_to_kv_pool_host
) )
......
...@@ -20,9 +20,12 @@ Memory pool. ...@@ -20,9 +20,12 @@ Memory pool.
SGLang has two levels of memory pool. SGLang has two levels of memory pool.
ReqToTokenPool maps a a request to its token locations. ReqToTokenPool maps a a request to its token locations.
BaseTokenToKVPool maps a token location to its KV cache data. TokenToKVPoolAllocator maps a token location to its KV cache data.
KVCache actually holds the physical kv cache. Allocation indices are allocated
by TokenToKVPoolAllocator
""" """
import abc
import logging import logging
import threading import threading
from enum import IntEnum from enum import IntEnum
...@@ -89,7 +92,7 @@ class ReqToTokenPool: ...@@ -89,7 +92,7 @@ class ReqToTokenPool:
self.free_slots = list(range(self.size)) self.free_slots = list(range(self.size))
class BaseTokenToKVPool: class TokenToKVPoolAllocator:
"""A memory pool that maps a token location to its kv cache data.""" """A memory pool that maps a token location to its kv cache data."""
def __init__( def __init__(
...@@ -100,11 +103,6 @@ class BaseTokenToKVPool: ...@@ -100,11 +103,6 @@ class BaseTokenToKVPool:
): ):
self.size = size self.size = size
self.dtype = dtype self.dtype = dtype
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.device = device self.device = device
self.free_slots = None self.free_slots = None
...@@ -148,15 +146,22 @@ class BaseTokenToKVPool: ...@@ -148,15 +146,22 @@ class BaseTokenToKVPool:
self.is_in_free_group = False self.is_in_free_group = False
self.free_group = [] self.free_group = []
class KVCache(abc.ABC):
@abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor: def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor: def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def set_kv_buffer( def set_kv_buffer(
self, self,
layer: RadixAttention, layer: RadixAttention,
...@@ -167,7 +172,7 @@ class BaseTokenToKVPool: ...@@ -167,7 +172,7 @@ class BaseTokenToKVPool:
raise NotImplementedError() raise NotImplementedError()
class MHATokenToKVPool(BaseTokenToKVPool): class MHATokenToKVPool(KVCache):
def __init__( def __init__(
self, self,
...@@ -179,8 +184,14 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -179,8 +184,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
device: str, device: str,
enable_memory_saver: bool, enable_memory_saver: bool,
): ):
super().__init__(size, dtype, device) self.size = size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.memory_saver_adapter = TorchMemorySaverAdapter.create( self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver enable=enable_memory_saver
) )
...@@ -297,7 +308,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): ...@@ -297,7 +308,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_2[loc] = src_2.to(dtype).view(store_dtype) dst_2[loc] = src_2.to(dtype).view(store_dtype)
class MLATokenToKVPool(BaseTokenToKVPool): class MLATokenToKVPool(KVCache):
def __init__( def __init__(
self, self,
size: int, size: int,
...@@ -308,8 +319,14 @@ class MLATokenToKVPool(BaseTokenToKVPool): ...@@ -308,8 +319,14 @@ class MLATokenToKVPool(BaseTokenToKVPool):
device: str, device: str,
enable_memory_saver: bool, enable_memory_saver: bool,
): ):
super().__init__(size, dtype, device) self.size = size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
memory_saver_adapter = TorchMemorySaverAdapter.create( memory_saver_adapter = TorchMemorySaverAdapter.create(
...@@ -356,7 +373,7 @@ class MLATokenToKVPool(BaseTokenToKVPool): ...@@ -356,7 +373,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
self.kv_buffer[layer_id][loc] = cache_k self.kv_buffer[layer_id][loc] = cache_k
class DoubleSparseTokenToKVPool(BaseTokenToKVPool): class DoubleSparseTokenToKVPool(KVCache):
def __init__( def __init__(
self, self,
size: int, size: int,
...@@ -368,8 +385,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool): ...@@ -368,8 +385,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
heavy_channel_num: int, heavy_channel_num: int,
enable_memory_saver: bool, enable_memory_saver: bool,
): ):
super().__init__(size, dtype, device) self.size = size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
memory_saver_adapter = TorchMemorySaverAdapter.create( memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver enable=enable_memory_saver
) )
...@@ -437,12 +460,12 @@ def synchronized(func): ...@@ -437,12 +460,12 @@ def synchronized(func):
return wrapper return wrapper
class MLATokenToKVPoolHost: class MHATokenToKVPoolHost:
def __init__( def __init__(
self, self,
device_pool: MHATokenToKVPool, device_pool: MHATokenToKVPool,
host_to_device_ratio: float = 4.0, host_to_device_ratio: float = 2.0,
pin_memory: bool = False, # no need to use pin memory with the double buffering pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu", device: str = "cpu",
): ):
......
...@@ -26,8 +26,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple ...@@ -26,8 +26,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
import torch import torch
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
...@@ -79,11 +80,11 @@ class RadixCache(BasePrefixCache): ...@@ -79,11 +80,11 @@ class RadixCache(BasePrefixCache):
def __init__( def __init__(
self, self,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: BaseTokenToKVPool, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
disable: bool = False, disable: bool = False,
): ):
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.disable = disable self.disable = disable
self.reset() self.reset()
...@@ -139,7 +140,7 @@ class RadixCache(BasePrefixCache): ...@@ -139,7 +140,7 @@ class RadixCache(BasePrefixCache):
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_ids_len req.req_pool_idx, :token_ids_len
] ]
self.token_to_kv_pool.free(kv_indices) self.token_to_kv_pool_allocator.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
return return
...@@ -151,7 +152,9 @@ class RadixCache(BasePrefixCache): ...@@ -151,7 +152,9 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone()) new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len]) self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# Remove req slot release the cache lock # Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
...@@ -171,7 +174,9 @@ class RadixCache(BasePrefixCache): ...@@ -171,7 +174,9 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone()) new_prefix_len = self.insert(token_ids, kv_indices.clone())
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len]) self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# The prefix indices could be updated, reuse it # The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(token_ids) new_indices, new_last_node = self.match_prefix(token_ids)
......
...@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool, MHATokenToKVPool,
MLATokenToKVPool, MLATokenToKVPool,
ReqToTokenPool, ReqToTokenPool,
TokenToKVPoolAllocator,
) )
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -98,6 +99,8 @@ class ModelRunner: ...@@ -98,6 +99,8 @@ class ModelRunner:
nccl_port: int, nccl_port: int,
server_args: ServerArgs, server_args: ServerArgs,
is_draft_worker: bool = False, is_draft_worker: bool = False,
req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
): ):
# Parse args # Parse args
self.model_config = model_config self.model_config = model_config
...@@ -115,6 +118,8 @@ class ModelRunner: ...@@ -115,6 +118,8 @@ class ModelRunner:
self.spec_algorithm = SpeculativeAlgorithm.from_string( self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
) )
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
# Model-specific adjustment # Model-specific adjustment
if ( if (
...@@ -257,8 +262,8 @@ class ModelRunner: ...@@ -257,8 +262,8 @@ class ModelRunner:
def init_torch_distributed(self): def init_torch_distributed(self):
logger.info("Init torch distributed begin.") logger.info("Init torch distributed begin.")
torch.get_device_module(self.device).set_device(self.gpu_id) torch.get_device_module(self.device).set_device(self.gpu_id)
if self.device == "cuda": if self.device == "cuda":
backend = "nccl" backend = "nccl"
elif self.device == "xpu": elif self.device == "xpu":
...@@ -660,12 +665,25 @@ class ModelRunner: ...@@ -660,12 +665,25 @@ class ModelRunner:
if not self.spec_algorithm.is_none(): if not self.spec_algorithm.is_none():
if self.is_draft_worker: if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size self.max_total_num_tokens = self.server_args.draft_runner_cache_size
max_num_reqs = self.server_args.max_num_reqs
else: else:
# We are sharing the `token_to_kv_pool`, and both verify and draft tokens
# can be concurrently allocated, so we should give a headroom for it.
self.server_args.draft_runner_cache_size = ( self.server_args.draft_runner_cache_size = (
self.max_total_num_tokens self.max_total_num_tokens
+ max_num_reqs * self.server_args.speculative_num_steps # draft
+ max_num_reqs
* self.server_args.speculative_num_steps
* self.server_args.speculative_eagle_topk
# verify
+ max_num_reqs * self.server_args.speculative_num_draft_tokens
# buffer
+ 100 + 100
) )
# Target worker and draft worker shares the same indices for the
# token_to_kv_pool, so we should make sure to match max_total_num_tokens.
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
self.server_args.max_num_reqs = max_num_reqs
if max_total_tokens is not None: if max_total_tokens is not None:
if max_total_tokens > self.max_total_num_tokens: if max_total_tokens > self.max_total_num_tokens:
...@@ -681,12 +699,25 @@ class ModelRunner: ...@@ -681,12 +699,25 @@ class ModelRunner:
"Not enough memory. Please try to increase --mem-fraction-static." "Not enough memory. Please try to increase --mem-fraction-static."
) )
self.req_to_token_pool = ReqToTokenPool( if self.req_to_token_pool is None:
size=max_num_reqs + 1, self.req_to_token_pool = ReqToTokenPool(
max_context_len=self.model_config.context_len + 4, size=max_num_reqs + 1,
device=self.device, max_context_len=self.model_config.context_len + 4,
enable_memory_saver=self.server_args.enable_memory_saver, device=self.device,
) enable_memory_saver=self.server_args.enable_memory_saver,
)
else:
# Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker
if self.token_to_kv_pool_allocator is None:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
)
else:
assert self.is_draft_worker
if ( if (
self.model_config.attention_arch == AttentionArch.MLA self.model_config.attention_arch == AttentionArch.MLA
......
...@@ -280,11 +280,16 @@ class ServerArgs: ...@@ -280,11 +280,16 @@ class ServerArgs:
self.disable_overlap_schedule = True self.disable_overlap_schedule = True
self.prefill_only_one_req = True self.prefill_only_one_req = True
self.disable_cuda_graph_padding = True self.disable_cuda_graph_padding = True
self.disable_radix_cache = True if self.max_running_requests is None:
self.chunked_prefill_size = -1 self.max_running_requests = 32
logger.info( logger.info(
f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding." "Overlap scheduler are disabled because of using "
"eagle speculative decoding."
"Max running request set to 32 because of using eagle speculative decoding."
) )
# The token generated from the verify step is counted.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
assert self.speculative_num_steps < self.speculative_num_draft_tokens
# GGUF # GGUF
if ( if (
......
...@@ -3,14 +3,8 @@ ...@@ -3,14 +3,8 @@
from typing import List from typing import List
import torch import torch
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
from sglang.srt.utils import is_cuda_available from sgl_kernel import build_tree_kernel_efficient as sgl_build_tree_kernel_efficient
if is_cuda_available():
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
from sgl_kernel import (
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
)
def build_tree_kernel_efficient_preprocess( def build_tree_kernel_efficient_preprocess(
......
...@@ -21,7 +21,6 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -21,7 +21,6 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.speculative.eagle_utils import EagleDraftInput from sglang.srt.speculative.eagle_utils import EagleDraftInput
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_worker import EAGLEWorker from sglang.srt.speculative.eagle_worker import EAGLEWorker
......
from __future__ import annotations from __future__ import annotations
import dataclasses from dataclasses import dataclass
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, Dict, List
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
create_flashinfer_kv_indices_triton, from sglang.srt.layers.logits_processor import LogitsProcessorOutput
) from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import ( from sglang.srt.speculative.build_eagle_tree import (
build_tree_kernel, build_tree_kernel,
...@@ -25,7 +26,7 @@ if TYPE_CHECKING: ...@@ -25,7 +26,7 @@ if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
@dataclasses.dataclass @dataclass
class EagleDraftInput: class EagleDraftInput:
# The inputs for decode # The inputs for decode
# shape: (b, topk) # shape: (b, topk)
...@@ -46,57 +47,46 @@ class EagleDraftInput: ...@@ -46,57 +47,46 @@ class EagleDraftInput:
kv_indptr: torch.Tensor = None kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None kv_indices: torch.Tensor = None
# indices of unfinished requests during extend-after-decode
# e.g. [0, 2, 3, 4] if only the 1st request is finished
keep_indices: List[int] = None
def prepare_for_extend(self, batch: ScheduleBatch): def prepare_for_extend(self, batch: ScheduleBatch):
req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) assert batch.input_ids.numel() == batch.out_cache_loc.shape[0]
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) # Prefill only generate 1 token.
batch.out_cache_loc = out_cache_loc assert len(self.verified_id) == len(batch.seq_lens)
pt = 0 pt = 0
for i, req in enumerate(batch.reqs): for i, extend_len in enumerate(batch.extend_lens):
req.req_pool_idx = req_pool_indices[i] input_ids = batch.input_ids[pt : pt + extend_len]
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) batch.input_ids[pt : pt + extend_len] = torch.concat(
assert seq_len - pre_len == req.extend_input_len (input_ids[1:], self.verified_id[i].reshape(1))
if pre_len > 0:
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
:pre_len
] = req.prefix_indices
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
out_cache_loc[pt : pt + req.extend_input_len]
) )
pt += req.extend_input_len
# TODO: support batching inputs
assert len(batch.extend_lens) == 1
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps): def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
accept_length_cpu = batch.spec_info.accept_length_cpu accept_length_cpu = batch.spec_info.accept_length_cpu
batch.extend_lens = [x + 1 for x in accept_length_cpu] batch.extend_lens = [x + 1 for x in accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens)
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
seq_lens_cpu = batch.seq_lens.tolist() seq_lens_cpu = batch.seq_lens.tolist()
assert len(batch.req_pool_indices) == len(batch.reqs)
pt = 0 pt = 0
i = 0 i = 0
for req in batch.reqs: self.keep_indices = []
for idx, req in enumerate(batch.reqs):
if req.finished(): if req.finished():
continue continue
self.keep_indices.append(idx)
# assert seq_len - pre_len == req.extend_input_len # assert seq_len - pre_len == req.extend_input_len
input_len = batch.extend_lens[i] input_len = batch.extend_lens[i]
seq_len = seq_lens_cpu[i] seq_len = seq_lens_cpu[i]
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
seq_len - input_len : seq_len
] = batch.out_cache_loc[pt : pt + input_len]
pt += input_len pt += input_len
i += 1 i += 1
assert pt == batch.out_cache_loc.shape[0]
self.positions = torch.empty_like(self.verified_id) self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long) new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
self.accept_length.add_(1) self.accept_length.add_(1)
create_extend_spec_info[(self.accept_length.numel(),)]( create_extend_spec_info[(self.accept_length.numel(),)](
...@@ -117,14 +107,22 @@ class EagleDraftInput: ...@@ -117,14 +107,22 @@ class EagleDraftInput:
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor, paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
req_to_token: torch.Tensor, req_to_token: torch.Tensor,
): ):
bs = self.accept_length.numel() bs = self.accept_length.numel()
keep_indices = torch.tensor(self.keep_indices, device=req_pool_indices.device)
req_pool_indices = req_pool_indices[keep_indices]
assert req_pool_indices.shape[0] == bs
assert req_pool_indices.shape[0] == paged_kernel_lens.shape[0]
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
# TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync.
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
...@@ -162,7 +160,21 @@ class EagleDraftInput: ...@@ -162,7 +160,21 @@ class EagleDraftInput:
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
@dataclasses.dataclass @dataclass
class EagleVerifyOutput:
# Draft input batch
draft_input: EagleDraftInput
# Logit outputs from target worker
logits_output: LogitsProcessorOutput
# Accepeted token ids including the bonus token
verified_id: torch.Tensor
# Accepeted token length per sequence in a batch in CPU.
accept_length_per_req_cpu: List[int]
# Accepeted indices from logits_output.next_token_logits
accepeted_indices_cpu: List[int]
@dataclass
class EagleVerifyInput: class EagleVerifyInput:
draft_token: torch.Tensor draft_token: torch.Tensor
custom_mask: torch.Tensor custom_mask: torch.Tensor
...@@ -267,6 +279,7 @@ class EagleVerifyInput: ...@@ -267,6 +279,7 @@ class EagleVerifyInput:
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor, paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
req_to_token: torch.Tensor, req_to_token: torch.Tensor,
): ):
batch_size = len(req_pool_indices) batch_size = len(req_pool_indices)
...@@ -285,7 +298,11 @@ class EagleVerifyInput: ...@@ -285,7 +298,11 @@ class EagleVerifyInput:
paged_kernel_lens = paged_kernel_lens + self.draft_token_num paged_kernel_lens = paged_kernel_lens + self.draft_token_num
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") kv_indices = torch.empty(
paged_kernel_lens_sum + self.draft_token_num * batch_size,
dtype=torch.int32,
device="cuda",
)
create_flashinfer_kv_indices_triton[(batch_size,)]( create_flashinfer_kv_indices_triton[(batch_size,)](
req_to_token, req_to_token,
...@@ -298,7 +315,21 @@ class EagleVerifyInput: ...@@ -298,7 +315,21 @@ class EagleVerifyInput:
) )
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: def verify(
self,
batch: ScheduleBatch,
logits_output: torch.Tensor,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
) -> torch.Tensor:
"""WARNING: This API in-place modifies the states of logits_output
Verify and find accepted tokens based on logits output and batch
(which contains spec decoding information).
This API updates values inside logits_output based on the accepted
tokens. I.e., logits_output.next_token_logits only contains
accepeted token logits.
"""
draft_token = torch.cat( draft_token = torch.cat(
[self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")], [self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
dim=-1, dim=-1,
...@@ -367,7 +398,6 @@ class EagleVerifyInput: ...@@ -367,7 +398,6 @@ class EagleVerifyInput:
new_accept_index = [] new_accept_index = []
unfinished_index = [] unfinished_index = []
finished_extend_len = {} # {rid:accept_length + 1}
accept_index_cpu = accept_index.tolist() accept_index_cpu = accept_index.tolist()
predict_cpu = predict.tolist() predict_cpu = predict.tolist()
has_finished = False has_finished = False
...@@ -382,7 +412,6 @@ class EagleVerifyInput: ...@@ -382,7 +412,6 @@ class EagleVerifyInput:
id = predict_cpu[idx] id = predict_cpu[idx]
# if not found_finished: # if not found_finished:
req.output_ids.append(id) req.output_ids.append(id)
finished_extend_len[req.rid] = j + 1
req.check_finished() req.check_finished()
if req.finished(): if req.finished():
has_finished = True has_finished = True
...@@ -400,11 +429,10 @@ class EagleVerifyInput: ...@@ -400,11 +429,10 @@ class EagleVerifyInput:
accept_index = accept_index[accept_index != -1] accept_index = accept_index[accept_index != -1]
accept_length_cpu = accept_length.tolist() accept_length_cpu = accept_length.tolist()
verified_id = predict[accept_index] verified_id = predict[accept_index]
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False evict_mask[accept_index] = False
mem_need_free_idx = batch.out_cache_loc[evict_mask] mem_need_free_idx = batch.out_cache_loc[evict_mask]
batch.token_to_kv_pool.free(mem_need_free_idx) token_to_kv_pool_allocator.free(mem_need_free_idx)
assign_req_to_token_pool[(bs,)]( assign_req_to_token_pool[(bs,)](
batch.req_pool_indices, batch.req_pool_indices,
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
...@@ -427,20 +455,16 @@ class EagleVerifyInput: ...@@ -427,20 +455,16 @@ class EagleVerifyInput:
] ]
if has_finished: if has_finished:
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index] draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
unfinished_index
]
else: else:
draft_input.seq_lens_for_draft_extend = batch.seq_lens draft_input.seq_lens_for_draft_extend = batch.seq_lens
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
logits_output.next_token_logits = logits_output.next_token_logits[accept_index] return EagleVerifyOutput(
return ( draft_input=draft_input,
draft_input, logits_output=logits_output,
logits_output, verified_id=verified_id,
verified_id, accept_length_per_req_cpu=accept_length_cpu,
finished_extend_len, accepeted_indices_cpu=accept_index,
accept_length_cpu,
) )
...@@ -456,6 +480,18 @@ def eagle_verify_retrive( ...@@ -456,6 +480,18 @@ def eagle_verify_retrive(
draft_token_num: tl.constexpr, draft_token_num: tl.constexpr,
max_len_upper: tl.constexpr, max_len_upper: tl.constexpr,
): ):
"""
Args:
retrive_index: Pointer to indices of draft tokens
accept_mask: Mask indicating which tokens were accepted
retrive_cum_len: Cumulative lengths of token sequences in a batch
accept_index (out): Accept token indices
accept_length (out): Length of accepted tokens per sequence in a batch
extract_index (out): Index for last accepted tokens
max_len: Maximum length in a batch
draft_token_num: Number of tokens speculatively generated
max_len_upper An upper bound for token sequence length
"""
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
retrive_end = tl.load(retrive_cum_len + pid + 1) retrive_end = tl.load(retrive_cum_len + pid + 1)
...@@ -649,7 +685,7 @@ def generate_draft_decode_kv_indices( ...@@ -649,7 +685,7 @@ def generate_draft_decode_kv_indices(
tl.store(kv_indptr + zid, base + zid * iters) tl.store(kv_indptr + zid, base + zid * iters)
@torch.compile @torch.compile(dynamic=True)
def select_top_k_tokens( def select_top_k_tokens(
i: int, i: int,
topk_p: torch.Tensor, topk_p: torch.Tensor,
...@@ -671,13 +707,11 @@ def select_top_k_tokens( ...@@ -671,13 +707,11 @@ def select_top_k_tokens(
.unsqueeze(0) .unsqueeze(0)
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1) .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
) )
else: else:
# The later decode steps # The later decode steps
expand_scores = torch.mul( expand_scores = torch.mul(
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk) scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs_p, topk_cs_index = fast_topk( topk_cs_p, topk_cs_index = fast_topk(
expand_scores.flatten(start_dim=1), topk, dim=-1 expand_scores.flatten(start_dim=1), topk, dim=-1
) # (b, topk) ) # (b, topk)
......
import logging import logging
import os import os
import time import time
from typing import List, Optional, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
...@@ -22,11 +22,13 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( ...@@ -22,11 +22,13 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
from sglang.srt.speculative.eagle_utils import ( from sglang.srt.speculative.eagle_utils import (
EagleDraftInput, EagleDraftInput,
EagleVerifyInput, EagleVerifyInput,
EagleVerifyOutput,
assign_draft_cache_locs, assign_draft_cache_locs,
fast_topk, fast_topk,
select_top_k_tokens, select_top_k_tokens,
) )
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import get_available_gpu_memory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -42,12 +44,16 @@ class EAGLEWorker(TpModelWorker): ...@@ -42,12 +44,16 @@ class EAGLEWorker(TpModelWorker):
nccl_port: int, nccl_port: int,
target_worker: TpModelWorker, target_worker: TpModelWorker,
): ):
# Override context length with target model's context length
server_args.context_length = target_worker.model_runner.model_config.context_len
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
# Do not capture cuda graph in `super().__init__()` # Do not capture cuda graph in `super().__init__()`
# We will capture it later # We will capture it later
backup_disable_cuda_graph = server_args.disable_cuda_graph backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True server_args.disable_cuda_graph = True
# Load hot token ids # Lossy optimization by using hot tokens
if server_args.speculative_token_map is not None: if server_args.speculative_token_map is not None:
self.hot_token_id = load_token_map(server_args.speculative_token_map) self.hot_token_id = load_token_map(server_args.speculative_token_map)
server_args.json_model_override_args = ( server_args.json_model_override_args = (
...@@ -56,6 +62,12 @@ class EAGLEWorker(TpModelWorker): ...@@ -56,6 +62,12 @@ class EAGLEWorker(TpModelWorker):
else: else:
self.hot_token_id = None self.hot_token_id = None
# We share the allocator with a target worker. Draft/target worker
# owns its own KV cache.
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
target_worker.get_memory_pool()
)
# Init target worker # Init target worker
super().__init__( super().__init__(
gpu_id=gpu_id, gpu_id=gpu_id,
...@@ -64,9 +76,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -64,9 +76,10 @@ class EAGLEWorker(TpModelWorker):
nccl_port=nccl_port, nccl_port=nccl_port,
dp_rank=dp_rank, dp_rank=dp_rank,
is_draft_worker=True, is_draft_worker=True,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
) )
self.target_worker = target_worker self.target_worker = target_worker
self.finish_extend_len = []
# Parse arguments # Parse arguments
self.topk = server_args.speculative_eagle_topk self.topk = server_args.speculative_eagle_topk
...@@ -75,6 +88,9 @@ class EAGLEWorker(TpModelWorker): ...@@ -75,6 +88,9 @@ class EAGLEWorker(TpModelWorker):
server_args.speculative_algorithm server_args.speculative_algorithm
) )
self.server_args = server_args self.server_args = server_args
self.use_nan_detection = self.server_args.enable_nan_detection
self.device = self.model_runner.device
self.gpu_id = self.model_runner.gpu_id
# Share the embedding and lm_head # Share the embedding and lm_head
embed, head = self.target_worker.model_runner.model.get_embed_and_head() embed, head = self.target_worker.model_runner.model.get_embed_and_head()
...@@ -82,8 +98,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -82,8 +98,10 @@ class EAGLEWorker(TpModelWorker):
head = head.clone() head = head.clone()
self.hot_token_id = self.hot_token_id.to(head.device) self.hot_token_id = self.hot_token_id.to(head.device)
head.data = head.data[self.hot_token_id] head.data = head.data[self.hot_token_id]
self.model_runner.model.set_embed_and_head(embed, head) self.draft_model_runner.model.set_embed_and_head(embed, head)
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph self.draft_model_runner.server_args.disable_cuda_graph = (
backup_disable_cuda_graph
)
# Create multi-step attn backends and cuda graph runners # Create multi-step attn backends and cuda graph runners
if server_args.attention_backend == "flashinfer": if server_args.attention_backend == "flashinfer":
...@@ -111,7 +129,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -111,7 +129,7 @@ class EAGLEWorker(TpModelWorker):
f"EAGLE is not supportted in attention backend {server_args.attention_backend}" f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
) )
self.model_runner.draft_attn_backend = self.draft_attn_backend self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
self.init_cuda_graphs() self.init_cuda_graphs()
def init_cuda_graphs(self): def init_cuda_graphs(self):
...@@ -122,55 +140,81 @@ class EAGLEWorker(TpModelWorker): ...@@ -122,55 +140,81 @@ class EAGLEWorker(TpModelWorker):
return return
tic = time.time() tic = time.time()
logger.info("Capture cuda graph begin. This can take up to several minutes.") logger.info(
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s") logger.info(
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
def forward_batch_speculative_generation(self, batch: ScheduleBatch): @property
def draft_model_runner(self):
return self.model_runner
def forward_batch_speculative_generation(
self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
"""Run speculative decoding forward.
NOTE: Many states of batch is modified as you go through. It is not guaranteed
the final output batch doesn't have the same state as the input.
Args:
batch: The batch to run forward. The state of the batch is modified as it runs.
Returns:
A tuple of the final logit output of the target model, next tokens accepeted,
the batch id (used for overlap schedule), and number of accepeted tokens.
"""
assert not batch.spec_algorithm.is_none()
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
# Draft spec_info, to_free_cache_loc = self.draft(batch)
spec_info: EagleVerifyInput = self.draft(batch) logits_output, verify_output, model_worker_batch = self.verify(
batch, spec_info
# Verify )
( # Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
next_draft_input, self.token_to_kv_pool_allocator.free(to_free_cache_loc)
logits_output, # if it is None, means all requests are finished
verified_id,
self.finish_extend_len,
accept_length_cpu,
model_worker_batch,
) = self.verify(batch, spec_info)
batch.spec_info = next_draft_input
# if it is None, means all requsets are finished
if batch.spec_info.verified_id is not None: if batch.spec_info.verified_id is not None:
self.forward_draft_extend_after_decode(batch) self.forward_draft_extend_after_decode(batch)
return ( return (
logits_output, logits_output,
verified_id, verify_output.verified_id,
model_worker_batch, model_worker_batch.bid,
sum(accept_length_cpu), sum(verify_output.accept_length_per_req_cpu),
) )
else: else:
# Forward with the target model and get hidden states. logits_output, next_token_ids, bid = self.forward_target_extend(batch)
# We need the full hidden states to prefill the KV cache of the draft model. self.forward_draft_extend(
model_worker_batch = batch.get_model_worker_batch() batch, logits_output.hidden_states, next_token_ids
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
model_worker_batch
)
# Forward with the draft model.
batch.spec_info = EagleDraftInput(
hidden_states=logits_output.hidden_states,
verified_id=next_token_ids,
) )
self.forward_draft_extend(batch) return logits_output, next_token_ids, bid, 0
return logits_output, next_token_ids, model_worker_batch, 0
def forward_target_extend(
self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, List[int], int]:
"""Run the target extend.
Args:
batch: The batch to run. States could be modified.
Returns:
logits_output: The output of logits. It will contain the full hidden states.
next_token_ids: Next token ids generated.
bid: The model batch ID. Used for overlap schedule.
"""
# Forward with the target model and get hidden states.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
model_worker_batch
)
return logits_output, next_token_ids, model_worker_batch.bid
def draft(self, batch: ScheduleBatch): def draft(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner)
# Parse args # Parse args
num_seqs = batch.batch_size() num_seqs = batch.batch_size()
spec_info = batch.spec_info spec_info = batch.spec_info
...@@ -188,7 +232,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -188,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
self.topk, self.topk,
self.speculative_num_steps, self.speculative_num_steps,
) )
batch.out_cache_loc = out_cache_loc batch.out_cache_loc = out_cache_loc
batch.seq_lens_sum = torch.sum(batch.seq_lens).item() batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
...@@ -196,11 +239,12 @@ class EAGLEWorker(TpModelWorker): ...@@ -196,11 +239,12 @@ class EAGLEWorker(TpModelWorker):
# Get forward batch # Get forward batch
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run( can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
forward_batch forward_batch
) )
if can_cuda_graph: if can_cuda_graph:
score_list, token_list, parents_list = self.cuda_graph_runner.replay( score_list, token_list, parents_list = self.cuda_graph_runner.replay(
forward_batch forward_batch
...@@ -208,7 +252,9 @@ class EAGLEWorker(TpModelWorker): ...@@ -208,7 +252,9 @@ class EAGLEWorker(TpModelWorker):
else: else:
# Initialize attention backend # Initialize attention backend
self.draft_attn_backend.init_forward_metadata(forward_batch) self.draft_attn_backend.init_forward_metadata(forward_batch)
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
# Run forward steps # Run forward steps
score_list, token_list, parents_list = self.draft_forward(forward_batch) score_list, token_list, parents_list = self.draft_forward(forward_batch)
...@@ -225,10 +271,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -225,10 +271,7 @@ class EAGLEWorker(TpModelWorker):
batch.sampling_info.is_all_greedy, batch.sampling_info.is_all_greedy,
) )
# Free cache locations return ret, out_cache_loc
batch.token_to_kv_pool.free(out_cache_loc)
self._set_mem_pool(batch, self.target_worker.model_runner)
return ret
def draft_forward(self, forward_batch: ForwardBatch): def draft_forward(self, forward_batch: ForwardBatch):
# Parse args # Parse args
...@@ -278,6 +321,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -278,6 +321,7 @@ class EAGLEWorker(TpModelWorker):
logits_output = self.model_runner.model.forward( logits_output = self.model_runner.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids, forward_batch.positions, forward_batch
) )
self._detect_nan_if_needed(logits_output)
probs = torch.softmax(logits_output.next_token_logits, dim=-1) probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
if self.hot_token_id is not None: if self.hot_token_id is not None:
...@@ -294,71 +338,88 @@ class EAGLEWorker(TpModelWorker): ...@@ -294,71 +338,88 @@ class EAGLEWorker(TpModelWorker):
logits_output, _ = self.target_worker.forward_batch_generation( logits_output, _ = self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True model_worker_batch, skip_sample=True
) )
self._detect_nan_if_needed(logits_output)
spec_info.hidden_states = logits_output.hidden_states spec_info.hidden_states = logits_output.hidden_states
res = spec_info.verify(batch, logits_output) res: EagleVerifyOutput = spec_info.verify(
batch, logits_output, self.token_to_kv_pool_allocator
)
# Post process based on verified outputs.
# Pick indices that we care (accepeted)
logits_output.next_token_logits = logits_output.next_token_logits[
res.accepeted_indices_cpu
]
logits_output.hidden_states = logits_output.hidden_states[
res.accepeted_indices_cpu
]
# Prepare the batch for the next draft forwards.
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
return res + (model_worker_batch,) batch.spec_info = res.draft_input
def forward_draft_extend(self, batch: ScheduleBatch): return logits_output, res, model_worker_batch
self._set_mem_pool(batch, self.model_runner)
def forward_draft_extend(
self,
batch: ScheduleBatch,
hidden_states: torch.Tensor,
next_token_ids: List[int],
):
"""Run draft model extend. This API modifies the states of the batch.
Args:
batch: The batch to run.
hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
"""
batch.spec_info = EagleDraftInput(
hidden_states=hidden_states,
verified_id=next_token_ids,
)
batch.spec_info.prepare_for_extend(batch) batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(
logits_output = self.model_runner.forward(forward_batch) model_worker_batch, self.draft_model_runner
self.capture_for_decode(logits_output, forward_batch) )
self._set_mem_pool(batch, self.target_worker.model_runner) logits_output = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output)
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): assert isinstance(forward_batch.spec_info, EagleDraftInput)
batch.token_to_kv_pool = runner.token_to_kv_pool assert forward_batch.spec_info is batch.spec_info
batch.req_to_token_pool = runner.req_to_token_pool self.capture_for_decode(logits_output, forward_batch.spec_info)
def forward_draft_extend_after_decode(self, batch: ScheduleBatch): def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
seq_lens_backup = batch.seq_lens seq_lens_backup = batch.seq_lens
req_pool_indices_backup = batch.req_pool_indices
self._set_mem_pool(batch, self.model_runner)
batch.forward_mode = ForwardMode.DRAFT_EXTEND batch.forward_mode = ForwardMode.DRAFT_EXTEND
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps) batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
# We don't need logprob for this extend.
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(
logits_output = self.model_runner.forward(forward_batch) model_worker_batch, self.draft_model_runner
self.capture_for_decode(logits_output, forward_batch) )
self._set_mem_pool(batch, self.target_worker.model_runner) logits_output = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output)
assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info)
# Restore backup. # Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode` # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
batch.seq_lens = seq_lens_backup batch.seq_lens = seq_lens_backup
batch.req_pool_indices = req_pool_indices_backup
def capture_for_decode( def capture_for_decode(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
): ):
probs = torch.softmax(logits_output.next_token_logits, dim=-1) probs = torch.softmax(logits_output.next_token_logits, dim=-1)
spec_info = forward_batch.spec_info draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1) draft_input.hidden_states = logits_output.hidden_states
spec_info.hidden_states = logits_output.hidden_states
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
# Don't support prefix share now. if self.use_nan_detection:
def finish_request(self, reqs: Union[Req, List[Req]]): logits = logits_output.next_token_logits
if not isinstance(reqs, List): if torch.any(torch.isnan(logits)):
reqs = [reqs] logger.warning("Detected errors during sampling! NaN in the logits.")
for req in reqs: raise ValueError("Detected errors during sampling! NaN in the logits.")
if req.rid not in self.finish_extend_len:
continue
req_len = (
len(req.origin_input_ids)
+ len(req.output_ids)
- self.finish_extend_len[req.rid]
- 1
)
kv_indices = self.model_runner.req_to_token_pool.req_to_token[
req.req_pool_idx
][:req_len]
self.model_runner.token_to_kv_pool.free(kv_indices)
self.model_runner.req_to_token_pool.free(req.req_pool_idx)
def load_token_map(token_map_path: str) -> List[int]: def load_token_map(token_map_path: str) -> List[int]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment