Unverified Commit 1ccd59c7 authored by Xuchun Shang's avatar Xuchun Shang Committed by GitHub
Browse files

[HICache] introduce evict policy (#10190)


Signed-off-by: default avatarXuchun Shang <xuchun.shang@linux.alibaba.com>
Co-authored-by: default avatarTeng Ma <sima.mt@alibaba-inc.com>
parent c32fb7a2
...@@ -667,6 +667,7 @@ class Scheduler( ...@@ -667,6 +667,7 @@ class Scheduler(
else self.tp_cpu_group else self.tp_cpu_group
), ),
page_size=self.page_size, page_size=self.page_size,
eviction_policy=server_args.radix_eviction_policy,
hicache_ratio=server_args.hicache_ratio, hicache_ratio=server_args.hicache_ratio,
hicache_size=server_args.hicache_size, hicache_size=server_args.hicache_size,
hicache_write_policy=server_args.hicache_write_policy, hicache_write_policy=server_args.hicache_write_policy,
...@@ -719,6 +720,7 @@ class Scheduler( ...@@ -719,6 +720,7 @@ class Scheduler(
tp_size=self.tp_size, tp_size=self.tp_size,
rank=self.tp_rank, rank=self.tp_rank,
tp_group=self.tp_group, tp_group=self.tp_group,
eviction_policy=server_args.radix_eviction_policy,
) )
else: else:
self.tree_cache = RadixCache( self.tree_cache = RadixCache(
...@@ -727,6 +729,7 @@ class Scheduler( ...@@ -727,6 +729,7 @@ class Scheduler(
page_size=self.page_size, page_size=self.page_size,
disable=server_args.disable_radix_cache, disable=server_args.disable_radix_cache,
enable_kv_cache_events=self.enable_kv_cache_events, enable_kv_cache_events=self.enable_kv_cache_events,
eviction_policy=server_args.radix_eviction_policy,
) )
self.decode_mem_cache_buf_multiplier = ( self.decode_mem_cache_buf_multiplier = (
......
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Tuple, Union
if TYPE_CHECKING:
from sglang.srt.mem_cache.radix_cache import TreeNode
class EvictionStrategy(ABC):
@abstractmethod
def get_priority(self, node: "TreeNode") -> Union[float, Tuple]:
pass
class LRUStrategy(EvictionStrategy):
def get_priority(self, node: "TreeNode") -> float:
return node.last_access_time
class LFUStrategy(EvictionStrategy):
def get_priority(self, node: "TreeNode") -> Tuple[int, float]:
return (node.hit_count, node.last_access_time)
...@@ -39,6 +39,7 @@ class HiRadixCache(RadixCache): ...@@ -39,6 +39,7 @@ class HiRadixCache(RadixCache):
hicache_io_backend: str, hicache_io_backend: str,
hicache_mem_layout: str, hicache_mem_layout: str,
enable_metrics: bool, enable_metrics: bool,
eviction_policy: str = "lru",
hicache_storage_backend: Optional[str] = None, hicache_storage_backend: Optional[str] = None,
hicache_storage_prefetch_policy: Optional[str] = "best_effort", hicache_storage_prefetch_policy: Optional[str] = "best_effort",
model_name: Optional[str] = None, model_name: Optional[str] = None,
...@@ -117,8 +118,13 @@ class HiRadixCache(RadixCache): ...@@ -117,8 +118,13 @@ class HiRadixCache(RadixCache):
1 if hicache_write_policy == "write_through" else 2 1 if hicache_write_policy == "write_through" else 2
) )
self.load_back_threshold = 10 self.load_back_threshold = 10
super().__init__( super().__init__(
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False req_to_token_pool,
token_to_kv_pool_allocator,
page_size,
disable=False,
eviction_policy=eviction_policy,
) )
def reset(self): def reset(self):
...@@ -258,12 +264,15 @@ class HiRadixCache(RadixCache): ...@@ -258,12 +264,15 @@ class HiRadixCache(RadixCache):
def evict(self, num_tokens: int): def evict(self, num_tokens: int):
leaves = self._collect_leaves_device() leaves = self._collect_leaves_device()
heapq.heapify(leaves) eviction_heap = [
(self.eviction_strategy.get_priority(node), node) for node in leaves
]
heapq.heapify(eviction_heap)
num_evicted = 0 num_evicted = 0
write_back_nodes = [] write_back_nodes = []
while num_evicted < num_tokens and len(leaves): while num_evicted < num_tokens and len(eviction_heap):
x = heapq.heappop(leaves) _priority, x = heapq.heappop(eviction_heap)
if x.lock_ref > 0: if x.lock_ref > 0:
continue continue
...@@ -285,7 +294,8 @@ class HiRadixCache(RadixCache): ...@@ -285,7 +294,8 @@ class HiRadixCache(RadixCache):
break break
else: else:
# all children are evicted or no children # all children are evicted or no children
heapq.heappush(leaves, x.parent) new_priority = self.eviction_strategy.get_priority(x.parent)
heapq.heappush(eviction_heap, (new_priority, x.parent))
if self.cache_controller.write_policy == "write_back": if self.cache_controller.write_policy == "write_back":
self.writing_check(write_back=True) self.writing_check(write_back=True)
...@@ -310,11 +320,14 @@ class HiRadixCache(RadixCache): ...@@ -310,11 +320,14 @@ class HiRadixCache(RadixCache):
def evict_host(self, num_tokens: int): def evict_host(self, num_tokens: int):
leaves = self._collect_leaves() leaves = self._collect_leaves()
heapq.heapify(leaves) eviction_heap = [
(self.eviction_strategy.get_priority(node), node) for node in leaves
]
heapq.heapify(eviction_heap)
num_evicted = 0 num_evicted = 0
while num_evicted < num_tokens and len(leaves): while num_evicted < num_tokens and len(eviction_heap):
x = heapq.heappop(leaves) _priority, x = heapq.heappop(eviction_heap)
if x == self.root_node: if x == self.root_node:
break break
# only evict the host value of evicted nodes # only evict the host value of evicted nodes
...@@ -333,7 +346,8 @@ class HiRadixCache(RadixCache): ...@@ -333,7 +346,8 @@ class HiRadixCache(RadixCache):
del x.parent.children[k] del x.parent.children[k]
if len(x.parent.children) == 0 and x.parent.evicted: if len(x.parent.children) == 0 and x.parent.evicted:
heapq.heappush(leaves, x.parent) new_priority = self.eviction_strategy.get_priority(x.parent)
heapq.heappush(eviction_heap, (new_priority, x.parent))
def load_back( def load_back(
self, node: TreeNode, mem_quota: Optional[int] = None self, node: TreeNode, mem_quota: Optional[int] = None
......
...@@ -34,6 +34,7 @@ from sglang.srt.disaggregation.kv_events import ( ...@@ -34,6 +34,7 @@ from sglang.srt.disaggregation.kv_events import (
) )
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.evict_policy import EvictionStrategy, LFUStrategy, LRUStrategy
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -122,6 +123,7 @@ class RadixCache(BasePrefixCache): ...@@ -122,6 +123,7 @@ class RadixCache(BasePrefixCache):
page_size: int, page_size: int,
disable: bool = False, disable: bool = False,
enable_kv_cache_events: bool = False, enable_kv_cache_events: bool = False,
eviction_policy: str = "lru",
): ):
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
...@@ -141,6 +143,15 @@ class RadixCache(BasePrefixCache): ...@@ -141,6 +143,15 @@ class RadixCache(BasePrefixCache):
else: else:
self.key_match_fn = partial(_key_match_paged, page_size=page_size) self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = lambda key: tuple(key[:page_size]) self.get_child_key_fn = lambda key: tuple(key[:page_size])
if eviction_policy.lower() == "lru":
self.eviction_strategy: EvictionStrategy = LRUStrategy()
elif eviction_policy.lower() == "lfu":
self.eviction_strategy: EvictionStrategy = LFUStrategy()
else:
raise ValueError(
f"Unknown eviction policy: {eviction_policy}. Supported policies: 'lru', 'lfu'."
)
self.reset() self.reset()
##### Public API ##### ##### Public API #####
...@@ -296,11 +307,14 @@ class RadixCache(BasePrefixCache): ...@@ -296,11 +307,14 @@ class RadixCache(BasePrefixCache):
return return
leaves = self._collect_leaves() leaves = self._collect_leaves()
heapq.heapify(leaves) eviction_heap = [
(self.eviction_strategy.get_priority(node), node) for node in leaves
]
heapq.heapify(eviction_heap)
num_evicted = 0 num_evicted = 0
while num_evicted < num_tokens and len(leaves): while num_evicted < num_tokens and len(eviction_heap):
x = heapq.heappop(leaves) _priority, x = heapq.heappop(eviction_heap)
if x == self.root_node: if x == self.root_node:
break break
...@@ -312,7 +326,8 @@ class RadixCache(BasePrefixCache): ...@@ -312,7 +326,8 @@ class RadixCache(BasePrefixCache):
self._delete_leaf(x) self._delete_leaf(x)
if len(x.parent.children) == 0: if len(x.parent.children) == 0:
heapq.heappush(leaves, x.parent) new_priority = self.eviction_strategy.get_priority(x.parent)
heapq.heappush(eviction_heap, (new_priority, x.parent))
self._record_remove_event(x) self._record_remove_event(x)
......
...@@ -78,6 +78,7 @@ class LMCRadixCache(RadixCache): ...@@ -78,6 +78,7 @@ class LMCRadixCache(RadixCache):
tp_size: int = 1, tp_size: int = 1,
rank: int = 0, rank: int = 0,
tp_group: Optional[torch.distributed.ProcessGroup] = None, tp_group: Optional[torch.distributed.ProcessGroup] = None,
eviction_policy: str = "lru",
): ):
super().__init__( super().__init__(
req_to_token_pool=req_to_token_pool, req_to_token_pool=req_to_token_pool,
...@@ -85,6 +86,7 @@ class LMCRadixCache(RadixCache): ...@@ -85,6 +86,7 @@ class LMCRadixCache(RadixCache):
page_size=page_size, page_size=page_size,
disable=disable, disable=disable,
enable_kv_cache_events=enable_kv_cache_events, enable_kv_cache_events=enable_kv_cache_events,
eviction_policy=eviction_policy,
) )
kvcache = self.token_to_kv_pool_allocator.get_kvcache() kvcache = self.token_to_kv_pool_allocator.get_kvcache()
......
...@@ -185,6 +185,7 @@ class ServerArgs: ...@@ -185,6 +185,7 @@ class ServerArgs:
hybrid_kvcache_ratio: Optional[float] = None hybrid_kvcache_ratio: Optional[float] = None
swa_full_tokens_ratio: float = 0.8 swa_full_tokens_ratio: float = 0.8
disable_hybrid_swa_memory: bool = False disable_hybrid_swa_memory: bool = False
radix_eviction_policy: str = "lru"
# Runtime options # Runtime options
device: Optional[str] = None device: Optional[str] = None
...@@ -1907,6 +1908,13 @@ class ServerArgs: ...@@ -1907,6 +1908,13 @@ class ServerArgs:
default=ServerArgs.hicache_write_policy, default=ServerArgs.hicache_write_policy,
help="The write policy of hierarchical cache.", help="The write policy of hierarchical cache.",
) )
parser.add_argument(
"--radix-eviction-policy",
type=str,
choices=["lru", "lfu"],
default=ServerArgs.radix_eviction_policy,
help="The eviction policy of radix trees. 'lru' stands for Least Recently Used, 'lfu' stands for Least Frequently Used.",
)
parser.add_argument( parser.add_argument(
"--hicache-io-backend", "--hicache-io-backend",
type=str, type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment