Unverified Commit dbd9435d authored by Roger Young's avatar Roger Young Committed by GitHub
Browse files

Fix mamba radix cache eviction logic in `alloc_req_slots` (#11616)


Signed-off-by: default avatarrogeryoungh <rogeryoungh@foxmail.com>
parent 8ae9d4bb
...@@ -66,7 +66,7 @@ from sglang.srt.mem_cache.common import ( ...@@ -66,7 +66,7 @@ from sglang.srt.mem_cache.common import (
evict_from_tree_cache, evict_from_tree_cache,
) )
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
...@@ -1080,27 +1080,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1080,27 +1080,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def is_empty(self): def is_empty(self):
return len(self.reqs) == 0 return len(self.reqs) == 0
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
if mamba_available_size < num_reqs:
if self.tree_cache is not None and isinstance(
self.tree_cache, MambaRadixCache
):
mamba_num = max(0, num_reqs - mamba_available_size)
self.tree_cache.evict_mamba(mamba_num)
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
else:
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
"alloc_req_slots runs out of memory. "
"Please set a smaller number for `--max-running-requests`. "
f"{self.req_to_token_pool.available_size()=}, "
f"{num_reqs=}, "
)
return req_pool_indices
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = [] self.encoder_lens_cpu = []
self.encoder_cached = [] self.encoder_cached = []
......
...@@ -10,6 +10,7 @@ import triton.language as tl ...@@ -10,6 +10,7 @@ import triton.language as tl
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.server_args import get_global_server_args from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import support_triton from sglang.srt.utils import support_triton
...@@ -292,9 +293,15 @@ def alloc_req_slots( ...@@ -292,9 +293,15 @@ def alloc_req_slots(
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
num_reqs: int, num_reqs: int,
reqs: list[Req] | None, reqs: list[Req] | None,
tree_cache: BasePrefixCache | None,
) -> list[int]: ) -> list[int]:
"""Allocate request slots from the pool.""" """Allocate request slots from the pool."""
if isinstance(req_to_token_pool, HybridReqToTokenPool): if isinstance(req_to_token_pool, HybridReqToTokenPool):
mamba_available_size = req_to_token_pool.mamba_pool.available_size()
if mamba_available_size < num_reqs:
if tree_cache is not None and isinstance(tree_cache, MambaRadixCache):
mamba_num = max(0, num_reqs - mamba_available_size)
tree_cache.evict_mamba(mamba_num)
req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs) req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs)
else: else:
req_pool_indices = req_to_token_pool.alloc(num_reqs) req_pool_indices = req_to_token_pool.alloc(num_reqs)
...@@ -337,7 +344,9 @@ def alloc_for_extend( ...@@ -337,7 +344,9 @@ def alloc_for_extend(
extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True) extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True)
# Allocate req slots # Allocate req slots
req_pool_indices = alloc_req_slots(batch.req_to_token_pool, bs, batch.reqs) req_pool_indices = alloc_req_slots(
batch.req_to_token_pool, bs, batch.reqs, batch.tree_cache
)
req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64) req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64)
req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True) req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment