You need to sign in or sign up before continuing.
Unverified Commit 08104b56 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Sanity check to prevent performance regression (#3171)


Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent cf142b6e
...@@ -149,6 +149,7 @@ class Scheduler: ...@@ -149,6 +149,7 @@ class Scheduler:
if not self.spec_algorithm.is_none() if not self.spec_algorithm.is_none()
else 1 else 1
) )
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
# Distributed rank info # Distributed rank info
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
...@@ -831,10 +832,16 @@ class Scheduler: ...@@ -831,10 +832,16 @@ class Scheduler:
available_size = ( available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
) )
if available_size != self.max_total_num_tokens: protected_size = self.tree_cache.protected_size()
memory_leak = available_size != (
self.max_total_num_tokens
if not self.enable_hierarchical_cache
else self.max_total_num_tokens - protected_size
)
if memory_leak:
msg = ( msg = (
"KV cache pool leak detected!" "KV cache pool leak detected!"
f"{available_size=}, {self.max_total_num_tokens=}\n" f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
) )
warnings.warn(msg) warnings.warn(msg)
if crash_on_warnings(): if crash_on_warnings():
...@@ -949,7 +956,14 @@ class Scheduler: ...@@ -949,7 +956,14 @@ class Scheduler:
res = adder.add_one_req(req) res = adder.add_one_req(req)
if res != AddReqResult.CONTINUE: if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN: if res == AddReqResult.NO_TOKEN:
self.batch_is_full = True if self.enable_hierarchical_cache:
# Set batch_is_full after making sure there are requests that can be served
self.batch_is_full = len(adder.can_run_list) > 0 or (
self.running_batch is not None
and not self.running_batch.is_empty()
)
else:
self.batch_is_full = True
break break
if self.server_args.prefill_only_one_req: if self.server_args.prefill_only_one_req:
break break
......
...@@ -41,6 +41,10 @@ class BasePrefixCache(ABC): ...@@ -41,6 +41,10 @@ class BasePrefixCache(ABC):
def evictable_size(self): def evictable_size(self):
pass pass
@abstractmethod
def protected_size(self):
raise NotImplementedError()
def total_size(self): def total_size(self):
raise NotImplementedError() raise NotImplementedError()
......
...@@ -85,3 +85,6 @@ class ChunkCache(BasePrefixCache): ...@@ -85,3 +85,6 @@ class ChunkCache(BasePrefixCache):
def evictable_size(self): def evictable_size(self):
return 0 return 0
def protected_size(self):
return 0
...@@ -34,7 +34,10 @@ if TYPE_CHECKING: ...@@ -34,7 +34,10 @@ if TYPE_CHECKING:
class TreeNode: class TreeNode:
def __init__(self):
counter = 0
def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode) self.children = defaultdict(TreeNode)
self.parent = None self.parent = None
self.key = None self.key = None
...@@ -42,6 +45,23 @@ class TreeNode: ...@@ -42,6 +45,23 @@ class TreeNode:
self.lock_ref = 0 self.lock_ref = 0
self.last_access_time = time.time() self.last_access_time = time.time()
self.hit_count = 0
# indicating the node is loading KV cache from host
self.loading = False
# store the host indices of KV cache
self.host_value = None
self.id = TreeNode.counter if id is None else id
TreeNode.counter += 1
@property
def evicted(self):
return self.value is None
@property
def backuped(self):
return self.host_value is not None
def __lt__(self, other: "TreeNode"): def __lt__(self, other: "TreeNode"):
return self.last_access_time < other.last_access_time return self.last_access_time < other.last_access_time
...@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache): ...@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache):
self.root_node.value = [] self.root_node.value = []
self.root_node.lock_ref = 1 self.root_node.lock_ref = 1
self.evictable_size_ = 0 self.evictable_size_ = 0
self.protected_size_ = 0
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]: def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
"""Find the matching prefix from the radix tree. """Find the matching prefix from the radix tree.
...@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache): ...@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache):
while node != self.root_node: while node != self.root_node:
if node.lock_ref == 0: if node.lock_ref == 0:
self.evictable_size_ -= len(node.value) self.evictable_size_ -= len(node.value)
self.protected_size_ += len(node.value)
delta -= len(node.value) delta -= len(node.value)
node.lock_ref += 1 node.lock_ref += 1
node = node.parent node = node.parent
...@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache): ...@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache):
while node != self.root_node: while node != self.root_node:
if node.lock_ref == 1: if node.lock_ref == 1:
self.evictable_size_ += len(node.value) self.evictable_size_ += len(node.value)
self.protected_size_ -= len(node.value)
delta += len(node.value) delta += len(node.value)
node.lock_ref -= 1 node.lock_ref -= 1
node = node.parent node = node.parent
...@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache): ...@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache):
def evictable_size(self): def evictable_size(self):
return self.evictable_size_ return self.evictable_size_
def protected_size(self):
# protected size refers to the size of the cache that is locked
return self.protected_size_
##### Internal Helper Functions ##### ##### Internal Helper Functions #####
def _match_prefix_helper( def _match_prefix_helper(
...@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache): ...@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache):
self.evictable_size_ -= len(node.key) self.evictable_size_ -= len(node.key)
def _total_size_helper(self, node: TreeNode): def _total_size_helper(self, node: TreeNode):
if node.evicted:
return 0
x = len(node.value) x = len(node.value)
for child in node.children.values(): for child in node.children.values():
x += self._total_size_helper(child) x += self._total_size_helper(child)
......
...@@ -163,6 +163,7 @@ class ServerArgs: ...@@ -163,6 +163,7 @@ class ServerArgs:
# Custom logit processor # Custom logit processor
enable_custom_logit_processor: bool = False enable_custom_logit_processor: bool = False
tool_call_parser: str = None tool_call_parser: str = None
enable_hierarchical_cache: bool = False
def __post_init__(self): def __post_init__(self):
# Set missing default values # Set missing default values
...@@ -892,6 +893,11 @@ class ServerArgs: ...@@ -892,6 +893,11 @@ class ServerArgs:
default=ServerArgs.tool_call_parser, default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.", help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
) )
parser.add_argument(
"--enable-hierarchical-cache",
action="store_true",
help="Enable hierarchical cache",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment