Unverified Commit 62b362b1 authored by luzengxiangcn's avatar luzengxiangcn Committed by GitHub
Browse files

Debug radixcache: refactor recursive helper methods (#3029)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 44d76463
...@@ -112,14 +112,12 @@ class RadixCache(BasePrefixCache): ...@@ -112,14 +112,12 @@ class RadixCache(BasePrefixCache):
if self.disable: if self.disable:
return [], self.root_node return [], self.root_node
value = [] value, last_node = self._match_prefix_helper(self.root_node, key)
last_node = [self.root_node]
self._match_prefix_helper(self.root_node, key, value, last_node)
if value: if value:
value = torch.concat(value) value = torch.concat(value)
else: else:
value = torch.tensor([], dtype=torch.int32) value = torch.tensor([], dtype=torch.int32)
return value, last_node[0] return value, last_node
def insert(self, key: List, value=None): def insert(self, key: List, value=None):
if self.disable: if self.disable:
...@@ -196,7 +194,7 @@ class RadixCache(BasePrefixCache): ...@@ -196,7 +194,7 @@ class RadixCache(BasePrefixCache):
print(f"#tokens: {self.total_size()}") print(f"#tokens: {self.total_size()}")
def total_size(self): def total_size(self):
return self._total_size_helper(self.root_node) return self._total_size_helper()
def evict(self, num_tokens: int, evict_callback: Callable): def evict(self, num_tokens: int, evict_callback: Callable):
if self.disable: if self.disable:
...@@ -258,24 +256,23 @@ class RadixCache(BasePrefixCache): ...@@ -258,24 +256,23 @@ class RadixCache(BasePrefixCache):
##### Internal Helper Functions ##### ##### Internal Helper Functions #####
def _match_prefix_helper( def _match_prefix_helper(self, node: TreeNode, key: List):
self, node: TreeNode, key: List, value, last_node: TreeNode
):
node.last_access_time = time.time() node.last_access_time = time.time()
if len(key) == 0: value = []
return while len(key) > 0 and key[0] in node.children.keys():
if key[0] in node.children.keys():
child = node.children[key[0]] child = node.children[key[0]]
child.last_access_time = time.time()
prefix_len = _key_match(child.key, key) prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key): if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len) new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value) value.append(new_node.value)
last_node[0] = new_node node = new_node
break
else: else:
value.append(child.value) value.append(child.value)
last_node[0] = child node = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node) key = key[prefix_len:]
return value, node
def _split_node(self, key, child: TreeNode, split_len: int): def _split_node(self, key, child: TreeNode, split_len: int):
# new_node -> child # new_node -> child
...@@ -296,22 +293,18 @@ class RadixCache(BasePrefixCache): ...@@ -296,22 +293,18 @@ class RadixCache(BasePrefixCache):
if len(key) == 0: if len(key) == 0:
return 0 return 0
if key[0] in node.children.keys(): total_prefix_length = 0
child = node.children[key[0]] while len(key) > 0 and key[0] in node.children.keys():
prefix_len = _key_match(child.key, key) node = node.children[key[0]]
node.last_access_time = time.time()
prefix_len = _key_match(node.key, key)
total_prefix_length += prefix_len
key = key[prefix_len:]
value = value[prefix_len:]
if prefix_len == len(child.key): if prefix_len < len(node.key):
if prefix_len == len(key): new_node = self._split_node(node.key, node, prefix_len)
return prefix_len node = new_node
else:
key = key[prefix_len:]
value = value[prefix_len:]
return prefix_len + self._insert_helper(child, key, value)
new_node = self._split_node(child.key, child, prefix_len)
return prefix_len + self._insert_helper(
new_node, key[prefix_len:], value[prefix_len:]
)
if len(key): if len(key):
new_node = TreeNode() new_node = TreeNode()
...@@ -320,12 +313,21 @@ class RadixCache(BasePrefixCache): ...@@ -320,12 +313,21 @@ class RadixCache(BasePrefixCache):
new_node.value = value new_node.value = value
node.children[key[0]] = new_node node.children[key[0]] = new_node
self.evictable_size_ += len(value) self.evictable_size_ += len(value)
return 0 return total_prefix_length
def _print_helper(self, node: TreeNode, indent: int): def _print_helper(self, node: TreeNode, indent: int):
for _, child in node.children.items(): """Prints the radix tree in a human-readable format."""
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}") stack = [(node, indent)]
self._print_helper(child, indent=indent + 2) while stack:
current_node, current_indent = stack.pop()
print(
" " * current_indent,
len(current_node.key),
current_node.key[:10],
f"r={current_node.lock_ref}",
)
for _, child in current_node.children.items():
stack.append((child, current_indent + 2))
def _delete_leaf(self, node): def _delete_leaf(self, node):
for k, v in node.parent.children.items(): for k, v in node.parent.children.items():
...@@ -334,13 +336,17 @@ class RadixCache(BasePrefixCache): ...@@ -334,13 +336,17 @@ class RadixCache(BasePrefixCache):
del node.parent.children[k] del node.parent.children[k]
self.evictable_size_ -= len(node.key) self.evictable_size_ -= len(node.key)
def _total_size_helper(self, node: TreeNode): def _total_size_helper(self):
if node.evicted: total_size = 0
return 0 stack = [self.root_node]
x = len(node.value) while stack:
for child in node.children.values(): current_node = stack.pop()
x += self._total_size_helper(child) total_size += len(current_node.value)
return x for child in current_node.children.values():
if child.evicted:
continue
stack.append(child)
return total_size
def _collect_leaves(self): def _collect_leaves(self):
ret_list = [] ret_list = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment