Unverified Commit 08104b56 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Sanity check to prevent performance regression (#3171)


Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent cf142b6e
......@@ -149,6 +149,7 @@ class Scheduler:
if not self.spec_algorithm.is_none()
else 1
)
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
# Distributed rank info
self.dp_size = server_args.dp_size
......@@ -831,10 +832,16 @@ class Scheduler:
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
if available_size != self.max_total_num_tokens:
protected_size = self.tree_cache.protected_size()
memory_leak = available_size != (
self.max_total_num_tokens
if not self.enable_hierarchical_cache
else self.max_total_num_tokens - protected_size
)
if memory_leak:
msg = (
"KV cache pool leak detected!"
f"{available_size=}, {self.max_total_num_tokens=}\n"
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
)
warnings.warn(msg)
if crash_on_warnings():
......@@ -949,7 +956,14 @@ class Scheduler:
res = adder.add_one_req(req)
if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN:
self.batch_is_full = True
if self.enable_hierarchical_cache:
# Set batch_is_full after making sure there are requests that can be served
self.batch_is_full = len(adder.can_run_list) > 0 or (
self.running_batch is not None
and not self.running_batch.is_empty()
)
else:
self.batch_is_full = True
break
if self.server_args.prefill_only_one_req:
break
......
......@@ -41,6 +41,10 @@ class BasePrefixCache(ABC):
def evictable_size(self):
pass
@abstractmethod
def protected_size(self):
raise NotImplementedError()
def total_size(self):
raise NotImplementedError()
......
......@@ -85,3 +85,6 @@ class ChunkCache(BasePrefixCache):
def evictable_size(self):
return 0
def protected_size(self):
return 0
......@@ -34,7 +34,10 @@ if TYPE_CHECKING:
class TreeNode:
def __init__(self):
counter = 0
def __init__(self, id: Optional[int] = None):
self.children = defaultdict(TreeNode)
self.parent = None
self.key = None
......@@ -42,6 +45,23 @@ class TreeNode:
self.lock_ref = 0
self.last_access_time = time.time()
self.hit_count = 0
# indicating the node is loading KV cache from host
self.loading = False
# store the host indices of KV cache
self.host_value = None
self.id = TreeNode.counter if id is None else id
TreeNode.counter += 1
@property
def evicted(self):
return self.value is None
@property
def backuped(self):
return self.host_value is not None
def __lt__(self, other: "TreeNode"):
return self.last_access_time < other.last_access_time
......@@ -75,6 +95,7 @@ class RadixCache(BasePrefixCache):
self.root_node.value = []
self.root_node.lock_ref = 1
self.evictable_size_ = 0
self.protected_size_ = 0
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
"""Find the matching prefix from the radix tree.
......@@ -203,6 +224,7 @@ class RadixCache(BasePrefixCache):
while node != self.root_node:
if node.lock_ref == 0:
self.evictable_size_ -= len(node.value)
self.protected_size_ += len(node.value)
delta -= len(node.value)
node.lock_ref += 1
node = node.parent
......@@ -216,6 +238,7 @@ class RadixCache(BasePrefixCache):
while node != self.root_node:
if node.lock_ref == 1:
self.evictable_size_ += len(node.value)
self.protected_size_ -= len(node.value)
delta += len(node.value)
node.lock_ref -= 1
node = node.parent
......@@ -224,6 +247,10 @@ class RadixCache(BasePrefixCache):
def evictable_size(self):
return self.evictable_size_
def protected_size(self):
# protected size refers to the size of the cache that is locked
return self.protected_size_
##### Internal Helper Functions #####
def _match_prefix_helper(
......@@ -303,6 +330,8 @@ class RadixCache(BasePrefixCache):
self.evictable_size_ -= len(node.key)
def _total_size_helper(self, node: TreeNode):
if node.evicted:
return 0
x = len(node.value)
for child in node.children.values():
x += self._total_size_helper(child)
......
......@@ -163,6 +163,7 @@ class ServerArgs:
# Custom logit processor
enable_custom_logit_processor: bool = False
tool_call_parser: str = None
enable_hierarchical_cache: bool = False
def __post_init__(self):
# Set missing default values
......@@ -892,6 +893,11 @@ class ServerArgs:
default=ServerArgs.tool_call_parser,
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
)
parser.add_argument(
"--enable-hierarchical-cache",
action="store_true",
help="Enable hierarchical cache",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment