"experiments/segmentation/demo.py" did not exist on "3ba8d2f7244b1f43f5fa0fb4ebdabafbfb33bdd2"
Unverified Commit a55cf530 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

[Feature] Support mamba radix cache v0 (#11214)


Co-authored-by: default avatarhanming-lu <hanming@x.ai>
Co-authored-by: default avatarhzh0425 <hzh0425@apache.org>
Co-authored-by: default avatarthalahors <ericalcaide1@gmail.com>
parent 19ba16aa
...@@ -65,7 +65,8 @@ from sglang.srt.mem_cache.common import ( ...@@ -65,7 +65,8 @@ from sglang.srt.mem_cache.common import (
alloc_for_extend, alloc_for_extend,
alloc_token_slots, alloc_token_slots,
) )
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
...@@ -522,6 +523,7 @@ class Req: ...@@ -522,6 +523,7 @@ class Req:
# Memory pool info # Memory pool info
self.req_pool_idx: Optional[int] = None self.req_pool_idx: Optional[int] = None
self.mamba_pool_idx: Optional[torch.Tensor] = None # shape (1)
# Check finish # Check finish
self.tokenizer = None self.tokenizer = None
...@@ -727,7 +729,12 @@ class Req: ...@@ -727,7 +729,12 @@ class Req:
self.last_host_node, self.last_host_node,
self.host_hit_length, self.host_hit_length,
) = tree_cache.match_prefix( ) = tree_cache.match_prefix(
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key) key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
**(
{"req": self, "cow_mamba": True}
if isinstance(tree_cache, MambaRadixCache)
else {}
),
) )
self.last_matched_prefix_len = len(self.prefix_indices) self.last_matched_prefix_len = len(self.prefix_indices)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
...@@ -877,6 +884,7 @@ class Req: ...@@ -877,6 +884,7 @@ class Req:
self.extend_logprob_start_len = 0 self.extend_logprob_start_len = 0
self.is_chunked = 0 self.is_chunked = 0
self.req_pool_idx = None self.req_pool_idx = None
self.mamba_pool_idx = None
self.already_computed = 0 self.already_computed = 0
def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator): def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
...@@ -1071,6 +1079,27 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1071,6 +1079,27 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def is_empty(self): def is_empty(self):
return len(self.reqs) == 0 return len(self.reqs) == 0
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
if mamba_available_size < num_reqs:
if self.tree_cache is not None and isinstance(
self.tree_cache, MambaRadixCache
):
mamba_num = max(0, num_reqs - mamba_available_size)
self.tree_cache.evict_mamba(mamba_num)
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
else:
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
"alloc_req_slots runs out of memory. "
"Please set a smaller number for `--max-running-requests`. "
f"{self.req_to_token_pool.available_size()=}, "
f"{num_reqs=}, "
)
return req_pool_indices
def allocate_for_eagle_v2(self): def allocate_for_eagle_v2(self):
from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
......
...@@ -27,6 +27,7 @@ import torch ...@@ -27,6 +27,7 @@ import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -357,6 +358,7 @@ class PrefillAdder: ...@@ -357,6 +358,7 @@ class PrefillAdder:
self.is_hybrid = isinstance( self.is_hybrid = isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
) )
self.is_hybrid_gdn_cache = isinstance(self.tree_cache, MambaRadixCache)
self.priority_scheduling_preemption_threshold = ( self.priority_scheduling_preemption_threshold = (
priority_scheduling_preemption_threshold priority_scheduling_preemption_threshold
...@@ -380,6 +382,11 @@ class PrefillAdder: ...@@ -380,6 +382,11 @@ class PrefillAdder:
self.token_to_kv_pool_allocator.swa_available_size() self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(), + self.tree_cache.swa_evictable_size(),
) )
elif self.is_hybrid_gdn_cache:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.full_evictable_size()
)
else: else:
available_and_evictable = ( available_and_evictable = (
self.token_to_kv_pool_allocator.available_size() self.token_to_kv_pool_allocator.available_size()
...@@ -397,6 +404,11 @@ class PrefillAdder: ...@@ -397,6 +404,11 @@ class PrefillAdder:
self.token_to_kv_pool_allocator.swa_available_size() self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(), + self.tree_cache.swa_evictable_size(),
) )
elif self.is_hybrid_gdn_cache:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.full_evictable_size()
)
else: else:
available_and_evictable = ( available_and_evictable = (
self.token_to_kv_pool_allocator.available_size() self.token_to_kv_pool_allocator.available_size()
......
...@@ -146,6 +146,7 @@ from sglang.srt.managers.session_controller import Session ...@@ -146,6 +146,7 @@ from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.utils import validate_input_length from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
...@@ -470,6 +471,10 @@ class Scheduler( ...@@ -470,6 +471,10 @@ class Scheduler(
# Hybrid memory pool # Hybrid memory pool
self.is_hybrid = self.tp_worker.is_hybrid self.is_hybrid = self.tp_worker.is_hybrid
self.is_hybrid_gdn = (
self.tp_worker.worker.model_runner.hybrid_gdn_config is not None
)
if self.is_hybrid: if self.is_hybrid:
self.sliding_window_size = self.tp_worker.sliding_window_size self.sliding_window_size = self.tp_worker.sliding_window_size
self.full_tokens_per_layer, self.swa_tokens_per_layer = ( self.full_tokens_per_layer, self.swa_tokens_per_layer = (
...@@ -816,6 +821,16 @@ class Scheduler( ...@@ -816,6 +821,16 @@ class Scheduler(
disable=server_args.disable_radix_cache, disable=server_args.disable_radix_cache,
is_eagle=self.spec_algorithm.is_eagle(), is_eagle=self.spec_algorithm.is_eagle(),
) )
elif self.is_hybrid_gdn:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid GDN mode does not support disaggregation yet"
self.tree_cache = MambaRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
)
elif server_args.enable_lmcache: elif server_args.enable_lmcache:
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import ( from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
LMCRadixCache, LMCRadixCache,
...@@ -1689,6 +1704,25 @@ class Scheduler( ...@@ -1689,6 +1704,25 @@ class Scheduler(
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n" f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n" f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
) )
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
(
full_num_used,
mamba_num_used,
_,
_,
full_available_size,
full_evictable_size,
mamba_available_size,
mamba_evictable_size,
) = self._get_mamba_token_info()
memory_leak = (
full_num_used != self.tree_cache.full_protected_size()
or mamba_num_used != self.tree_cache.mamba_protected_size()
)
token_msg = (
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
)
else: else:
_, _, available_size, evictable_size = self._get_token_info() _, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size() protected_size = self.tree_cache.protected_size()
...@@ -1739,6 +1773,17 @@ class Scheduler( ...@@ -1739,6 +1773,17 @@ class Scheduler(
) = self._get_swa_token_info() ) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used) num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage) token_usage = max(full_token_usage, swa_token_usage)
elif self.is_hybrid_gdn:
(
num_used,
_,
token_usage,
_,
_,
_,
_,
_,
) = self._get_mamba_token_info()
else: else:
num_used, token_usage, _, _ = self._get_token_info() num_used, token_usage, _, _ = self._get_token_info()
num_running_reqs = len(self.running_batch.reqs) num_running_reqs = len(self.running_batch.reqs)
...@@ -1766,7 +1811,9 @@ class Scheduler( ...@@ -1766,7 +1811,9 @@ class Scheduler(
self._publish_kv_events() self._publish_kv_events()
def check_tree_cache(self): def check_tree_cache(self):
if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache): if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
):
self.tree_cache.sanity_check() self.tree_cache.sanity_check()
def _get_token_info(self): def _get_token_info(self):
...@@ -1776,6 +1823,35 @@ class Scheduler( ...@@ -1776,6 +1823,35 @@ class Scheduler(
token_usage = num_used / self.max_total_num_tokens token_usage = num_used / self.max_total_num_tokens
return num_used, token_usage, available_size, evictable_size return num_used, token_usage, available_size, evictable_size
def _get_mamba_token_info(self):
is_radix_tree = isinstance(self.tree_cache, MambaRadixCache)
full_available_size = self.token_to_kv_pool_allocator.available_size()
full_evictable_size = (
self.tree_cache.full_evictable_size() if is_radix_tree else 0
)
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
mamba_evictable_size = (
self.tree_cache.mamba_evictable_size() if is_radix_tree else 0
)
full_num_used = self.token_to_kv_pool_allocator.size - (
full_available_size + full_evictable_size
)
mamba_num_used = self.req_to_token_pool.mamba_pool.size - (
mamba_available_size + mamba_evictable_size
)
full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size
mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size
return (
full_num_used,
mamba_num_used,
full_token_usage,
mamba_usage,
full_available_size,
full_evictable_size,
mamba_available_size,
mamba_evictable_size,
)
def _get_swa_token_info(self): def _get_swa_token_info(self):
full_available_size = self.token_to_kv_pool_allocator.full_available_size() full_available_size = self.token_to_kv_pool_allocator.full_available_size()
full_evictable_size = self.tree_cache.full_evictable_size() full_evictable_size = self.tree_cache.full_evictable_size()
......
...@@ -104,6 +104,23 @@ class SchedulerMetricsMixin: ...@@ -104,6 +104,23 @@ class SchedulerMetricsMixin:
f"full token usage: {full_token_usage:.2f}, " f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, " f"swa token usage: {swa_token_usage:.2f}, "
) )
elif self.is_hybrid_gdn:
(
full_num_used,
_,
full_token_usage,
mamba_usage,
_,
_,
_,
_,
) = self._get_mamba_token_info()
num_used = full_num_used
token_usage = full_token_usage
token_usage_msg = (
f"full token usage: {full_token_usage:.2f}, "
f"mamba usage: {mamba_usage:.2f}, "
)
else: else:
num_used, token_usage, _, _ = self._get_token_info() num_used, token_usage, _, _ = self._get_token_info()
token_usage_msg = f"token usage: {token_usage:.2f}, " token_usage_msg = f"token usage: {token_usage:.2f}, "
...@@ -203,6 +220,25 @@ class SchedulerMetricsMixin: ...@@ -203,6 +220,25 @@ class SchedulerMetricsMixin:
f"#swa token: {swa_num_used}, " f"#swa token: {swa_num_used}, "
f"swa token usage: {swa_token_usage:.2f}, " f"swa token usage: {swa_token_usage:.2f}, "
) )
elif self.is_hybrid_gdn:
(
full_num_used,
mamba_used,
full_token_usage,
mamba_usage,
_,
_,
_,
_,
) = self._get_mamba_token_info()
num_used = full_num_used
token_usage = full_token_usage
token_usage_msg = (
f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, "
f"mamba num: {mamba_used}, "
f"mamba usage: {mamba_usage:.2f}, "
)
else: else:
num_used, token_usage, _, _ = self._get_token_info() num_used, token_usage, _, _ = self._get_token_info()
token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, " token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, "
......
from __future__ import annotations
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
The radix tree data structure for managing the hybrid (full and Mamba) KV cache.
"""
import heapq
import time
from collections import defaultdict
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
from sglang.srt.mem_cache.radix_cache import (
RadixKey,
_key_match_page_size1,
_key_match_paged,
get_child_key,
)
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
import logging
logger = logging.getLogger(__name__)
class TreeNode:
counter = 0
def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode)
self.parent: TreeNode = None
self.key: RadixKey = None
self.value: Optional[torch.Tensor] = None
self.mamba_value: Optional[torch.Tensor] = None
# invariant: for any node, if mamba_lock_ref is locked, full_lock_ref must be locked;
# if full_lock_ref is locked, mamba_lock_ref doesn't need to be locked. So,
# full_lock_ref is always >= mamba_lock_ref.
# for full_lock, once it is locked, its parent must be locked as well
# for mamba_lock, it only need lock node itself
self.full_lock_ref = 0
self.mamba_lock_ref = 0
# last access time is only used for sanity check. LRU is maintained by the lru list.
self.last_access_time = time.monotonic()
self.hit_count = 0
# store the host indices of KV cache
self.host_value = None
# for lru list, invariant:
# 1. prev has greater last_access_time
# 2. next has smaller last_access_time
self.prev = None
self.next = None
self.mamba_prev = None
self.mamba_next = None
self.id = TreeNode.counter if id is None else id
TreeNode.counter += 1
@property
def evicted(self):
return self.value is None
@property
def backuped(self):
return self.host_value is not None
def __lt__(self, other: "TreeNode"):
return self.last_access_time < other.last_access_time
class LRUList:
def __init__(self, mamba: bool = False):
self.mamba = mamba
if self.mamba:
self.prv = "mamba_prev"
self.nxt = "mamba_next"
self.lock_ref = "mamba_lock_ref"
else:
self.prv = "prev"
self.nxt = "next"
self.lock_ref = "full_lock_ref"
# Initialize dummy head and tail nodes
self.head = TreeNode() # Most recently used side
self.tail = TreeNode() # Least recently used side
setattr(self.head, self.nxt, self.tail) # self.head.next = self.tail
setattr(self.tail, self.prv, self.head) # self.tail.prev = self.head
self.cache = {}
def _add_node(self, node):
"""Helper to add node right after head (most recently used)"""
self._add_node_after(self.head, node)
def _add_node_after(self, old_node, new_node):
"""Helper to add node right after old_node"""
setattr(new_node, self.prv, old_node) # new_node.prev = old_node
setattr(
new_node, self.nxt, getattr(old_node, self.nxt)
) # new_node.next = old_node.next
setattr(
getattr(old_node, self.nxt), self.prv, new_node
) # old_node.next.prev = new_node
setattr(old_node, self.nxt, new_node) # old_node.next = new_node
def _remove_node(self, node):
"""Helper to remove node from linked list"""
setattr(
getattr(node, self.prv), self.nxt, getattr(node, self.nxt)
) # node.prev.next = node.next
setattr(
getattr(node, self.nxt), self.prv, getattr(node, self.prv)
) # node.next.prev = node.prev
def _get_lru(self) -> Optional[TreeNode]:
"""
Get the least recently used node
"""
if len(self.cache) == 0:
return None
return getattr(self.tail, self.prv)
def reset_node_mru(self, node):
"""
Move a (existing) node to most recently used position
"""
assert node.id in self.cache, f"Resetting node {node.id=} not in lru list"
assert (
not self.mamba or node.mamba_value is not None
), f"Resetting mamba tombstone node in mamba lru list: {node.id=}"
self._remove_node(node)
self._add_node(node)
def reset_node_and_parents_mru(self, node, root_node):
"""
Move an (existing) node and its parents to most recently used position. Child node is
more recently used than parent node.
"""
prev_node = self.head
while node != root_node:
if not self.mamba or node.mamba_value is not None:
assert (
node.id in self.cache
), f"Resetting node {node.id=} not in lru list when resetting node and parents mru"
self._remove_node(node)
self._add_node_after(prev_node, node)
prev_node = node
node = node.parent
def insert_mru(self, node):
"""
Insert a (new) node as most recently used
"""
assert (
not self.mamba or node.mamba_value is not None
), f"Inserting mamba tombstone node in mamba lru list: {node.id=}"
assert (
node.id not in self.cache
), f"Inserting node {node.id=} already in lru list, existing node: {self.cache[node.id].id=}"
self.cache[node.id] = node
self._add_node(node)
def remove_node(self, node: TreeNode):
"""
Remove node from lru list
"""
assert node.id in self.cache, f"Removing node {node.id=} not in lru list"
assert (
not self.mamba or node.mamba_value is not None
), f"Removing mamba tombstone node from mamba lru list: {node.id=}"
del self.cache[node.id]
self._remove_node(node)
def get_lru_no_lock(self) -> Optional[TreeNode]:
"""
Get the least recently used node that is not locked
"""
return self.get_prev_no_lock(self.tail, check_id=False)
def get_leaf_lru_no_lock(self) -> Optional[TreeNode]:
"""
Get the least recently used leaf node that is not locked
"""
return self.get_prev_leaf_no_lock(self.tail, check_id=False)
def get_prev_no_lock(
self, node: TreeNode, check_id: bool = True
) -> Optional[TreeNode]:
"""
Get the previous (i.e. more recently used) node that is not locked
"""
if check_id:
assert (
node.id in self.cache
), f"Getting prev of node {node.id=} not in lru list"
x = getattr(node, self.prv) # x = node.prev
while getattr(x, self.lock_ref) > 0:
x = getattr(x, self.prv) # x = x.prev
# if x is the head, it means there is no node in the lru list without lock
if x == self.head:
return None
return x
def get_prev_leaf_no_lock(self, node: TreeNode, check_id: bool = True):
"""
Get the previous (i.e. more recently used) leaf node that is not locked
"""
if check_id:
assert (
node.id in self.cache
), f"Getting prev of node {node.id=} not in lru list"
x = getattr(node, self.prv) # x = node.prev
while getattr(x, self.lock_ref) > 0 or len(x.children) > 0:
x = getattr(x, self.prv) # x = x.prev
# if x is the head, it means there is no leaf node in the lru list without lock
if x == self.head:
return None
return x
def in_list(self, node: Optional[TreeNode]):
"""
Check if the node is in the lru list
"""
if not node:
return False
return node.id in self.cache
# Note: this is expensive, only use for debug
def sanity_check_evictable_size(self):
"""
Check the evictable size (i.e. the size of the nodes that are not locked)
"""
node = self.get_lru_no_lock()
evictable_size = 0
while self.in_list(node):
evictable_size += (
len(node.value) if not self.mamba else len(node.mamba_value)
)
node = self.get_prev_no_lock(node)
return evictable_size
# Note: this is expensive, only use for debug or idle check
def sanity_check(self, tree_cache: "MambaRadixCache"):
"""
Check if the lru list is valid by rebuilding the lru list from the tree, heapifying it, and
checking if the lru list is valid.
"""
try:
if self.mamba:
nodes = tree_cache._collect_nontombstone_nodes()
else:
nodes = tree_cache._collect_all_nodes()
total_nodes = len(nodes)
total_lru = len(self.cache)
# heapify based on last_access_time
heapq.heapify(nodes)
# the root node is not in the lru list
assert len(nodes) == (
total_lru + (0 if self.mamba else 1)
), f"len(nodes): {len(nodes)}, total_lru: {total_lru}"
x_lru = self._get_lru()
while len(nodes):
x = heapq.heappop(nodes)
if x == tree_cache.root_node:
# root node is not in the lru list
continue
assert (
x == x_lru
), f"Incorrect LRU list, {self.mamba=}, x: {x.id=} != x_lru: {x_lru.id=}"
assert (
x_lru.full_lock_ref == 0
), f"x_lru should not be locked when idle, {x_lru.full_lock_ref=}, {x_lru.id=}"
assert (
x_lru.mamba_lock_ref == 0
), f"x_lru should not be locked when idle, {x_lru.mamba_lock_ref=}, {x_lru.id=}"
x_lru = getattr(x, self.prv)
if self.mamba:
evictable_size = tree_cache.mamba_evictable_size()
lru_list_evictable_size = tree_cache.mamba_lru_list_evictable_size()
else:
evictable_size = tree_cache.full_evictable_size()
lru_list_evictable_size = tree_cache.full_lru_list_evictable_size()
assert (
evictable_size == lru_list_evictable_size
), f"{self.mamba=}, total nodes: {total_nodes}, total lru: {total_lru}, evictable size: {evictable_size} != lru list evictable size: {lru_list_evictable_size}"
except Exception as e:
msg = f"Mamba Radix tree sanity check failed, ping @yizhang2077: {e}"
logger.error(msg)
raise Exception(msg)
class MambaRadixCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: HybridReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int,
disable: bool = False,
):
assert isinstance(token_to_kv_pool_allocator, TokenToKVPoolAllocator)
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
assert page_size == 1, "Only support page_size=1 in mamba radix cache now."
self.page_size = page_size
self.disable = disable
if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
else:
self.device = torch.device("cpu")
self.key_match_fn = _key_match_page_size1
self.get_child_key_fn = get_child_key
self.reset()
##### Public API #####
def reset(self) -> None:
self.root_node = TreeNode()
self.root_node.key = []
self.root_node.value = []
self.root_node.full_lock_ref = 1
self.root_node.mamba_lock_ref = 1
self.full_evictable_size_ = 0
self.mamba_evictable_size_ = 0
self.full_protected_size_ = 0
self.mamba_protected_size_ = 0
# LRU lists are used to maintain the order of eviction of the nodes in the tree
self.full_lru_list = LRUList(mamba=False)
self.mamba_lru_list = LRUList(mamba=True)
def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
"""Find the matching prefix from the radix tree.
Args:
key: A RadixKey contains token IDs to find a matching prefix.
Returns:
A tuple of a tensor of matching prefix token IDs and
the last node that contains the prefix values. Note that
this API can modify the internal state of the Radix tree.
The last node create a new child if the prefix is shorter
than the last node's value.
"""
cow_mamba: bool = kwargs.get("cow_mamba", False)
req: Req = kwargs.get("req", None)
if self.disable or len(key) == 0:
return MatchResult(
device_indices=torch.empty(
(0,),
dtype=torch.int64,
device=self.device,
),
last_device_node=self.root_node,
last_host_node=self.root_node,
)
value, last_node = self._match_prefix_helper(key)
# copy mamba state to req local space if cow is true
if cow_mamba and last_node.mamba_value is not None:
assert req.req_pool_idx is None # req_pool_idx is uninitialed
# for reqs without mamba cache
if req.mamba_pool_idx is None:
dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
# try to alloc again, protect last_node from eviction
if dst_index is None:
self.inc_lock_ref(last_node)
self.evict_mamba(1)
dst_index = self.req_to_token_pool.mamba_pool.alloc(1)
self.dec_lock_ref(last_node)
assert dst_index is not None, "Can not alloc mamba cache"
src_index = last_node.mamba_value
self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index)
req.mamba_pool_idx = dst_index[0]
else:
src_index = last_node.mamba_value
dst_index = req.mamba_pool_idx.unsqueeze(0)
self.req_to_token_pool.mamba_pool.copy_from(src_index, dst_index)
if value:
value = torch.cat(value)
else:
value = torch.empty((0,), dtype=torch.int64, device=self.device)
return MatchResult(
device_indices=value,
last_device_node=last_node,
last_host_node=last_node,
)
def insert(self, key: RadixKey, value=None, mamba_value=None) -> Tuple[int, bool]:
if self.disable:
return 0
if value is None:
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
return self._insert_helper(self.root_node, key, value, mamba_value)
def cache_finished_req(self, req: Req) -> None:
"""Cache request when it finishes."""
if self.disable:
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx,
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
]
self.token_to_kv_pool_allocator.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
# Radix Cache takes one ref in memory pool
# insert the token_ids and kv_indices into the radix tree
# Note: the insert function already frees the overlapped kv_indices
mamba_value = (
self.req_to_token_pool.get_mamba_indices(req.req_pool_idx)
.unsqueeze(-1)
.clone()
)
new_prefix_len, mamba_exist = self.insert(
RadixKey(token_ids[:page_aligned_len], req.extra_key),
page_aligned_kv_indices,
mamba_value,
)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
self.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=mamba_exist)
self.dec_lock_ref(req.last_node)
def cache_unfinished_req(self, req: Req, chunked=False) -> None:
"""Cache request when it is unfinished."""
if self.disable:
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(req.fill_ids)
]
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req.prefix_indices = kv_indices
return
token_ids = req.fill_ids
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
page_aligned_token_ids = token_ids[:page_aligned_len]
mamba_value = self.req_to_token_pool.get_mamba_indices(
req.req_pool_idx
).unsqueeze(-1)
# radix tree mamba value is forked from req space
mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(mamba_value)
# if alloc mamba cache failed, do evict and alloc again
if mamba_value_forked is None:
self.evict_mamba(1)
mamba_value_forked = self.req_to_token_pool.mamba_pool.fork_from(
mamba_value
)
assert mamba_value_forked is not None, "Can not alloc mamba cache"
new_prefix_len, mamba_exist = self.insert(
RadixKey(page_aligned_token_ids, req.extra_key),
page_aligned_kv_indices,
mamba_value_forked,
)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# there is a mamba cache in radix cache, release it
if mamba_exist:
self.req_to_token_pool.mamba_pool.free(mamba_value_forked)
# The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(
RadixKey(page_aligned_token_ids, req.extra_key)
)
if not mamba_exist:
assert torch.equal(new_last_node.mamba_value, mamba_value_forked)
assert len(req.prefix_indices) <= len(
new_indices
), f"{req.prefix_indices=}, {new_indices=}"
assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
)
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req.prefix_indices = new_indices
req.last_node = new_last_node
def pretty_print(self) -> None:
self._print_helper(self.root_node, 0)
total_size, total_mamba_size = self._total_size_helper()
print(f"#full_tokens: {total_size}, #mamba_num: {total_mamba_size}")
def total_size(self) -> Tuple[int, int]:
return self._total_size_helper()
def _evict_leaf_node(
self, x: TreeNode, is_evict_mamba: bool
) -> Tuple[int, int, TreeNode, TreeNode]:
assert (
x.full_lock_ref == 0 and x.mamba_lock_ref == 0
), f"evict leaf node invalid with {x.id=} {x.full_lock_ref=} {x.mamba_lock_ref=}"
assert x.mamba_value is not None, f"leaf node mamba value is not None, {x.id=}"
# 1. a leaf node, free full tokens and mamba
self.token_to_kv_pool_allocator.free(x.value)
full_num_evicted = len(x.value)
self.req_to_token_pool.mamba_pool.free(x.mamba_value)
mamba_num_evicted = len(x.mamba_value)
# 2. get the next node, update the lru lists
if is_evict_mamba:
x_next = self.mamba_lru_list.get_prev_no_lock(x)
else:
x_next = self.full_lru_list.get_prev_leaf_no_lock(x)
self.full_lru_list.remove_node(x)
self.mamba_lru_list.remove_node(x)
# 3. delete the leaf node
self._delete_leaf(x)
# 4. Iteratively delete tombstone leaves to maintain invariant that leaf nodes are not tombstone
x, leaf_full_num_evicted = self._iteratively_delete_tombstone_leaf(x)
full_num_evicted += leaf_full_num_evicted
return full_num_evicted, mamba_num_evicted, x, x_next
def evict_mamba(self, mamba_num: int) -> None:
if self.disable or mamba_num <= 0:
return
# get the least recently used node that is not locked, doesn't have to be a leaf
x = self.mamba_lru_list.get_lru_no_lock()
mamba_num_evicted = 0
# evict lru leaf nodes until mamba_num_tokens is reached
while mamba_num_evicted < mamba_num and (self.mamba_lru_list.in_list(x)):
assert x.mamba_value is not None, f"node has no mamba value, {x.id=}"
assert (
len(x.mamba_value) == 1
), f"node has abnormal mamba length, {x.id=}, {len(x.mamba_value)=}"
assert x != self.root_node, f"root node is not evictable, {x.id=}"
assert x.mamba_lock_ref == 0, f"node is in use by mamba kv indices, {x.id=}"
if len(x.children) > 0:
# 1. an internal node, free mamba tokens.
self.req_to_token_pool.mamba_pool.free(x.mamba_value)
mamba_num_evicted += len(x.mamba_value)
# 2. get the next node, update the lru lists
x_next = self.mamba_lru_list.get_prev_no_lock(x)
self.mamba_lru_list.remove_node(x)
# 3. tombstone the node
self._tombstone_internal_node(x)
else:
_, mamba_evicted_delta, _, x_next = self._evict_leaf_node(x, True)
mamba_num_evicted += mamba_evicted_delta
x = x_next
def evict(self, full_num_tokens: int) -> None:
if self.disable or full_num_tokens <= 0:
return
full_num_evicted = 0
# get the least recently used leaf node that is not locked
x = self.full_lru_list.get_leaf_lru_no_lock()
while full_num_evicted < full_num_tokens and self.full_lru_list.in_list(x):
assert (
x != self.root_node
), f"root node should not exist in full lru list, {x.id=}"
full_num_evicted_delta, _, x, x_next = self._evict_leaf_node(x, False)
full_num_evicted += full_num_evicted_delta
# if parent has no more children, it is a leaf. It is possible that this node is lru, so
# we need to get the first leaf node in the lru list
if len(x.parent.children) == 0:
x_next = self.full_lru_list.get_leaf_lru_no_lock()
x = x_next
def inc_lock_ref(self, node: TreeNode) -> Optional[int]:
"""
Increment the lock reference count for the node.
It locks the full_lock_ref for nodes between the [last node, root), exclusive.
It locks the mamba_lock_ref for current node if its mamba_value exists.
"""
if self.disable:
return None
# protect mamba value in current node if it exists
if node.mamba_value is not None:
if node.mamba_lock_ref == 0:
self.mamba_evictable_size_ -= len(node.mamba_value)
self.mamba_protected_size_ += len(node.mamba_value)
node.mamba_lock_ref += 1
while node != self.root_node:
# lock full from node to root
assert (
node.full_lock_ref >= 0
), f"inc_lock_ref on node with {node.full_lock_ref=}, {node.id=}"
if node.full_lock_ref == 0:
self.full_evictable_size_ -= len(node.value)
self.full_protected_size_ += len(node.value)
node.full_lock_ref += 1
node = node.parent
return None
def dec_lock_ref(self, node: TreeNode):
"""
Decrement the lock reference count for the node.
It unlocks the full_lock_ref for nodes between the [last node, root), exclusive.
It unlocks the mamba_lock_ref for current node if its mamba_value exists.
"""
if self.disable:
return
if node.mamba_value is not None:
assert (
node.mamba_lock_ref > 0
), f"dec_lock_ref on node with {node.mamba_lock_ref=}, {node.id=}"
if node.mamba_lock_ref == 1:
self.mamba_evictable_size_ += len(node.mamba_value)
self.mamba_protected_size_ -= len(node.mamba_value)
node.mamba_lock_ref -= 1
while node != self.root_node:
assert (
node.full_lock_ref > 0
), f"dec_lock_ref on node with {node.full_lock_ref=}, {node.id=}"
if node.full_lock_ref == 1:
self.full_evictable_size_ += len(node.value)
self.full_protected_size_ -= len(node.value)
node.full_lock_ref -= 1
node = node.parent
def sanity_check(self):
self.full_lru_list.sanity_check(self)
self.mamba_lru_list.sanity_check(self)
def evictable_size(self) -> Tuple[int, int]:
# Note: use full_evictable_size() and mamba_evictable_size() instead.
raise NotImplementedError
def full_evictable_size(self) -> int:
return self.full_evictable_size_
def mamba_evictable_size(self) -> int:
return self.mamba_evictable_size_
# Note: this is expensive, only use for debug
def full_lru_list_evictable_size(self) -> int:
return self.full_lru_list.sanity_check_evictable_size()
# Note: this is expensive, only use for debug
def mamba_lru_list_evictable_size(self) -> int:
return self.mamba_lru_list.sanity_check_evictable_size()
def protected_size(self) -> Tuple[int, int]:
# Note: use full_protected_size() and mamba_protected_size() instead.
raise NotImplementedError
def full_protected_size(self) -> int:
# protected size refers to the size of the full cache that is locked
return self.full_protected_size_
def mamba_protected_size(self) -> int:
# protected size refers to the size of the mamba cache that is locked
return self.mamba_protected_size_
def all_values_flatten(self) -> torch.Tensor:
values = []
def _dfs_helper(node: TreeNode):
for _, child in node.children.items():
values.append(child.value)
_dfs_helper(child)
_dfs_helper(self.root_node)
return torch.cat(values)
##### Internal Helper Functions #####
def _match_prefix_helper(
self, key: RadixKey
) -> Tuple[List[torch.Tensor], TreeNode]:
"""
Mamba prefix matching helper. It factors in the sliding window size such that
the matched node is guaranteed to either 1. connected to root without mamba tombstone,
or 2. the number of matching tokens from the matched node to the last mamba tombstone
node is greater than or equal to the sliding window size.
"""
node = self.root_node
child_key = self.get_child_key_fn(key)
value = []
best_value_len = 0
best_last_node = node
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
# update best_value_len and best_last_node if needed
if node.mamba_value is not None:
best_value_len = len(value)
best_last_node = node
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
node = new_node
break
else:
value.append(child.value)
node = child
key = key[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)
# handle best_value_len and best_last_node, for the case that last node is fully matched
if node.mamba_value is not None:
best_value_len = len(value)
best_last_node = node
# update time for matched nodes, and make nodes closer to root to be least recently used
# this allows mamba to evict nodes closer to root first
self.full_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
self.mamba_lru_list.reset_node_and_parents_mru(best_last_node, self.root_node)
# This last_access_time is for sanity check, can be deleted after validation in production
cur_time = time.monotonic()
while node:
node.last_access_time = cur_time
cur_time -= 0.0001
node = node.parent
return value[:best_value_len], best_last_node
def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode:
# new_node -> child
new_node = TreeNode()
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent
new_node.mamba_value = None # mamba cache can not be split
new_node.full_lock_ref = child.full_lock_ref
new_node.mamba_lock_ref = 0
new_node.key = child.key[:split_len]
new_node.value = child.value[:split_len]
# child time should be later than parent's time for mamba tombstone
child.last_access_time = time.monotonic()
self.full_lru_list.remove_node(child)
if child.mamba_value is not None:
self.mamba_lru_list.remove_node(child)
child.parent = new_node
child.key = child.key[split_len:]
child.value = child.value[split_len:]
new_node.parent.children[self.get_child_key_fn(key)] = new_node
# insert the new node and child into the lru lists, insert
# parent first so that parent is after child in the lru list
self.full_lru_list.insert_mru(new_node)
self.full_lru_list.insert_mru(child)
if child.mamba_value is not None:
self.mamba_lru_list.insert_mru(child)
return new_node
def _insert_helper(
self,
node: TreeNode,
key: RadixKey,
value,
mamba_value,
) -> Tuple[int, bool]:
# Update the last access time from root to leaf, so that
# mamba will tombstone the node closer to root first
assert mamba_value is not None, "Mamba value should not be None here."
node.last_access_time = time.monotonic()
if node != self.root_node:
self.full_lru_list.reset_node_mru(node)
if node.mamba_value is not None:
self.mamba_lru_list.reset_node_mru(node)
if len(key) == 0:
return 0, True
child_key = self.get_child_key_fn(key)
total_prefix_length = 0
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.monotonic()
self.full_lru_list.reset_node_mru(node)
if node.mamba_value is not None:
self.mamba_lru_list.reset_node_mru(node)
prefix_len = self.key_match_fn(node.key, key)
total_prefix_length += prefix_len
key = key[prefix_len:]
value = value[prefix_len:]
if prefix_len < len(node.key):
new_node = self._split_node(node.key, node, prefix_len)
node = new_node
if len(key):
child_key = self.get_child_key_fn(key)
mamba_value_exist = False
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = value
new_node.mamba_value = mamba_value
self.full_lru_list.insert_mru(new_node)
self.full_evictable_size_ += len(value)
self.mamba_evictable_size_ += len(mamba_value)
self.mamba_lru_list.insert_mru(new_node)
node.children[child_key] = new_node
elif node.mamba_value is None: # add for mamba tombstone
node.mamba_value = mamba_value
self.mamba_evictable_size_ += len(mamba_value)
self.mamba_lru_list.insert_mru(node)
else:
mamba_value_exist = True
self.mamba_lru_list.reset_node_mru(node)
return total_prefix_length, mamba_value_exist
def _iteratively_delete_tombstone_leaf(
self, node: TreeNode
) -> Tuple[TreeNode, int]:
full_num_evicted = 0
while node.parent.mamba_value is None and len(node.parent.children) == 0:
# root node is not evictable
if node.parent == self.root_node:
break
# if locked, means node is in use, skip
if node.parent.full_lock_ref > 0:
break
assert (
node.parent.mamba_lock_ref == 0
), f"tombstone mamba_lock_ref should always be 0, {node.parent.full_lock_ref=}, {node.parent.mamba_lock_ref=}, {node.parent.id=}"
# delete tombstone node evicts full tokens
self.token_to_kv_pool_allocator.free(node.parent.value)
full_num_evicted += len(node.parent.value)
self.full_lru_list.remove_node(node.parent)
self._delete_tombstone_leaf(node.parent)
node = node.parent
return node, full_num_evicted
def _delete_leaf(self, node: TreeNode) -> None:
assert (
node.mamba_value is not None
), f"Invariant violated: leaf node is a tombstone, {node.id=}"
assert len(node.children) == 0, f"leaf node has children, {node.id=}"
for k, v in node.parent.children.items():
if v == node:
break
del node.parent.children[k]
self.full_evictable_size_ -= len(node.key)
self.mamba_evictable_size_ -= len(node.mamba_value)
def _tombstone_internal_node(self, node: TreeNode) -> None:
assert len(node.children) != 0, f"Cannot tombstone a leaf node, {node.id=}"
self.mamba_evictable_size_ -= len(node.mamba_value)
node.mamba_value = None
def _delete_tombstone_leaf(self, node: TreeNode) -> None:
assert (
node.mamba_value is None
), f"Deleting a unexpected non-tombstone leaf node, {node.id=}"
assert len(node.children) == 0, f"leaf node has children, {node.id=}"
for k, v in node.parent.children.items():
if v == node:
break
del node.parent.children[k]
self.full_evictable_size_ -= len(node.key)
def _collect_leaves(self) -> List[TreeNode]:
ret_list = []
stack = [self.root_node]
while stack:
cur_node = stack.pop()
if len(cur_node.children) == 0:
ret_list.append(cur_node)
else:
stack.extend(cur_node.children.values())
return ret_list
def _collect_nontombstone_nodes(self) -> List[TreeNode]:
ret_list = []
stack = [self.root_node]
while stack:
cur_node = stack.pop()
if cur_node.mamba_value is not None:
ret_list.append(cur_node)
stack.extend(cur_node.children.values())
return ret_list
def _collect_all_nodes(self) -> List[TreeNode]:
ret_list = []
stack = [self.root_node]
while stack:
cur_node = stack.pop()
ret_list.append(cur_node)
stack.extend(cur_node.children.values())
return ret_list
def _print_helper(self, node: TreeNode, indent: int) -> None:
"""Prints the radix tree in a human-readable format."""
stack = [(node, indent)]
while stack:
current_node, current_indent = stack.pop()
print(
" " * current_indent,
f"[{current_node.id}]",
len(current_node.key),
f"fr={current_node.full_lock_ref}",
f"mr={current_node.mamba_lock_ref}",
f"fll={self.full_lru_list.in_list(current_node)}",
f"mll={self.mamba_lru_list.in_list(current_node)}",
f"mv={current_node.mamba_value}",
)
for key, child in current_node.children.items():
stack.append((child, current_indent + 2))
assert key == self.get_child_key_fn(
child.key
), f"{key=}, {self.get_child_key_fn(child.key)=}"
def _total_size_helper(self) -> Tuple[int, int]:
total_size = 0
total_mamba_size = 0
stack = [self.root_node]
while stack:
current_node = stack.pop()
total_size += len(current_node.value)
if current_node.mamba_value is not None:
total_mamba_size += len(current_node.mamba_value)
for child in current_node.children.values():
if child.evicted:
continue
stack.append(child)
return total_size, total_mamba_size
...@@ -190,6 +190,7 @@ class MambaPool: ...@@ -190,6 +190,7 @@ class MambaPool:
) )
logger.info( logger.info(
f"Mamba Cache is allocated. " f"Mamba Cache is allocated. "
f"max_mamba_cache_size: {size}, "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB " f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
...@@ -199,11 +200,13 @@ class MambaPool: ...@@ -199,11 +200,13 @@ class MambaPool:
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state) self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
logger.info( logger.info(
f"Mamba Cache is allocated. " f"Mamba Cache is allocated. "
f"max_mamba_cache_size: {size}, "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
) )
self.size = size self.size = size
self.free_slots = list(range(size)) self.device = device
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState: def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
...@@ -216,7 +219,7 @@ class MambaPool: ...@@ -216,7 +219,7 @@ class MambaPool:
def available_size(self): def available_size(self):
return len(self.free_slots) return len(self.free_slots)
def alloc(self, need_size: int) -> Optional[List[int]]: def alloc(self, need_size: int) -> Optional[torch.Tensor]:
if need_size > len(self.free_slots): if need_size > len(self.free_slots):
return None return None
...@@ -225,17 +228,30 @@ class MambaPool: ...@@ -225,17 +228,30 @@ class MambaPool:
return select_index return select_index
def free(self, free_index: Union[int, List[int]]): def free(self, free_index: torch.Tensor):
if isinstance(free_index, (int,)): if free_index.numel() == 0:
self.free_slots.append(free_index) return
else: self.free_slots = torch.cat((self.free_slots, free_index))
self.free_slots.extend(free_index)
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[ self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
:, free_index :, free_index
] = 0 ] = 0
def clear(self): def clear(self):
self.free_slots = list(range(self.size)) self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
:, src_index
]
return
def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]:
dst_index = self.alloc(1)
if dst_index == None:
return None
self.copy_from(src_index, dst_index)
return dst_index
class HybridReqToTokenPool(ReqToTokenPool): class HybridReqToTokenPool(ReqToTokenPool):
...@@ -245,6 +261,7 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -245,6 +261,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
self, self,
*, *,
size: int, size: int,
mamba_size: int,
max_context_len: int, max_context_len: int,
device: str, device: str,
enable_memory_saver: bool, enable_memory_saver: bool,
...@@ -259,7 +276,7 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -259,7 +276,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
) )
self.mamba_pool = MambaPool( self.mamba_pool = MambaPool(
size=size, size=mamba_size,
cache_params=cache_params, cache_params=cache_params,
device=device, device=device,
speculative_num_draft_tokens=speculative_num_draft_tokens, speculative_num_draft_tokens=speculative_num_draft_tokens,
...@@ -271,9 +288,6 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -271,9 +288,6 @@ class HybridReqToTokenPool(ReqToTokenPool):
size, dtype=torch.int32, device=self.device size, dtype=torch.int32, device=self.device
) )
self.rid_to_mamba_index_mapping: Dict[str, int] = {}
self.mamba_index_to_rid_mapping: Dict[int, str] = {}
# For chunk prefill req, we do not need to allocate mamba cache, # For chunk prefill req, we do not need to allocate mamba cache,
# We could use allocated mamba cache instead. # We could use allocated mamba cache instead.
def alloc( def alloc(
...@@ -285,14 +299,14 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -285,14 +299,14 @@ class HybridReqToTokenPool(ReqToTokenPool):
mamba_index = [] mamba_index = []
for req in reqs: for req in reqs:
rid = req.rid mid = None
if rid in self.rid_to_mamba_index_mapping: if req.mamba_pool_idx is not None: # for radix cache
mid = self.rid_to_mamba_index_mapping[rid] mid = req.mamba_pool_idx
elif (mid := self.mamba_pool.alloc(1)) is not None: else:
mid = mid[0] mid = self.mamba_pool.alloc(1)[0]
self.rid_to_mamba_index_mapping[rid] = mid req.mamba_pool_idx = mid
self.mamba_index_to_rid_mapping[mid] = rid if mid is not None:
mamba_index.append(mid) mamba_index.append(mid)
assert len(select_index) == len( assert len(select_index) == len(
mamba_index mamba_index
), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size." ), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
...@@ -313,17 +327,12 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -313,17 +327,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
# For chunk prefill, we can not free mamba cache, we need use it in the future # For chunk prefill, we can not free mamba cache, we need use it in the future
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True): def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
if isinstance(free_index, (int,)):
free_index = [free_index]
super().free(free_index) super().free(free_index)
if free_mamba_cache: if free_mamba_cache:
mamba_index = self.req_index_to_mamba_index_mapping[free_index] mamba_index = self.req_index_to_mamba_index_mapping[free_index]
mamba_index_list = mamba_index.tolist() self.mamba_pool.free(mamba_index)
if isinstance(mamba_index_list, int):
mamba_index_list = [mamba_index_list]
self.mamba_pool.free(mamba_index_list)
for mid in mamba_index_list:
rid = self.mamba_index_to_rid_mapping[mid]
self.mamba_index_to_rid_mapping.pop(mid)
self.rid_to_mamba_index_mapping.pop(rid)
def clear(self): def clear(self):
super().clear() super().clear()
......
...@@ -191,6 +191,9 @@ SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) ...@@ -191,6 +191,9 @@ SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
# Detect stragger ranks in model loading # Detect stragger ranks in model loading
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -382,26 +385,10 @@ class ModelRunner: ...@@ -382,26 +385,10 @@ class ModelRunner:
if architectures and not any("Llama4" in arch for arch in architectures): if architectures and not any("Llama4" in arch for arch in architectures):
self.is_hybrid = self.model_config.is_hybrid = True self.is_hybrid = self.model_config.is_hybrid = True
if config := self.mambaish_config: if config := self.mamba2_config:
class_name = config.__class__.__name__ class_name = config.__class__.__name__
logger.warning(f"{class_name} model detected, disable radix cache") logger.warning(f"{class_name} model detected, disable radix cache")
self.server_args.disable_radix_cache = True self.server_args.disable_radix_cache = True
if self.server_args.max_mamba_cache_size is None:
if self.server_args.max_running_requests is not None:
self.server_args.max_mamba_cache_size = (
self.server_args.max_running_requests
)
else:
self.server_args.max_mamba_cache_size = 512
if self.hybrid_gdn_config is not None:
self.server_args.max_mamba_cache_size = (
self.server_args.max_mamba_cache_size
// (
self.server_args.dp_size
if self.server_args.enable_dp_attention
else 1
)
)
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
...@@ -1330,15 +1317,60 @@ class ModelRunner: ...@@ -1330,15 +1317,60 @@ class ModelRunner:
rest_memory = available_gpu_memory - total_gpu_memory * ( rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static 1 - self.mem_fraction_static
) )
if config := self.mambaish_config: if self.mambaish_config is not None:
rest_memory -= ( rest_memory = self.handle_max_mamba_cache(rest_memory)
self.server_args.max_mamba_cache_size
* config.mamba2_cache_params.mamba_cache_per_req
/ (1 << 30)
)
max_num_token = int(rest_memory * (1 << 30) // cell_size) max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token return max_num_token
def handle_max_mamba_cache(self, total_rest_memory):
config = self.mambaish_config
server_args = self.server_args
assert config is not None
speculativa_ratio = (
0
if server_args.speculative_num_draft_tokens is None
else server_args.speculative_num_draft_tokens
)
if (
server_args.disable_radix_cache
or config.mamba2_cache_params.mamba_cache_per_req == 0
):
# with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests
if server_args.max_mamba_cache_size is None:
if server_args.max_running_requests is not None:
server_args.max_mamba_cache_size = server_args.max_running_requests
else:
server_args.max_mamba_cache_size = 512
else:
# allocate the memory based on the ratio between mamba state memory vs. full kv cache memory
# solve the equations:
# 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory
# 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio
mamba_state_memory_raw = (
total_rest_memory
* server_args.mamba_full_memory_ratio
/ (1 + server_args.mamba_full_memory_ratio)
)
# calculate the max_mamba_cache_size based on the given total mamba memory
server_args.max_mamba_cache_size = int(
(mamba_state_memory_raw * (1 << 30))
// config.mamba2_cache_params.mamba_cache_per_req
// (1 + speculativa_ratio)
)
if self.hybrid_gdn_config is not None:
server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // (
server_args.dp_size if server_args.enable_dp_attention else 1
)
mamba_state_memory = (
server_args.max_mamba_cache_size
* config.mamba2_cache_params.mamba_cache_per_req
* (1 + speculativa_ratio)
/ (1 << 30)
)
return total_rest_memory - mamba_state_memory
@property @property
def hybrid_gdn_config(self): def hybrid_gdn_config(self):
config = self.model_config.hf_config config = self.model_config.hf_config
...@@ -1511,8 +1543,16 @@ class ModelRunner: ...@@ -1511,8 +1543,16 @@ class ModelRunner:
), ),
4096, 4096,
) )
if self.mambaish_config is not None: if self.mambaish_config is not None:
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size) ratio = (
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO
if not self.server_args.disable_radix_cache
else 1
)
max_num_reqs = min(
max_num_reqs, self.server_args.max_mamba_cache_size // ratio
)
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone(): if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
if self.is_draft_worker: if self.is_draft_worker:
...@@ -1595,6 +1635,7 @@ class ModelRunner: ...@@ -1595,6 +1635,7 @@ class ModelRunner:
elif config := self.mambaish_config: elif config := self.mambaish_config:
self.req_to_token_pool = HybridReqToTokenPool( self.req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs, size=max_num_reqs,
mamba_size=self.server_args.max_mamba_cache_size,
max_context_len=self.model_config.context_len max_context_len=self.model_config.context_len
+ extra_max_context_len, + extra_max_context_len,
device=self.device, device=self.device,
......
...@@ -362,6 +362,7 @@ class ServerArgs: ...@@ -362,6 +362,7 @@ class ServerArgs:
# Mamba cache # Mamba cache
max_mamba_cache_size: Optional[int] = None max_mamba_cache_size: Optional[int] = None
mamba_ssm_dtype: str = "float32" mamba_ssm_dtype: str = "float32"
mamba_full_memory_ratio: float = 0.2
# Hierarchical cache # Hierarchical cache
enable_hierarchical_cache: bool = False enable_hierarchical_cache: bool = False
...@@ -2433,6 +2434,12 @@ class ServerArgs: ...@@ -2433,6 +2434,12 @@ class ServerArgs:
choices=["float32", "bfloat16"], choices=["float32", "bfloat16"],
help="The data type of the SSM states in mamba cache.", help="The data type of the SSM states in mamba cache.",
) )
parser.add_argument(
"--mamba-full-memory-ratio",
type=float,
default=ServerArgs.mamba_full_memory_ratio,
help="The ratio of mamba state memory to full kv cache memory.",
)
# Args for multi-item-scoring # Args for multi-item-scoring
parser.add_argument( parser.add_argument(
"--multi-item-scoring-delimiter", "--multi-item-scoring-delimiter",
......
...@@ -84,6 +84,7 @@ suites = { ...@@ -84,6 +84,7 @@ suites = {
TestFile("test_io_struct.py", 8), TestFile("test_io_struct.py", 8),
TestFile("test_jinja_template_utils.py", 1), TestFile("test_jinja_template_utils.py", 1),
TestFile("test_logprobs.py", 55), TestFile("test_logprobs.py", 55),
TestFile("test_mamba_unittest.py", 4),
TestFile("test_metrics.py", 32), TestFile("test_metrics.py", 32),
TestFile("test_metrics_utils.py", 1), TestFile("test_metrics_utils.py", 1),
TestFile("test_mla.py", 167), TestFile("test_mla.py", 167),
......
import inspect
import os
import unittest
import torch
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, HybridReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.sampling.sampling_params import SamplingParams
class TestMamba(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
@classmethod
def tearDownClass(cls):
pass
def test_hybrid_linear_kv_pool(self):
size = 16
head_num = 2
head_dim = 256
num_layers = 48
global_interval = 4
dtype = torch.bfloat16
device = "cuda"
full_attention_layer_ids = [
i for i in range(global_interval - 1, num_layers, global_interval)
]
pool = HybridLinearKVPool(
size=size,
dtype=dtype,
page_size=1,
head_num=head_num,
head_dim=head_dim,
full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False,
device=device,
)
assert pool._transfer_full_attention_id(global_interval - 1) == 0
assert pool._transfer_full_attention_id(2 * global_interval - 1) == 1
with self.assertRaises(ValueError) as context:
pool._transfer_full_attention_id(1)
self.assertIn(
"layer_id=1 not in full attention layers:", str(context.exception)
)
def test_mamba_pool(self):
max_num_reqs = 10
mamba_cache_size = 20
max_context_len = 128
device = "cuda"
global_interval = 4
num_layers = 48
full_attention_layer_ids = [
i for i in range(global_interval - 1, num_layers, global_interval)
]
mamba_layers = [
i for i in range(num_layers) if i not in full_attention_layer_ids
]
shape = Mamba2StateShape.create(
tp_world_size=1,
intermediate_size=4096,
n_groups=16,
num_heads=32,
head_dim=128,
state_size=128,
conv_kernel=4,
)
os.environ["SGLANG_MAMBA_SSM_DTYPE"] = "bfloat16"
mamba2_cache_params = Mamba2CacheParams(shape=shape, layers=mamba_layers)
req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs,
mamba_size=mamba_cache_size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=False,
cache_params=mamba2_cache_params,
speculative_num_draft_tokens=3,
)
assert req_to_token_pool.available_size() == max_num_reqs
assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=1,
)
req = Req(
rid=0,
origin_input_text="",
origin_input_ids=[],
sampling_params=sampling_params,
)
# alloc req
req_index = req_to_token_pool.alloc(1, [req])
assert req_to_token_pool.available_size() == max_num_reqs - 1
assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size - 1
# free req
req_to_token_pool.free(req_index)
assert req_to_token_pool.available_size() == max_num_reqs
assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size
# alloc req without free mamba cache
req.mamba_pool_idx = None
req_index = req_to_token_pool.alloc(1, [req])
req_to_token_pool.free(req_index, free_mamba_cache=False)
assert req_to_token_pool.available_size() == max_num_reqs
assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size - 1
# alloc again
req_index = req_to_token_pool.alloc(1, [req])
assert req_to_token_pool.available_size() == max_num_reqs - 1
assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size - 1
def test_mamba_radix_cache_1(self):
# kv cache
size = 128
dtype = torch.bfloat16
head_num = 2
head_dim = 256
num_layers = 48
global_interval = 4
max_num_reqs = 10
mamba_cache_size = 20
max_context_len = 128
device = "cuda"
full_attention_layer_ids = [
i for i in range(global_interval - 1, num_layers, global_interval)
]
# mamba
mamba_layers = [
i for i in range(num_layers) if i not in full_attention_layer_ids
]
os.environ["SGLANG_MAMBA_SSM_DTYPE"] = "bfloat16"
shape = Mamba2StateShape.create(
tp_world_size=1,
intermediate_size=4096,
n_groups=16,
num_heads=32,
head_dim=128,
state_size=128,
conv_kernel=4,
)
mamba2_cache_params = Mamba2CacheParams(shape=shape, layers=mamba_layers)
req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs,
mamba_size=mamba_cache_size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=False,
cache_params=mamba2_cache_params,
speculative_num_draft_tokens=3,
)
# setup kv pool
pool = HybridLinearKVPool(
size=size,
dtype=dtype,
page_size=1,
head_num=head_num,
head_dim=head_dim,
full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False,
device=device,
)
# setup token to kv pool allocator
allocator = TokenToKVPoolAllocator(
size=size,
dtype=dtype,
device=device,
kvcache=pool,
need_sort=False,
)
# setup radix cache
tree = MambaRadixCache(
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=allocator,
page_size=1,
disable=False,
)
def make_dummy_req():
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=1,
)
req = Req(
rid=0,
origin_input_text="",
origin_input_ids=[],
sampling_params=sampling_params,
)
req_to_token_pool.alloc(1, reqs=[req])
return req
mamba_pool = req_to_token_pool.mamba_pool
# test
print(
f"[Start] allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}"
)
req1 = make_dummy_req()
req1_token_ids, req1_kv_indices = [1, 2, 3], allocator.alloc(3)
assert len(req1_token_ids) == len(req1_kv_indices)
print(
f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}"
)
prefix_len = tree.insert(
RadixKey(req1_token_ids), req1_kv_indices, req1.mamba_pool_idx.unsqueeze(0)
)
print(
f"req1: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}"
)
req2 = make_dummy_req()
req2_token_ids, req2_kv_indices = [1, 2, 3, 4, 5, 6, 7], allocator.alloc(7)
assert len(req2_token_ids) == len(req2_kv_indices)
print(
f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}"
)
prefix_len = tree.insert(
RadixKey(req2_token_ids), req2_kv_indices, req2.mamba_pool_idx.unsqueeze(0)
)
print(
f"req2: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}"
)
req3 = make_dummy_req()
req3_token_ids, req3_kv_indices = [10, 11, 12], allocator.alloc(3)
assert len(req3_token_ids) == len(req3_kv_indices)
print(
f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}"
)
prefix_len = tree.insert(
RadixKey(req3_token_ids), req3_kv_indices, req3.mamba_pool_idx.unsqueeze(0)
)
print(
f"req3: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}"
)
req4 = make_dummy_req()
req4_token_ids, req4_kv_indices = [1, 2, 3, 4, 5, 60, 70], allocator.alloc(7)
assert len(req4_token_ids) == len(req4_kv_indices)
print(
f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}"
)
prefix_len = tree.insert(
RadixKey(req4_token_ids), req4_kv_indices, req4.mamba_pool_idx.unsqueeze(0)
)
print(
f"req4: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}"
)
tree.pretty_print()
full_num_tokens = 1
print(f"evicting {full_num_tokens} full token")
tree.evict(full_num_tokens=full_num_tokens)
tree.pretty_print()
mamba_num = 1
print(f"evicting {mamba_num} mamba")
tree.evict_mamba(mamba_num=mamba_num)
tree.pretty_print()
req5_token_ids = [1, 2, 3, 4, 5]
result = tree.match_prefix(RadixKey(req5_token_ids))
kv_indices, last_node = result.device_indices, result.last_device_node
print(
f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
assert len(kv_indices) == 0
req6_token_ids = [1, 2, 3, 4, 5, 60, 70]
result = tree.match_prefix(RadixKey(req6_token_ids))
kv_indices, last_node = result.device_indices, result.last_device_node
print(
f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
assert len(kv_indices) == 7
assert len(last_node.key) == 2
req7_token_ids = [1, 2, 3, 4, 5, 6, 7]
result = tree.match_prefix(RadixKey(req7_token_ids))
kv_indices, last_node = result.device_indices, result.last_device_node
print(
f"req7: token_ids: {req7_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
assert len(kv_indices) == 7
assert len(last_node.key) == 2
mamba_num = 1
print(f"evicting {mamba_num} mamba")
tree.evict_mamba(mamba_num=mamba_num)
tree.pretty_print()
req8_token_ids = [1, 2, 3, 4, 5, 60, 70]
result = tree.match_prefix(RadixKey(req8_token_ids))
kv_indices, last_node = result.device_indices, result.last_device_node
print(
f"req8: token_ids: {req8_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
assert len(kv_indices) == 0
assert len(last_node.key) == 0
req9_token_ids = [1, 2, 3, 4, 5, 6, 7]
req9 = make_dummy_req()
result = tree.match_prefix(
RadixKey(req9_token_ids), **({"req": req9, "cow_mamba": True})
)
kv_indices, last_node = result.device_indices, result.last_device_node
assert req9.mamba_pool_idx is not None
assert torch.all(
mamba_pool.mamba_cache.conv[:, req9.mamba_pool_idx]
== mamba_pool.mamba_cache.conv[:, last_node.mamba_value]
)
assert torch.all(
mamba_pool.mamba_cache.temporal[:, req9.mamba_pool_idx]
== mamba_pool.mamba_cache.temporal[:, last_node.mamba_value]
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment