Unverified Commit 08104b56 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Sanity check to prevent performance regression (#3171)


Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent cf142b6e
...@@ -149,6 +149,7 @@ class Scheduler: ...@@ -149,6 +149,7 @@ class Scheduler:
if not self.spec_algorithm.is_none() if not self.spec_algorithm.is_none()
else 1 else 1
) )
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
# Distributed rank info # Distributed rank info
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
...@@ -831,10 +832,16 @@ class Scheduler: ...@@ -831,10 +832,16 @@ class Scheduler:
available_size = ( available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
) )
if available_size != self.max_total_num_tokens: protected_size = self.tree_cache.protected_size()
memory_leak = available_size != (
self.max_total_num_tokens
if not self.enable_hierarchical_cache
else self.max_total_num_tokens - protected_size
)
if memory_leak:
msg = ( msg = (
"KV cache pool leak detected!" "KV cache pool leak detected!"
f"{available_size=}, {self.max_total_num_tokens=}\n" f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
) )
warnings.warn(msg) warnings.warn(msg)
if crash_on_warnings(): if crash_on_warnings():
...@@ -949,7 +956,14 @@ class Scheduler: ...@@ -949,7 +956,14 @@ class Scheduler:
res = adder.add_one_req(req) res = adder.add_one_req(req)
if res != AddReqResult.CONTINUE: if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN: if res == AddReqResult.NO_TOKEN:
self.batch_is_full = True if self.enable_hierarchical_cache:
# Set batch_is_full after making sure there are requests that can be served
self.batch_is_full = len(adder.can_run_list) > 0 or (
self.running_batch is not None
and not self.running_batch.is_empty()
)
else:
self.batch_is_full = True
break break
if self.server_args.prefill_only_one_req: if self.server_args.prefill_only_one_req:
break break
......
...@@ -41,6 +41,10 @@ class BasePrefixCache(ABC): ...@@ -41,6 +41,10 @@ class BasePrefixCache(ABC):
def evictable_size(self): def evictable_size(self):
pass pass
@abstractmethod
def protected_size(self):
raise NotImplementedError()
def total_size(self): def total_size(self):
raise NotImplementedError() raise NotImplementedError()
......
...@@ -85,3 +85,6 @@ class ChunkCache(BasePrefixCache): ...@@ -85,3 +85,6 @@ class ChunkCache(BasePrefixCache):
def evictable_size(self): def evictable_size(self):
return 0 return 0
def protected_size(self):
return 0
...@@ -34,7 +34,10 @@ if TYPE_CHECKING: ...@@ -34,7 +34,10 @@ if TYPE_CHECKING:
class TreeNode: class TreeNode:
def __init__(self):
counter = 0
def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode) self.children = defaultdict(TreeNode)
self.parent = None self.parent = None
self.key = None self.key = None
...@@ -42,6 +45,23 @@ class TreeNode: ...@@ -42,6 +45,23 @@ class TreeNode:
self.lock_ref = 0 self.lock_ref = 0
self.last_access_time = time.time() self.last_access_time = time.time()
self.hit_count = 0
# indicating the node is loading KV cache from host
self.loading = False
# store the host indices of KV cache
self.host_value = None
self.id = TreeNode.counter if id is None else id
TreeNode.counter += 1
@property
def evicted(self):
return self.value is None
@property
def backuped(self):
return self.host_value is not None
def __lt__(self, other: "TreeNode"): def __lt__(self, other: "TreeNode"):
return self.last_access_time < other.last_access_time return self.last_access_time < other.last_access_time
...@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache): ...@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache):
self.root_node.value = [] self.root_node.value = []
self.root_node.lock_ref = 1 self.root_node.lock_ref = 1
self.evictable_size_ = 0 self.evictable_size_ = 0
self.protected_size_ = 0
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]: def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
"""Find the matching prefix from the radix tree. """Find the matching prefix from the radix tree.
...@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache): ...@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache):
while node != self.root_node: while node != self.root_node:
if node.lock_ref == 0: if node.lock_ref == 0:
self.evictable_size_ -= len(node.value) self.evictable_size_ -= len(node.value)
self.protected_size_ += len(node.value)
delta -= len(node.value) delta -= len(node.value)
node.lock_ref += 1 node.lock_ref += 1
node = node.parent node = node.parent
...@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache): ...@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache):
while node != self.root_node: while node != self.root_node:
if node.lock_ref == 1: if node.lock_ref == 1:
self.evictable_size_ += len(node.value) self.evictable_size_ += len(node.value)
self.protected_size_ -= len(node.value)
delta += len(node.value) delta += len(node.value)
node.lock_ref -= 1 node.lock_ref -= 1
node = node.parent node = node.parent
...@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache): ...@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache):
def evictable_size(self): def evictable_size(self):
return self.evictable_size_ return self.evictable_size_
def protected_size(self):
# protected size refers to the size of the cache that is locked
return self.protected_size_
##### Internal Helper Functions ##### ##### Internal Helper Functions #####
def _match_prefix_helper( def _match_prefix_helper(
...@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache): ...@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache):
self.evictable_size_ -= len(node.key) self.evictable_size_ -= len(node.key)
def _total_size_helper(self, node: TreeNode): def _total_size_helper(self, node: TreeNode):
if node.evicted:
return 0
x = len(node.value) x = len(node.value)
for child in node.children.values(): for child in node.children.values():
x += self._total_size_helper(child) x += self._total_size_helper(child)
......
...@@ -163,6 +163,7 @@ class ServerArgs: ...@@ -163,6 +163,7 @@ class ServerArgs:
# Custom logit processor # Custom logit processor
enable_custom_logit_processor: bool = False enable_custom_logit_processor: bool = False
tool_call_parser: str = None tool_call_parser: str = None
enable_hierarchical_cache: bool = False
def __post_init__(self): def __post_init__(self):
# Set missing default values # Set missing default values
...@@ -892,6 +893,11 @@ class ServerArgs: ...@@ -892,6 +893,11 @@ class ServerArgs:
default=ServerArgs.tool_call_parser, default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.", help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
) )
parser.add_argument(
"--enable-hierarchical-cache",
action="store_true",
help="Enable hierarchical cache",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment