Unverified Commit 47367b76 authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

[Refactor] Clean up radix cache related API (#7303)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 650127a1
...@@ -38,7 +38,7 @@ import logging ...@@ -38,7 +38,7 @@ import logging
import threading import threading
from enum import Enum, auto from enum import Enum, auto
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -436,7 +436,7 @@ class Req: ...@@ -436,7 +436,7 @@ class Req:
self, self,
rid: str, rid: str,
origin_input_text: str, origin_input_text: str,
origin_input_ids: Tuple[int], origin_input_ids: List[int],
sampling_params: SamplingParams, sampling_params: SamplingParams,
return_logprob: bool = False, return_logprob: bool = False,
top_logprobs_num: int = 0, top_logprobs_num: int = 0,
...@@ -467,7 +467,7 @@ class Req: ...@@ -467,7 +467,7 @@ class Req:
# Each decode stage's output ids # Each decode stage's output ids
self.output_ids = [] self.output_ids = []
# fill_ids = origin_input_ids + output_ids. Updated if chunked. # fill_ids = origin_input_ids + output_ids. Updated if chunked.
self.fill_ids = None self.fill_ids = []
self.session_id = session_id self.session_id = session_id
self.input_embeds = input_embeds self.input_embeds = input_embeds
...@@ -519,13 +519,14 @@ class Req: ...@@ -519,13 +519,14 @@ class Req:
# Prefix info # Prefix info
# The indices to kv cache for the shared prefix. # The indices to kv cache for the shared prefix.
self.prefix_indices = [] self.prefix_indices: torch.Tensor = []
# Number of tokens to run prefill. # Number of tokens to run prefill.
self.extend_input_len = 0 self.extend_input_len = 0
# The relative logprob_start_len in an extend batch # The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0 self.extend_logprob_start_len = 0
self.last_node = None self.last_node: Any = None
self.last_node_global = None self.last_host_node: Any = None
self.host_hit_length = 0
# Whether or not if it is chunked. It increments whenever # Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is # it is chunked, and decrement whenever chunked request is
...@@ -644,21 +645,17 @@ class Req: ...@@ -644,21 +645,17 @@ class Req:
def init_next_round_input( def init_next_round_input(
self, self,
tree_cache: Optional[BasePrefixCache] = None, tree_cache: Optional[BasePrefixCache] = None,
enable_hierarchical_cache=False,
): ):
self.fill_ids = self.origin_input_ids + self.output_ids self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None: if tree_cache is not None:
# tree cache is None if the prefix is not computed with tree cache. (
if enable_hierarchical_cache: self.prefix_indices,
self.prefix_indices, self.last_node, self.last_node_global = ( self.last_node,
tree_cache.match_prefix( self.last_host_node,
key=self.adjust_max_prefix_ids(), include_evicted=True self.host_hit_length,
) ) = tree_cache.match_prefix(
) key=self.adjust_max_prefix_ids(),
else: )
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
def adjust_max_prefix_ids(self): def adjust_max_prefix_ids(self):
......
...@@ -90,7 +90,7 @@ class SchedulePolicy: ...@@ -90,7 +90,7 @@ class SchedulePolicy:
def calc_priority(self, waiting_queue: List[Req]) -> bool: def calc_priority(self, waiting_queue: List[Req]) -> bool:
if self.policy == CacheAgnosticPolicy.FCFS: if self.policy == CacheAgnosticPolicy.FCFS:
# A shortcut for FCFS # A shortcut for FCFS
return return False
policy = self._determine_active_policy(waiting_queue) policy = self._determine_active_policy(waiting_queue)
...@@ -134,7 +134,7 @@ class SchedulePolicy: ...@@ -134,7 +134,7 @@ class SchedulePolicy:
""" """
try: try:
policy_enum = CacheAwarePolicy(policy) policy_enum = CacheAwarePolicy(policy)
if tree_cache.disable: if getattr(tree_cache, "disable", True):
# If tree_cache is disabled, using CacheAgnosticPolicy policy # If tree_cache is disabled, using CacheAgnosticPolicy policy
return CacheAgnosticPolicy.FCFS return CacheAgnosticPolicy.FCFS
return policy_enum return policy_enum
...@@ -158,14 +158,9 @@ class SchedulePolicy: ...@@ -158,14 +158,9 @@ class SchedulePolicy:
prefix_ids = r.adjust_max_prefix_ids() prefix_ids = r.adjust_max_prefix_ids()
# NOTE: the prefix_indices must always be aligned with last_node # NOTE: the prefix_indices must always be aligned with last_node
if self.enable_hierarchical_cache: r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = (
r.prefix_indices, r.last_node, r.last_node_global = ( self.tree_cache.match_prefix(rid=r.rid, key=prefix_ids)
self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True) )
)
else:
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids
)
# NOTE(sang): This logic is for in-batch prefix caching; # NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from # If there are more than 1 request that have small matching prefix from
...@@ -175,7 +170,7 @@ class SchedulePolicy: ...@@ -175,7 +170,7 @@ class SchedulePolicy:
# threshold means we cannot use in-batch prefix caching for short prefixes. # threshold means we cannot use in-batch prefix caching for short prefixes.
# It is kind of common when the engine is long running (e.g., imagine the prefix "the"). # It is kind of common when the engine is long running (e.g., imagine the prefix "the").
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
in_batch_matching_prefixes, _ = ( in_batch_matching_prefixes, _, _, _ = (
self.waiting_queue_radix_tree.match_prefix( self.waiting_queue_radix_tree.match_prefix(
rid=r.rid, key=prefix_ids rid=r.rid, key=prefix_ids
) )
...@@ -268,6 +263,7 @@ class AddReqResult(Enum): ...@@ -268,6 +263,7 @@ class AddReqResult(Enum):
class PrefillAdder: class PrefillAdder:
def __init__( def __init__(
self, self,
page_size: int,
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
running_batch: ScheduleBatch, running_batch: ScheduleBatch,
...@@ -276,6 +272,7 @@ class PrefillAdder: ...@@ -276,6 +272,7 @@ class PrefillAdder:
rem_chunk_tokens: Optional[int], rem_chunk_tokens: Optional[int],
mixed_with_decode_tokens: int = 0, mixed_with_decode_tokens: int = 0,
): ):
self.page_size = page_size
self.tree_cache = tree_cache self.tree_cache = tree_cache
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.running_batch = running_batch self.running_batch = running_batch
...@@ -442,46 +439,43 @@ class PrefillAdder: ...@@ -442,46 +439,43 @@ class PrefillAdder:
return self.budget_state() return self.budget_state()
def add_one_req( def add_one_req(self, req: Req, has_chunked_req: bool):
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
):
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True): if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
return self.add_one_req_ignore_eos(req, has_chunked_req) return self.add_one_req_ignore_eos(req, has_chunked_req)
total_tokens = req.extend_input_len + min( total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
) )
input_tokens = (
-(-req.extend_input_len // self.tree_cache.page_size) # adjusting the input_tokens based on host_hit_length and page_size
* self.tree_cache.page_size real_input_tokens = req.extend_input_len - req.host_hit_length
) real_input_tokens = -(-real_input_tokens // self.page_size) * self.page_size
prefix_len = len(req.prefix_indices) prefix_len = len(req.prefix_indices)
if total_tokens >= self.rem_total_tokens: if total_tokens >= self.rem_total_tokens:
return AddReqResult.NO_TOKEN return AddReqResult.NO_TOKEN
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0: if real_input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
return AddReqResult.OTHER return AddReqResult.OTHER
with self._lock_node(req.last_node): with self._lock_node(req.last_node):
if total_tokens > self.rem_total_tokens: # self.rem_total_tokens may decrease after the lock acquisition
if total_tokens >= self.rem_total_tokens:
return AddReqResult.NO_TOKEN return AddReqResult.NO_TOKEN
if ( if req.host_hit_length > 0:
enable_hierarchical_cache new_indices, req.last_node = self.tree_cache.init_load_back(
and req.last_node_global is not None req.last_host_node, req.host_hit_length
and req.last_node_global.evicted
):
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
req.last_node_global, req.prefix_indices
) )
req.prefix_indices = torch.cat([req.prefix_indices, new_indices])
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
input_tokens = (
-(-req.extend_input_len // self.tree_cache.page_size)
* self.tree_cache.page_size
)
prefix_len = len(req.prefix_indices) prefix_len = len(req.prefix_indices)
input_tokens = -(-req.extend_input_len // self.page_size) * self.page_size
if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
return AddReqResult.OTHER
if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens: if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
# Non-chunked prefill # Non-chunked prefill
self.can_run_list.append(req) self.can_run_list.append(req)
...@@ -496,7 +490,7 @@ class PrefillAdder: ...@@ -496,7 +490,7 @@ class PrefillAdder:
) )
else: else:
# Make sure at least one page is available # Make sure at least one page is available
trunc_len = self.rem_chunk_tokens - self.tree_cache.page_size + 1 trunc_len = self.rem_chunk_tokens - self.page_size + 1
if trunc_len <= 0: if trunc_len <= 0:
return AddReqResult.OTHER return AddReqResult.OTHER
......
...@@ -1467,15 +1467,14 @@ class Scheduler( ...@@ -1467,15 +1467,14 @@ class Scheduler(
return None return None
if self.enable_hierarchical_cache: if self.enable_hierarchical_cache:
# check for completion of hierarchical cache activities to release memory self.tree_cache.check_hicache_events()
self.tree_cache.writing_check()
self.tree_cache.loading_check()
# Get priority queue # Get priority queue
prefix_computed = self.policy.calc_priority(self.waiting_queue) self.policy.calc_priority(self.waiting_queue)
# Prefill policy # Prefill policy
adder = PrefillAdder( adder = PrefillAdder(
self.page_size,
self.tree_cache, self.tree_cache,
self.token_to_kv_pool_allocator, self.token_to_kv_pool_allocator,
self.running_batch, self.running_batch,
...@@ -1517,19 +1516,8 @@ class Scheduler( ...@@ -1517,19 +1516,8 @@ class Scheduler(
self.running_batch.batch_is_full = True self.running_batch.batch_is_full = True
break break
# bypass prefix_computed if enable_hierarchical_cache req.init_next_round_input(self.tree_cache)
req.init_next_round_input( res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
(
None
if (prefix_computed and not self.enable_hierarchical_cache)
else self.tree_cache
),
self.enable_hierarchical_cache,
)
res = adder.add_one_req(
req, self.chunked_req, self.enable_hierarchical_cache
)
if res != AddReqResult.CONTINUE: if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN: if res == AddReqResult.NO_TOKEN:
...@@ -1581,7 +1569,9 @@ class Scheduler( ...@@ -1581,7 +1569,9 @@ class Scheduler(
) )
if self.enable_hierarchical_cache: if self.enable_hierarchical_cache:
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered # todo (zhiqiang): disable cuda graph execution if hicache loading triggered
new_batch.hicache_consumer_index = self.tree_cache.ready_to_load_cache() new_batch.hicache_consumer_index = (
self.tree_cache.ready_to_load_host_cache()
)
new_batch.prepare_for_extend() new_batch.prepare_for_extend()
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Tuple from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
import torch
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
else:
Req = Any # Placeholder for Req type when not type checking
class MatchResult(NamedTuple):
"""Result of a prefix match operation.
Attributes:
device_indices : Indices of the KV cache on the device matched by common prefix.
last_device_node: The last TreeNode on the device that was matched.
last_host_node : The last TreeNode on the host that was matched.
Note that if HiCache is not enabled,
this **must** be the same as `last_device_node`.
host_hit_length : Length of the KV cache hit on the host, if applicable.
0 if HiCache is not enabled.
"""
device_indices: torch.Tensor
last_device_node: Any
last_host_node: Any
host_hit_length: int = 0
class BasePrefixCache(ABC): class BasePrefixCache(ABC):
...@@ -10,19 +36,15 @@ class BasePrefixCache(ABC): ...@@ -10,19 +36,15 @@ class BasePrefixCache(ABC):
pass pass
@abstractmethod @abstractmethod
def match_prefix(self, **kwargs) -> Tuple[List[int], int]: def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
pass pass
@abstractmethod @abstractmethod
def insert(self, **kwargs): def cache_finished_req(self, req: Req, **kwargs):
pass pass
@abstractmethod @abstractmethod
def cache_finished_req(self, **kwargs): def cache_unfinished_req(self, req: Req, **kwargs):
pass
@abstractmethod
def cache_unfinished_req(self, **kwargs):
pass pass
@abstractmethod @abstractmethod
...@@ -49,5 +71,27 @@ class BasePrefixCache(ABC): ...@@ -49,5 +71,27 @@ class BasePrefixCache(ABC):
def pretty_print(self): def pretty_print(self):
raise NotImplementedError() raise NotImplementedError()
def init_load_back(
self,
last_host_node: Any,
host_hit_length: int,
) -> Tuple[torch.Tensor, Any]:
"""
Preparing KV cache loading from host to device.
"""
raise NotImplementedError()
def ready_to_load_host_cache(self) -> Any:
"""
Notify the cache controller to start the KV cache loading
"""
raise NotImplementedError()
def check_hicache_events(self) -> Any:
"""
Check HiCache related activities to update radix tree and synchronize across TP workers if needed
"""
raise NotImplementedError()
def take_events(self): def take_events(self):
return [] return []
...@@ -6,19 +6,13 @@ from typing import TYPE_CHECKING, Any, Callable, List, Tuple ...@@ -6,19 +6,13 @@ from typing import TYPE_CHECKING, Any, Callable, List, Tuple
import torch import torch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
class ChunkCacheEntry:
def __init__(self, rid: str, value: torch.Tensor):
self.rid = rid
self.value = value
class ChunkCache(BasePrefixCache): class ChunkCache(BasePrefixCache):
def __init__( def __init__(
self, self,
...@@ -29,13 +23,16 @@ class ChunkCache(BasePrefixCache): ...@@ -29,13 +23,16 @@ class ChunkCache(BasePrefixCache):
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size self.page_size = page_size
self.disable = True
def reset(self): def reset(self):
pass pass
def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]: def match_prefix(self, **unused_kwargs) -> MatchResult:
return [], None return MatchResult(
device_indices=torch.empty((0,), dtype=torch.int64),
last_device_node=None,
last_host_node=None,
)
def cache_finished_req(self, req: Req): def cache_finished_req(self, req: Req):
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
...@@ -54,9 +51,6 @@ class ChunkCache(BasePrefixCache): ...@@ -54,9 +51,6 @@ class ChunkCache(BasePrefixCache):
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req.prefix_indices = kv_indices req.prefix_indices = kv_indices
def insert(self):
raise NotImplementedError()
def evict(self, num_tokens: int): def evict(self, num_tokens: int):
pass pass
......
...@@ -7,6 +7,7 @@ from typing import List, Optional ...@@ -7,6 +7,7 @@ from typing import List, Optional
import torch import torch
from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool, MHATokenToKVPool,
MLATokenToKVPool, MLATokenToKVPool,
...@@ -283,41 +284,44 @@ class HiRadixCache(RadixCache): ...@@ -283,41 +284,44 @@ class HiRadixCache(RadixCache):
def init_load_back( def init_load_back(
self, self,
last_node: TreeNode, last_node: TreeNode,
prefix_indices: torch.Tensor, host_hit_length: int,
mem_quota: Optional[int] = None, mem_quota: Optional[int] = None,
): ):
assert ( _ = host_hit_length # unused, but kept for compatibility
len(prefix_indices) == 0 or prefix_indices.is_cuda
), "indices of device kV caches should be on GPU"
if last_node.evicted: if last_node.evicted:
loading_values = self.load_back(last_node, mem_quota) loading_values = self.load_back(last_node, mem_quota)
if loading_values is not None: if loading_values is not None:
prefix_indices = (
loading_values
if len(prefix_indices) == 0
else torch.cat([prefix_indices, loading_values])
)
logger.debug( logger.debug(
f"loading back {len(loading_values)} tokens for node {last_node.id}" f"loading back {len(loading_values)} tokens for node {last_node.id}"
) )
return loading_values, last_node
while last_node.evicted: while last_node.evicted:
last_node = last_node.parent last_node = last_node.parent
return last_node, prefix_indices return (
torch.empty((0,), dtype=torch.int64, device=self.device),
last_node,
)
def ready_to_load_cache(self): def ready_to_load_host_cache(self):
producer_index = self.cache_controller.layer_done_counter.next_producer() producer_index = self.cache_controller.layer_done_counter.next_producer()
self.load_cache_event.set() self.load_cache_event.set()
return producer_index return producer_index
def match_prefix(self, key: List[int], include_evicted=False, **kwargs): def check_hicache_events(self):
self.writing_check()
self.loading_check()
def match_prefix(self, key: List[int], **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
if self.disable or len(key) == 0: if self.disable or len(key) == 0:
if include_evicted: return MatchResult(
return empty_value, self.root_node, self.root_node device_indices=empty_value,
else: last_device_node=self.root_node,
return empty_value, self.root_node last_host_node=self.root_node,
host_hit_length=0,
)
if self.page_size != 1: if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size page_aligned_len = len(key) // self.page_size * self.page_size
...@@ -329,14 +333,18 @@ class HiRadixCache(RadixCache): ...@@ -329,14 +333,18 @@ class HiRadixCache(RadixCache):
else: else:
value = empty_value value = empty_value
last_node_global = last_node host_hit_length = 0
last_host_node = last_node
while last_node.evicted: while last_node.evicted:
host_hit_length += len(last_node.host_value)
last_node = last_node.parent last_node = last_node.parent
if include_evicted: return MatchResult(
return value, last_node, last_node_global device_indices=value,
else: last_device_node=last_node,
return value, last_node last_host_node=last_host_node,
host_hit_length=host_hit_length,
)
def _match_prefix_helper(self, node: TreeNode, key: List): def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.monotonic() node.last_access_time = time.monotonic()
......
...@@ -33,8 +33,7 @@ from sglang.srt.disaggregation.kv_events import ( ...@@ -33,8 +33,7 @@ from sglang.srt.disaggregation.kv_events import (
BlockStored, BlockStored,
KVCacheEvent, KVCacheEvent,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -47,9 +46,9 @@ class TreeNode: ...@@ -47,9 +46,9 @@ class TreeNode:
def __init__(self, id: Optional[int] = None): def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode) self.children = defaultdict(TreeNode)
self.parent = None self.parent: TreeNode = None
self.key = None self.key: List[int] = None
self.value = None self.value: Optional[torch.Tensor] = None
self.lock_ref = 0 self.lock_ref = 0
self.last_access_time = time.monotonic() self.last_access_time = time.monotonic()
...@@ -57,7 +56,7 @@ class TreeNode: ...@@ -57,7 +56,7 @@ class TreeNode:
# indicating the node is loading KV cache from host # indicating the node is loading KV cache from host
self.loading = False self.loading = False
# store the host indices of KV cache # store the host indices of KV cache
self.host_value = None self.host_value: Optional[torch.Tensor] = None
self.id = TreeNode.counter if id is None else id self.id = TreeNode.counter if id is None else id
TreeNode.counter += 1 TreeNode.counter += 1
...@@ -135,7 +134,7 @@ class RadixCache(BasePrefixCache): ...@@ -135,7 +134,7 @@ class RadixCache(BasePrefixCache):
self.protected_size_ = 0 self.protected_size_ = 0
self._record_all_cleared_event() self._record_all_cleared_event()
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]: def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
"""Find the matching prefix from the radix tree. """Find the matching prefix from the radix tree.
Args: Args:
key: A list of token IDs to find a matching prefix. key: A list of token IDs to find a matching prefix.
...@@ -147,13 +146,14 @@ class RadixCache(BasePrefixCache): ...@@ -147,13 +146,14 @@ class RadixCache(BasePrefixCache):
than the last node's value. than the last node's value.
""" """
if self.disable or len(key) == 0: if self.disable or len(key) == 0:
return ( return MatchResult(
torch.empty( device_indices=torch.empty(
(0,), (0,),
dtype=torch.int64, dtype=torch.int64,
device=self.device, device=self.device,
), ),
self.root_node, last_device_node=self.root_node,
last_host_node=self.root_node,
) )
if self.page_size != 1: if self.page_size != 1:
...@@ -165,7 +165,11 @@ class RadixCache(BasePrefixCache): ...@@ -165,7 +165,11 @@ class RadixCache(BasePrefixCache):
value = torch.cat(value) value = torch.cat(value)
else: else:
value = torch.empty((0,), dtype=torch.int64, device=self.device) value = torch.empty((0,), dtype=torch.int64, device=self.device)
return value, last_node return MatchResult(
device_indices=value,
last_device_node=last_node,
last_host_node=last_node,
)
def insert(self, key: List, value=None): def insert(self, key: List, value=None):
if self.disable: if self.disable:
...@@ -235,7 +239,7 @@ class RadixCache(BasePrefixCache): ...@@ -235,7 +239,7 @@ class RadixCache(BasePrefixCache):
) )
# The prefix indices could be updated, reuse it # The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids) new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
self.req_to_token_pool.write( self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :], new_indices[len(req.prefix_indices) :],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment