Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
62b362b1
Unverified
Commit
62b362b1
authored
Mar 06, 2025
by
luzengxiangcn
Committed by
GitHub
Mar 05, 2025
Browse files
Debug radixcache: refactor recursive helper methods (#3029)
Co-authored-by:
Zhiqiang Xie
<
xiezhq@stanford.edu
>
parent
44d76463
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
47 additions
and
41 deletions
+47
-41
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+47
-41
No files found.
python/sglang/srt/mem_cache/radix_cache.py
View file @
62b362b1
...
@@ -112,14 +112,12 @@ class RadixCache(BasePrefixCache):
...
@@ -112,14 +112,12 @@ class RadixCache(BasePrefixCache):
if
self
.
disable
:
if
self
.
disable
:
return
[],
self
.
root_node
return
[],
self
.
root_node
value
=
[]
value
,
last_node
=
self
.
_match_prefix_helper
(
self
.
root_node
,
key
)
last_node
=
[
self
.
root_node
]
self
.
_match_prefix_helper
(
self
.
root_node
,
key
,
value
,
last_node
)
if
value
:
if
value
:
value
=
torch
.
concat
(
value
)
value
=
torch
.
concat
(
value
)
else
:
else
:
value
=
torch
.
tensor
([],
dtype
=
torch
.
int32
)
value
=
torch
.
tensor
([],
dtype
=
torch
.
int32
)
return
value
,
last_node
[
0
]
return
value
,
last_node
def
insert
(
self
,
key
:
List
,
value
=
None
):
def
insert
(
self
,
key
:
List
,
value
=
None
):
if
self
.
disable
:
if
self
.
disable
:
...
@@ -196,7 +194,7 @@ class RadixCache(BasePrefixCache):
...
@@ -196,7 +194,7 @@ class RadixCache(BasePrefixCache):
print
(
f
"#tokens:
{
self
.
total_size
()
}
"
)
print
(
f
"#tokens:
{
self
.
total_size
()
}
"
)
def
total_size
(
self
):
def
total_size
(
self
):
return
self
.
_total_size_helper
(
self
.
root_node
)
return
self
.
_total_size_helper
()
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
:
Callable
):
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
:
Callable
):
if
self
.
disable
:
if
self
.
disable
:
...
@@ -258,24 +256,23 @@ class RadixCache(BasePrefixCache):
...
@@ -258,24 +256,23 @@ class RadixCache(BasePrefixCache):
##### Internal Helper Functions #####
##### Internal Helper Functions #####
def
_match_prefix_helper
(
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
self
,
node
:
TreeNode
,
key
:
List
,
value
,
last_node
:
TreeNode
):
node
.
last_access_time
=
time
.
time
()
node
.
last_access_time
=
time
.
time
()
if
len
(
key
)
==
0
:
value
=
[]
return
while
len
(
key
)
>
0
and
key
[
0
]
in
node
.
children
.
keys
():
if
key
[
0
]
in
node
.
children
.
keys
():
child
=
node
.
children
[
key
[
0
]]
child
=
node
.
children
[
key
[
0
]]
child
.
last_access_time
=
time
.
time
()
prefix_len
=
_key_match
(
child
.
key
,
key
)
prefix_len
=
_key_match
(
child
.
key
,
key
)
if
prefix_len
<
len
(
child
.
key
):
if
prefix_len
<
len
(
child
.
key
):
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
value
.
append
(
new_node
.
value
)
value
.
append
(
new_node
.
value
)
last_node
[
0
]
=
new_node
node
=
new_node
break
else
:
else
:
value
.
append
(
child
.
value
)
value
.
append
(
child
.
value
)
last_node
[
0
]
=
child
node
=
child
self
.
_match_prefix_helper
(
child
,
key
[
prefix_len
:],
value
,
last_node
)
key
=
key
[
prefix_len
:]
return
value
,
node
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
:
int
):
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
:
int
):
# new_node -> child
# new_node -> child
...
@@ -296,22 +293,18 @@ class RadixCache(BasePrefixCache):
...
@@ -296,22 +293,18 @@ class RadixCache(BasePrefixCache):
if
len
(
key
)
==
0
:
if
len
(
key
)
==
0
:
return
0
return
0
if
key
[
0
]
in
node
.
children
.
keys
():
total_prefix_length
=
0
child
=
node
.
children
[
key
[
0
]]
while
len
(
key
)
>
0
and
key
[
0
]
in
node
.
children
.
keys
():
prefix_len
=
_key_match
(
child
.
key
,
key
)
node
=
node
.
children
[
key
[
0
]]
node
.
last_access_time
=
time
.
time
()
prefix_len
=
_key_match
(
node
.
key
,
key
)
total_prefix_length
+=
prefix_len
key
=
key
[
prefix_len
:]
value
=
value
[
prefix_len
:]
if
prefix_len
==
len
(
child
.
key
):
if
prefix_len
<
len
(
node
.
key
):
if
prefix_len
==
len
(
key
):
new_node
=
self
.
_split_node
(
node
.
key
,
node
,
prefix_len
)
return
prefix_len
node
=
new_node
else
:
key
=
key
[
prefix_len
:]
value
=
value
[
prefix_len
:]
return
prefix_len
+
self
.
_insert_helper
(
child
,
key
,
value
)
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
return
prefix_len
+
self
.
_insert_helper
(
new_node
,
key
[
prefix_len
:],
value
[
prefix_len
:]
)
if
len
(
key
):
if
len
(
key
):
new_node
=
TreeNode
()
new_node
=
TreeNode
()
...
@@ -320,12 +313,21 @@ class RadixCache(BasePrefixCache):
...
@@ -320,12 +313,21 @@ class RadixCache(BasePrefixCache):
new_node
.
value
=
value
new_node
.
value
=
value
node
.
children
[
key
[
0
]]
=
new_node
node
.
children
[
key
[
0
]]
=
new_node
self
.
evictable_size_
+=
len
(
value
)
self
.
evictable_size_
+=
len
(
value
)
return
0
return
total_prefix_length
def
_print_helper
(
self
,
node
:
TreeNode
,
indent
:
int
):
def
_print_helper
(
self
,
node
:
TreeNode
,
indent
:
int
):
for
_
,
child
in
node
.
children
.
items
():
"""Prints the radix tree in a human-readable format."""
print
(
" "
*
indent
,
len
(
child
.
key
),
child
.
key
[:
10
],
f
"r=
{
child
.
lock_ref
}
"
)
stack
=
[(
node
,
indent
)]
self
.
_print_helper
(
child
,
indent
=
indent
+
2
)
while
stack
:
current_node
,
current_indent
=
stack
.
pop
()
print
(
" "
*
current_indent
,
len
(
current_node
.
key
),
current_node
.
key
[:
10
],
f
"r=
{
current_node
.
lock_ref
}
"
,
)
for
_
,
child
in
current_node
.
children
.
items
():
stack
.
append
((
child
,
current_indent
+
2
))
def
_delete_leaf
(
self
,
node
):
def
_delete_leaf
(
self
,
node
):
for
k
,
v
in
node
.
parent
.
children
.
items
():
for
k
,
v
in
node
.
parent
.
children
.
items
():
...
@@ -334,13 +336,17 @@ class RadixCache(BasePrefixCache):
...
@@ -334,13 +336,17 @@ class RadixCache(BasePrefixCache):
del
node
.
parent
.
children
[
k
]
del
node
.
parent
.
children
[
k
]
self
.
evictable_size_
-=
len
(
node
.
key
)
self
.
evictable_size_
-=
len
(
node
.
key
)
def
_total_size_helper
(
self
,
node
:
TreeNode
):
def
_total_size_helper
(
self
):
if
node
.
evicted
:
total_size
=
0
return
0
stack
=
[
self
.
root_node
]
x
=
len
(
node
.
value
)
while
stack
:
for
child
in
node
.
children
.
values
():
current_node
=
stack
.
pop
()
x
+=
self
.
_total_size_helper
(
child
)
total_size
+=
len
(
current_node
.
value
)
return
x
for
child
in
current_node
.
children
.
values
():
if
child
.
evicted
:
continue
stack
.
append
(
child
)
return
total_size
def
_collect_leaves
(
self
):
def
_collect_leaves
(
self
):
ret_list
=
[]
ret_list
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment