Unverified Commit 9379da77 authored by Hanming Lu's avatar Hanming Lu Committed by GitHub
Browse files

SWA Prefix Cache (#7367)


Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
parent 0c55cbcf
...@@ -711,7 +711,6 @@ def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int) ...@@ -711,7 +711,6 @@ def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int)
i for i in range(num_hidden_layers) if (i + 1) % 4 == 0 i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
] ]
else: else:
raise ValueError( swa_attention_layer_ids = None
"get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration" full_attention_layer_ids = None
)
return swa_attention_layer_ids, full_attention_layer_ids return swa_attention_layer_ids, full_attention_layer_ids
...@@ -439,7 +439,15 @@ class DecodePreallocQueue: ...@@ -439,7 +439,15 @@ class DecodePreallocQueue:
else 0 else 0
) )
allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max( if self.scheduler.model_config.is_hybrid:
available_size = min(
self.token_to_kv_pool_allocator.full_available_size(),
self.token_to_kv_pool_allocator.swa_available_size(),
)
else:
available_size = self.token_to_kv_pool_allocator.available_size()
allocatable_tokens = available_size - max(
# preserve some space for future decode # preserve some space for future decode
self.num_reserved_decode_tokens self.num_reserved_decode_tokens
* ( * (
......
...@@ -26,6 +26,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend ...@@ -26,6 +26,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available, next_power_of_2 from sglang.srt.utils import is_flashinfer_available, next_power_of_2
...@@ -589,6 +590,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -589,6 +590,7 @@ class FlashInferIndicesUpdaterDecode:
self.kv_indptr = attn_backend.kv_indptr self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
# Dispatch the update function # Dispatch the update function
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
...@@ -655,6 +657,10 @@ class FlashInferIndicesUpdaterDecode: ...@@ -655,6 +657,10 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum_tmp = seq_lens_sum paged_kernel_lens_sum_tmp = seq_lens_sum
kv_start_idx_tmp = None kv_start_idx_tmp = None
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.call_begin_forward( self.call_begin_forward(
decode_wrappers[wrapper_id], decode_wrappers[wrapper_id],
req_pool_indices, req_pool_indices,
...@@ -663,6 +669,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -663,6 +669,7 @@ class FlashInferIndicesUpdaterDecode:
self.kv_indptr[wrapper_id], self.kv_indptr[wrapper_id],
kv_start_idx_tmp, kv_start_idx_tmp,
spec_info, spec_info,
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
) )
def update_cross_attention( def update_cross_attention(
...@@ -704,6 +711,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -704,6 +711,7 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor, kv_start_idx: torch.Tensor,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
use_sliding_window_kv_pool: bool = False,
): ):
if spec_info is None: if spec_info is None:
bs = len(req_pool_indices) bs = len(req_pool_indices)
...@@ -731,6 +739,14 @@ class FlashInferIndicesUpdaterDecode: ...@@ -731,6 +739,14 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1 bs = kv_indptr.shape[0] - 1
if use_sliding_window_kv_pool:
kv_last_index = kv_indptr[-1]
kv_indices[:kv_last_index] = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
kv_indices[:kv_last_index]
)
)
wrapper.begin_forward( wrapper.begin_forward(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
...@@ -765,6 +781,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -765,6 +781,7 @@ class FlashInferIndicesUpdaterPrefill:
self.kv_last_page_len = attn_backend.kv_last_page_len self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
# Dispatch the update function # Dispatch the update function
...@@ -848,6 +865,9 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -848,6 +865,9 @@ class FlashInferIndicesUpdaterPrefill:
paged_kernel_lens_sum = seq_lens_sum paged_kernel_lens_sum = seq_lens_sum
kv_start_idx = seq_lens - paged_kernel_lens kv_start_idx = seq_lens - paged_kernel_lens
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.call_begin_forward( self.call_begin_forward(
self.prefill_wrapper_ragged, self.prefill_wrapper_ragged,
...@@ -862,6 +882,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -862,6 +882,7 @@ class FlashInferIndicesUpdaterPrefill:
self.qo_indptr[wrapper_id], self.qo_indptr[wrapper_id],
use_ragged, use_ragged,
spec_info, spec_info,
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
) )
def update_cross_attention( def update_cross_attention(
...@@ -916,6 +937,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -916,6 +937,7 @@ class FlashInferIndicesUpdaterPrefill:
qo_indptr: torch.Tensor, qo_indptr: torch.Tensor,
use_ragged: bool, use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
use_sliding_window_kv_pool: bool = False,
): ):
bs = len(seq_lens) bs = len(seq_lens)
if spec_info is None: if spec_info is None:
...@@ -964,6 +986,14 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -964,6 +986,14 @@ class FlashInferIndicesUpdaterPrefill:
q_data_type=self.q_data_type, q_data_type=self.q_data_type,
) )
if use_sliding_window_kv_pool:
kv_last_index = kv_indptr[-1]
kv_indices[:kv_last_index] = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
kv_indices[:kv_last_index]
)
)
# cached part # cached part
wrapper_paged.begin_forward( wrapper_paged.begin_forward(
qo_indptr, qo_indptr,
......
...@@ -52,10 +52,14 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ...@@ -52,10 +52,14 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin, ScheduleBatchDisaggregationDecodeMixin,
) )
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import TimeStats from sglang.srt.metrics.collector import TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
...@@ -527,6 +531,8 @@ class Req: ...@@ -527,6 +531,8 @@ class Req:
self.last_node: Any = None self.last_node: Any = None
self.last_host_node: Any = None self.last_host_node: Any = None
self.host_hit_length = 0 self.host_hit_length = 0
# The node to lock until for swa radix tree lock ref
self.swa_uuid_for_lock: Optional[int] = None
# Whether or not if it is chunked. It increments whenever # Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is # it is chunked, and decrement whenever chunked request is
...@@ -745,6 +751,7 @@ class Req: ...@@ -745,6 +751,7 @@ class Req:
def reset_for_retract(self): def reset_for_retract(self):
self.prefix_indices = [] self.prefix_indices = []
self.last_node = None self.last_node = None
self.swa_uuid_for_lock = None
self.extend_input_len = 0 self.extend_input_len = 0
self.is_retracted = True self.is_retracted = True
self.input_token_logprobs = None self.input_token_logprobs = None
...@@ -813,6 +820,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -813,6 +820,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req_to_token_pool: ReqToTokenPool = None req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
tree_cache: BasePrefixCache = None tree_cache: BasePrefixCache = None
is_hybrid: bool = False
# Batch configs # Batch configs
model_config: ModelConfig = None model_config: ModelConfig = None
...@@ -918,11 +926,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -918,11 +926,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
): ):
return_logprob = any(req.return_logprob for req in reqs) return_logprob = any(req.return_logprob for req in reqs)
is_hybrid = False
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
assert isinstance(tree_cache, SWARadixCache) or isinstance(
tree_cache, SWAChunkCache
), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
is_hybrid = True
return cls( return cls(
reqs=reqs, reqs=reqs,
req_to_token_pool=req_to_token_pool, req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator, token_to_kv_pool_allocator=token_to_kv_pool_allocator,
tree_cache=tree_cache, tree_cache=tree_cache,
is_hybrid=is_hybrid,
model_config=model_config, model_config=model_config,
enable_overlap=enable_overlap, enable_overlap=enable_overlap,
return_logprob=return_logprob, return_logprob=return_logprob,
...@@ -953,9 +969,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -953,9 +969,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
return req_pool_indices return req_pool_indices
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False): def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
if self.token_to_kv_pool_allocator.available_size() < num_tokens: self._evict_tree_cache_if_needed(num_tokens)
if self.tree_cache is not None:
self.tree_cache.evict(num_tokens)
if backup_state: if backup_state:
state = self.token_to_kv_pool_allocator.backup_state() state = self.token_to_kv_pool_allocator.backup_state()
...@@ -966,7 +980,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -966,7 +980,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
error_msg = ( error_msg = (
f"{phase_str} out of memory. Try to lower your batch size.\n" f"{phase_str} out of memory. Try to lower your batch size.\n"
f"Try to allocate {num_tokens} tokens.\n" f"Try to allocate {num_tokens} tokens.\n"
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n" f"{self._available_and_evictable_str()}"
) )
logger.error(error_msg) logger.error(error_msg)
if self.tree_cache is not None: if self.tree_cache is not None:
...@@ -986,16 +1000,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -986,16 +1000,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_num_tokens: int, extend_num_tokens: int,
backup_state: bool = False, backup_state: bool = False,
): ):
if ( num_tokens = (
self.token_to_kv_pool_allocator.available_size()
< extend_num_tokens
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
):
if self.tree_cache is not None:
self.tree_cache.evict(
extend_num_tokens extend_num_tokens
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size, + len(seq_lens) * self.token_to_kv_pool_allocator.page_size
) )
self._evict_tree_cache_if_needed(num_tokens)
if backup_state: if backup_state:
state = self.token_to_kv_pool_allocator.backup_state() state = self.token_to_kv_pool_allocator.backup_state()
...@@ -1007,9 +1016,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1007,9 +1016,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
error_msg = ( error_msg = (
f"Prefill out of memory. Try to lower your batch size.\n" f"Prefill out of memory. Try to lower your batch size.\n"
f"Try to allocate {extend_num_tokens} tokens.\n" f"Try to allocate {extend_num_tokens} tokens.\n"
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n" f"{self._available_and_evictable_str()}"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n"
) )
logger.error(error_msg) logger.error(error_msg)
raise RuntimeError(error_msg) raise RuntimeError(error_msg)
...@@ -1025,14 +1032,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1025,14 +1032,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
last_loc: torch.Tensor, last_loc: torch.Tensor,
backup_state: bool = False, backup_state: bool = False,
): ):
if self.tree_cache is not None: num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
if (
self.token_to_kv_pool_allocator.available_size() self._evict_tree_cache_if_needed(num_tokens)
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
):
self.tree_cache.evict(
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
)
if backup_state: if backup_state:
state = self.token_to_kv_pool_allocator.backup_state() state = self.token_to_kv_pool_allocator.backup_state()
...@@ -1042,9 +1044,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1042,9 +1044,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
error_msg = ( error_msg = (
f"Decode out of memory. Try to lower your batch size.\n" f"Decode out of memory. Try to lower your batch size.\n"
f"Try to allocate {len(seq_lens)} tokens.\n" f"Try to allocate {len(seq_lens)} tokens.\n"
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n" f"{self._available_and_evictable_str()}"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n"
) )
logger.error(error_msg) logger.error(error_msg)
raise RuntimeError(error_msg) raise RuntimeError(error_msg)
...@@ -1181,7 +1181,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1181,7 +1181,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
) )
if isinstance(self.tree_cache, SWAChunkCache): if isinstance(self.tree_cache, SWAChunkCache):
self.tree_cache.evict( self.tree_cache.evict_swa(
req, pre_len, self.model_config.attention_chunk_size req, pre_len, self.model_config.attention_chunk_size
) )
...@@ -1371,17 +1371,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1371,17 +1371,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) )
def check_decode_mem(self, buf_multiplier=1): def check_decode_mem(self, buf_multiplier=1):
tokens_required = ( num_tokens = (
self.new_page_count_next_decode() self.new_page_count_next_decode()
* buf_multiplier * buf_multiplier
* self.token_to_kv_pool_allocator.page_size * self.token_to_kv_pool_allocator.page_size
) )
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
return True
self.tree_cache.evict(tokens_required)
return self.token_to_kv_pool_allocator.available_size() >= tokens_required self._evict_tree_cache_if_needed(num_tokens)
return self._is_available_size_sufficient(num_tokens)
def retract_decode(self, server_args: ServerArgs): def retract_decode(self, server_args: ServerArgs):
"""Retract the decoding requests when there is not enough memory.""" """Retract the decoding requests when there is not enough memory."""
...@@ -1414,19 +1411,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1414,19 +1411,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
) )
def _get_available_size():
if self.is_hybrid:
return min(
self.token_to_kv_pool_allocator.full_available_size(),
self.token_to_kv_pool_allocator.swa_available_size(),
)
else:
return self.token_to_kv_pool_allocator.available_size()
retracted_reqs = [] retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy() seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True first_iter = True
while ( while (
self.token_to_kv_pool_allocator.available_size() _get_available_size() < get_required_tokens(len(sorted_indices))
< get_required_tokens(len(sorted_indices))
or first_iter or first_iter
): ):
if len(sorted_indices) == 1: if len(sorted_indices) == 1:
# Corner case: only one request left # Corner case: only one request left
if self.is_hybrid:
full_available_size = (
self.token_to_kv_pool_allocator.full_available_size()
)
swa_available_size = (
self.token_to_kv_pool_allocator.swa_available_size()
)
assert (
full_available_size > 0 and swa_available_size > 0
), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
else:
assert ( assert (
self.token_to_kv_pool_allocator.available_size() > 0 self.token_to_kv_pool_allocator.available_size() > 0
), "No space left for only one request" ), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
break break
first_iter = False first_iter = False
...@@ -1458,15 +1474,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1458,15 +1474,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
# release the last node # release the last node
if self.is_hybrid:
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
else:
self.tree_cache.dec_lock_ref(req.last_node) self.tree_cache.dec_lock_ref(req.last_node)
# NOTE(lsyin): we should use the newly evictable memory instantly. # NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size = ( num_tokens = len(sorted_indices) * global_config.retract_decode_steps
len(sorted_indices) * global_config.retract_decode_steps self._evict_tree_cache_if_needed(num_tokens)
- self.token_to_kv_pool_allocator.available_size()
)
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size)
req.reset_for_retract() req.reset_for_retract()
...@@ -1559,7 +1574,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1559,7 +1574,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# free memory # free memory
if isinstance(self.tree_cache, SWAChunkCache): if isinstance(self.tree_cache, SWAChunkCache):
for req in self.reqs: for req in self.reqs:
self.tree_cache.evict( self.tree_cache.evict_swa(
req, req.seqlen - 1, self.model_config.attention_chunk_size req, req.seqlen - 1, self.model_config.attention_chunk_size
) )
...@@ -1778,6 +1793,53 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1778,6 +1793,53 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
is_extend_in_batch=self.is_extend_in_batch, is_extend_in_batch=self.is_extend_in_batch,
) )
def _evict_tree_cache_if_needed(
self,
num_tokens: int,
) -> None:
if isinstance(self.tree_cache, SWAChunkCache):
return
if self.is_hybrid:
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
if full_available_size < num_tokens or swa_available_size < num_tokens:
if self.tree_cache is not None:
full_num_tokens = max(0, num_tokens - full_available_size)
swa_num_tokens = max(0, num_tokens - swa_available_size)
self.tree_cache.evict(full_num_tokens, swa_num_tokens)
else:
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
if self.tree_cache is not None:
self.tree_cache.evict(num_tokens)
def _is_available_size_sufficient(self, num_tokens: int) -> bool:
if self.is_hybrid:
return (
self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
)
else:
return self.token_to_kv_pool_allocator.available_size() >= num_tokens
def _available_and_evictable_str(self) -> str:
if self.is_hybrid:
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
full_evictable_size = self.tree_cache.full_evictable_size()
swa_evictable_size = self.tree_cache.swa_evictable_size()
return (
f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
)
else:
available_size = self.token_to_kv_pool_allocator.available_size()
evictable_size = self.tree_cache.evictable_size()
return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
def __str__(self): def __str__(self):
return ( return (
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
......
...@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union ...@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
import torch import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
...@@ -311,22 +312,44 @@ class PrefillAdder: ...@@ -311,22 +312,44 @@ class PrefillAdder:
] ]
) )
self.is_hybrid = isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
@property @property
def rem_total_tokens(self): def rem_total_tokens(self):
return ( if self.is_hybrid:
available_and_evictable = min(
self.token_to_kv_pool_allocator.full_available_size()
+ self.tree_cache.full_evictable_size(),
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
else:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size() self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size() + self.tree_cache.evictable_size()
- self.rem_total_token_offset
) )
return available_and_evictable - self.rem_total_token_offset
@property @property
def cur_rem_tokens(self): def cur_rem_tokens(self):
return ( if self.is_hybrid:
available_and_evictable = min(
self.token_to_kv_pool_allocator.full_available_size()
+ self.tree_cache.full_evictable_size(),
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
else:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size() self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size() + self.tree_cache.evictable_size()
- self.cur_rem_token_offset
) )
return available_and_evictable - self.cur_rem_token_offset
def ceil_paged_tokens(self, tokens: int) -> int: def ceil_paged_tokens(self, tokens: int) -> int:
return -(-tokens // self.page_size) * self.page_size return -(-tokens // self.page_size) * self.page_size
...@@ -376,6 +399,13 @@ class PrefillAdder: ...@@ -376,6 +399,13 @@ class PrefillAdder:
@contextmanager @contextmanager
def _lock_node(self, last_node: TreeNode): def _lock_node(self, last_node: TreeNode):
if self.is_hybrid:
try:
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(last_node)
yield None
finally:
self.tree_cache.dec_lock_ref(last_node, swa_uuid_for_lock)
else:
try: try:
self.tree_cache.inc_lock_ref(last_node) self.tree_cache.inc_lock_ref(last_node)
yield None yield None
...@@ -422,6 +452,9 @@ class PrefillAdder: ...@@ -422,6 +452,9 @@ class PrefillAdder:
else: else:
add_req_state(req, insert_sort=True) add_req_state(req, insert_sort=True)
if not self.is_hybrid:
# Skip this logic for swa. The SWA has different memory management, and
# this mechanism is underestimating the memory usage.
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids) cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
tokens_freed = 0 tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
...@@ -499,6 +532,10 @@ class PrefillAdder: ...@@ -499,6 +532,10 @@ class PrefillAdder:
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens: if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
# Non-chunked prefill # Non-chunked prefill
self.can_run_list.append(req) self.can_run_list.append(req)
if self.is_hybrid:
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
req.swa_uuid_for_lock = swa_uuid_for_lock
else:
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
self._update_prefill_budget( self._update_prefill_budget(
prefix_len, prefix_len,
...@@ -520,6 +557,10 @@ class PrefillAdder: ...@@ -520,6 +557,10 @@ class PrefillAdder:
self.can_run_list.append(req) self.can_run_list.append(req)
self.new_chunked_req = req self.new_chunked_req = req
if self.is_hybrid:
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
req.swa_uuid_for_lock = swa_uuid_for_lock
else:
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
self._update_prefill_budget(prefix_len, trunc_len, 0) self._update_prefill_budget(prefix_len, trunc_len, 0)
......
...@@ -129,10 +129,10 @@ from sglang.srt.managers.session_controller import Session ...@@ -129,10 +129,10 @@ from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
...@@ -390,6 +390,14 @@ class Scheduler( ...@@ -390,6 +390,14 @@ class Scheduler(
global_server_args_dict.update(worker_global_server_args_dict) global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
# Hybrid
self.is_hybrid = self.tp_worker.is_hybrid
if self.is_hybrid:
self.sliding_window_size = self.tp_worker.sliding_window_size
self.full_tokens_per_layer, self.swa_tokens_per_layer = (
self.tp_worker.get_tokens_per_layer_info()
)
# Print debug info # Print debug info
if tp_rank == 0: if tp_rank == 0:
avail_mem = get_available_gpu_memory( avail_mem = get_available_gpu_memory(
...@@ -570,7 +578,7 @@ class Scheduler( ...@@ -570,7 +578,7 @@ class Scheduler(
server_args.chunked_prefill_size is not None server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache and server_args.disable_radix_cache
): ):
if self.model_config.is_hybrid: if self.is_hybrid:
ChunkCacheClass = SWAChunkCache ChunkCacheClass = SWAChunkCache
else: else:
ChunkCacheClass = ChunkCache ChunkCacheClass = ChunkCache
...@@ -603,6 +611,17 @@ class Scheduler( ...@@ -603,6 +611,17 @@ class Scheduler(
self.tp_worker.register_hicache_layer_transfer_counter( self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter self.tree_cache.cache_controller.layer_done_counter
) )
elif self.is_hybrid:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid mode does not support disaggregation yet"
self.tree_cache = SWARadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
sliding_window_size=self.sliding_window_size,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
)
else: else:
self.tree_cache = RadixCache( self.tree_cache = RadixCache(
...@@ -774,6 +793,7 @@ class Scheduler( ...@@ -774,6 +793,7 @@ class Scheduler(
else: else:
# When the server is idle, do self-check and re-init some states # When the server is idle, do self-check and re-init some states
self.check_memory() self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle() self.maybe_sleep_on_idle()
...@@ -819,6 +839,7 @@ class Scheduler( ...@@ -819,6 +839,7 @@ class Scheduler(
elif batch is None: elif batch is None:
# When the server is idle, do self-check and re-init some states # When the server is idle, do self-check and re-init some states
self.check_memory() self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle() self.maybe_sleep_on_idle()
...@@ -955,6 +976,7 @@ class Scheduler( ...@@ -955,6 +976,7 @@ class Scheduler(
# When the server is idle, self-check and re-init some states # When the server is idle, self-check and re-init some states
if server_is_idle: if server_is_idle:
self.check_memory() self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle() self.maybe_sleep_on_idle()
...@@ -1306,9 +1328,26 @@ class Scheduler( ...@@ -1306,9 +1328,26 @@ class Scheduler(
self.last_input_throughput = self.last_prefill_tokens / gap_latency self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens self.last_prefill_tokens = adder.log_input_tokens
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage( if self.is_hybrid:
self.tree_cache.evictable_size() (
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, "
) )
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"token usage: {token_usage:.2f}, "
num_new_seq = len(can_run_list) num_new_seq = len(can_run_list)
f = ( f = (
...@@ -1316,7 +1355,7 @@ class Scheduler( ...@@ -1316,7 +1355,7 @@ class Scheduler(
f"#new-seq: {num_new_seq}, " f"#new-seq: {num_new_seq}, "
f"#new-token: {adder.log_input_tokens}, " f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, " f"#cached-token: {adder.log_hit_tokens}, "
f"{usage_msg}" f"{token_msg}"
) )
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
...@@ -1338,7 +1377,7 @@ class Scheduler( ...@@ -1338,7 +1377,7 @@ class Scheduler(
) )
self.stats.num_running_reqs = running_bs self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used self.stats.num_used_tokens = num_used
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2) self.stats.token_usage = round(token_usage, 2)
self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.cache_hit_rate = cache_hit_rate self.stats.cache_hit_rate = cache_hit_rate
...@@ -1361,16 +1400,35 @@ class Scheduler( ...@@ -1361,16 +1400,35 @@ class Scheduler(
self.last_gen_throughput = self.num_generated_tokens / gap_latency self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0 self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs) num_running_reqs = len(batch.reqs)
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage( if self.is_hybrid:
self.tree_cache.evictable_size() (
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, "
f"#swa token: {swa_num_used}, "
f"swa token usage: {swa_token_usage:.2f}, "
) )
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
if RECORD_STEP_TIME: if RECORD_STEP_TIME:
self.step_time_dict[num_running_reqs].append( self.step_time_dict[num_running_reqs].append(
gap_latency / self.server_args.decode_log_interval gap_latency / self.server_args.decode_log_interval
) )
msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}" msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
spec_accept_length = 0 spec_accept_length = 0
...@@ -1398,7 +1456,7 @@ class Scheduler( ...@@ -1398,7 +1456,7 @@ class Scheduler(
if self.enable_metrics: if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens self.stats.token_usage = round(token_usage, 2)
self.stats.cache_hit_rate = 0.0 self.stats.cache_hit_rate = 0.0
self.stats.gen_throughput = self.last_gen_throughput self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_queue_reqs = len(self.waiting_queue)
...@@ -1409,24 +1467,34 @@ class Scheduler( ...@@ -1409,24 +1467,34 @@ class Scheduler(
self._publish_kv_events() self._publish_kv_events()
def check_memory(self): def check_memory(self):
if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator): if self.is_hybrid:
available_token_size = self.token_to_kv_pool_allocator.full_available_size() (
full_num_used,
swa_num_used,
_,
_,
full_available_size,
full_evictable_size,
swa_available_size,
swa_evictable_size,
) = self._get_swa_token_info()
memory_leak = full_num_used != 0 or swa_num_used != 0
token_msg = (
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
)
else: else:
available_token_size = self.token_to_kv_pool_allocator.available_size() _, _, available_size, evictable_size = self._get_token_info()
available_size = available_token_size + self.tree_cache.evictable_size()
protected_size = self.tree_cache.protected_size() protected_size = self.tree_cache.protected_size()
memory_leak = available_size != ( memory_leak = (available_size + evictable_size) != (
self.max_total_num_tokens self.max_total_num_tokens
if not self.enable_hierarchical_cache if not self.enable_hierarchical_cache
else self.max_total_num_tokens - protected_size else self.max_total_num_tokens - protected_size
) )
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
if memory_leak: if memory_leak:
msg = ( msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
"token_to_kv_pool_allocator memory leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
f"{available_token_size=}\n"
f"{self.tree_cache.evictable_size()=}\n"
)
raise ValueError(msg) raise ValueError(msg)
if self.disaggregation_mode == DisaggregationMode.DECODE: if self.disaggregation_mode == DisaggregationMode.DECODE:
...@@ -1450,20 +1518,66 @@ class Scheduler( ...@@ -1450,20 +1518,66 @@ class Scheduler(
and time.perf_counter() > self.metrics_collector.last_log_time + 30 and time.perf_counter() > self.metrics_collector.last_log_time + 30
): ):
# During idle time, also collect metrics every 30 seconds. # During idle time, also collect metrics every 30 seconds.
num_used = self.max_total_num_tokens - ( if self.is_hybrid:
self.token_to_kv_pool_allocator.available_size() (
+ self.tree_cache.evictable_size() full_num_used,
) swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
else:
num_used, token_usage, _, _ = self._get_token_info()
num_running_reqs = len(self.running_batch.reqs) num_running_reqs = len(self.running_batch.reqs)
self.stats.num_running_reqs = num_running_reqs self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens self.stats.token_usage = round(token_usage, 2)
self.stats.gen_throughput = 0 self.stats.gen_throughput = 0
self.stats.num_queue_reqs = len(self.waiting_queue) self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue) self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.metrics_collector.log_stats(self.stats) self.metrics_collector.log_stats(self.stats)
self._publish_kv_events() self._publish_kv_events()
def check_tree_cache(self):
if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
self.tree_cache.sanity_check()
def _get_token_info(self):
available_size = self.token_to_kv_pool_allocator.available_size()
evictable_size = self.tree_cache.evictable_size()
num_used = self.max_total_num_tokens - (available_size + evictable_size)
token_usage = num_used / self.max_total_num_tokens
return num_used, token_usage, available_size, evictable_size
def _get_swa_token_info(self):
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
full_evictable_size = self.tree_cache.full_evictable_size()
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
swa_evictable_size = self.tree_cache.swa_evictable_size()
full_num_used = self.full_tokens_per_layer - (
full_available_size + full_evictable_size
)
swa_num_used = self.swa_tokens_per_layer - (
swa_available_size + swa_evictable_size
)
full_token_usage = full_num_used / self.full_tokens_per_layer
swa_token_usage = swa_num_used / self.swa_tokens_per_layer
return (
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
full_available_size,
full_evictable_size,
swa_available_size,
swa_evictable_size,
)
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch # Merge the prefill batch into the running batch
chunked_req_to_exclude = set() chunked_req_to_exclude = set()
...@@ -2042,11 +2156,30 @@ class Scheduler( ...@@ -2042,11 +2156,30 @@ class Scheduler(
if not disable_request_logging(): if not disable_request_logging():
# Print batch size and memory pool info to check whether there are de-sync issues. # Print batch size and memory pool info to check whether there are de-sync issues.
if self.is_hybrid:
(
_,
_,
_,
_,
full_available_size,
full_evictable_size,
swa_available_size,
swa_evictable_size,
) = self._get_swa_token_info()
info_msg = (
f"{full_available_size=}, "
f"{full_evictable_size=}, "
f"{swa_available_size=}, "
f"{swa_evictable_size=}, "
)
else:
_, _, available_size, evictable_size = self._get_token_info()
info_msg = f"{available_size=}, " f"{evictable_size=}, "
logger.error( logger.error(
f"{self.cur_batch.batch_size()=}, " f"{self.cur_batch.batch_size()=}, "
f"{self.cur_batch.reqs=}, " f"{self.cur_batch.reqs=}, "
f"{self.token_to_kv_pool_allocator.available_size()=}, " f"{info_msg}"
f"{self.tree_cache.evictable_size()=}, "
) )
pyspy_dump_schedulers() pyspy_dump_schedulers()
...@@ -2101,6 +2234,19 @@ class Scheduler( ...@@ -2101,6 +2234,19 @@ class Scheduler(
def get_load(self): def get_load(self):
# TODO(lsyin): use dynamically maintained num_waiting_tokens # TODO(lsyin): use dynamically maintained num_waiting_tokens
if self.is_hybrid:
load_full = (
self.full_tokens_per_layer
- self.token_to_kv_pool_allocator.full_available_size()
- self.tree_cache.full_evictable_size()
)
load_swa = (
self.swa_tokens_per_layer
- self.token_to_kv_pool_allocator.swa_available_size()
- self.tree_cache.swa_evictable_size()
)
load = max(load_full, load_swa)
else:
load = ( load = (
self.max_total_num_tokens self.max_total_num_tokens
- self.token_to_kv_pool_allocator.available_size() - self.token_to_kv_pool_allocator.available_size()
......
...@@ -174,6 +174,20 @@ class TpModelWorker: ...@@ -174,6 +174,20 @@ class TpModelWorker:
self.model_runner.token_to_kv_pool.size, self.model_runner.token_to_kv_pool.size,
) )
@property
def sliding_window_size(self) -> Optional[int]:
return self.model_runner.sliding_window_size
@property
def is_hybrid(self) -> bool:
return self.model_runner.is_hybrid is not None
def get_tokens_per_layer_info(self):
return (
self.model_runner.full_max_total_num_tokens,
self.model_runner.swa_max_total_num_tokens,
)
def get_pad_input_ids_func(self): def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None) return getattr(self.model_runner.model, "pad_input_ids", None)
......
...@@ -102,6 +102,17 @@ class TpModelWorkerClient: ...@@ -102,6 +102,17 @@ class TpModelWorkerClient:
def get_worker_info(self): def get_worker_info(self):
return self.worker.get_worker_info() return self.worker.get_worker_info()
def get_tokens_per_layer_info(self):
return self.worker.get_tokens_per_layer_info()
@property
def sliding_window_size(self) -> Optional[int]:
return self.worker.sliding_window_size
@property
def is_hybrid(self) -> bool:
return self.worker.is_hybrid
def get_pad_input_ids_func(self): def get_pad_input_ids_func(self):
return self.worker.get_pad_input_ids_func() return self.worker.get_pad_input_ids_func()
......
...@@ -57,11 +57,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC): ...@@ -57,11 +57,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
def debug_print(self) -> str: def debug_print(self) -> str:
return "" return ""
def log_usage(self, evictable_size: int = 0):
num_used = self.size - (self.available_size() + evictable_size)
msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, "
return msg, num_used
def available_size(self): def available_size(self):
return len(self.free_pages) * self.page_size return len(self.free_pages) * self.page_size
...@@ -190,7 +185,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -190,7 +185,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping
def available_size(self): def available_size(self):
return min(self.full_available_size(), self.swa_available_size()) raise NotImplementedError()
def full_available_size(self): def full_available_size(self):
return self.full_attn_allocator.available_size() return self.full_attn_allocator.available_size()
...@@ -214,16 +209,6 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -214,16 +209,6 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
) )
return msg return msg
def log_usage(self, swa_evictable_size: int = 0, full_evictable_size: int = 0):
used_full = self.size_full - (self.full_available_size() + full_evictable_size)
used_swa = self.size_swa - (self.swa_available_size() + swa_evictable_size)
msg = (
f"#token: full={used_full}, swa={used_swa}, "
f"token usage: full={used_full / self.size_full:.2f}, "
f"swa={used_swa / self.size_swa:.2f}, "
)
return msg, used_full
def get_kvcache(self): def get_kvcache(self):
return self._kvcache return self._kvcache
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Tuple
import torch import torch
...@@ -56,15 +56,27 @@ class BasePrefixCache(ABC): ...@@ -56,15 +56,27 @@ class BasePrefixCache(ABC):
pass pass
@abstractmethod @abstractmethod
def dec_lock_ref(self, node: Any): def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
pass pass
def evictable_size(self): def evictable_size(self):
return 0 return 0
def full_evictable_size(self):
return 0
def swa_evictable_size(self):
return 0
def protected_size(self): def protected_size(self):
return 0 return 0
def full_protected_size(self):
return 0
def swa_protected_size(self):
return 0
def total_size(self): def total_size(self):
raise NotImplementedError() raise NotImplementedError()
......
...@@ -61,7 +61,7 @@ class ChunkCache(BasePrefixCache): ...@@ -61,7 +61,7 @@ class ChunkCache(BasePrefixCache):
def inc_lock_ref(self, node: Any): def inc_lock_ref(self, node: Any):
return 0 return 0
def dec_lock_ref(self, node: Any): def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
return 0 return 0
def pretty_print(self): def pretty_print(self):
...@@ -80,7 +80,7 @@ class SWAChunkCache(ChunkCache): ...@@ -80,7 +80,7 @@ class SWAChunkCache(ChunkCache):
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size) super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size)
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
def evict( def evict_swa(
self, self,
req: Req, req: Req,
prelen: int, prelen: int,
...@@ -95,3 +95,6 @@ class SWAChunkCache(ChunkCache): ...@@ -95,3 +95,6 @@ class SWAChunkCache(ChunkCache):
] ]
self.token_to_kv_pool_allocator.free_swa(free_slots) self.token_to_kv_pool_allocator.free_swa(free_slots)
req.evicted_seqlen_local = new_evicted_seqlen_local req.evicted_seqlen_local = new_evicted_seqlen_local
def evict(self, num_tokens: int):
pass
This diff is collapsed.
...@@ -275,6 +275,15 @@ class ModelRunner: ...@@ -275,6 +275,15 @@ class ModelRunner:
self.sampler = Sampler() self.sampler = Sampler()
self.load_model() self.load_model()
if (
not self.server_args.disable_hybrid_swa_memory
and self.sliding_window_size is not None
and self.sliding_window_size > 0
):
architectures = self.model_config.hf_config.architectures
if architectures and not any("Llama4" in arch for arch in architectures):
self.is_hybrid = self.model_config.is_hybrid = True
self.start_layer = getattr(self.model, "start_layer", 0) self.start_layer = getattr(self.model, "start_layer", 0)
self.end_layer = getattr( self.end_layer = getattr(
self.model, "end_layer", self.model_config.num_hidden_layers self.model, "end_layer", self.model_config.num_hidden_layers
...@@ -471,10 +480,6 @@ class ModelRunner: ...@@ -471,10 +480,6 @@ class ModelRunner:
if self.model_config.context_len > 8192: if self.model_config.context_len > 8192:
self.mem_fraction_static *= 0.85 self.mem_fraction_static *= 0.85
if self.is_hybrid and not server_args.disable_radix_cache:
logger.info("Automatically disable radix cache for hybrid cache.")
server_args.disable_radix_cache = True
def init_torch_distributed(self): def init_torch_distributed(self):
logger.info("Init torch distributed begin.") logger.info("Init torch distributed begin.")
...@@ -645,11 +650,15 @@ class ModelRunner: ...@@ -645,11 +650,15 @@ class ModelRunner:
) )
# Parse other args # Parse other args
self.sliding_window_size = ( self.sliding_window_size = None
self.model.get_attention_sliding_window_size() if hasattr(self.model, "get_attention_sliding_window_size"):
if hasattr(self.model, "get_attention_sliding_window_size") self.sliding_window_size = self.model.get_attention_sliding_window_size()
else None elif self.model_config.attention_chunk_size is not None:
self.sliding_window_size = self.model_config.attention_chunk_size
print(
f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
) )
self.dtype = self.model_config.dtype self.dtype = self.model_config.dtype
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
...@@ -992,8 +1001,53 @@ class ModelRunner: ...@@ -992,8 +1001,53 @@ class ModelRunner:
) )
self.max_total_num_tokens = self.full_max_total_num_tokens self.max_total_num_tokens = self.full_max_total_num_tokens
else: else:
raise ValueError( assert self.sliding_window_size is not None and self.sliding_window_size > 0
f"Unsupported model for hybrid cache: {self.model_config.hf_config.architectures}." full_attention_layer_ids = []
swa_attention_layer_ids = []
try:
layers = self.model.model.layers
except:
try:
layers = self.model.language_model.model.layers
except:
self.is_hybrid = False
return
for layer in layers:
if (
layer.self_attn.attn.sliding_window_size is None
or layer.self_attn.attn.sliding_window_size == -1
):
full_attention_layer_ids.append(layer.layer_id)
else:
swa_attention_layer_ids.append(layer.layer_id)
self.model_config.swa_attention_layer_ids = swa_attention_layer_ids
self.model_config.full_attention_layer_ids = full_attention_layer_ids
# Algorithm:
# Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens.
# - Find total # of tokens available across layers.
# - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio.
total_tokens = (
self.max_total_num_tokens * self.model_config.num_hidden_layers
)
full_layers_num = len(full_attention_layer_ids)
swa_layers_num = len(swa_attention_layer_ids)
swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio
# Solve the equations:
# 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens
# 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens
denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num
self.full_max_total_num_tokens = int(total_tokens / denominator)
self.swa_max_total_num_tokens = int(
self.full_max_total_num_tokens * swa_full_tokens_ratio
)
self.max_total_num_tokens = self.full_max_total_num_tokens
logger.info(
f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
) )
def init_memory_pool( def init_memory_pool(
...@@ -1072,7 +1126,6 @@ class ModelRunner: ...@@ -1072,7 +1126,6 @@ class ModelRunner:
// self.server_args.page_size // self.server_args.page_size
* self.server_args.page_size * self.server_args.page_size
) )
# create token size for hybrid cache # create token size for hybrid cache
if self.is_hybrid: if self.is_hybrid:
self.set_num_token_hybrid() self.set_num_token_hybrid()
......
...@@ -190,6 +190,7 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -190,6 +190,7 @@ class Gemma2DecoderLayer(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_id = layer_id
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Gemma2Attention( self.self_attn = Gemma2Attention(
layer_id=layer_id, layer_id=layer_id,
......
...@@ -63,6 +63,7 @@ class ServerArgs: ...@@ -63,6 +63,7 @@ class ServerArgs:
enable_multimodal: Optional[bool] = None enable_multimodal: Optional[bool] = None
revision: Optional[str] = None revision: Optional[str] = None
hybrid_kvcache_ratio: Optional[float] = None hybrid_kvcache_ratio: Optional[float] = None
swa_full_tokens_ratio: float = 0.8
impl: str = "auto" impl: str = "auto"
# Port for the HTTP server # Port for the HTTP server
...@@ -225,6 +226,7 @@ class ServerArgs: ...@@ -225,6 +226,7 @@ class ServerArgs:
enable_return_hidden_states: bool = False enable_return_hidden_states: bool = False
enable_triton_kernel_moe: bool = False enable_triton_kernel_moe: bool = False
warmups: Optional[str] = None warmups: Optional[str] = None
disable_hybrid_swa_memory: bool = False
# Debug tensor dumps # Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None debug_tensor_dump_output_folder: Optional[str] = None
...@@ -481,14 +483,22 @@ class ServerArgs: ...@@ -481,14 +483,22 @@ class ServerArgs:
model_arch = get_model_arch(self) model_arch = get_model_arch(self)
# Auto set draft_model_path DeepSeek-V3/R1
if model_arch == "DeepseekV3ForCausalLM": if model_arch == "DeepseekV3ForCausalLM":
# Auto set draft_model_path DeepSeek-V3/R1
if self.speculative_draft_model_path is None: if self.speculative_draft_model_path is None:
self.speculative_draft_model_path = self.model_path self.speculative_draft_model_path = self.model_path
else: else:
logger.warning( logger.warning(
"DeepSeek MTP does not require setting speculative_draft_model_path." "DeepSeek MTP does not require setting speculative_draft_model_path."
) )
elif "Llama4" in model_arch:
# TODO: remove this after Llama4 supports in other backends
if self.attention_backend != "fa3":
self.attention_backend = "fa3"
logger.warning(
"Llama4 requires using fa3 attention backend. "
"Attention backend is automatically set to fa3."
)
# Auto choose parameters # Auto choose parameters
if self.speculative_num_steps is None: if self.speculative_num_steps is None:
...@@ -852,6 +862,18 @@ class ServerArgs: ...@@ -852,6 +862,18 @@ class ServerArgs:
"(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)" "(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
), ),
) )
parser.add_argument(
"--swa-full-tokens-ratio",
type=float,
default=ServerArgs.swa_full_tokens_ratio,
help="The ratio of SWA layer KV tokens / full layer KV tokens, regardless of the number of swa:full layers. It should be between 0 and 1. "
"E.g. 0.5 means if each swa layer has 50 tokens, then each full layer has 100 tokens.",
)
parser.add_argument(
"--disable-hybrid-swa-memory",
action="store_true",
help="Disable the hybrid SWA memory.",
)
# Other runtime options # Other runtime options
parser.add_argument( parser.add_argument(
...@@ -1730,10 +1752,6 @@ class ServerArgs: ...@@ -1730,10 +1752,6 @@ class ServerArgs:
else: else:
self.lora_paths[lora_path] = lora_path self.lora_paths[lora_path] = lora_path
model_arch = get_model_arch(self)
if "Llama4" in model_arch and self.hybrid_kvcache_ratio is not None:
assert self.attention_backend == "fa3"
def prepare_server_args(argv: List[str]) -> ServerArgs: def prepare_server_args(argv: List[str]) -> ServerArgs:
""" """
......
import unittest
import torch
from sglang.srt.mem_cache.allocator import SWAKVPool, SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import SWARadixCache
class TestSWA(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
@classmethod
def tearDownClass(cls):
pass
def test_swa_memory_pool(self):
size = 16
size_swa = 16
num_head = 8
head_dim = 128
num_layers = 48
global_interval = 4
dtype = torch.bfloat16
device = "cuda"
full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)]
full_attention_layer_ids_set = set(full_attention_layer_ids)
swa_attention_layer_ids = [
i for i in range(num_layers) if i not in full_attention_layer_ids_set
]
pool = SWAKVPool(
size,
size_swa,
dtype,
num_head,
head_dim,
swa_attention_layer_ids,
full_attention_layer_ids,
device,
)
alloc = SWATokenToKVPoolAllocator(size, size_swa, dtype, device, pool)
assert alloc.available_size() == size + size_swa
index = alloc.alloc(1)
assert alloc.available_size() == size_swa + size_swa - 2
alloc.free_swa(index)
result = alloc.translate_loc_from_full_to_swa(index)
print(result)
def test_swa_radix_cache_1(self):
# args
req_size = 10
max_context_len = 128
kv_size = 128
kv_size_swa = 64
sliding_window_size = 4
num_head = 8
head_dim = 128
num_layers = 48
global_interval = 4
dtype = torch.bfloat16
device = "cuda"
full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)]
full_attention_layer_ids_set = set(full_attention_layer_ids)
swa_attention_layer_ids = [
i for i in range(num_layers) if i not in full_attention_layer_ids_set
]
# setup req to token pool
req_to_token_pool = ReqToTokenPool(
size=req_size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=False,
)
# setup kv pool
kv_pool = SWAKVPool(
kv_size,
kv_size_swa,
dtype,
num_head,
head_dim,
swa_attention_layer_ids,
full_attention_layer_ids,
device,
)
# setup token to kv pool allocator
allocator = SWATokenToKVPoolAllocator(
kv_size, kv_size_swa, dtype, device, kv_pool
)
# setup radix cache
tree = SWARadixCache(
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=allocator,
sliding_window_size=sliding_window_size,
page_size=1,
disable=False,
)
# test
print(
f"[Start] allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req1_token_ids, req1_kv_indices = [1, 2, 3], allocator.alloc(3)
assert len(req1_token_ids) == len(req1_kv_indices)
print(
f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}"
)
prefix_len = tree.insert(req1_token_ids, req1_kv_indices)
print(
f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req2_token_ids, req2_kv_indices = [1, 2, 3, 4, 5, 6, 7], allocator.alloc(7)
assert len(req2_token_ids) == len(req2_kv_indices)
print(
f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}"
)
prefix_len = tree.insert(req2_token_ids, req2_kv_indices)
print(
f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req3_token_ids, req3_kv_indices = [10, 11, 12], allocator.alloc(3)
assert len(req3_token_ids) == len(req3_kv_indices)
print(
f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}"
)
prefix_len = tree.insert(req3_token_ids, req3_kv_indices)
print(
f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req4_token_ids, req4_kv_indices = [1, 2, 3, 4, 5, 60, 70], allocator.alloc(7)
assert len(req4_token_ids) == len(req4_kv_indices)
print(
f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}"
)
prefix_len = tree.insert(req4_token_ids, req4_kv_indices)
print(
f"req4: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
tree.pretty_print()
full_num_tokens, swa_num_tokens = 1, 0
print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
tree.pretty_print()
full_num_tokens, swa_num_tokens = 0, 1
print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
tree.pretty_print()
full_num_tokens, swa_num_tokens = 1, 2
print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
tree.pretty_print()
req5_token_ids = [1, 2, 3, 4, 5]
kv_indices, last_node = tree.match_prefix(req5_token_ids)
print(
f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
assert len(kv_indices) == 0
req6_token_ids = [1, 2, 3, 4, 5, 60, 70]
kv_indices, last_node = tree.match_prefix(req6_token_ids)
print(
f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
assert len(kv_indices) == 7
assert len(last_node.key) == 2
assert last_node.key[0] == 60
assert last_node.key[1] == 70
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment