import heapq import logging import threading import time from queue import Queue from typing import List, Optional import torch from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import MatchResult from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, MLATokenToKVPool, ReqToTokenPool, ) from sglang.srt.mem_cache.memory_pool_host import ( MHATokenToKVPoolHost, MLATokenToKVPoolHost, ) from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode logger = logging.getLogger(__name__) class HiRadixCache(RadixCache): def __init__( self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, tp_cache_group: torch.distributed.ProcessGroup, page_size: int, hicache_ratio: float, hicache_size: int, hicache_write_policy: str, hicache_io_backend: str, hicache_mem_layout: str, hicache_storage_backend: Optional[str] = None, hicache_storage_prefetch_policy: Optional[str] = "best_effort", model_name: Optional[str] = None, storage_backend_extra_config: Optional[str] = None, ): if hicache_io_backend == "direct": if hicache_mem_layout == "page_first": hicache_mem_layout = "layer_first" logger.warning( "Page first layout is not supported with direct IO backend, switching to layer first layout" ) self.kv_cache = token_to_kv_pool_allocator.get_kvcache() if isinstance(self.kv_cache, MHATokenToKVPool): self.token_to_kv_pool_host = MHATokenToKVPoolHost( self.kv_cache, hicache_ratio, hicache_size, page_size, hicache_mem_layout, ) elif isinstance(self.kv_cache, MLATokenToKVPool): self.token_to_kv_pool_host = MLATokenToKVPoolHost( self.kv_cache, hicache_ratio, hicache_size, page_size, hicache_mem_layout, ) else: raise ValueError(f"HiRadixCache only supports MHA and MLA yet") self.tp_group = tp_cache_group self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) self.enable_storage = hicache_storage_backend is not None # todo: customizable storage prefetch threshold and timeout self.prefetch_threshold = 256 self.prefetch_timeout = 3 # seconds self.prefetch_stop_policy = hicache_storage_prefetch_policy self.load_cache_event = threading.Event() self.cache_controller = HiCacheController( token_to_kv_pool_allocator, self.token_to_kv_pool_host, page_size, self.tp_group, load_cache_event=self.load_cache_event, write_policy=hicache_write_policy, io_backend=hicache_io_backend, storage_backend=hicache_storage_backend, prefetch_threshold=self.prefetch_threshold, model_name=model_name, storage_backend_extra_config=storage_backend_extra_config, ) # record the nodes with ongoing write through self.ongoing_write_through = {} # record the node segments with ongoing load back self.ongoing_load_back = {} # record the ongoing prefetch requests self.ongoing_prefetch = {} self.ongoing_backup = {} # todo: dynamically adjust the threshold self.write_through_threshold = ( 1 if hicache_write_policy == "write_through" else 2 ) self.load_back_threshold = 10 super().__init__( req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False ) def reset(self): TreeNode.counter = 0 self.cache_controller.reset() self.token_to_kv_pool_host.clear() super().reset() def get_height(self, node: TreeNode): height = 0 while node != self.root_node: node = node.parent height += 1 return height def clear_storage_backend(self): if self.enable_storage: self.cache_controller.storage_backend.clear() logger.info("Hierarchical cache storage backend cleared successfully!") return True else: logger.warning("Hierarchical cache storage backend is not enabled.") return False def write_backup(self, node: TreeNode, write_back=False): host_indices = self.cache_controller.write( device_indices=node.value, node_id=node.id, ) if host_indices is None: self.evict_host(len(node.value)) host_indices = self.cache_controller.write( device_indices=node.value, node_id=node.id, ) if host_indices is not None: node.host_value = host_indices assert len(node.host_value) > 0 self.ongoing_write_through[node.id] = node if not write_back: # no need to lock nodes if write back self.inc_lock_ref(node) else: return 0 return len(host_indices) def write_backup_storage(self, node: TreeNode): operation_id = self.cache_controller.write_storage( node.host_value, node.key, node.hash_value ) self.ongoing_backup[operation_id] = node node.protect_host() def _inc_hit_count(self, node: TreeNode, chunked=False): # skip the hit count update for chunked requests if self.cache_controller.write_policy == "write_back" or chunked: return node.hit_count += 1 if not node.backuped: if node.hit_count >= self.write_through_threshold: # write to host if the node is not backuped self.write_backup(node) def writing_check(self, write_back=False): if write_back: # blocking till all write back complete while len(self.ongoing_write_through) > 0: ack_id = self.cache_controller.ack_write_queue.get() del self.ongoing_write_through[ack_id] return queue_size = torch.tensor( self.cache_controller.ack_write_queue.qsize(), dtype=torch.int ) if self.tp_world_size > 1: # synchrnoize TP workers to make the same update to radix cache torch.distributed.all_reduce( queue_size, op=torch.distributed.ReduceOp.MIN, group=self.tp_group, ) for _ in range(queue_size.item()): ack_id = self.cache_controller.ack_write_queue.get() backuped_node = self.ongoing_write_through[ack_id] self.dec_lock_ref(backuped_node) del self.ongoing_write_through[ack_id] if self.enable_storage: self.write_backup_storage(backuped_node) def loading_check(self): while not self.cache_controller.ack_load_queue.empty(): try: ack_id = self.cache_controller.ack_load_queue.get_nowait() start_node, end_node = self.ongoing_load_back[ack_id] self.dec_lock_ref(end_node) while end_node != start_node: assert end_node.loading end_node.loading = False end_node = end_node.parent # clear the reference del self.ongoing_load_back[ack_id] except Exception: break def evictable_size(self): return self.evictable_size_ def evict(self, num_tokens: int): leaves = self._collect_leaves_device() heapq.heapify(leaves) num_evicted = 0 write_back_nodes = [] while num_evicted < num_tokens and len(leaves): x = heapq.heappop(leaves) if x.lock_ref > 0: continue if not x.backuped: if self.cache_controller.write_policy == "write_back": # write to host if the node is not backuped num_evicted += self.write_backup(x, write_back=True) write_back_nodes.append(x) else: num_evicted += self._evict_regular(x) else: num_evicted += self._evict_backuped(x) for child in x.parent.children.values(): if child in write_back_nodes: continue if not child.evicted: break else: # all children are evicted or no children heapq.heappush(leaves, x.parent) if self.cache_controller.write_policy == "write_back": self.writing_check(write_back=True) for node in write_back_nodes: assert node.backuped self._evict_backuped(node) def _evict_backuped(self, node: TreeNode): # evict a node already written to host num_evicted = self.cache_controller.evict_device(node.value, node.host_value) assert num_evicted > 0 self.evictable_size_ -= num_evicted node.value = None return num_evicted def _evict_regular(self, node: TreeNode): # evict a node not initiated write to host self.cache_controller.mem_pool_device_allocator.free(node.value) num_evicted = len(node.value) self._delete_leaf(node) return num_evicted def evict_host(self, num_tokens: int): leaves = self._collect_leaves() heapq.heapify(leaves) num_evicted = 0 while num_evicted < num_tokens and len(leaves): x = heapq.heappop(leaves) if x == self.root_node: break # only evict the host value of evicted nodes if not x.evicted: continue # node is protected from eviction as it has ongoing prefetch or backup to storage if x.host_ref_counter > 0: continue num_evicted += self.cache_controller.evict_host(x.host_value) for k, v in x.parent.children.items(): if v == x: break del x.parent.children[k] if len(x.parent.children) == 0 and x.parent.evicted: heapq.heappush(leaves, x.parent) def load_back( self, node: TreeNode, mem_quota: Optional[int] = None ) -> Optional[torch.Tensor]: # todo: more loading policies last_hit_node = node nodes_to_load = [] while node.evicted: assert ( node.backuped ), "No backup available on evicted nodes, should not happen" nodes_to_load.insert(0, node) node = node.parent else: ancester_node = node # protect the ancestor nodes from eviction delta = self.inc_lock_ref(ancester_node) # load it all or not at all host_indices = torch.cat([n.host_value for n in nodes_to_load]) if len(host_indices) < self.load_back_threshold or ( len(host_indices) > mem_quota + delta if mem_quota is not None else False ): # skip loading back if the total size is too small or exceeding the memory quota self.dec_lock_ref(ancester_node) return None device_indices = self.cache_controller.load( host_indices=host_indices, node_id=last_hit_node.id ) if device_indices is None: self.evict(len(host_indices)) device_indices = self.cache_controller.load( host_indices=host_indices, node_id=last_hit_node.id ) self.dec_lock_ref(ancester_node) if device_indices is None: # no sufficient GPU memory to load back KV caches return None self.ongoing_load_back[last_hit_node.id] = (ancester_node, last_hit_node) offset = 0 for node in nodes_to_load: node.value = device_indices[offset : offset + len(node.host_value)] offset += len(node.host_value) node.loading = True self.evictable_size_ += len(device_indices) self.inc_lock_ref(last_hit_node) return device_indices def init_load_back( self, last_node: TreeNode, host_hit_length: int, mem_quota: Optional[int] = None, ): _ = host_hit_length # unused, but kept for compatibility if last_node.evicted: loading_values = self.load_back(last_node, mem_quota) if loading_values is not None: logger.debug( f"loading back {len(loading_values)} tokens for node {last_node.id}" ) return loading_values, last_node while last_node.evicted: last_node = last_node.parent return ( torch.empty((0,), dtype=torch.int64, device=self.device), last_node, ) def ready_to_load_host_cache(self): producer_index = self.cache_controller.layer_done_counter.next_producer() self.load_cache_event.set() return producer_index def check_hicache_events(self): self.writing_check() self.loading_check() if self.enable_storage: self.drain_storage_control_queues() def drain_storage_control_queues(self): """ Combine prefetch revoke, backup ack, and host mem release checks to minimize TP synchronization and Python overhead. """ cc = self.cache_controller qsizes = torch.tensor( [ cc.prefetch_revoke_queue.qsize(), cc.ack_backup_queue.qsize(), cc.host_mem_release_queue.qsize(), ], dtype=torch.int, ) if self.tp_world_size > 1: torch.distributed.all_reduce( qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group ) n_revoke, n_backup, n_release = map(int, qsizes.tolist()) # process prefetch revokes for _ in range(n_revoke): req_id = cc.prefetch_revoke_queue.get() info = self.ongoing_prefetch.pop(req_id, None) if info is not None: last_host_node, token_ids, _, _ = info last_host_node.release_host() cc.prefetch_tokens_occupied -= len(token_ids) # else: the revoked operation already got terminated, nothing to do # process backup acks for _ in range(n_backup): ack_id = cc.ack_backup_queue.get() entry = self.ongoing_backup.pop(ack_id, None) if entry is not None: entry.release_host() # release host memory host_indices_list = [] for _ in range(n_release): host_indices_list.append(cc.host_mem_release_queue.get()) if host_indices_list: host_indices = torch.cat(host_indices_list, dim=0) cc.mem_pool_host.free(host_indices) def can_terminate_prefetch(self, operation: PrefetchOperation): can_terminate = True if self.prefetch_stop_policy == "best_effort": return can_terminate if len(operation.hash_value) == 0: completed = False else: completed = ( operation.completed_tokens == len(operation.hash_value) * self.page_size ) if self.prefetch_stop_policy == "wait_complete": can_terminate = completed elif self.prefetch_stop_policy == "timeout": can_terminate = completed or ( time.monotonic() - operation.start_time > self.prefetch_timeout ) else: # unknown prefetch stop policy, just return True return True if self.tp_world_size > 1: can_terminate = torch.tensor(can_terminate, dtype=torch.int) torch.distributed.all_reduce( can_terminate, op=torch.distributed.ReduceOp.MIN, group=self.tp_group, ) can_terminate = bool(can_terminate.item()) return can_terminate def check_prefetch_progress(self, req_id: str) -> bool: if req_id not in self.ongoing_prefetch: # there is no ongoing prefetch for this request or it has been revoked return True # todo: more policies for prefetch progress such as timeout # the current policy is to prefetch with best effort and terminate when queuing is over last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[ req_id ] if operation.host_indices is None: # prefetch has not been issued due to insufficient host memory return True if not self.can_terminate_prefetch(operation): return False completed_tokens, hash_value = self.cache_controller.terminate_prefetch( operation ) logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") min_completed_tokens = completed_tokens if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete": # synchrnoize TP workers to make the same update to hiradix cache completed_tokens_tensor = torch.tensor( min_completed_tokens, dtype=torch.int ) torch.distributed.all_reduce( completed_tokens_tensor, op=torch.distributed.ReduceOp.MIN, group=self.tp_group, ) min_completed_tokens = completed_tokens_tensor.item() fetched_token_ids = token_ids[:min_completed_tokens] written_indices = host_indices[:min_completed_tokens] matched_length = self._insert_helper_host( last_host_node, fetched_token_ids, written_indices, hash_value[: min_completed_tokens // self.page_size], ) if len(written_indices): self.cache_controller.mem_pool_host.update_prefetch(written_indices) self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) self.cache_controller.append_host_mem_release( host_indices[min_completed_tokens:completed_tokens] ) last_host_node.release_host() del self.ongoing_prefetch[req_id] self.cache_controller.prefetch_tokens_occupied -= len(token_ids) return True def match_prefix(self, key: List[int], **kwargs): empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) if self.disable or len(key) == 0: return MatchResult( device_indices=empty_value, last_device_node=self.root_node, last_host_node=self.root_node, host_hit_length=0, ) if self.page_size != 1: page_aligned_len = len(key) // self.page_size * self.page_size key = key[:page_aligned_len] value, last_node = self._match_prefix_helper(self.root_node, key) if value: value = torch.cat(value) else: value = empty_value host_hit_length = 0 last_host_node = last_node while last_node.evicted: host_hit_length += len(last_node.host_value) last_node = last_node.parent while not last_host_node.backuped: last_host_node = last_host_node.parent return MatchResult( device_indices=value, last_device_node=last_node, last_host_node=last_host_node, host_hit_length=host_hit_length, ) def prefetch_from_storage( self, req_id: str, last_host_node: TreeNode, new_input_tokens: List[int], last_hash: Optional[str] = None, ): # align the number of fetching tokens to the page size prefetch_length = len(new_input_tokens) - ( len(new_input_tokens) % self.page_size ) new_input_tokens = new_input_tokens[:prefetch_length] if ( not self.enable_storage or prefetch_length < self.prefetch_threshold or self.cache_controller.prefetch_rate_limited() ): return last_host_node.protect_host() host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length) if host_indices is None: self.evict_host(prefetch_length) host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length) if host_indices is None: last_host_node.release_host() # no sufficient host memory for prefetch return operation = self.cache_controller.prefetch( req_id, host_indices, new_input_tokens, last_hash ) self.ongoing_prefetch[req_id] = ( last_host_node, new_input_tokens, host_indices, operation, ) self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens) def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value): node.last_access_time = time.monotonic() if len(key) == 0: return 0 child_key = self.get_child_key_fn(key) matched_length = 0 while len(key) > 0 and child_key in node.children.keys(): node = node.children[child_key] node.last_access_time = time.monotonic() prefix_len = self.key_match_fn(node.key, key) key = key[prefix_len:] host_value = host_value[prefix_len:] hash_value = hash_value[prefix_len // self.page_size :] matched_length += prefix_len if prefix_len < len(node.key): new_node = self._split_node(node.key, node, prefix_len) node = new_node if len(key): child_key = self.get_child_key_fn(key) if len(key): new_node = TreeNode() new_node.parent = node new_node.key = key new_node.value = None new_node.host_value = host_value new_node.hash_value = hash_value node.children[child_key] = new_node return matched_length def _match_prefix_helper(self, node: TreeNode, key: List): node.last_access_time = time.monotonic() child_key = self.get_child_key_fn(key) value = [] while len(key) > 0 and child_key in node.children.keys(): child = node.children[child_key] child.last_access_time = time.monotonic() prefix_len = self.key_match_fn(child.key, key) if prefix_len < len(child.key): new_node = self._split_node(child.key, child, prefix_len) if not new_node.evicted: value.append(new_node.value) node = new_node break else: if not child.evicted: value.append(child.value) node = child key = key[prefix_len:] if len(key): child_key = self.get_child_key_fn(key) return value, node def _split_node(self, key, child: TreeNode, split_len: int): # child node split into new_node -> child new_node = TreeNode() new_node.children = {self.get_child_key_fn(key[split_len:]): child} new_node.parent = child.parent new_node.lock_ref = child.lock_ref new_node.key = child.key[:split_len] new_node.loading = child.loading new_node.hit_count = child.hit_count # split value and host value if exists if child.evicted: new_node.value = None else: new_node.value = child.value[:split_len] child.value = child.value[split_len:] if child.backuped: new_node.host_value = child.host_value[:split_len] child.host_value = child.host_value[split_len:] if child.hash_value: new_node.hash_value = child.hash_value[: split_len // self.page_size] child.hash_value = child.hash_value[split_len // self.page_size :] child.parent = new_node child.key = child.key[split_len:] new_node.parent.children[self.get_child_key_fn(key)] = new_node return new_node def insert(self, key: List, value, chunked=False): if len(key) == 0: return 0 node = self.root_node child_key = self.get_child_key_fn(key) total_prefix_length = 0 while len(key) > 0 and child_key in node.children.keys(): node = node.children[child_key] node.last_access_time = time.monotonic() prefix_len = self.key_match_fn(node.key, key) if prefix_len == len(node.key): if node.evicted: # change the reference if the node is evicted # this often happens in the case of KV cache recomputation node.value = value[:prefix_len] self.token_to_kv_pool_host.update_synced(node.host_value) self.evictable_size_ += len(node.value) else: self._inc_hit_count(node, chunked) total_prefix_length += prefix_len else: # partial match, split the node new_node = self._split_node(node.key, node, prefix_len) if new_node.evicted: new_node.value = value[:prefix_len] self.token_to_kv_pool_host.update_synced(new_node.host_value) self.evictable_size_ += len(new_node.value) else: self._inc_hit_count(new_node, chunked) total_prefix_length += prefix_len node = new_node key = key[prefix_len:] value = value[prefix_len:] if len(key): child_key = self.get_child_key_fn(key) if len(key): new_node = TreeNode() new_node.parent = node new_node.key = key new_node.value = value node.children[child_key] = new_node self.evictable_size_ += len(value) if self.enable_storage: last_hash = node.get_last_hash_value() assert (node == self.root_node) or ( last_hash is not None ), "Parent node must have a hash value with storage enabled" new_node.hash_value = [] for idx in range(0, len(key), self.page_size): new_node.hash_value.append( self.cache_controller.get_hash_str( key[idx : idx + self.page_size], prior_hash=last_hash, ) ) last_hash = new_node.hash_value[-1] if self.cache_controller.write_policy != "write_back": self._inc_hit_count(new_node, chunked) return total_prefix_length def _collect_leaves_device(self): def is_leaf(node): if node.evicted: return False if node == self.root_node: return False if len(node.children) == 0: return True for child in node.children.values(): if not child.evicted: return False return True ret_list = [] stack = [self.root_node] while stack: cur_node = stack.pop() if is_leaf(cur_node): ret_list.append(cur_node) else: for cur_child in cur_node.children.values(): if not cur_child.evicted: stack.append(cur_child) return ret_list def release_aborted_request(self, rid: str): if rid not in self.ongoing_prefetch: return last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[rid] if operation.host_indices is None: return completed_tokens, _ = self.cache_controller.terminate_prefetch(operation) if self.tp_world_size > 1: torch.distributed.barrier(group=self.tp_group) last_host_node.release_host() del self.ongoing_prefetch[rid] self.cache_controller.append_host_mem_release(host_indices[:completed_tokens]) self.cache_controller.prefetch_tokens_occupied -= len(token_ids)