Unverified Commit 12d6cf18 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

Refactors radix cache for extra key support (#10317)


Signed-off-by: default avatarXinyuan Tong <xinyuantong.cs@gmail.com>
parent fc3e5420
...@@ -61,8 +61,8 @@ from sglang.srt.mem_cache.allocator import ( ...@@ -61,8 +61,8 @@ from sglang.srt.mem_cache.allocator import (
) )
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
...@@ -457,6 +457,7 @@ class Req: ...@@ -457,6 +457,7 @@ class Req:
vocab_size: Optional[int] = None, vocab_size: Optional[int] = None,
priority: Optional[int] = None, priority: Optional[int] = None,
metrics_collector: Optional[SchedulerMetricsCollector] = None, metrics_collector: Optional[SchedulerMetricsCollector] = None,
extra_key: Optional[str] = None,
): ):
# Input and output info # Input and output info
self.rid = rid self.rid = rid
...@@ -489,6 +490,14 @@ class Req: ...@@ -489,6 +490,14 @@ class Req:
self.sampling_params = sampling_params self.sampling_params = sampling_params
self.custom_logit_processor = custom_logit_processor self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states self.return_hidden_states = return_hidden_states
# extra key for classifying the request (e.g. lora_id, cache_salt)
if lora_id is not None:
extra_key = (
extra_key or ""
) + lora_id # lora_id is concatenated to the extra key
self.extra_key = extra_key
self.lora_id = lora_id self.lora_id = lora_id
# Memory pool info # Memory pool info
...@@ -679,26 +688,16 @@ class Req: ...@@ -679,26 +688,16 @@ class Req:
): ):
self.fill_ids = self.origin_input_ids + self.output_ids self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None: if tree_cache is not None:
if isinstance(tree_cache, LoRARadixCache): (
( self.prefix_indices,
self.prefix_indices, self.last_node,
self.last_node, self.last_host_node,
self.last_host_node, self.host_hit_length,
self.host_hit_length, ) = tree_cache.match_prefix(
) = tree_cache.match_prefix_with_lora_id( key=RadixKey(
key=LoRAKey( token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids() ),
), )
)
else:
(
self.prefix_indices,
self.last_node,
self.last_host_node,
self.host_hit_length,
) = tree_cache.match_prefix(
key=self.adjust_max_prefix_ids(),
)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self): def adjust_max_prefix_ids(self):
......
...@@ -27,7 +27,7 @@ import torch ...@@ -27,7 +27,7 @@ import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -175,10 +175,13 @@ class SchedulePolicy: ...@@ -175,10 +175,13 @@ class SchedulePolicy:
for r in waiting_queue: for r in waiting_queue:
prefix_ids = r.adjust_max_prefix_ids() prefix_ids = r.adjust_max_prefix_ids()
extra_key = r.extra_key
# NOTE: the prefix_indices must always be aligned with last_node # NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = ( r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = (
self.tree_cache.match_prefix(rid=r.rid, key=prefix_ids) self.tree_cache.match_prefix(
rid=r.rid, key=RadixKey(token_ids=prefix_ids, extra_key=extra_key)
)
) )
# NOTE(sang): This logic is for in-batch prefix caching; # NOTE(sang): This logic is for in-batch prefix caching;
...@@ -191,7 +194,8 @@ class SchedulePolicy: ...@@ -191,7 +194,8 @@ class SchedulePolicy:
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
in_batch_matching_prefixes, _, _, _ = ( in_batch_matching_prefixes, _, _, _ = (
self.waiting_queue_radix_tree.match_prefix( self.waiting_queue_radix_tree.match_prefix(
rid=r.rid, key=prefix_ids rid=r.rid,
key=RadixKey(token_ids=prefix_ids, extra_key=extra_key),
) )
) )
if ( if (
...@@ -202,7 +206,8 @@ class SchedulePolicy: ...@@ -202,7 +206,8 @@ class SchedulePolicy:
else: else:
# Insert with a dummy key # Insert with a dummy key
self.waiting_queue_radix_tree.insert( self.waiting_queue_radix_tree.insert(
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool) RadixKey(token_ids=prefix_ids, extra_key=extra_key),
torch.empty(len(prefix_ids), dtype=torch.bool),
) )
return temporary_deprioritized return temporary_deprioritized
......
...@@ -145,7 +145,6 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient ...@@ -145,7 +145,6 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
...@@ -719,19 +718,6 @@ class Scheduler( ...@@ -719,19 +718,6 @@ class Scheduler(
page_size=self.page_size, page_size=self.page_size,
disable=server_args.disable_radix_cache, disable=server_args.disable_radix_cache,
) )
elif self.enable_lora:
assert (
not self.enable_hierarchical_cache
), "LoRA radix cache doesn't support hierarchical cache"
assert (
self.schedule_policy == "fcfs"
), "LoRA radix cache only supports FCFS policy"
self.tree_cache = LoRARadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
)
elif server_args.enable_lmcache: elif server_args.enable_lmcache:
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import ( from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
LMCRadixCache, LMCRadixCache,
......
...@@ -36,7 +36,7 @@ class BasePrefixCache(ABC): ...@@ -36,7 +36,7 @@ class BasePrefixCache(ABC):
pass pass
@abstractmethod @abstractmethod
def match_prefix(self, key: List[int], **kwargs) -> MatchResult: def match_prefix(self, key: Any, **kwargs) -> MatchResult:
pass pass
@abstractmethod @abstractmethod
......
...@@ -19,7 +19,7 @@ from sglang.srt.mem_cache.memory_pool_host import ( ...@@ -19,7 +19,7 @@ from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost, MHATokenToKVPoolHost,
MLATokenToKVPoolHost, MLATokenToKVPoolHost,
) )
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
from sglang.srt.metrics.collector import StorageMetricsCollector from sglang.srt.metrics.collector import StorageMetricsCollector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -570,7 +570,9 @@ class HiRadixCache(RadixCache): ...@@ -570,7 +570,9 @@ class HiRadixCache(RadixCache):
written_indices = host_indices[:min_completed_tokens] written_indices = host_indices[:min_completed_tokens]
matched_length = self._insert_helper_host( matched_length = self._insert_helper_host(
last_host_node, last_host_node,
fetched_token_ids, RadixKey(
token_ids=fetched_token_ids, extra_key=last_host_node.key.extra_key
),
written_indices, written_indices,
hash_value[: min_completed_tokens // self.page_size], hash_value[: min_completed_tokens // self.page_size],
) )
...@@ -592,7 +594,7 @@ class HiRadixCache(RadixCache): ...@@ -592,7 +594,7 @@ class HiRadixCache(RadixCache):
return True return True
def match_prefix(self, key: List[int], **kwargs): def match_prefix(self, key: RadixKey, **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
if self.disable or len(key) == 0: if self.disable or len(key) == 0:
return MatchResult( return MatchResult(
...@@ -666,7 +668,9 @@ class HiRadixCache(RadixCache): ...@@ -666,7 +668,9 @@ class HiRadixCache(RadixCache):
) )
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens) self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value): def _insert_helper_host(
self, node: TreeNode, key: RadixKey, host_value, hash_value
):
node.last_access_time = time.monotonic() node.last_access_time = time.monotonic()
if len(key) == 0: if len(key) == 0:
return 0 return 0
...@@ -700,7 +704,7 @@ class HiRadixCache(RadixCache): ...@@ -700,7 +704,7 @@ class HiRadixCache(RadixCache):
node.children[child_key] = new_node node.children[child_key] = new_node
return matched_length return matched_length
def _match_prefix_helper(self, node: TreeNode, key: List): def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
node.last_access_time = time.monotonic() node.last_access_time = time.monotonic()
child_key = self.get_child_key_fn(key) child_key = self.get_child_key_fn(key)
value = [] value = []
...@@ -726,7 +730,7 @@ class HiRadixCache(RadixCache): ...@@ -726,7 +730,7 @@ class HiRadixCache(RadixCache):
return value, node return value, node
def _split_node(self, key, child: TreeNode, split_len: int): def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
# child node split into new_node -> child # child node split into new_node -> child
new_node = TreeNode() new_node = TreeNode()
new_node.children = {self.get_child_key_fn(key[split_len:]): child} new_node.children = {self.get_child_key_fn(key[split_len:]): child}
...@@ -753,7 +757,7 @@ class HiRadixCache(RadixCache): ...@@ -753,7 +757,7 @@ class HiRadixCache(RadixCache):
new_node.parent.children[self.get_child_key_fn(key)] = new_node new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node return new_node
def insert(self, key: List, value, chunked=False): def insert(self, key: RadixKey, value=None, chunked=False):
if len(key) == 0: if len(key) == 0:
return 0 return 0
...@@ -811,7 +815,7 @@ class HiRadixCache(RadixCache): ...@@ -811,7 +815,7 @@ class HiRadixCache(RadixCache):
for idx in range(0, len(key), self.page_size): for idx in range(0, len(key), self.page_size):
new_node.hash_value.append( new_node.hash_value.append(
self.cache_controller.get_hash_str( self.cache_controller.get_hash_str(
key[idx : idx + self.page_size], key.token_ids[idx : idx + self.page_size],
prior_hash=last_hash, prior_hash=last_hash,
) )
) )
......
"""Radix cache for LoRA. It's modified based on RadixCache with lora_id added to the key of nodes."""
import heapq
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Any, List, Optional
import torch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
else:
Req = Any # Placeholder for Req type when not type checking
class LoRAKey:
def __init__(self, lora_id: str, token_ids: List[int]):
self.lora_id = (
lora_id # lora_id of adaptor, should be hash value of adaptor path
)
self.token_ids = token_ids # token_ids of the key
def __len__(self):
return len(self.token_ids)
def get_child_key(key: LoRAKey):
# Here the key of children dict is the hash of lora_id + str(token_ids[0])
# So the child key can be matched only when lora_id and token_ids[0] are the same
if key.lora_id is None:
return hash(str(key.token_ids[0]))
else:
return hash(key.lora_id + str(key.token_ids[0]))
class LoRATreeNode:
counter = 0
def __init__(self, id: Optional[int] = None):
self.children = defaultdict(LoRATreeNode)
self.parent: LoRATreeNode = None
self.key: LoRAKey = None
self.value: Optional[torch.Tensor] = None
self.lock_ref = 0
self.last_access_time = time.monotonic()
self.id = LoRATreeNode.counter if id is None else id
LoRATreeNode.counter += 1
@property
def evicted(self):
return self.value is None
def __lt__(self, other: "LoRATreeNode"):
return self.last_access_time < other.last_access_time
def _key_match(key0: LoRAKey, key1: LoRAKey):
if key0.lora_id != key1.lora_id:
raise ValueError(
f"_key_match should be run on the same lora_id, but got key0.lora_id={key0.lora_id} != key1.lora_id={key1.lora_id}"
)
i = 0
for k0, k1 in zip(key0.token_ids, key1.token_ids):
if k0 != k1:
break
i += 1
return i
class LoRARadixCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int,
disable: bool = False,
):
if page_size > 1:
raise ValueError("LoRARadixCache currently only supports page_size = 1")
if token_to_kv_pool_allocator is None:
raise ValueError(
"token_to_kv_pool_allocator is required to run LoraRadixCache"
)
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
self.disable = disable
self.device = self.token_to_kv_pool_allocator.device
self.key_match_fn = _key_match
self.get_child_key_fn = get_child_key
self.reset()
def reset(self):
self.root_node = LoRATreeNode()
self.root_node.key = LoRAKey(lora_id="", token_ids=[])
self.root_node.value = None
self.evictable_size_ = 0
self.protected_size_ = 0
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
raise ValueError(
"LoRARadixCache needs both token ids and lora id as inputs for matching. Please use match_prefix_with_lora_id instead."
)
def match_prefix_with_lora_id(self, key: LoRAKey, **kwargs) -> MatchResult:
"""Find the matching prefix from the lora radix tree.
Args:
key: A LoRAKey to find a matching prefix.
Returns:
A tuple of a tensor of matching prefix token IDs and
the last node that contains the prefix values. Note that
this API can modify the internal state of the Radix tree.
The last node create a new child if the prefix is shorter
than the last node's value.
"""
if self.disable or len(key) == 0:
return MatchResult(
device_indices=torch.empty(
(0,),
dtype=torch.int64,
device=self.device,
),
last_device_node=self.root_node,
last_host_node=self.root_node,
)
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.cat(value)
else:
value = torch.empty((0,), dtype=torch.int64, device=self.device)
return MatchResult(
device_indices=value,
last_device_node=last_node,
last_host_node=last_node,
)
def insert(self, key: LoRAKey, value=None):
if self.disable:
return 0
if value is None:
value = [x for x in key.token_ids]
return self._insert_helper(self.root_node, key, value)
def cache_finished_req(self, req: Req):
"""Cache request when it finishes."""
if self.disable:
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
]
self.token_to_kv_pool_allocator.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
# Radix Cache takes one ref in memory pool
lora_key = LoRAKey(lora_id=req.lora_id, token_ids=token_ids[:page_aligned_len])
new_prefix_len = self.insert(lora_key, page_aligned_kv_indices)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node)
def cache_unfinished_req(self, req: Req, chunked=False):
"""Cache request when it is unfinished."""
if self.disable:
return
token_ids = req.fill_ids
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
page_aligned_token_ids = token_ids[:page_aligned_len]
# Radix Cache takes one ref in memory pool
inserted_key = LoRAKey(lora_id=req.lora_id, token_ids=page_aligned_token_ids)
new_prefix_len = self.insert(inserted_key, page_aligned_kv_indices)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix_with_lora_id(inserted_key)
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
)
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req.prefix_indices = new_indices
req.last_node = new_last_node
def pretty_print(self):
self._print_helper(self.root_node, 0)
print(f"#tokens: {self.total_size()}")
def total_size(self):
return self._total_size_helper()
def evict(self, num_tokens: int):
if self.disable:
return
leaves = self._collect_leaves()
heapq.heapify(leaves)
num_evicted = 0
while num_evicted < num_tokens and len(leaves):
x = heapq.heappop(leaves)
if x == self.root_node:
break
if x.lock_ref > 0:
continue
self.token_to_kv_pool_allocator.free(x.value)
num_evicted += len(x.value)
self._delete_leaf(x)
if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent)
def inc_lock_ref(self, node: LoRATreeNode):
if self.disable:
return 0
delta = 0
while node != self.root_node:
if node.lock_ref == 0:
self.evictable_size_ -= len(node.value)
self.protected_size_ += len(node.value)
delta -= len(node.value)
node.lock_ref += 1
node = node.parent
return delta
def dec_lock_ref(self, node: LoRATreeNode):
if self.disable:
return 0
delta = 0
while node != self.root_node:
if node.lock_ref == 1:
self.evictable_size_ += len(node.value)
self.protected_size_ -= len(node.value)
delta += len(node.value)
node.lock_ref -= 1
node = node.parent
return delta
def evictable_size(self):
return self.evictable_size_
def protected_size(self):
# protected size refers to the size of the cache that is locked
return self.protected_size_
def all_values_flatten(self):
values = []
def _dfs_helper(node: LoRATreeNode):
for _, child in node.children.items():
values.append(child.value)
_dfs_helper(child)
_dfs_helper(self.root_node)
return torch.cat(values)
##### Internal Helper Functions #####
def _match_prefix_helper(self, node: LoRATreeNode, key: LoRAKey):
node.last_access_time = time.monotonic()
child_key = self.get_child_key_fn(key)
value = []
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
child.last_access_time = time.monotonic()
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
node = new_node
break
else:
value.append(child.value)
node = child
key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
if len(key):
child_key = self.get_child_key_fn(key)
return value, node
def _split_node(self, key: LoRAKey, child: LoRATreeNode, split_len: int):
# new_node -> child
new_node = LoRATreeNode()
key_split_1 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[:split_len])
key_split_2 = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[split_len:])
new_node.children = {self.get_child_key_fn(key_split_2): child}
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.key = key_split_1
new_node.value = child.value[:split_len]
child.parent = new_node
child.key = key_split_2
child.value = child.value[split_len:]
new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node
def _insert_helper(self, node: LoRATreeNode, key: LoRAKey, value):
node.last_access_time = time.monotonic()
if len(key) == 0:
return 0
child_key = self.get_child_key_fn(key)
total_prefix_length = 0
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.monotonic()
prefix_len = self.key_match_fn(node.key, key)
total_prefix_length += prefix_len
key = LoRAKey(lora_id=key.lora_id, token_ids=key.token_ids[prefix_len:])
value = value[prefix_len:]
if prefix_len < len(node.key):
new_node = self._split_node(node.key, node, prefix_len)
node = new_node
if len(key):
child_key = self.get_child_key_fn(key)
if len(key):
new_node = LoRATreeNode()
new_node.parent = node
new_node.key = key
new_node.value = value
node.children[child_key] = new_node
self.evictable_size_ += len(value)
return total_prefix_length
def _print_helper(self, node: LoRATreeNode, indent: int):
"""Prints the radix tree in a human-readable format."""
stack = [(node, indent)]
while stack:
current_node, current_indent = stack.pop()
print(
" " * current_indent,
len(current_node.key),
current_node.key.token_ids[:10],
f"r={current_node.lock_ref}",
)
for key, child in current_node.children.items():
stack.append((child, current_indent + 2))
assert key == self.get_child_key_fn(
child.key
), f"{key=}, {self.get_child_key_fn(child.key)=}"
def _delete_leaf(self, node):
for k, v in node.parent.children.items():
if v == node:
break
del node.parent.children[k]
self.evictable_size_ -= len(node.key)
def _total_size_helper(self):
total_size = 0
stack = [self.root_node]
while stack:
current_node = stack.pop()
total_size += len(current_node.value)
for child in current_node.children.values():
if child.evicted:
continue
stack.append(child)
return total_size
def _collect_leaves(self):
ret_list = []
stack = [self.root_node]
while stack:
cur_node = stack.pop()
if len(cur_node.children) == 0:
ret_list.append(cur_node)
else:
stack.extend(cur_node.children.values())
return ret_list
...@@ -23,7 +23,7 @@ import heapq ...@@ -23,7 +23,7 @@ import heapq
import time import time
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
import torch import torch
...@@ -41,6 +41,30 @@ if TYPE_CHECKING: ...@@ -41,6 +41,30 @@ if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
class RadixKey:
def __init__(self, token_ids: List[int], extra_key: Optional[str] = None):
# token ids sequence
self.token_ids = token_ids
# extra key (e.g. lora_id, cache_salt)
self.extra_key = extra_key
def __len__(self) -> int:
return len(self.token_ids)
def __iter__(self) -> Iterator[int]:
return iter(self.token_ids)
def __getitem__(self, idx: Union[int, slice]) -> "RadixKey":
if isinstance(idx, slice):
return RadixKey(self.token_ids[idx], self.extra_key)
return RadixKey([self.token_ids[idx]], self.extra_key)
def __repr__(self) -> str:
preview = self.token_ids[:10]
return f"RadixKey(extra_key={self.extra_key!r}, token_ids={preview}{'...' if len(self.token_ids) > 10 else ''})"
class TreeNode: class TreeNode:
counter = 0 counter = 0
...@@ -48,7 +72,7 @@ class TreeNode: ...@@ -48,7 +72,7 @@ class TreeNode:
def __init__(self, id: Optional[int] = None): def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode) self.children = defaultdict(TreeNode)
self.parent: TreeNode = None self.parent: TreeNode = None
self.key: List[int] = None self.key: RadixKey = None
self.value: Optional[torch.Tensor] = None self.value: Optional[torch.Tensor] = None
self.lock_ref = 0 self.lock_ref = 0
self.last_access_time = time.monotonic() self.last_access_time = time.monotonic()
...@@ -94,27 +118,47 @@ class TreeNode: ...@@ -94,27 +118,47 @@ class TreeNode:
return self.last_access_time < other.last_access_time return self.last_access_time < other.last_access_time
def _key_match_page_size1(key0: List, key1: List): def _check_extra_key(key0: RadixKey, key1: RadixKey):
if key0.extra_key != key1.extra_key:
raise ValueError(
f"_key_match should be run on the same extra key, but got key0.extra_key={key0.extra_key} != key1.extra_key={key1.extra_key}"
)
def _key_match_page_size1(key0: RadixKey, key1: RadixKey):
_check_extra_key(key0, key1)
i = 0 i = 0
for k0, k1 in zip(key0, key1): for k0, k1 in zip(key0.token_ids, key1.token_ids):
if k0 != k1: if k0 != k1:
break break
i += 1 i += 1
return i return i
def _key_match_paged(key0: List, key1: List, page_size: int): def _key_match_paged(key0: RadixKey, key1: RadixKey, page_size: int):
_check_extra_key(key0, key1)
min_len = min(len(key0), len(key1)) min_len = min(len(key0), len(key1))
i = 0 i = 0
while i < min_len: while i < min_len:
if key0[i : i + page_size] != key1[i : i + page_size]: if key0.token_ids[i : i + page_size] != key1.token_ids[i : i + page_size]:
break break
i += page_size i += page_size
return i return i
def get_child_key(key: RadixKey, page_size: int = 1):
if page_size == 1:
plain_key = key.token_ids[0]
else:
plain_key = tuple(key.token_ids[:page_size])
if key.extra_key is None:
return plain_key
else:
return (key.extra_key, plain_key)
class RadixCache(BasePrefixCache): class RadixCache(BasePrefixCache):
def __init__( def __init__(
self, self,
...@@ -139,10 +183,10 @@ class RadixCache(BasePrefixCache): ...@@ -139,10 +183,10 @@ class RadixCache(BasePrefixCache):
if self.page_size == 1: if self.page_size == 1:
self.key_match_fn = _key_match_page_size1 self.key_match_fn = _key_match_page_size1
self.get_child_key_fn = lambda key: key[0] self.get_child_key_fn = get_child_key
else: else:
self.key_match_fn = partial(_key_match_paged, page_size=page_size) self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = lambda key: tuple(key[:page_size]) self.get_child_key_fn = partial(get_child_key, page_size=page_size)
if eviction_policy.lower() == "lru": if eviction_policy.lower() == "lru":
self.eviction_strategy: EvictionStrategy = LRUStrategy() self.eviction_strategy: EvictionStrategy = LRUStrategy()
...@@ -158,7 +202,7 @@ class RadixCache(BasePrefixCache): ...@@ -158,7 +202,7 @@ class RadixCache(BasePrefixCache):
def reset(self): def reset(self):
self.root_node = TreeNode() self.root_node = TreeNode()
self.root_node.key = [] self.root_node.key = RadixKey(token_ids=[], extra_key=None)
self.root_node.value = [] self.root_node.value = []
self.root_node.host_value = [] self.root_node.host_value = []
self.root_node.lock_ref = 1 self.root_node.lock_ref = 1
...@@ -166,16 +210,43 @@ class RadixCache(BasePrefixCache): ...@@ -166,16 +210,43 @@ class RadixCache(BasePrefixCache):
self.protected_size_ = 0 self.protected_size_ = 0
self._record_all_cleared_event() self._record_all_cleared_event()
def match_prefix(self, key: List[int], **kwargs) -> MatchResult: def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
"""Find the matching prefix from the radix tree. """Find the longest cached prefix of ``key`` in the radix tree.
The logical namespace for prefix matching is determined by both the
token id sequence and the optional ``extra_key`` carried by ``RadixKey``.
Entries that share identical leading token ids but have *different*
``extra_key`` values are intentionally kept disjoint and never share
prefix nodes. This is useful to:
* Isolate KV cache lines for different LoRA / adapter IDs.
* Separate requests that intentionally should not share state (e.g.,
different sampling salt, cache version, or retrieval augmentation
context) by supplying a distinct ``extra_key``.
Args: Args:
key: A list of token IDs to find a matching prefix. key (RadixKey): The lookup key containing a list of token ids and an
optional ``extra_key`` namespace tag. If ``page_size > 1`` the
length is internally truncated to a multiple of ``page_size``
before matching. Passing an empty key returns an empty result
with the root as the last node.
**kwargs: Reserved for future extensions (ignored currently).
Returns: Returns:
A tuple of a tensor of matching prefix token IDs and MatchResult: ``device_indices`` is a 1-D ``torch.int64`` tensor of
the last node that contains the prefix values. Note that the concatenated KV cache indices corresponding to the longest
this API can modify the internal state of the Radix tree. cached prefix (may be length 0). ``last_device_node`` and
The last node create a new child if the prefix is shorter ``last_host_node`` (currently the same) are the tree node objects
than the last node's value. representing the terminal node of the matched prefix. This method
may mutate internal structure by splitting an existing node if the
match ends inside a stored segment.
Internal updates:
* Refreshes access metadata (timestamps) used by the
configured eviction strategy.
* If the lookup ends inside a stored segment the node is split once
to expose a precise boundary; this structural refinement improves
subsequent match efficiency and does not duplicate data.
""" """
if self.disable or len(key) == 0: if self.disable or len(key) == 0:
return MatchResult( return MatchResult(
...@@ -203,12 +274,12 @@ class RadixCache(BasePrefixCache): ...@@ -203,12 +274,12 @@ class RadixCache(BasePrefixCache):
last_host_node=last_node, last_host_node=last_node,
) )
def insert(self, key: List, value=None, chunked=False): def insert(self, key: RadixKey, value=None, chunked=False):
if self.disable: if self.disable:
return 0 return 0
if value is None: if value is None:
value = [x for x in key] value = torch.tensor(key.token_ids, dtype=torch.int64)
return self._insert_helper(self.root_node, key, value) return self._insert_helper(self.root_node, key, value)
def cache_finished_req(self, req: Req): def cache_finished_req(self, req: Req):
...@@ -238,7 +309,8 @@ class RadixCache(BasePrefixCache): ...@@ -238,7 +309,8 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
new_prefix_len = self.insert( new_prefix_len = self.insert(
token_ids[:page_aligned_len], page_aligned_kv_indices RadixKey(token_ids[:page_aligned_len], req.extra_key),
page_aligned_kv_indices,
) )
self.token_to_kv_pool_allocator.free( self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len] kv_indices[len(req.prefix_indices) : new_prefix_len]
...@@ -270,14 +342,18 @@ class RadixCache(BasePrefixCache): ...@@ -270,14 +342,18 @@ class RadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
new_prefix_len = self.insert( new_prefix_len = self.insert(
page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked RadixKey(page_aligned_token_ids, req.extra_key),
page_aligned_kv_indices,
chunked=chunked,
) )
self.token_to_kv_pool_allocator.free( self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len] kv_indices[len(req.prefix_indices) : new_prefix_len]
) )
# The prefix indices could be updated, reuse it # The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids) new_indices, new_last_node, _, _ = self.match_prefix(
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
)
self.req_to_token_pool.write( self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :], new_indices[len(req.prefix_indices) :],
...@@ -379,7 +455,7 @@ class RadixCache(BasePrefixCache): ...@@ -379,7 +455,7 @@ class RadixCache(BasePrefixCache):
##### Internal Helper Functions ##### ##### Internal Helper Functions #####
def _match_prefix_helper(self, node: TreeNode, key: List): def _match_prefix_helper(self, node: TreeNode, key: RadixKey):
node.last_access_time = time.monotonic() node.last_access_time = time.monotonic()
child_key = self.get_child_key_fn(key) child_key = self.get_child_key_fn(key)
...@@ -404,7 +480,7 @@ class RadixCache(BasePrefixCache): ...@@ -404,7 +480,7 @@ class RadixCache(BasePrefixCache):
return value, node return value, node
def _split_node(self, key, child: TreeNode, split_len: int): def _split_node(self, key: RadixKey, child: TreeNode, split_len: int):
# new_node -> child # new_node -> child
self._record_remove_event(child) self._record_remove_event(child)
new_node = TreeNode() new_node = TreeNode()
...@@ -423,7 +499,7 @@ class RadixCache(BasePrefixCache): ...@@ -423,7 +499,7 @@ class RadixCache(BasePrefixCache):
return new_node return new_node
def _insert_helper(self, node: TreeNode, key: List, value): def _insert_helper(self, node: TreeNode, key: RadixKey, value):
node.last_access_time = time.monotonic() node.last_access_time = time.monotonic()
if len(key) == 0: if len(key) == 0:
return 0 return 0
...@@ -464,7 +540,7 @@ class RadixCache(BasePrefixCache): ...@@ -464,7 +540,7 @@ class RadixCache(BasePrefixCache):
print( print(
" " * current_indent, " " * current_indent,
len(current_node.key), len(current_node.key),
current_node.key[:10], current_node.key.token_ids[:10],
f"r={current_node.lock_ref}", f"r={current_node.lock_ref}",
) )
for key, child in current_node.children.items(): for key, child in current_node.children.items():
...@@ -516,11 +592,11 @@ class RadixCache(BasePrefixCache): ...@@ -516,11 +592,11 @@ class RadixCache(BasePrefixCache):
last_page_start = ( last_page_start = (
(len(node.parent.key) - 1) // self.page_size (len(node.parent.key) - 1) // self.page_size
) * self.page_size ) * self.page_size
parent_parent_tokens = node.parent.key[last_page_start:] parent_parent_tokens = node.parent.key.token_ids[last_page_start:]
parent_block_hash = hash(tuple(parent_parent_tokens)) parent_block_hash = hash(tuple(parent_parent_tokens))
for start in range(0, len(node.key), self.page_size): for start in range(0, len(node.key), self.page_size):
page_tokens = node.key[start : start + self.page_size] page_tokens = node.key.token_ids[start : start + self.page_size]
if not page_tokens: if not page_tokens:
continue continue
...@@ -543,7 +619,7 @@ class RadixCache(BasePrefixCache): ...@@ -543,7 +619,7 @@ class RadixCache(BasePrefixCache):
# One BlockRemoved per chunk. # One BlockRemoved per chunk.
if self.enable_kv_cache_events: if self.enable_kv_cache_events:
for start in range(0, len(node.key), self.page_size): for start in range(0, len(node.key), self.page_size):
page_tokens = node.key[start : start + self.page_size] page_tokens = node.key.token_ids[start : start + self.page_size]
if not page_tokens: if not page_tokens:
continue continue
block_hash = hash(tuple(page_tokens)) block_hash = hash(tuple(page_tokens))
...@@ -569,19 +645,12 @@ class RadixCache(BasePrefixCache): ...@@ -569,19 +645,12 @@ class RadixCache(BasePrefixCache):
if __name__ == "__main__": if __name__ == "__main__":
tree = RadixCache(None, None, page_size=1, disable=False) tree = RadixCache(None, None, page_size=1, disable=False)
tree.insert("Hello") # Example token id sequences (as lists of ints)
tree.insert("Hello") tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
tree.insert("Hello_L.A.!") tree.insert(RadixKey(token_ids=[1, 2, 3], extra_key=None))
# tree.insert("Hello_world! Happy") tree.insert(RadixKey(token_ids=[1, 2, 4, 5], extra_key=None))
# tree.insert("I love you!") tree.insert(RadixKey(token_ids=[1, 2, 4, 5, 6, 7], extra_key=None))
tree.insert(RadixKey(token_ids=[8, 9, 10, 11, 12], extra_key=None))
tree.pretty_print() tree.pretty_print()
# print(tree.match_prefix("I love you! aha")) print(tree.match_prefix(RadixKey(token_ids=[1, 2, 3, 13, 14], extra_key=None)))
# def evict_callback(x):
# print("evict", x)
# return len(x)
# tree.evict(5, evict_callback)
# tree.evict(10, evict_callback)
# tree.pretty_print()
...@@ -13,6 +13,7 @@ from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import ( ...@@ -13,6 +13,7 @@ from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
TreeNodeCpp, TreeNodeCpp,
) )
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
...@@ -93,9 +94,9 @@ class RadixCacheCpp(BasePrefixCache): ...@@ -93,9 +94,9 @@ class RadixCacheCpp(BasePrefixCache):
raise NotImplementedError("Host cache is not supported yet") raise NotImplementedError("Host cache is not supported yet")
self.tree.reset() self.tree.reset()
def match_prefix(self, key: List[int], **kwargs) -> MatchResult: def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
device_indices_vec, host_indices_length, node_gpu, node_cpu = ( device_indices_vec, host_indices_length, node_gpu, node_cpu = (
self.tree.match_prefix(key) self.tree.match_prefix(key.token_ids)
) )
return MatchResult( return MatchResult(
device_indices=self._merge_tensor(device_indices_vec), device_indices=self._merge_tensor(device_indices_vec),
...@@ -104,16 +105,16 @@ class RadixCacheCpp(BasePrefixCache): ...@@ -104,16 +105,16 @@ class RadixCacheCpp(BasePrefixCache):
host_hit_length=host_indices_length, host_hit_length=host_indices_length,
) )
def _insert(self, key: List[int], value: torch.Tensor) -> int: def _insert(self, key: RadixKey, value: torch.Tensor) -> int:
""" """
Insert a key-value pair into the radix tree. Insert a key-value pair into the radix tree.
Args: Args:
key (List[int]): The key to insert, represented as a list of integers. key (RadixKey): The key to insert, represented as a RadixKey.
value (torch.Tensor): The value to associate with the key. value (torch.Tensor): The value to associate with the key.
Returns: Returns:
int: Number of device indices that were already present in the tree before the insertion. int: Number of device indices that were already present in the tree before the insertion.
""" """
ongoing_write, length = self.tree.writing_through(key, value) ongoing_write, length = self.tree.writing_through(key.token_ids, value)
if self.cache_controller is None: if self.cache_controller is None:
assert len(ongoing_write) == 0, "Implementation error" assert len(ongoing_write) == 0, "Implementation error"
return length return length
...@@ -160,7 +161,7 @@ class RadixCacheCpp(BasePrefixCache): ...@@ -160,7 +161,7 @@ class RadixCacheCpp(BasePrefixCache):
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
# it will automatically align them, but length of them should be equal # it will automatically align them, but length of them should be equal
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
new_prefix_len = self._insert(token_ids, kv_indices) new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices" assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
...@@ -191,14 +192,16 @@ class RadixCacheCpp(BasePrefixCache): ...@@ -191,14 +192,16 @@ class RadixCacheCpp(BasePrefixCache):
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
# it will automatically align them, but length of them should be equal # it will automatically align them, but length of them should be equal
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
new_prefix_len = self._insert(token_ids, kv_indices) new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices)
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices" assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
# TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function) # TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
# The prefix indices need to updated to reuse the kv indices in the pool # The prefix indices need to updated to reuse the kv indices in the pool
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids) new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(
RadixKey(token_ids, req.extra_key).token_ids
)
new_indices = self._merge_tensor(new_indices_vec) new_indices = self._merge_tensor(new_indices_vec)
assert new_prefix_len <= len(new_indices) assert new_prefix_len <= len(new_indices)
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import MatchResult from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
try: try:
from lmcache.integration.sglang.sglang_adapter import ( from lmcache.integration.sglang.sglang_adapter import (
...@@ -131,7 +131,7 @@ class LMCRadixCache(RadixCache): ...@@ -131,7 +131,7 @@ class LMCRadixCache(RadixCache):
with self._node_lock: with self._node_lock:
self._in_flight_nodes.clear() self._in_flight_nodes.clear()
def match_prefix(self, key: List[int], **kwargs) -> MatchResult: # type: ignore[override] def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult: # type: ignore[override]
"""Match cached prefix; if there's a tail miss, prefetch from LMCache. """Match cached prefix; if there's a tail miss, prefetch from LMCache.
Reuses the base matching logic to obtain (value, last_node). If there Reuses the base matching logic to obtain (value, last_node). If there
...@@ -178,7 +178,7 @@ class LMCRadixCache(RadixCache): ...@@ -178,7 +178,7 @@ class LMCRadixCache(RadixCache):
with torch.cuda.stream(self.load_stream): with torch.cuda.stream(self.load_stream):
num_retrieved = self.lmcache_connector.start_load_kv( num_retrieved = self.lmcache_connector.start_load_kv(
LoadMetadata( LoadMetadata(
token_ids=key, # full page-aligned key token_ids=key.token_ids, # full page-aligned key
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
offset=value.numel() - prefix_pad, # LMCache offset convention offset=value.numel() - prefix_pad, # LMCache offset convention
) )
...@@ -227,7 +227,7 @@ class LMCRadixCache(RadixCache): ...@@ -227,7 +227,7 @@ class LMCRadixCache(RadixCache):
req.req_pool_idx, : len(token_ids) req.req_pool_idx, : len(token_ids)
] ]
_, new_last_node, _, _ = self.match_prefix(token_ids) _, new_last_node, _, _ = self.match_prefix(RadixKey(token_ids, req.extra_key))
assert new_last_node is not None assert new_last_node is not None
self.inc_lock_ref(new_last_node) self.inc_lock_ref(new_last_node)
...@@ -277,6 +277,8 @@ if __name__ == "__main__": ...@@ -277,6 +277,8 @@ if __name__ == "__main__":
rank=0, rank=0,
tp_group=None, tp_group=None,
) )
cache.insert([1, 2, 3], torch.tensor([10, 11, 12], dtype=torch.int64)) cache.insert(RadixKey([1, 2, 3]), torch.tensor([10, 11, 12], dtype=torch.int64))
cache.insert([1, 2, 3, 4], torch.tensor([10, 11, 12, 13], dtype=torch.int64)) cache.insert(
RadixKey([1, 2, 3, 4]), torch.tensor([10, 11, 12, 13], dtype=torch.int64)
)
cache.pretty_print() cache.pretty_print()
...@@ -30,6 +30,12 @@ import torch ...@@ -30,6 +30,12 @@ import torch
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import (
RadixKey,
_key_match_page_size1,
_key_match_paged,
get_child_key,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
...@@ -47,7 +53,7 @@ class TreeNode: ...@@ -47,7 +53,7 @@ class TreeNode:
def __init__(self, id: Optional[int] = None): def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode) self.children = defaultdict(TreeNode)
self.parent: TreeNode = None self.parent: TreeNode = None
self.key: List[int] = None self.key: RadixKey = None
self.value: Optional[torch.Tensor] = None self.value: Optional[torch.Tensor] = None
# swa_tombstone is used to indicate the kv indices have been freed for swa layers # swa_tombstone is used to indicate the kv indices have been freed for swa layers
self.swa_tombstone = False self.swa_tombstone = False
...@@ -87,27 +93,6 @@ class TreeNode: ...@@ -87,27 +93,6 @@ class TreeNode:
return self.last_access_time < other.last_access_time return self.last_access_time < other.last_access_time
def _key_match_page_size1(key0: List, key1: List):
i = 0
for k0, k1 in zip(key0, key1):
if k0 != k1:
break
i += 1
return i
def _key_match_paged(key0: List, key1: List, page_size: int):
min_len = min(len(key0), len(key1))
i = 0
while i < min_len:
if key0[i : i + page_size] != key1[i : i + page_size]:
break
i += page_size
return i
def gen_swa_uuid() -> int: def gen_swa_uuid() -> int:
TreeNode.swa_uuid_counter += 1 TreeNode.swa_uuid_counter += 1
return TreeNode.swa_uuid_counter return TreeNode.swa_uuid_counter
...@@ -356,10 +341,10 @@ class SWARadixCache(BasePrefixCache): ...@@ -356,10 +341,10 @@ class SWARadixCache(BasePrefixCache):
if self.page_size == 1: if self.page_size == 1:
self.key_match_fn = _key_match_page_size1 self.key_match_fn = _key_match_page_size1
self.get_child_key_fn = lambda key: key[0] self.get_child_key_fn = get_child_key
else: else:
self.key_match_fn = partial(_key_match_paged, page_size=page_size) self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = lambda key: tuple(key[:page_size]) self.get_child_key_fn = partial(get_child_key, page_size=page_size)
self.sliding_window_size = sliding_window_size self.sliding_window_size = sliding_window_size
self.reset() self.reset()
...@@ -380,10 +365,10 @@ class SWARadixCache(BasePrefixCache): ...@@ -380,10 +365,10 @@ class SWARadixCache(BasePrefixCache):
self.full_lru_list = LRUList(swa=False) self.full_lru_list = LRUList(swa=False)
self.swa_lru_list = LRUList(swa=True) self.swa_lru_list = LRUList(swa=True)
def match_prefix(self, key: List[int], **kwargs) -> MatchResult: def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
"""Find the matching prefix from the radix tree. """Find the matching prefix from the radix tree.
Args: Args:
key: A list of token IDs to find a matching prefix. key: A RadixKey contains token IDs to find a matching prefix.
Returns: Returns:
A tuple of a tensor of matching prefix token IDs and A tuple of a tensor of matching prefix token IDs and
the last node that contains the prefix values. Note that the last node that contains the prefix values. Note that
...@@ -417,12 +402,12 @@ class SWARadixCache(BasePrefixCache): ...@@ -417,12 +402,12 @@ class SWARadixCache(BasePrefixCache):
last_host_node=last_node, last_host_node=last_node,
) )
def insert(self, key: List, value=None, prev_prefix_len: int = 0) -> int: def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int:
if self.disable: if self.disable:
return 0 return 0
if value is None: if value is None:
value = [x for x in key] value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
return self._insert_helper(self.root_node, key, value, prev_prefix_len) return self._insert_helper(self.root_node, key, value, prev_prefix_len)
def cache_finished_req(self, req: Req) -> None: def cache_finished_req(self, req: Req) -> None:
...@@ -453,7 +438,7 @@ class SWARadixCache(BasePrefixCache): ...@@ -453,7 +438,7 @@ class SWARadixCache(BasePrefixCache):
# insert the token_ids and kv_indices into the radix tree # insert the token_ids and kv_indices into the radix tree
# Note: the insert function already frees the overlapped kv_indices # Note: the insert function already frees the overlapped kv_indices
new_prefix_len = self.insert( new_prefix_len = self.insert(
token_ids[:page_aligned_len], RadixKey(token_ids[:page_aligned_len], req.extra_key),
page_aligned_kv_indices, page_aligned_kv_indices,
len(req.prefix_indices), len(req.prefix_indices),
) )
...@@ -489,11 +474,15 @@ class SWARadixCache(BasePrefixCache): ...@@ -489,11 +474,15 @@ class SWARadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
# Note: the insert function already frees the overlapped kv_indices # Note: the insert function already frees the overlapped kv_indices
new_prefix_len = self.insert( new_prefix_len = self.insert(
page_aligned_token_ids, page_aligned_kv_indices, len(req.prefix_indices) RadixKey(page_aligned_token_ids, req.extra_key),
page_aligned_kv_indices,
len(req.prefix_indices),
) )
# The prefix indices could be updated, reuse it # The prefix indices could be updated, reuse it
new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids) new_indices, new_last_node, _, _ = self.match_prefix(
RadixKey(page_aligned_token_ids, req.extra_key)
)
assert len(req.prefix_indices) <= len( assert len(req.prefix_indices) <= len(
new_indices new_indices
), f"{req.prefix_indices=}, {new_indices=}" ), f"{req.prefix_indices=}, {new_indices=}"
...@@ -732,7 +721,9 @@ class SWARadixCache(BasePrefixCache): ...@@ -732,7 +721,9 @@ class SWARadixCache(BasePrefixCache):
##### Internal Helper Functions ##### ##### Internal Helper Functions #####
def _match_prefix_helper(self, key: List) -> Tuple[List[torch.Tensor], TreeNode]: def _match_prefix_helper(
self, key: RadixKey
) -> Tuple[List[torch.Tensor], TreeNode]:
""" """
SWA prefix matching helper. It factors in the sliding window size such that SWA prefix matching helper. It factors in the sliding window size such that
the matched node is guaranteed to either 1. connected to root without swa tombstone, the matched node is guaranteed to either 1. connected to root without swa tombstone,
...@@ -796,7 +787,7 @@ class SWARadixCache(BasePrefixCache): ...@@ -796,7 +787,7 @@ class SWARadixCache(BasePrefixCache):
return value[:best_value_len], best_last_node return value[:best_value_len], best_last_node
def _split_node(self, key: List[int], child: TreeNode, split_len: int) -> TreeNode: def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode:
# new_node -> child # new_node -> child
new_node = TreeNode() new_node = TreeNode()
new_node.children = {self.get_child_key_fn(key[split_len:]): child} new_node.children = {self.get_child_key_fn(key[split_len:]): child}
...@@ -831,7 +822,7 @@ class SWARadixCache(BasePrefixCache): ...@@ -831,7 +822,7 @@ class SWARadixCache(BasePrefixCache):
return new_node return new_node
def _insert_helper( def _insert_helper(
self, node: TreeNode, key: List, value, update_kv_after_len: int self, node: TreeNode, key: RadixKey, value, update_kv_after_len: int
) -> int: ) -> int:
# Update the last access time from root to leaf, so that # Update the last access time from root to leaf, so that
# swa will tombstone the node closer to root first # swa will tombstone the node closer to root first
......
...@@ -99,6 +99,7 @@ suites = { ...@@ -99,6 +99,7 @@ suites = {
TestFile("test_priority_scheduling.py", 100), TestFile("test_priority_scheduling.py", 100),
TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_pytorch_sampling_backend.py", 66),
TestFile("test_radix_attention.py", 105), TestFile("test_radix_attention.py", 105),
TestFile("test_radix_cache_unit.py", 5),
TestFile("test_regex_constrained.py", 64), TestFile("test_regex_constrained.py", 64),
TestFile("test_reasoning_parser.py", 5), TestFile("test_reasoning_parser.py", 5),
TestFile("test_retract_decode.py", 54), TestFile("test_retract_decode.py", 54),
......
This diff is collapsed.
...@@ -4,7 +4,8 @@ import torch ...@@ -4,7 +4,8 @@ import torch
from sglang.srt.mem_cache.allocator import SWAKVPool, SWATokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import SWAKVPool, SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import SWARadixCache from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
class TestSWA(unittest.TestCase): class TestSWA(unittest.TestCase):
...@@ -19,7 +20,7 @@ class TestSWA(unittest.TestCase): ...@@ -19,7 +20,7 @@ class TestSWA(unittest.TestCase):
def test_swa_memory_pool(self): def test_swa_memory_pool(self):
size = 16 size = 16
size_swa = 16 size_swa = 16
num_head = 8 head_num = 8
head_dim = 128 head_dim = 128
num_layers = 48 num_layers = 48
global_interval = 4 global_interval = 4
...@@ -34,14 +35,20 @@ class TestSWA(unittest.TestCase): ...@@ -34,14 +35,20 @@ class TestSWA(unittest.TestCase):
size=size, size=size,
size_swa=size_swa, size_swa=size_swa,
dtype=dtype, dtype=dtype,
num_head=num_head, head_num=head_num,
head_dim=head_dim, head_dim=head_dim,
swa_attention_layer_ids=swa_attention_layer_ids, swa_attention_layer_ids=swa_attention_layer_ids,
full_attention_layer_ids=full_attention_layer_ids, full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False,
device=device, device=device,
) )
alloc = SWATokenToKVPoolAllocator( alloc = SWATokenToKVPoolAllocator(
size=size, size_swa=size_swa, dtype=dtype, device=device, kvcache=pool size=size,
size_swa=size_swa,
dtype=dtype,
device=device,
kvcache=pool,
need_sort=False,
) )
assert alloc.available_size() == size + size_swa assert alloc.available_size() == size + size_swa
index = alloc.alloc(1) index = alloc.alloc(1)
...@@ -57,7 +64,7 @@ class TestSWA(unittest.TestCase): ...@@ -57,7 +64,7 @@ class TestSWA(unittest.TestCase):
kv_size = 128 kv_size = 128
kv_size_swa = 64 kv_size_swa = 64
sliding_window_size = 4 sliding_window_size = 4
num_head = 8 head_num = 8
head_dim = 128 head_dim = 128
num_layers = 48 num_layers = 48
global_interval = 4 global_interval = 4
...@@ -80,10 +87,11 @@ class TestSWA(unittest.TestCase): ...@@ -80,10 +87,11 @@ class TestSWA(unittest.TestCase):
size=kv_size, size=kv_size,
size_swa=kv_size_swa, size_swa=kv_size_swa,
dtype=dtype, dtype=dtype,
num_head=num_head, head_num=head_num,
head_dim=head_dim, head_dim=head_dim,
swa_attention_layer_ids=swa_attention_layer_ids, swa_attention_layer_ids=swa_attention_layer_ids,
full_attention_layer_ids=full_attention_layer_ids, full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False,
device=device, device=device,
) )
# setup token to kv pool allocator # setup token to kv pool allocator
...@@ -93,6 +101,7 @@ class TestSWA(unittest.TestCase): ...@@ -93,6 +101,7 @@ class TestSWA(unittest.TestCase):
dtype=dtype, dtype=dtype,
device=device, device=device,
kvcache=kv_pool, kvcache=kv_pool,
need_sort=False,
) )
# setup radix cache # setup radix cache
tree = SWARadixCache( tree = SWARadixCache(
...@@ -112,7 +121,7 @@ class TestSWA(unittest.TestCase): ...@@ -112,7 +121,7 @@ class TestSWA(unittest.TestCase):
print( print(
f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}" f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}"
) )
prefix_len = tree.insert(req1_token_ids, req1_kv_indices) prefix_len = tree.insert(RadixKey(req1_token_ids), req1_kv_indices)
print( print(
f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" f"req1: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
) )
...@@ -121,7 +130,7 @@ class TestSWA(unittest.TestCase): ...@@ -121,7 +130,7 @@ class TestSWA(unittest.TestCase):
print( print(
f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}" f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}"
) )
prefix_len = tree.insert(req2_token_ids, req2_kv_indices) prefix_len = tree.insert(RadixKey(req2_token_ids), req2_kv_indices)
print( print(
f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" f"req2: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
) )
...@@ -130,7 +139,7 @@ class TestSWA(unittest.TestCase): ...@@ -130,7 +139,7 @@ class TestSWA(unittest.TestCase):
print( print(
f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}" f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}"
) )
prefix_len = tree.insert(req3_token_ids, req3_kv_indices) prefix_len = tree.insert(RadixKey(req3_token_ids), req3_kv_indices)
print( print(
f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" f"req3: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
) )
...@@ -139,7 +148,7 @@ class TestSWA(unittest.TestCase): ...@@ -139,7 +148,7 @@ class TestSWA(unittest.TestCase):
print( print(
f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}" f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}"
) )
prefix_len = tree.insert(req4_token_ids, req4_kv_indices) prefix_len = tree.insert(RadixKey(req4_token_ids), req4_kv_indices)
print( print(
f"req4: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}" f"req4: prefix_len: {prefix_len}, allocator swa available size: {allocator.swa_available_size()}, full available size: {allocator.full_available_size()}"
) )
...@@ -161,21 +170,23 @@ class TestSWA(unittest.TestCase): ...@@ -161,21 +170,23 @@ class TestSWA(unittest.TestCase):
tree.pretty_print() tree.pretty_print()
req5_token_ids = [1, 2, 3, 4, 5] req5_token_ids = [1, 2, 3, 4, 5]
kv_indices, last_node = tree.match_prefix(req5_token_ids) result = tree.match_prefix(RadixKey(req5_token_ids))
kv_indices, last_node = result.device_indices, result.last_device_node
print( print(
f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
) )
assert len(kv_indices) == 0 assert len(kv_indices) == 0
req6_token_ids = [1, 2, 3, 4, 5, 60, 70] req6_token_ids = [1, 2, 3, 4, 5, 60, 70]
kv_indices, last_node = tree.match_prefix(req6_token_ids) result = tree.match_prefix(RadixKey(req6_token_ids))
kv_indices, last_node = result.device_indices, result.last_device_node
print( print(
f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}" f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
) )
assert len(kv_indices) == 7 assert len(kv_indices) == 7
assert len(last_node.key) == 2 assert len(last_node.key) == 2
assert last_node.key[0] == 60 assert last_node.key.token_ids[0] == 60
assert last_node.key[1] == 70 assert last_node.key.token_ids[1] == 70
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment