Unverified Commit 9379da77 authored by Hanming Lu's avatar Hanming Lu Committed by GitHub
Browse files

SWA Prefix Cache (#7367)


Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
parent 0c55cbcf
......@@ -711,7 +711,6 @@ def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int)
i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
]
else:
raise ValueError(
"get_hybrid_layer_ids is only implemented for Llama4ForConditionalGeneration"
)
swa_attention_layer_ids = None
full_attention_layer_ids = None
return swa_attention_layer_ids, full_attention_layer_ids
......@@ -439,7 +439,15 @@ class DecodePreallocQueue:
else 0
)
allocatable_tokens = self.token_to_kv_pool_allocator.available_size() - max(
if self.scheduler.model_config.is_hybrid:
available_size = min(
self.token_to_kv_pool_allocator.full_available_size(),
self.token_to_kv_pool_allocator.swa_available_size(),
)
else:
available_size = self.token_to_kv_pool_allocator.available_size()
allocatable_tokens = available_size - max(
# preserve some space for future decode
self.num_reserved_decode_tokens
* (
......
......@@ -26,6 +26,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
......@@ -589,6 +590,7 @@ class FlashInferIndicesUpdaterDecode:
self.kv_indptr = attn_backend.kv_indptr
self.kv_last_page_len = attn_backend.kv_last_page_len
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
# Dispatch the update function
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
......@@ -655,6 +657,10 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum_tmp = seq_lens_sum
kv_start_idx_tmp = None
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.call_begin_forward(
decode_wrappers[wrapper_id],
req_pool_indices,
......@@ -663,6 +669,7 @@ class FlashInferIndicesUpdaterDecode:
self.kv_indptr[wrapper_id],
kv_start_idx_tmp,
spec_info,
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
)
def update_cross_attention(
......@@ -704,6 +711,7 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
use_sliding_window_kv_pool: bool = False,
):
if spec_info is None:
bs = len(req_pool_indices)
......@@ -731,6 +739,14 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
if use_sliding_window_kv_pool:
kv_last_index = kv_indptr[-1]
kv_indices[:kv_last_index] = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
kv_indices[:kv_last_index]
)
)
wrapper.begin_forward(
kv_indptr,
kv_indices,
......@@ -765,6 +781,7 @@ class FlashInferIndicesUpdaterPrefill:
self.kv_last_page_len = attn_backend.kv_last_page_len
self.qo_indptr = attn_backend.qo_indptr
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
# Dispatch the update function
......@@ -848,6 +865,9 @@ class FlashInferIndicesUpdaterPrefill:
paged_kernel_lens_sum = seq_lens_sum
kv_start_idx = seq_lens - paged_kernel_lens
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.call_begin_forward(
self.prefill_wrapper_ragged,
......@@ -862,6 +882,7 @@ class FlashInferIndicesUpdaterPrefill:
self.qo_indptr[wrapper_id],
use_ragged,
spec_info,
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
)
def update_cross_attention(
......@@ -916,6 +937,7 @@ class FlashInferIndicesUpdaterPrefill:
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
use_sliding_window_kv_pool: bool = False,
):
bs = len(seq_lens)
if spec_info is None:
......@@ -964,6 +986,14 @@ class FlashInferIndicesUpdaterPrefill:
q_data_type=self.q_data_type,
)
if use_sliding_window_kv_pool:
kv_last_index = kv_indptr[-1]
kv_indices[:kv_last_index] = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
kv_indices[:kv_last_index]
)
)
# cached part
wrapper_paged.begin_forward(
qo_indptr,
......
......@@ -52,10 +52,14 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin,
)
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
......@@ -527,6 +531,8 @@ class Req:
self.last_node: Any = None
self.last_host_node: Any = None
self.host_hit_length = 0
# The node to lock until for swa radix tree lock ref
self.swa_uuid_for_lock: Optional[int] = None
# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
......@@ -745,6 +751,7 @@ class Req:
def reset_for_retract(self):
self.prefix_indices = []
self.last_node = None
self.swa_uuid_for_lock = None
self.extend_input_len = 0
self.is_retracted = True
self.input_token_logprobs = None
......@@ -813,6 +820,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
tree_cache: BasePrefixCache = None
is_hybrid: bool = False
# Batch configs
model_config: ModelConfig = None
......@@ -918,11 +926,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
):
return_logprob = any(req.return_logprob for req in reqs)
is_hybrid = False
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
assert isinstance(tree_cache, SWARadixCache) or isinstance(
tree_cache, SWAChunkCache
), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
is_hybrid = True
return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
tree_cache=tree_cache,
is_hybrid=is_hybrid,
model_config=model_config,
enable_overlap=enable_overlap,
return_logprob=return_logprob,
......@@ -953,9 +969,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
return req_pool_indices
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
if self.tree_cache is not None:
self.tree_cache.evict(num_tokens)
self._evict_tree_cache_if_needed(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
......@@ -966,7 +980,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
error_msg = (
f"{phase_str} out of memory. Try to lower your batch size.\n"
f"Try to allocate {num_tokens} tokens.\n"
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
f"{self._available_and_evictable_str()}"
)
logger.error(error_msg)
if self.tree_cache is not None:
......@@ -986,16 +1000,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_num_tokens: int,
backup_state: bool = False,
):
if (
self.token_to_kv_pool_allocator.available_size()
< extend_num_tokens
num_tokens = (
extend_num_tokens
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
):
if self.tree_cache is not None:
self.tree_cache.evict(
extend_num_tokens
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
)
)
self._evict_tree_cache_if_needed(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
......@@ -1007,9 +1016,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
error_msg = (
f"Prefill out of memory. Try to lower your batch size.\n"
f"Try to allocate {extend_num_tokens} tokens.\n"
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n"
f"{self._available_and_evictable_str()}"
)
logger.error(error_msg)
raise RuntimeError(error_msg)
......@@ -1025,14 +1032,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
last_loc: torch.Tensor,
backup_state: bool = False,
):
if self.tree_cache is not None:
if (
self.token_to_kv_pool_allocator.available_size()
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size
):
self.tree_cache.evict(
len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
)
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
self._evict_tree_cache_if_needed(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
......@@ -1042,9 +1044,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
error_msg = (
f"Decode out of memory. Try to lower your batch size.\n"
f"Try to allocate {len(seq_lens)} tokens.\n"
f"Available tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n"
f"{self._available_and_evictable_str()}"
)
logger.error(error_msg)
raise RuntimeError(error_msg)
......@@ -1181,7 +1181,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
)
if isinstance(self.tree_cache, SWAChunkCache):
self.tree_cache.evict(
self.tree_cache.evict_swa(
req, pre_len, self.model_config.attention_chunk_size
)
......@@ -1371,17 +1371,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
def check_decode_mem(self, buf_multiplier=1):
tokens_required = (
num_tokens = (
self.new_page_count_next_decode()
* buf_multiplier
* self.token_to_kv_pool_allocator.page_size
)
if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
return True
self.tree_cache.evict(tokens_required)
return self.token_to_kv_pool_allocator.available_size() >= tokens_required
self._evict_tree_cache_if_needed(num_tokens)
return self._is_available_size_sufficient(num_tokens)
def retract_decode(self, server_args: ServerArgs):
"""Retract the decoding requests when there is not enough memory."""
......@@ -1414,19 +1411,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
)
def _get_available_size():
if self.is_hybrid:
return min(
self.token_to_kv_pool_allocator.full_available_size(),
self.token_to_kv_pool_allocator.swa_available_size(),
)
else:
return self.token_to_kv_pool_allocator.available_size()
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
while (
self.token_to_kv_pool_allocator.available_size()
< get_required_tokens(len(sorted_indices))
_get_available_size() < get_required_tokens(len(sorted_indices))
or first_iter
):
if len(sorted_indices) == 1:
# Corner case: only one request left
assert (
self.token_to_kv_pool_allocator.available_size() > 0
), "No space left for only one request"
if self.is_hybrid:
full_available_size = (
self.token_to_kv_pool_allocator.full_available_size()
)
swa_available_size = (
self.token_to_kv_pool_allocator.swa_available_size()
)
assert (
full_available_size > 0 and swa_available_size > 0
), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
else:
assert (
self.token_to_kv_pool_allocator.available_size() > 0
), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
break
first_iter = False
......@@ -1458,15 +1474,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.req_to_token_pool.free(req.req_pool_idx)
# release the last node
self.tree_cache.dec_lock_ref(req.last_node)
if self.is_hybrid:
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
else:
self.tree_cache.dec_lock_ref(req.last_node)
# NOTE(lsyin): we should use the newly evictable memory instantly.
residual_size = (
len(sorted_indices) * global_config.retract_decode_steps
- self.token_to_kv_pool_allocator.available_size()
)
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size)
num_tokens = len(sorted_indices) * global_config.retract_decode_steps
self._evict_tree_cache_if_needed(num_tokens)
req.reset_for_retract()
......@@ -1559,7 +1574,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# free memory
if isinstance(self.tree_cache, SWAChunkCache):
for req in self.reqs:
self.tree_cache.evict(
self.tree_cache.evict_swa(
req, req.seqlen - 1, self.model_config.attention_chunk_size
)
......@@ -1778,6 +1793,53 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
is_extend_in_batch=self.is_extend_in_batch,
)
def _evict_tree_cache_if_needed(
self,
num_tokens: int,
) -> None:
if isinstance(self.tree_cache, SWAChunkCache):
return
if self.is_hybrid:
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
if full_available_size < num_tokens or swa_available_size < num_tokens:
if self.tree_cache is not None:
full_num_tokens = max(0, num_tokens - full_available_size)
swa_num_tokens = max(0, num_tokens - swa_available_size)
self.tree_cache.evict(full_num_tokens, swa_num_tokens)
else:
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
if self.tree_cache is not None:
self.tree_cache.evict(num_tokens)
def _is_available_size_sufficient(self, num_tokens: int) -> bool:
if self.is_hybrid:
return (
self.token_to_kv_pool_allocator.full_available_size() >= num_tokens
and self.token_to_kv_pool_allocator.swa_available_size() >= num_tokens
)
else:
return self.token_to_kv_pool_allocator.available_size() >= num_tokens
def _available_and_evictable_str(self) -> str:
if self.is_hybrid:
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
full_evictable_size = self.tree_cache.full_evictable_size()
swa_evictable_size = self.tree_cache.swa_evictable_size()
return (
f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
)
else:
available_size = self.token_to_kv_pool_allocator.available_size()
evictable_size = self.tree_cache.evictable_size()
return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
def __str__(self):
return (
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
......
......@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
......@@ -311,21 +312,43 @@ class PrefillAdder:
]
)
self.is_hybrid = isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
@property
def rem_total_tokens(self):
return (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
- self.rem_total_token_offset
)
if self.is_hybrid:
available_and_evictable = min(
self.token_to_kv_pool_allocator.full_available_size()
+ self.tree_cache.full_evictable_size(),
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
else:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
return available_and_evictable - self.rem_total_token_offset
@property
def cur_rem_tokens(self):
return (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
- self.cur_rem_token_offset
)
if self.is_hybrid:
available_and_evictable = min(
self.token_to_kv_pool_allocator.full_available_size()
+ self.tree_cache.full_evictable_size(),
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
else:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
return available_and_evictable - self.cur_rem_token_offset
def ceil_paged_tokens(self, tokens: int) -> int:
return -(-tokens // self.page_size) * self.page_size
......@@ -376,11 +399,18 @@ class PrefillAdder:
@contextmanager
def _lock_node(self, last_node: TreeNode):
try:
self.tree_cache.inc_lock_ref(last_node)
yield None
finally:
self.tree_cache.dec_lock_ref(last_node)
if self.is_hybrid:
try:
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(last_node)
yield None
finally:
self.tree_cache.dec_lock_ref(last_node, swa_uuid_for_lock)
else:
try:
self.tree_cache.inc_lock_ref(last_node)
yield None
finally:
self.tree_cache.dec_lock_ref(last_node)
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
# Early exit if no enough tokens for the input tokens
......@@ -422,16 +452,19 @@ class PrefillAdder:
else:
add_req_state(req, insert_sort=True)
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
# tokens_left gives a reservative calculation as the last token is not stored
bs = len(self.req_states) - i
min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
# reserve tokens for corner cases
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
return AddReqResult.NO_TOKEN
tokens_freed += tokens_occupied
if not self.is_hybrid:
# Skip this logic for swa. The SWA has different memory management, and
# this mechanism is underestimating the memory usage.
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
# tokens_left gives a reservative calculation as the last token is not stored
bs = len(self.req_states) - i
min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
# reserve tokens for corner cases
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
return AddReqResult.NO_TOKEN
tokens_freed += tokens_occupied
if (
self.rem_chunk_tokens is None # chunked prefill is disabled
......@@ -499,7 +532,11 @@ class PrefillAdder:
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
# Non-chunked prefill
self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node)
if self.is_hybrid:
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
req.swa_uuid_for_lock = swa_uuid_for_lock
else:
self.tree_cache.inc_lock_ref(req.last_node)
self._update_prefill_budget(
prefix_len,
input_tokens,
......@@ -520,7 +557,11 @@ class PrefillAdder:
self.can_run_list.append(req)
self.new_chunked_req = req
self.tree_cache.inc_lock_ref(req.last_node)
if self.is_hybrid:
swa_uuid_for_lock = self.tree_cache.inc_lock_ref(req.last_node)
req.swa_uuid_for_lock = swa_uuid_for_lock
else:
self.tree_cache.inc_lock_ref(req.last_node)
self._update_prefill_budget(prefix_len, trunc_len, 0)
return self.budget_state()
......@@ -129,10 +129,10 @@ from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.reasoning_parser import ReasoningParser
......@@ -390,6 +390,14 @@ class Scheduler(
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
# Hybrid
self.is_hybrid = self.tp_worker.is_hybrid
if self.is_hybrid:
self.sliding_window_size = self.tp_worker.sliding_window_size
self.full_tokens_per_layer, self.swa_tokens_per_layer = (
self.tp_worker.get_tokens_per_layer_info()
)
# Print debug info
if tp_rank == 0:
avail_mem = get_available_gpu_memory(
......@@ -570,7 +578,7 @@ class Scheduler(
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
if self.model_config.is_hybrid:
if self.is_hybrid:
ChunkCacheClass = SWAChunkCache
else:
ChunkCacheClass = ChunkCache
......@@ -603,6 +611,17 @@ class Scheduler(
self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter
)
elif self.is_hybrid:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid mode does not support disaggregation yet"
self.tree_cache = SWARadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
sliding_window_size=self.sliding_window_size,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
)
else:
self.tree_cache = RadixCache(
......@@ -774,6 +793,7 @@ class Scheduler(
else:
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
......@@ -819,6 +839,7 @@ class Scheduler(
elif batch is None:
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
......@@ -955,6 +976,7 @@ class Scheduler(
# When the server is idle, self-check and re-init some states
if server_is_idle:
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
self.maybe_sleep_on_idle()
......@@ -1306,9 +1328,26 @@ class Scheduler(
self.last_input_throughput = self.last_prefill_tokens / gap_latency
self.last_prefill_tokens = adder.log_input_tokens
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
self.tree_cache.evictable_size()
)
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"token usage: {token_usage:.2f}, "
num_new_seq = len(can_run_list)
f = (
......@@ -1316,7 +1355,7 @@ class Scheduler(
f"#new-seq: {num_new_seq}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"{usage_msg}"
f"{token_msg}"
)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
......@@ -1338,7 +1377,7 @@ class Scheduler(
)
self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
self.stats.token_usage = round(token_usage, 2)
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.cache_hit_rate = cache_hit_rate
......@@ -1361,16 +1400,35 @@ class Scheduler(
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(batch.reqs)
usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
self.tree_cache.evictable_size()
)
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
token_msg = (
f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, "
f"#swa token: {swa_num_used}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
if RECORD_STEP_TIME:
self.step_time_dict[num_running_reqs].append(
gap_latency / self.server_args.decode_log_interval
)
msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
if self.spec_algorithm.is_none():
spec_accept_length = 0
......@@ -1398,7 +1456,7 @@ class Scheduler(
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.token_usage = round(token_usage, 2)
self.stats.cache_hit_rate = 0.0
self.stats.gen_throughput = self.last_gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
......@@ -1409,24 +1467,34 @@ class Scheduler(
self._publish_kv_events()
def check_memory(self):
if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
available_token_size = self.token_to_kv_pool_allocator.full_available_size()
if self.is_hybrid:
(
full_num_used,
swa_num_used,
_,
_,
full_available_size,
full_evictable_size,
swa_available_size,
swa_evictable_size,
) = self._get_swa_token_info()
memory_leak = full_num_used != 0 or swa_num_used != 0
token_msg = (
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
)
else:
available_token_size = self.token_to_kv_pool_allocator.available_size()
available_size = available_token_size + self.tree_cache.evictable_size()
protected_size = self.tree_cache.protected_size()
memory_leak = available_size != (
self.max_total_num_tokens
if not self.enable_hierarchical_cache
else self.max_total_num_tokens - protected_size
)
if memory_leak:
msg = (
"token_to_kv_pool_allocator memory leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
f"{available_token_size=}\n"
f"{self.tree_cache.evictable_size()=}\n"
_, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size()
memory_leak = (available_size + evictable_size) != (
self.max_total_num_tokens
if not self.enable_hierarchical_cache
else self.max_total_num_tokens - protected_size
)
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
if memory_leak:
msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
raise ValueError(msg)
if self.disaggregation_mode == DisaggregationMode.DECODE:
......@@ -1450,20 +1518,66 @@ class Scheduler(
and time.perf_counter() > self.metrics_collector.last_log_time + 30
):
# During idle time, also collect metrics every 30 seconds.
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
if self.is_hybrid:
(
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
_,
_,
_,
_,
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
else:
num_used, token_usage, _, _ = self._get_token_info()
num_running_reqs = len(self.running_batch.reqs)
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.token_usage = round(token_usage, 2)
self.stats.gen_throughput = 0
self.stats.num_queue_reqs = len(self.waiting_queue)
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()
def check_tree_cache(self):
if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
self.tree_cache.sanity_check()
def _get_token_info(self):
available_size = self.token_to_kv_pool_allocator.available_size()
evictable_size = self.tree_cache.evictable_size()
num_used = self.max_total_num_tokens - (available_size + evictable_size)
token_usage = num_used / self.max_total_num_tokens
return num_used, token_usage, available_size, evictable_size
def _get_swa_token_info(self):
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
full_evictable_size = self.tree_cache.full_evictable_size()
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
swa_evictable_size = self.tree_cache.swa_evictable_size()
full_num_used = self.full_tokens_per_layer - (
full_available_size + full_evictable_size
)
swa_num_used = self.swa_tokens_per_layer - (
swa_available_size + swa_evictable_size
)
full_token_usage = full_num_used / self.full_tokens_per_layer
swa_token_usage = swa_num_used / self.swa_tokens_per_layer
return (
full_num_used,
swa_num_used,
full_token_usage,
swa_token_usage,
full_available_size,
full_evictable_size,
swa_available_size,
swa_evictable_size,
)
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
chunked_req_to_exclude = set()
......@@ -2042,11 +2156,30 @@ class Scheduler(
if not disable_request_logging():
# Print batch size and memory pool info to check whether there are de-sync issues.
if self.is_hybrid:
(
_,
_,
_,
_,
full_available_size,
full_evictable_size,
swa_available_size,
swa_evictable_size,
) = self._get_swa_token_info()
info_msg = (
f"{full_available_size=}, "
f"{full_evictable_size=}, "
f"{swa_available_size=}, "
f"{swa_evictable_size=}, "
)
else:
_, _, available_size, evictable_size = self._get_token_info()
info_msg = f"{available_size=}, " f"{evictable_size=}, "
logger.error(
f"{self.cur_batch.batch_size()=}, "
f"{self.cur_batch.reqs=}, "
f"{self.token_to_kv_pool_allocator.available_size()=}, "
f"{self.tree_cache.evictable_size()=}, "
f"{info_msg}"
)
pyspy_dump_schedulers()
......@@ -2101,11 +2234,24 @@ class Scheduler(
def get_load(self):
# TODO(lsyin): use dynamically maintained num_waiting_tokens
load = (
self.max_total_num_tokens
- self.token_to_kv_pool_allocator.available_size()
- self.tree_cache.evictable_size()
)
if self.is_hybrid:
load_full = (
self.full_tokens_per_layer
- self.token_to_kv_pool_allocator.full_available_size()
- self.tree_cache.full_evictable_size()
)
load_swa = (
self.swa_tokens_per_layer
- self.token_to_kv_pool_allocator.swa_available_size()
- self.tree_cache.swa_evictable_size()
)
load = max(load_full, load_swa)
else:
load = (
self.max_total_num_tokens
- self.token_to_kv_pool_allocator.available_size()
- self.tree_cache.evictable_size()
)
load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
if self.disaggregation_mode == DisaggregationMode.PREFILL:
load += sum(
......
......@@ -174,6 +174,20 @@ class TpModelWorker:
self.model_runner.token_to_kv_pool.size,
)
@property
def sliding_window_size(self) -> Optional[int]:
return self.model_runner.sliding_window_size
@property
def is_hybrid(self) -> bool:
return self.model_runner.is_hybrid is not None
def get_tokens_per_layer_info(self):
return (
self.model_runner.full_max_total_num_tokens,
self.model_runner.swa_max_total_num_tokens,
)
def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None)
......
......@@ -102,6 +102,17 @@ class TpModelWorkerClient:
def get_worker_info(self):
return self.worker.get_worker_info()
def get_tokens_per_layer_info(self):
return self.worker.get_tokens_per_layer_info()
@property
def sliding_window_size(self) -> Optional[int]:
return self.worker.sliding_window_size
@property
def is_hybrid(self) -> bool:
return self.worker.is_hybrid
def get_pad_input_ids_func(self):
return self.worker.get_pad_input_ids_func()
......
......@@ -57,11 +57,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
def debug_print(self) -> str:
return ""
def log_usage(self, evictable_size: int = 0):
num_used = self.size - (self.available_size() + evictable_size)
msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, "
return msg, num_used
def available_size(self):
return len(self.free_pages) * self.page_size
......@@ -190,7 +185,7 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping
def available_size(self):
return min(self.full_available_size(), self.swa_available_size())
raise NotImplementedError()
def full_available_size(self):
return self.full_attn_allocator.available_size()
......@@ -214,16 +209,6 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
)
return msg
def log_usage(self, swa_evictable_size: int = 0, full_evictable_size: int = 0):
used_full = self.size_full - (self.full_available_size() + full_evictable_size)
used_swa = self.size_swa - (self.swa_available_size() + swa_evictable_size)
msg = (
f"#token: full={used_full}, swa={used_swa}, "
f"token usage: full={used_full / self.size_full:.2f}, "
f"swa={used_swa / self.size_swa:.2f}, "
)
return msg, used_full
def get_kvcache(self):
return self._kvcache
......
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
from typing import TYPE_CHECKING, Any, List, NamedTuple, Optional, Tuple
import torch
......@@ -56,15 +56,27 @@ class BasePrefixCache(ABC):
pass
@abstractmethod
def dec_lock_ref(self, node: Any):
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
pass
def evictable_size(self):
return 0
def full_evictable_size(self):
return 0
def swa_evictable_size(self):
return 0
def protected_size(self):
return 0
def full_protected_size(self):
return 0
def swa_protected_size(self):
return 0
def total_size(self):
raise NotImplementedError()
......
......@@ -61,7 +61,7 @@ class ChunkCache(BasePrefixCache):
def inc_lock_ref(self, node: Any):
return 0
def dec_lock_ref(self, node: Any):
def dec_lock_ref(self, node: Any, swa_uuid_for_lock: Optional[str] = None):
return 0
def pretty_print(self):
......@@ -80,7 +80,7 @@ class SWAChunkCache(ChunkCache):
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size)
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
def evict(
def evict_swa(
self,
req: Req,
prelen: int,
......@@ -95,3 +95,6 @@ class SWAChunkCache(ChunkCache):
]
self.token_to_kv_pool_allocator.free_swa(free_slots)
req.evicted_seqlen_local = new_evicted_seqlen_local
def evict(self, num_tokens: int):
pass
from __future__ import annotations
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
The radix tree data structure for managing the hybrid (full and SWA) KV cache.
"""
import heapq
import time
from collections import defaultdict
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
import logging
logger = logging.getLogger(__name__)
class TreeNode:
counter = 0
swa_uuid_counter = 1
def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode)
self.parent: TreeNode = None
self.key: List[int] = None
self.value: Optional[torch.Tensor] = None
# swa_tombstone is used to indicate the kv indices have been freed for swa layers
self.swa_tombstone = False
# invariant: for any node, if swa_lock_ref is locked, full_lock_ref must be locked;
# if full_lock_ref is locked, swa_lock_ref doesn't need to be locked. So,
# full_lock_ref is always >= swa_lock_ref.
self.full_lock_ref = 0
self.swa_lock_ref = 0
# last access time is only used for sanity check. LRU is maintained by the lru list.
self.last_access_time = time.monotonic()
self.hit_count = 0
# indicating the node is loading KV cache from host
self.loading = False
# store the host indices of KV cache
self.host_value = None
# for lru list, invariant:
# 1. prev has greater last_access_time
# 2. next has smaller last_access_time
self.prev = None
self.next = None
self.swa_prev = None
self.swa_next = None
self.id = TreeNode.counter if id is None else id
TreeNode.counter += 1
self.swa_uuid = None
@property
def evicted(self):
return self.value is None
@property
def backuped(self):
return self.host_value is not None
def __lt__(self, other: "TreeNode"):
return self.last_access_time < other.last_access_time
def _key_match_page_size1(key0: List, key1: List):
i = 0
for k0, k1 in zip(key0, key1):
if k0 != k1:
break
i += 1
return i
def _key_match_paged(key0: List, key1: List, page_size: int):
min_len = min(len(key0), len(key1))
i = 0
while i < min_len:
if key0[i : i + page_size] != key1[i : i + page_size]:
break
i += page_size
return i
def gen_swa_uuid() -> int:
TreeNode.swa_uuid_counter += 1
return TreeNode.swa_uuid_counter
class LRUList:
def __init__(self, swa: bool = False):
self.swa = swa
if self.swa:
self.prv = "swa_prev"
self.nxt = "swa_next"
self.lock_ref = "swa_lock_ref"
else:
self.prv = "prev"
self.nxt = "next"
self.lock_ref = "full_lock_ref"
# Initialize dummy head and tail nodes
self.head = TreeNode() # Most recently used side
self.tail = TreeNode() # Least recently used side
setattr(self.head, self.nxt, self.tail) # self.head.next = self.tail
setattr(self.tail, self.prv, self.head) # self.tail.prev = self.head
self.cache = {}
def _add_node(self, node):
"""Helper to add node right after head (most recently used)"""
self._add_node_after(self.head, node)
def _add_node_after(self, old_node, new_node):
"""Helper to add node right after old_node"""
setattr(new_node, self.prv, old_node) # new_node.prev = old_node
setattr(
new_node, self.nxt, getattr(old_node, self.nxt)
) # new_node.next = old_node.next
setattr(
getattr(old_node, self.nxt), self.prv, new_node
) # old_node.next.prev = new_node
setattr(old_node, self.nxt, new_node) # old_node.next = new_node
def _remove_node(self, node):
"""Helper to remove node from linked list"""
setattr(
getattr(node, self.prv), self.nxt, getattr(node, self.nxt)
) # node.prev.next = node.next
setattr(
getattr(node, self.nxt), self.prv, getattr(node, self.prv)
) # node.next.prev = node.prev
def _get_lru(self) -> Optional[TreeNode]:
"""
Get the least recently used node
"""
if len(self.cache) == 0:
return None
return getattr(self.tail, self.prv)
def reset_node_mru(self, node):
"""
Move a (existing) node to most recently used position
"""
assert node.id in self.cache, f"Resetting node {node.id=} not in lru list"
assert (
not self.swa or not node.swa_tombstone
), f"Resetting swa tombstone node in swa lru list: {node.id=}"
self._remove_node(node)
self._add_node(node)
def reset_node_and_parents_mru(self, node, root_node):
"""
Move an (existing) node and its parents to most recently used position. Child node is
more recently used than parent node.
"""
prev_node = self.head
while node != root_node:
# for swa lru list, only reset non-tombstone nodes
if not self.swa or not node.swa_tombstone:
assert (
node.id in self.cache
), f"Resetting node {node.id=} not in lru list when resetting node and parents mru"
self._remove_node(node)
self._add_node_after(prev_node, node)
prev_node = node
node = node.parent
def insert_mru(self, node):
"""
Insert a (new) node as most recently used
"""
assert (
not self.swa or not node.swa_tombstone
), f"Inserting swa tombstone node in swa lru list: {node.id=}"
assert (
node.id not in self.cache
), f"Inserting node {node.id=} already in lru list, existing node: {self.cache[node.id].id=}"
self.cache[node.id] = node
self._add_node(node)
def remove_node(self, node: TreeNode):
"""
Remove node from lru list
"""
assert node.id in self.cache, f"Removing node {node.id=} not in lru list"
assert (
not self.swa or not node.swa_tombstone
), f"Removing swa tombstone node from swa lru list: {node.id=}"
del self.cache[node.id]
self._remove_node(node)
def get_lru_no_lock(self) -> Optional[TreeNode]:
"""
Get the least recently used node that is not locked
"""
return self.get_prev_no_lock(self.tail, check_id=False)
def get_leaf_lru_no_lock(self) -> Optional[TreeNode]:
"""
Get the least recently used leaf node that is not locked
"""
return self.get_prev_leaf_no_lock(self.tail, check_id=False)
def get_prev_no_lock(
self, node: TreeNode, check_id: bool = True
) -> Optional[TreeNode]:
"""
Get the previous (i.e. more recently used) node that is not locked
"""
if check_id:
assert (
node.id in self.cache
), f"Getting prev of node {node.id=} not in lru list"
x = getattr(node, self.prv) # x = node.prev
while getattr(x, self.lock_ref) > 0:
x = getattr(x, self.prv) # x = x.prev
# if x is the head, it means there is no node in the lru list without lock
if x == self.head:
return None
return x
def get_prev_leaf_no_lock(self, node: TreeNode, check_id: bool = True):
"""
Get the previous (i.e. more recently used) leaf node that is not locked
"""
if check_id:
assert (
node.id in self.cache
), f"Getting prev of node {node.id=} not in lru list"
x = getattr(node, self.prv) # x = node.prev
while getattr(x, self.lock_ref) > 0 or len(x.children) > 0:
x = getattr(x, self.prv) # x = x.prev
# if x is the head, it means there is no leaf node in the lru list without lock
if x == self.head:
return None
return x
def in_list(self, node: Optional[TreeNode]):
"""
Check if the node is in the lru list
"""
if not node:
return False
return node.id in self.cache
# Note: this is expensive, only use for debug
def sanity_check_evictable_size(self):
"""
Check the evictable size (i.e. the size of the nodes that are not locked)
"""
node = self.get_lru_no_lock()
evictable_size = 0
while self.in_list(node):
evictable_size += len(node.value)
node = self.get_prev_no_lock(node)
return evictable_size
# Note: this is expensive, only use for debug or idle check
def sanity_check(self, tree_cache: "SWARadixCache"):
"""
Check if the lru list is valid by rebuilding the lru list from the tree, heapifying it, and
checking if the lru list is valid.
"""
try:
if self.swa:
nodes = tree_cache._collect_nontombstone_nodes()
else:
nodes = tree_cache._collect_all_nodes()
total_nodes = len(nodes)
total_lru_plus_1 = len(self.cache) + 1
# heapify based on last_access_time
heapq.heapify(nodes)
# the root node is not in the lru list
assert (
len(nodes) == len(self.cache) + 1
), f"len(nodes): {len(nodes)} != len(self.cache) + 1: {len(self.cache) + 1}"
x_lru = self._get_lru()
while len(nodes):
x = heapq.heappop(nodes)
if x == tree_cache.root_node:
# root node is not in the lru list
continue
assert (
x == x_lru
), f"Incorrect LRU list, {self.swa=}, x: {x.id=} != x_lru: {x_lru.id=}"
assert (
x_lru.full_lock_ref == 0
), f"x_lru should not be locked when idle, {x_lru.full_lock_ref=}, {x_lru.swa_uuid=}, {x_lru.id=}"
assert (
x_lru.swa_lock_ref == 0
), f"x_lru should not be locked when idle, {x_lru.swa_lock_ref=}, {x_lru.swa_uuid=}, {x_lru.id=}"
x_lru = getattr(x, self.prv)
if self.swa:
evictable_size = tree_cache.swa_evictable_size()
lru_list_evictable_size = tree_cache.swa_lru_list_evictable_size()
else:
evictable_size = tree_cache.full_evictable_size()
lru_list_evictable_size = tree_cache.full_lru_list_evictable_size()
assert (
evictable_size == lru_list_evictable_size
), f"{self.swa=}, total nodes: {total_nodes}, total lru plus 1: {total_lru_plus_1}, evictable size: {evictable_size} != lru list evictable size: {lru_list_evictable_size}"
except Exception as e:
msg = f"SWA Radix tree sanity check failed, ping @hanming-lu: {e}"
logger.error(msg)
raise Exception(msg)
class SWARadixCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: SWATokenToKVPoolAllocator,
sliding_window_size: int,
page_size: int,
disable: bool = False,
):
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
self.disable = disable
if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
else:
self.device = torch.device("cpu")
if self.page_size == 1:
self.key_match_fn = _key_match_page_size1
self.get_child_key_fn = lambda key: key[0]
else:
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = lambda key: tuple(key[:page_size])
self.sliding_window_size = sliding_window_size
self.reset()
##### Public API #####
def reset(self) -> None:
self.root_node = TreeNode()
self.root_node.key = []
self.root_node.value = []
self.root_node.full_lock_ref = 1
self.root_node.swa_lock_ref = 1
self.full_evictable_size_ = 0
self.swa_evictable_size_ = 0
self.full_protected_size_ = 0
self.swa_protected_size_ = 0
# LRU lists are used to maintain the order of eviction of the nodes in the tree
self.full_lru_list = LRUList(swa=False)
self.swa_lru_list = LRUList(swa=True)
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
"""Find the matching prefix from the radix tree.
Args:
key: A list of token IDs to find a matching prefix.
Returns:
A tuple of a tensor of matching prefix token IDs and
the last node that contains the prefix values. Note that
this API can modify the internal state of the Radix tree.
The last node create a new child if the prefix is shorter
than the last node's value.
"""
if self.disable or len(key) == 0:
return MatchResult(
device_indices=torch.empty(
(0,),
dtype=torch.int64,
device=self.device,
),
last_device_node=self.root_node,
last_host_node=self.root_node,
)
if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size
key = key[:page_aligned_len]
value, last_node = self._match_prefix_helper(key)
if value:
value = torch.cat(value)
else:
value = torch.empty((0,), dtype=torch.int64, device=self.device)
return MatchResult(
device_indices=value,
last_device_node=last_node,
last_host_node=last_node,
)
def insert(self, key: List, value=None, prev_prefix_len: int = 0) -> int:
if self.disable:
return 0
if value is None:
value = [x for x in key]
return self._insert_helper(self.root_node, key, value, prev_prefix_len)
def cache_finished_req(self, req: Req) -> None:
"""Cache request when it finishes."""
if self.disable:
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx,
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
]
self.token_to_kv_pool_allocator.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else:
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone()
# Radix Cache takes one ref in memory pool
# insert the token_ids and kv_indices into the radix tree
# Note: the insert function already frees the overlapped kv_indices
new_prefix_len = self.insert(
token_ids[:page_aligned_len],
page_aligned_kv_indices,
len(req.prefix_indices),
)
# Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
def cache_unfinished_req(self, req: Req) -> None:
"""Cache request when it is unfinished."""
if self.disable:
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(req.fill_ids)
]
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req.prefix_indices = kv_indices
return
token_ids = req.fill_ids
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
else:
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone()
page_aligned_token_ids = token_ids[:page_aligned_len]
# Radix Cache takes one ref in memory pool
# Note: the insert function already frees the overlapped kv_indices
new_prefix_len = self.insert(
page_aligned_token_ids, page_aligned_kv_indices, len(req.prefix_indices)
)
# The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
assert len(req.prefix_indices) <= len(
new_indices
), f"{req.prefix_indices=}, {new_indices=}"
assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
)
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
swa_uuid_for_lock = self.inc_lock_ref(new_last_node)
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
if self.page_size != 1:
req.prefix_indices = torch.cat(
[new_indices, kv_indices[len(new_indices) :]]
)
else:
req.prefix_indices = new_indices
req.last_node = new_last_node
req.swa_uuid_for_lock = swa_uuid_for_lock
def pretty_print(self) -> None:
self._print_helper(self.root_node, 0)
total_size, total_swa_size = self._total_size_helper()
print(f"#full_tokens: {total_size}, #swa_tokens: {total_swa_size}")
def total_size(self) -> Tuple[int, int]:
return self._total_size_helper()
def evict(self, full_num_tokens: int, swa_num_tokens: int = 0) -> None:
if self.disable:
return
full_num_evicted = 0
swa_num_evicted = 0
if full_num_tokens > 0:
# get the least recently used leaf node that is not locked
x = self.full_lru_list.get_leaf_lru_no_lock()
while full_num_evicted < full_num_tokens and self.full_lru_list.in_list(x):
assert (
x != self.root_node
), f"root node should not exist in full lru list, {x.id=}"
assert x.full_lock_ref == 0, f"node is in use, {x.id=}"
# 1. free node kv indices, evict full and swa tokens
self.token_to_kv_pool_allocator.free(x.value)
full_num_evicted += len(x.value)
swa_num_evicted += len(x.value)
# 2. get the next leaf, update the lru lists
x_next = self.full_lru_list.get_prev_leaf_no_lock(x)
self.full_lru_list.remove_node(x)
self.swa_lru_list.remove_node(x)
# 3. delete the leaf node
self._delete_leaf(x)
# 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone
x, leaf_full_num_evicted = self._iteratively_delete_tombstone_leaf(x)
full_num_evicted += leaf_full_num_evicted
# 5. if parent has no more children, it is a leaf. It is possible that this node is lru, so
# we need to get the first leaf node in the lru list
if len(x.parent.children) == 0:
x_next = self.full_lru_list.get_leaf_lru_no_lock()
x = x_next
if swa_num_evicted < swa_num_tokens:
# get the least recently used node that is not locked, doesn't have to be a leaf
x = self.swa_lru_list.get_lru_no_lock()
# evict lru leaf nodes until swa_num_tokens is reached
while swa_num_evicted < swa_num_tokens and (self.swa_lru_list.in_list(x)):
assert not x.swa_tombstone, f"duplicate swa tombstone node, {x.id=}"
assert x != self.root_node, f"root node is not evictable, {x.id=}"
assert x.swa_lock_ref == 0, f"node is in use by swa kv indices, {x.id=}"
if len(x.children) > 0:
# 1. an internal node, free swa tokens.
self.token_to_kv_pool_allocator.free_swa(x.value)
swa_num_evicted += len(x.value)
# 2. get the next node, update the lru lists
x_next = self.swa_lru_list.get_prev_no_lock(x)
self.swa_lru_list.remove_node(x)
# 3. tombstone the node
self._tombstone_internal_node(x)
else:
assert (
x.full_lock_ref == 0
), f"leaf node with full lock must also have swa lock, {x.id=}"
# 1. a leaf node, free full and swa tokens
self.token_to_kv_pool_allocator.free(x.value)
full_num_evicted += len(x.value)
swa_num_evicted += len(x.value)
# 2. get the next node, update the lru lists
x_next = self.swa_lru_list.get_prev_no_lock(x)
self.full_lru_list.remove_node(x)
self.swa_lru_list.remove_node(x)
# 3. delete the leaf node
self._delete_leaf(x)
# 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone
self._iteratively_delete_tombstone_leaf(x)
x = x_next
def inc_lock_ref(self, node: TreeNode) -> Optional[int]:
"""
Increment the lock reference count for the node. Returns the swa_uuid_for_lock, which needs
to be passed to dec_lock_ref.
It locks the full_lock_ref for nodes between the [last node, root), exclusive.
It locks the swa_lock_ref for nodes between the [last node, swa_uuid_for_lock], inclusive.
"""
if self.disable:
return None
swa_lock_size = 0
swa_uuid_for_lock = None
while node != self.root_node:
# lock full from node to root
assert (
node.full_lock_ref >= 0
), f"inc_lock_ref on node with {node.full_lock_ref=}, {node.id=}"
if node.full_lock_ref == 0:
self.full_evictable_size_ -= len(node.value)
self.full_protected_size_ += len(node.value)
node.full_lock_ref += 1
# lock swa if we have not reached the sliding window size.
# When we reach the sliding window size, we will set the swa_uuid_for_lock.
# caller needs to pass the swa_uuid_for_lock to dec_lock_ref
if swa_lock_size < self.sliding_window_size:
assert (
not node.swa_tombstone
), f"inc_lock_swa on swa_tombstone node, {node.id=}"
if node.swa_lock_ref == 0:
self.swa_evictable_size_ -= len(node.value)
self.swa_protected_size_ += len(node.value)
node.swa_lock_ref += 1
swa_lock_size += len(node.value)
if swa_lock_size >= self.sliding_window_size:
if node.swa_uuid is None:
node.swa_uuid = gen_swa_uuid()
swa_uuid_for_lock = node.swa_uuid
node = node.parent
return swa_uuid_for_lock
def dec_lock_ref(self, node: TreeNode, swa_uuid_for_lock: Optional[int] = None):
"""
Decrement the lock reference count for the node.
It unlocks the full_lock_ref for nodes between the [last node, root), exclusive.
It unlocks the swa_lock_ref for nodes between the [last node, swa_uuid_for_lock], inclusive.
If swa_uuid_for_lock is None, it unlocks to the root, exclusive.
"""
if self.disable:
return
dec_lock_swa = True
while node != self.root_node:
assert (
node.full_lock_ref > 0
), f"dec_lock_ref on node with {node.full_lock_ref=}, {node.id=}"
if node.full_lock_ref == 1:
self.full_evictable_size_ += len(node.value)
self.full_protected_size_ -= len(node.value)
node.full_lock_ref -= 1
if dec_lock_swa:
assert (
not node.swa_tombstone
), f"dec_lock_ref on swa_tombstone node, {node.id=}"
assert (
node.swa_lock_ref > 0
), f"dec_lock_ref on node with {node.swa_lock_ref=}, {node.id=}"
if node.swa_lock_ref == 1:
self.swa_evictable_size_ += len(node.value)
self.swa_protected_size_ -= len(node.value)
node.swa_lock_ref -= 1
if swa_uuid_for_lock and node.swa_uuid == swa_uuid_for_lock:
dec_lock_swa = False
node = node.parent
def sanity_check(self):
self.full_lru_list.sanity_check(self)
self.swa_lru_list.sanity_check(self)
def evictable_size(self) -> Tuple[int, int]:
# Note: use full_evictable_size() and swa_evictable_size() instead.
raise NotImplementedError
def full_evictable_size(self) -> int:
return self.full_evictable_size_
def swa_evictable_size(self) -> int:
return self.swa_evictable_size_
# Note: this is expensive, only use for debug
def full_lru_list_evictable_size(self) -> int:
return self.full_lru_list.sanity_check_evictable_size()
# Note: this is expensive, only use for debug
def swa_lru_list_evictable_size(self) -> int:
return self.swa_lru_list.sanity_check_evictable_size()
def protected_size(self) -> Tuple[int, int]:
# Note: use full_protected_size() and swa_protected_size() instead.
raise NotImplementedError
def full_protected_size(self) -> int:
# protected size refers to the size of the full cache that is locked
return self.full_protected_size_
def swa_protected_size(self) -> int:
# protected size refers to the size of the swa cache that is locked
return self.swa_protected_size_
def all_values_flatten(self) -> torch.Tensor:
values = []
def _dfs_helper(node: TreeNode):
for _, child in node.children.items():
values.append(child.value)
_dfs_helper(child)
_dfs_helper(self.root_node)
return torch.cat(values)
##### Internal Helper Functions #####
def _match_prefix_helper(self, key: List) -> Tuple[List[torch.Tensor], TreeNode]:
"""
SWA prefix matching helper. It factors in the sliding window size such that
the matched node is guaranteed to either 1. connected to root without swa tombstone,
or 2. the number of matching tokens from the matched node to the last swa tombstone
node is greater than or equal to the sliding window size.
"""
node = self.root_node
child_key = self.get_child_key_fn(key)
value = []
# for path connected to root without tombstone, always match, so set to inf
match_len_since_tombstone = float("inf")
best_value_len = 0
best_last_node = node
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
# update best_value_len and best_last_node if needed
if (
child.swa_tombstone
and match_len_since_tombstone >= self.sliding_window_size
):
best_value_len = len(value)
best_last_node = node
match_len_since_tombstone = 0
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
if not new_node.swa_tombstone:
match_len_since_tombstone += len(new_node.value)
node = new_node
break
else:
value.append(child.value)
if not child.swa_tombstone:
match_len_since_tombstone += len(child.value)
node = child
key = key[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)
# handle best_value_len and best_last_node, for the case that last node is fully matched
if match_len_since_tombstone >= self.sliding_window_size:
best_value_len = len(value)
best_last_node = node
# update time for matched nodes, and make nodes closer to root to be least recently used
# this allows swa to evict nodes closer to root first
self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
self.swa_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
# This last_access_time is for sanity check, can be deleted after validation in production
cur_time = time.monotonic()
while node:
node.last_access_time = cur_time
cur_time -= 0.0001
node = node.parent
return value[:best_value_len], best_last_node
def _split_node(self, key: List[int], child: TreeNode, split_len: int) -> TreeNode:
# new_node -> child
new_node = TreeNode()
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent
new_node.swa_tombstone = child.swa_tombstone
new_node.full_lock_ref = child.full_lock_ref
new_node.swa_lock_ref = child.swa_lock_ref
new_node.key = child.key[:split_len]
new_node.value = child.value[:split_len]
# parent inherits the swa_uuid from child for swa lock ref
new_node.swa_uuid = child.swa_uuid
child.swa_uuid = None
# child time should be later than parent's time for swa tombstone
child.last_access_time = time.monotonic()
# remove the child from the lru lists because it is being split
self.full_lru_list.remove_node(child)
if not new_node.swa_tombstone:
self.swa_lru_list.remove_node(child)
child.parent = new_node
child.key = child.key[split_len:]
child.value = child.value[split_len:]
new_node.parent.children[self.get_child_key_fn(key)] = new_node
# insert the new node and child into the lru lists, insert
# parent first so that parent is after child in the lru list
self.full_lru_list.insert_mru(new_node)
self.full_lru_list.insert_mru(child)
if not new_node.swa_tombstone:
self.swa_lru_list.insert_mru(new_node)
self.swa_lru_list.insert_mru(child)
return new_node
def _insert_helper(
self, node: TreeNode, key: List, value, update_kv_after_len: int
) -> int:
# Update the last access time from root to leaf, so that
# swa will tombstone the node closer to root first
node.last_access_time = time.monotonic()
if node != self.root_node:
self.full_lru_list.reset_node_mru(node)
if not node.swa_tombstone:
self.swa_lru_list.reset_node_mru(node)
if len(key) == 0:
return 0
child_key = self.get_child_key_fn(key)
total_prefix_length = 0
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.monotonic()
self.full_lru_list.reset_node_mru(node)
if not node.swa_tombstone:
self.swa_lru_list.reset_node_mru(node)
prefix_len = self.key_match_fn(node.key, key)
if prefix_len < len(node.key):
new_node = self._split_node(node.key, node, prefix_len)
node = new_node
# if tombstone after update_kv_after_len, update node.value to be the input value.
# This is needed because it is possible that the last sliding window size tokens
# contains tombstone. If this is the case and we don't update the kv value, then
# the prefill prefix matching will stuck.
if update_kv_after_len < total_prefix_length + prefix_len:
first_diff_idx = max(0, update_kv_after_len - total_prefix_length)
if node.swa_tombstone:
assert (
node.swa_lock_ref == 0
), f"tombstone swa_lock_ref should always be 0, {node.full_lock_ref=}, {node.swa_lock_ref=}, {node.id=}"
self.token_to_kv_pool_allocator.free(node.value[first_diff_idx:])
node.value = value[:prefix_len]
node.swa_tombstone = False
# insert the node into the lru lists
self.swa_lru_list.insert_mru(node)
self.swa_evictable_size_ += len(node.value)
else:
self.token_to_kv_pool_allocator.free(
value[first_diff_idx:prefix_len]
)
total_prefix_length += prefix_len
key = key[prefix_len:]
value = value[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = value
self.full_lru_list.insert_mru(new_node)
self.swa_lru_list.insert_mru(new_node)
node.children[child_key] = new_node
self.full_evictable_size_ += len(value)
self.swa_evictable_size_ += len(value)
return total_prefix_length
def _iteratively_delete_tombstone_leaf(
self, node: TreeNode
) -> Tuple[TreeNode, int]:
full_num_evicted = 0
while node.parent.swa_tombstone and len(node.parent.children) == 0:
# root node is not evictable
if node.parent == self.root_node:
break
# if locked, means node is in use, skip
if node.parent.full_lock_ref > 0:
break
assert (
node.parent.swa_lock_ref == 0
), f"tombstone swa_lock_ref should always be 0, {node.parent.full_lock_ref=}, {node.parent.swa_lock_ref=}, {node.parent.id=}"
# delete tombstone node evicts full tokens
self.token_to_kv_pool_allocator.free(node.parent.value)
full_num_evicted += len(node.parent.value)
self.full_lru_list.remove_node(node.parent)
self._delete_tombstone_leaf(node.parent)
node = node.parent
return node, full_num_evicted
def _delete_leaf(self, node: TreeNode) -> None:
assert (
not node.swa_tombstone
), f"Invariant violated: leaf node is a tombstone, {node.id=}"
assert len(node.children) == 0, f"leaf node has children, {node.id=}"
for k, v in node.parent.children.items():
if v == node:
break
del node.parent.children[k]
self.full_evictable_size_ -= len(node.key)
self.swa_evictable_size_ -= len(node.key)
def _tombstone_internal_node(self, node: TreeNode) -> None:
assert len(node.children) != 0, f"Cannot tombstone a leaf node, {node.id=}"
node.swa_tombstone = True
self.swa_evictable_size_ -= len(node.key)
def _delete_tombstone_leaf(self, node: TreeNode) -> None:
assert (
node.swa_tombstone
), f"Deleting a unexpected non-tombstone leaf node, {node.id=}"
assert len(node.children) == 0, f"leaf node has children, {node.id=}"
for k, v in node.parent.children.items():
if v == node:
break
del node.parent.children[k]
self.full_evictable_size_ -= len(node.key)
def _collect_leaves(self) -> List[TreeNode]:
ret_list = []
stack = [self.root_node]
while stack:
cur_node = stack.pop()
if len(cur_node.children) == 0:
ret_list.append(cur_node)
else:
stack.extend(cur_node.children.values())
return ret_list
def _collect_nontombstone_nodes(self) -> List[TreeNode]:
ret_list = []
stack = [self.root_node]
while stack:
cur_node = stack.pop()
if not cur_node.swa_tombstone:
ret_list.append(cur_node)
stack.extend(cur_node.children.values())
return ret_list
def _collect_all_nodes(self) -> List[TreeNode]:
ret_list = []
stack = [self.root_node]
while stack:
cur_node = stack.pop()
ret_list.append(cur_node)
stack.extend(cur_node.children.values())
return ret_list
def _print_helper(self, node: TreeNode, indent: int) -> None:
"""Prints the radix tree in a human-readable format."""
stack = [(node, indent)]
while stack:
current_node, current_indent = stack.pop()
print(
" " * current_indent,
current_node.id,
len(current_node.key),
f"fr={current_node.full_lock_ref}",
f"sr={current_node.swa_lock_ref}",
f"fll={self.full_lru_list.in_list(current_node)}",
f"sll={self.swa_lru_list.in_list(current_node)}",
f"ts={current_node.swa_tombstone}",
)
for key, child in current_node.children.items():
stack.append((child, current_indent + 2))
assert key == self.get_child_key_fn(
child.key
), f"{key=}, {self.get_child_key_fn(child.key)=}"
def _total_size_helper(self) -> Tuple[int, int]:
total_size = 0
total_swa_size = 0
stack = [self.root_node]
while stack:
current_node = stack.pop()
total_size += len(current_node.value)
if not current_node.swa_tombstone:
total_swa_size += len(current_node.value)
for child in current_node.children.values():
if child.evicted:
continue
stack.append(child)
return total_size, total_swa_size
......@@ -275,6 +275,15 @@ class ModelRunner:
self.sampler = Sampler()
self.load_model()
if (
not self.server_args.disable_hybrid_swa_memory
and self.sliding_window_size is not None
and self.sliding_window_size > 0
):
architectures = self.model_config.hf_config.architectures
if architectures and not any("Llama4" in arch for arch in architectures):
self.is_hybrid = self.model_config.is_hybrid = True
self.start_layer = getattr(self.model, "start_layer", 0)
self.end_layer = getattr(
self.model, "end_layer", self.model_config.num_hidden_layers
......@@ -471,10 +480,6 @@ class ModelRunner:
if self.model_config.context_len > 8192:
self.mem_fraction_static *= 0.85
if self.is_hybrid and not server_args.disable_radix_cache:
logger.info("Automatically disable radix cache for hybrid cache.")
server_args.disable_radix_cache = True
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
......@@ -645,11 +650,15 @@ class ModelRunner:
)
# Parse other args
self.sliding_window_size = (
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.sliding_window_size = None
if hasattr(self.model, "get_attention_sliding_window_size"):
self.sliding_window_size = self.model.get_attention_sliding_window_size()
elif self.model_config.attention_chunk_size is not None:
self.sliding_window_size = self.model_config.attention_chunk_size
print(
f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
)
self.dtype = self.model_config.dtype
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
......@@ -992,8 +1001,53 @@ class ModelRunner:
)
self.max_total_num_tokens = self.full_max_total_num_tokens
else:
raise ValueError(
f"Unsupported model for hybrid cache: {self.model_config.hf_config.architectures}."
assert self.sliding_window_size is not None and self.sliding_window_size > 0
full_attention_layer_ids = []
swa_attention_layer_ids = []
try:
layers = self.model.model.layers
except:
try:
layers = self.model.language_model.model.layers
except:
self.is_hybrid = False
return
for layer in layers:
if (
layer.self_attn.attn.sliding_window_size is None
or layer.self_attn.attn.sliding_window_size == -1
):
full_attention_layer_ids.append(layer.layer_id)
else:
swa_attention_layer_ids.append(layer.layer_id)
self.model_config.swa_attention_layer_ids = swa_attention_layer_ids
self.model_config.full_attention_layer_ids = full_attention_layer_ids
# Algorithm:
# Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens.
# - Find total # of tokens available across layers.
# - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio.
total_tokens = (
self.max_total_num_tokens * self.model_config.num_hidden_layers
)
full_layers_num = len(full_attention_layer_ids)
swa_layers_num = len(swa_attention_layer_ids)
swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio
# Solve the equations:
# 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens
# 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens
denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num
self.full_max_total_num_tokens = int(total_tokens / denominator)
self.swa_max_total_num_tokens = int(
self.full_max_total_num_tokens * swa_full_tokens_ratio
)
self.max_total_num_tokens = self.full_max_total_num_tokens
logger.info(
f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
)
def init_memory_pool(
......@@ -1072,7 +1126,6 @@ class ModelRunner:
// self.server_args.page_size
* self.server_args.page_size
)
# create token size for hybrid cache
if self.is_hybrid:
self.set_num_token_hybrid()
......
......@@ -190,6 +190,7 @@ class Gemma2DecoderLayer(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
self.layer_id = layer_id
self.hidden_size = config.hidden_size
self.self_attn = Gemma2Attention(
layer_id=layer_id,
......
......@@ -63,6 +63,7 @@ class ServerArgs:
enable_multimodal: Optional[bool] = None
revision: Optional[str] = None
hybrid_kvcache_ratio: Optional[float] = None
swa_full_tokens_ratio: float = 0.8
impl: str = "auto"
# Port for the HTTP server
......@@ -225,6 +226,7 @@ class ServerArgs:
enable_return_hidden_states: bool = False
enable_triton_kernel_moe: bool = False
warmups: Optional[str] = None
disable_hybrid_swa_memory: bool = False
# Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None
......@@ -481,14 +483,22 @@ class ServerArgs:
model_arch = get_model_arch(self)
# Auto set draft_model_path DeepSeek-V3/R1
if model_arch == "DeepseekV3ForCausalLM":
# Auto set draft_model_path DeepSeek-V3/R1
if self.speculative_draft_model_path is None:
self.speculative_draft_model_path = self.model_path
else:
logger.warning(
"DeepSeek MTP does not require setting speculative_draft_model_path."
)
elif "Llama4" in model_arch:
# TODO: remove this after Llama4 supports in other backends
if self.attention_backend != "fa3":
self.attention_backend = "fa3"
logger.warning(
"Llama4 requires using fa3 attention backend. "
"Attention backend is automatically set to fa3."
)
# Auto choose parameters
if self.speculative_num_steps is None:
......@@ -852,6 +862,18 @@ class ServerArgs:
"(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
),
)
parser.add_argument(
"--swa-full-tokens-ratio",
type=float,
default=ServerArgs.swa_full_tokens_ratio,
help="The ratio of SWA layer KV tokens / full layer KV tokens, regardless of the number of swa:full layers. It should be between 0 and 1. "
"E.g. 0.5 means if each swa layer has 50 tokens, then each full layer has 100 tokens.",
)
parser.add_argument(
"--disable-hybrid-swa-memory",
action="store_true",
help="Disable the hybrid SWA memory.",
)
# Other runtime options
parser.add_argument(
......@@ -1730,10 +1752,6 @@ class ServerArgs:
else:
self.lora_paths[lora_path] = lora_path
model_arch = get_model_arch(self)
if "Llama4" in model_arch and self.hybrid_kvcache_ratio is not None:
assert self.attention_backend == "fa3"
def prepare_server_args(argv: List[str]) -> ServerArgs:
"""
......
import unittest
import torch
from sglang.srt.mem_cache.allocator import SWAKVPool, SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import SWARadixCache
class TestSWA(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
@classmethod
def tearDownClass(cls):
pass
def test_swa_memory_pool(self):
size = 16
size_swa = 16
num_head = 8
head_dim = 128
num_layers = 48
global_interval = 4
dtype = torch.bfloat16
device = "cuda"
full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)]
full_attention_layer_ids_set = set(full_attention_layer_ids)
swa_attention_layer_ids = [
i for i in range(num_layers) if i not in full_attention_layer_ids_set
]
pool = SWAKVPool(
size,
size_swa,
dtype,
num_head,
head_dim,
swa_attention_layer_ids,
full_attention_layer_ids,
device,
)
alloc = SWATokenToKVPoolAllocator(size, size_swa, dtype, device, pool)
assert alloc.available_size() == size + size_swa
index = alloc.alloc(1)
assert alloc.available_size() == size_swa + size_swa - 2
alloc.free_swa(index)
result = alloc.translate_loc_from_full_to_swa(index)
print(result)
def test_swa_radix_cache_1(self):
# args
req_size = 10
max_context_len = 128
kv_size = 128
kv_size_swa = 64
sliding_window_size = 4
num_head = 8
head_dim = 128
num_layers = 48
global_interval = 4
dtype = torch.bfloat16
device = "cuda"
full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)]
full_attention_layer_ids_set = set(full_attention_layer_ids)
swa_attention_layer_ids = [
i for i in range(num_layers) if i not in full_attention_layer_ids_set
]
# setup req to token pool
req_to_token_pool = ReqToTokenPool(
size=req_size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=False,
)
# setup kv pool
kv_pool = SWAKVPool(
kv_size,
kv_size_swa,
dtype,
num_head,
head_dim,
swa_attention_layer_ids,
full_attention_layer_ids,
device,
)
# setup token to kv pool allocator
allocator = SWATokenToKVPoolAllocator(
kv_size, kv_size_swa, dtype, device, kv_pool
)
# setup radix cache
tree = SWARadixCache(
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=allocator,
sliding_window_size=sliding_window_size,
page_size=1,
disable=False,
)
# test
print(
f"[Start] allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req1_token_ids, req1_kv_indices = [1, 2, 3], allocator.alloc(3)
assert len(req1_token_ids) == len(req1_kv_indices)
print(
f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}"
)
prefix_len = tree.insert(req1_token_ids, req1_kv_indices)
print(
f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req2_token_ids, req2_kv_indices = [1, 2, 3, 4, 5, 6, 7], allocator.alloc(7)
assert len(req2_token_ids) == len(req2_kv_indices)
print(
f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}"
)
prefix_len = tree.insert(req2_token_ids, req2_kv_indices)
print(
f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req3_token_ids, req3_kv_indices = [10, 11, 12], allocator.alloc(3)
assert len(req3_token_ids) == len(req3_kv_indices)
print(
f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}"
)
prefix_len = tree.insert(req3_token_ids, req3_kv_indices)
print(
f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
req4_token_ids, req4_kv_indices = [1, 2, 3, 4, 5, 60, 70], allocator.alloc(7)
assert len(req4_token_ids) == len(req4_kv_indices)
print(
f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}"
)
prefix_len = tree.insert(req4_token_ids, req4_kv_indices)
print(
f"req4: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
)
tree.pretty_print()
full_num_tokens, swa_num_tokens = 1, 0
print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
tree.pretty_print()
full_num_tokens, swa_num_tokens = 0, 1
print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
tree.pretty_print()
full_num_tokens, swa_num_tokens = 1, 2
print(f"evicting {full_num_tokens} full token and {swa_num_tokens} swa token")
tree.evict(full_num_tokens=full_num_tokens, swa_num_tokens=swa_num_tokens)
tree.pretty_print()
req5_token_ids = [1, 2, 3, 4, 5]
kv_indices, last_node = tree.match_prefix(req5_token_ids)
print(
f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
assert len(kv_indices) == 0
req6_token_ids = [1, 2, 3, 4, 5, 60, 70]
kv_indices, last_node = tree.match_prefix(req6_token_ids)
print(
f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
assert len(kv_indices) == 7
assert len(last_node.key) == 2
assert last_node.key[0] == 60
assert last_node.key[1] == 70
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment