Unverified Commit 1a31229c authored by Alex Chi Z's avatar Alex Chi Z Committed by GitHub
Browse files

fix: radix cache memory accounting (#10637)


Signed-off-by: default avatarAlex Chi Z <iskyzh@gmail.com>
parent de89ef49
...@@ -267,7 +267,7 @@ class RadixCache(BasePrefixCache): ...@@ -267,7 +267,7 @@ class RadixCache(BasePrefixCache):
""" """
key.token_ids = self.key_convert_fn(key.token_ids) key.token_ids = self.key_convert_fn(key.token_ids)
if self.disable or len(key) == 0: def empty_match_result():
return MatchResult( return MatchResult(
device_indices=torch.empty( device_indices=torch.empty(
(0,), (0,),
...@@ -278,10 +278,16 @@ class RadixCache(BasePrefixCache): ...@@ -278,10 +278,16 @@ class RadixCache(BasePrefixCache):
last_host_node=self.root_node, last_host_node=self.root_node,
) )
if self.disable or len(key) == 0:
return empty_match_result()
if self.page_size != 1: if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size page_aligned_len = len(key) // self.page_size * self.page_size
key = key[:page_aligned_len] key = key[:page_aligned_len]
if len(key) == 0:
return empty_match_result()
value, last_node = self._match_prefix_helper(self.root_node, key) value, last_node = self._match_prefix_helper(self.root_node, key)
if value: if value:
value = torch.cat(value) value = torch.cat(value)
...@@ -475,9 +481,9 @@ class RadixCache(BasePrefixCache): ...@@ -475,9 +481,9 @@ class RadixCache(BasePrefixCache):
delta = 0 delta = 0
while node != self.root_node: while node != self.root_node:
if node.lock_ref == 0: if node.lock_ref == 0:
self.evictable_size_ -= len(node.value) self.evictable_size_ -= len(node.key)
self.protected_size_ += len(node.value) self.protected_size_ += len(node.key)
delta -= len(node.value) delta -= len(node.key)
node.lock_ref += 1 node.lock_ref += 1
node = node.parent node = node.parent
return delta return delta
...@@ -489,9 +495,9 @@ class RadixCache(BasePrefixCache): ...@@ -489,9 +495,9 @@ class RadixCache(BasePrefixCache):
delta = 0 delta = 0
while node != self.root_node: while node != self.root_node:
if node.lock_ref == 1: if node.lock_ref == 1:
self.evictable_size_ += len(node.value) self.evictable_size_ += len(node.key)
self.protected_size_ -= len(node.value) self.protected_size_ -= len(node.key)
delta += len(node.value) delta += len(node.key)
node.lock_ref -= 1 node.lock_ref -= 1
node = node.parent node = node.parent
return delta return delta
...@@ -589,7 +595,7 @@ class RadixCache(BasePrefixCache): ...@@ -589,7 +595,7 @@ class RadixCache(BasePrefixCache):
new_node.key = key new_node.key = key
new_node.value = value new_node.value = value
node.children[child_key] = new_node node.children[child_key] = new_node
self.evictable_size_ += len(value) self.evictable_size_ += len(key)
self._record_store_event(new_node) self._record_store_event(new_node)
return total_prefix_length return total_prefix_length
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment