Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
48d6bea1
Unverified
Commit
48d6bea1
authored
Nov 03, 2025
by
Hanming Lu
Committed by
GitHub
Nov 04, 2025
Browse files
[GDN/SWA] mamba and swa radix cache edge case fix (#12111)
Co-authored-by:
yizhang2077
<
1109276519@qq.com
>
parent
1689c0e3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
73 additions
and
40 deletions
+73
-40
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+6
-0
python/sglang/srt/mem_cache/mamba_radix_cache.py
python/sglang/srt/mem_cache/mamba_radix_cache.py
+45
-28
python/sglang/srt/mem_cache/swa_radix_cache.py
python/sglang/srt/mem_cache/swa_radix_cache.py
+22
-12
No files found.
python/sglang/srt/managers/scheduler.py
View file @
48d6bea1
...
...
@@ -2386,6 +2386,12 @@ class Scheduler(
-
self
.
tree_cache
.
swa_evictable_size
()
)
num_tokens
=
max
(
num_tokens_full
,
num_tokens_swa
)
elif
self
.
is_hybrid_gdn
:
num_tokens
=
(
self
.
max_total_num_tokens
-
self
.
token_to_kv_pool_allocator
.
available_size
()
-
self
.
tree_cache
.
full_evictable_size
()
)
else
:
num_tokens
=
(
self
.
max_total_num_tokens
...
...
python/sglang/srt/mem_cache/mamba_radix_cache.py
View file @
48d6bea1
...
...
@@ -20,11 +20,11 @@ The radix tree data structure for managing the hybrid (full and Mamba) KV cache.
"""
import
heapq
import
time
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
import
torch
from
numpy
import
float64
from
sglang.srt.mem_cache.allocator
import
TokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
,
MatchResult
...
...
@@ -46,6 +46,7 @@ logger = logging.getLogger(__name__)
class
TreeNode
:
counter
=
0
last_access_time_counter_float
=
float64
(
1.0
)
def
__init__
(
self
,
id
:
Optional
[
int
]
=
None
):
self
.
children
=
defaultdict
(
TreeNode
)
...
...
@@ -61,7 +62,7 @@ class TreeNode:
self
.
full_lock_ref
=
0
self
.
mamba_lock_ref
=
0
# last access time is only used for sanity check. LRU is maintained by the lru list.
self
.
last_access_time
=
time
.
monotonic
()
self
.
last_access_time
=
get_last_access_time
()
self
.
hit_count
=
0
# store the host indices of KV cache
...
...
@@ -90,6 +91,12 @@ class TreeNode:
return
self
.
last_access_time
<
other
.
last_access_time
def
get_last_access_time
()
->
float64
:
ret
=
TreeNode
.
last_access_time_counter_float
TreeNode
.
last_access_time_counter_float
+=
1.0
return
ret
class
LRUList
:
def
__init__
(
self
,
mamba
:
bool
=
False
):
self
.
mamba
=
mamba
...
...
@@ -382,8 +389,6 @@ class MambaRadixCache(BasePrefixCache):
# copy mamba state to req local space if cow is true
if
cow_mamba
and
last_node
.
mamba_value
is
not
None
:
assert
req
.
req_pool_idx
is
None
# req_pool_idx is uninitialed
# for reqs without mamba cache
if
req
.
mamba_pool_idx
is
None
:
dst_index
=
self
.
req_to_token_pool
.
mamba_pool
.
alloc
(
1
)
...
...
@@ -421,7 +426,7 @@ class MambaRadixCache(BasePrefixCache):
value
=
torch
.
tensor
([
x
for
x
in
key
.
token_ids
],
dtype
=
torch
.
int64
)
return
self
.
_insert_helper
(
self
.
root_node
,
key
,
value
,
mamba_value
)
def
cache_finished_req
(
self
,
req
:
Req
)
->
None
:
def
cache_finished_req
(
self
,
req
:
Req
,
is_insert
=
True
)
->
None
:
"""Cache request when it finishes."""
if
self
.
disable
:
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
...
...
@@ -449,15 +454,20 @@ class MambaRadixCache(BasePrefixCache):
.
clone
()
)
if
is_insert
:
new_prefix_len
,
mamba_exist
=
self
.
insert
(
RadixKey
(
token_ids
[:
page_aligned_len
],
req
.
extra_key
),
page_aligned_kv_indices
,
mamba_value
,
)
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
]
)
else
:
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
page_aligned_len
]
)
mamba_exist
=
True
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
,
free_mamba_cache
=
mamba_exist
)
self
.
dec_lock_ref
(
req
.
last_node
)
...
...
@@ -767,15 +777,18 @@ class MambaRadixCache(BasePrefixCache):
# update time for matched nodes, and make nodes closer to root to be least recently used
# this allows mamba to evict nodes closer to root first
self
.
full_lru_list
.
reset_node_and_parents_mru
(
best_last_node
,
self
.
root_node
)
self
.
mamba_lru_list
.
reset_node_and_parents_mru
(
best_last_node
,
self
.
root_node
)
node_update
=
best_last_node
self
.
full_lru_list
.
reset_node_and_parents_mru
(
node_update
,
self
.
root_node
)
self
.
mamba_lru_list
.
reset_node_and_parents_mru
(
node_update
,
self
.
root_node
)
# This last_access_time is for sanity check, can be deleted after validation in production
cur_time
=
time
.
monotonic
()
while
node
:
node
.
last_access_time
=
cur_time
cur_time
-=
0.0001
node
=
node
.
parent
cur_time
=
get_last_access_time
()
while
node_update
:
node_update
.
last_access_time
=
cur_time
cur_time
-=
(
0.00001
# assuming less than 100000 nodes in a branch of the tree
)
node_update
=
node_update
.
parent
return
value
[:
best_value_len
],
best_last_node
...
...
@@ -791,7 +804,7 @@ class MambaRadixCache(BasePrefixCache):
new_node
.
value
=
child
.
value
[:
split_len
]
# child time should be later than parent's time for mamba tombstone
child
.
last_access_time
=
time
.
monotonic
()
child
.
last_access_time
=
get_last_access_time
()
self
.
full_lru_list
.
remove_node
(
child
)
if
child
.
mamba_value
is
not
None
:
...
...
@@ -819,7 +832,7 @@ class MambaRadixCache(BasePrefixCache):
# Update the last access time from root to leaf, so that
# mamba will tombstone the node closer to root first
assert
mamba_value
is
not
None
,
"Mamba value should not be None here."
node
.
last_access_time
=
time
.
monotonic
()
node
.
last_access_time
=
get_last_access_time
()
if
node
!=
self
.
root_node
:
self
.
full_lru_list
.
reset_node_mru
(
node
)
if
node
.
mamba_value
is
not
None
:
...
...
@@ -832,7 +845,7 @@ class MambaRadixCache(BasePrefixCache):
total_prefix_length
=
0
while
len
(
key
)
>
0
and
child_key
in
node
.
children
.
keys
():
node
=
node
.
children
[
child_key
]
node
.
last_access_time
=
time
.
monotonic
()
node
.
last_access_time
=
get_last_access_time
()
self
.
full_lru_list
.
reset_node_mru
(
node
)
if
node
.
mamba_value
is
not
None
:
self
.
mamba_lru_list
.
reset_node_mru
(
node
)
...
...
@@ -856,17 +869,21 @@ class MambaRadixCache(BasePrefixCache):
new_node
.
value
=
value
new_node
.
mamba_value
=
mamba_value
self
.
full_lru_list
.
insert_mru
(
new_node
)
self
.
full_evictable_size_
+=
len
(
value
)
self
.
mamba_evictable_size_
+=
len
(
mamba_value
)
self
.
mamba_lru_list
.
insert_mru
(
new_node
)
node
.
children
[
child_key
]
=
new_node
self
.
full_evictable_size_
+=
len
(
value
)
self
.
mamba_evictable_size_
+=
len
(
mamba_value
)
elif
node
.
mamba_value
is
None
:
# add for mamba tombstone
node
.
mamba_value
=
mamba_value
self
.
mamba_evictable_size_
+=
len
(
mamba_valu
e
)
self
.
full_lru_list
.
reset_node_mru
(
nod
e
)
self
.
mamba_lru_list
.
insert_mru
(
node
)
else
:
self
.
mamba_evictable_size_
+=
len
(
mamba_value
)
node
.
last_access_time
=
get_last_access_time
()
else
:
# mamba value already exists
mamba_value_exist
=
True
self
.
full_lru_list
.
reset_node_mru
(
node
)
self
.
mamba_lru_list
.
reset_node_mru
(
node
)
node
.
last_access_time
=
get_last_access_time
()
return
total_prefix_length
,
mamba_value_exist
...
...
python/sglang/srt/mem_cache/swa_radix_cache.py
View file @
48d6bea1
...
...
@@ -20,12 +20,12 @@ The radix tree data structure for managing the hybrid (full and SWA) KV cache.
"""
import
heapq
import
time
from
collections
import
defaultdict
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
import
torch
from
numpy
import
float64
from
sglang.srt.mem_cache.allocator
import
SWATokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
,
MatchResult
...
...
@@ -50,6 +50,7 @@ class TreeNode:
counter
=
0
swa_uuid_counter
=
1
last_access_time_counter_float
=
float64
(
1.0
)
def
__init__
(
self
,
id
:
Optional
[
int
]
=
None
):
self
.
children
=
defaultdict
(
TreeNode
)
...
...
@@ -64,7 +65,7 @@ class TreeNode:
self
.
full_lock_ref
=
0
self
.
swa_lock_ref
=
0
# last access time is only used for sanity check. LRU is maintained by the lru list.
self
.
last_access_time
=
time
.
monotonic
()
self
.
last_access_time
=
get_last_access_time
()
self
.
hit_count
=
0
# store the host indices of KV cache
...
...
@@ -99,6 +100,12 @@ def gen_swa_uuid() -> int:
return
TreeNode
.
swa_uuid_counter
def
get_last_access_time
()
->
float64
:
ret
=
TreeNode
.
last_access_time_counter_float
TreeNode
.
last_access_time_counter_float
+=
1.0
return
ret
class
LRUList
:
def
__init__
(
self
,
swa
:
bool
=
False
):
self
.
swa
=
swa
...
...
@@ -841,15 +848,18 @@ class SWARadixCache(BasePrefixCache):
# update time for matched nodes, and make nodes closer to root to be least recently used
# this allows swa to evict nodes closer to root first
self
.
full_lru_list
.
reset_node_and_parents_mru
(
best_last_node
,
self
.
root_node
)
self
.
swa_lru_list
.
reset_node_and_parents_mru
(
best_last_node
,
self
.
root_node
)
node_update
=
best_last_node
self
.
full_lru_list
.
reset_node_and_parents_mru
(
node_update
,
self
.
root_node
)
self
.
swa_lru_list
.
reset_node_and_parents_mru
(
node_update
,
self
.
root_node
)
# This last_access_time is for sanity check, can be deleted after validation in production
cur_time
=
time
.
monotonic
()
while
node
:
node
.
last_access_time
=
cur_time
cur_time
-=
0.0001
node
=
node
.
parent
cur_time
=
get_last_access_time
()
while
node_update
:
node_update
.
last_access_time
=
cur_time
cur_time
-=
(
0.00001
# assuming less than 100000 nodes in a branch of the tree
)
node_update
=
node_update
.
parent
return
value
[:
best_value_len
],
best_last_node
...
...
@@ -867,7 +877,7 @@ class SWARadixCache(BasePrefixCache):
new_node
.
swa_uuid
=
child
.
swa_uuid
child
.
swa_uuid
=
None
# child time should be later than parent's time for swa tombstone
child
.
last_access_time
=
time
.
monotonic
()
child
.
last_access_time
=
get_last_access_time
()
# remove the child from the lru lists because it is being split
self
.
full_lru_list
.
remove_node
(
child
)
...
...
@@ -892,7 +902,7 @@ class SWARadixCache(BasePrefixCache):
)
->
int
:
# Update the last access time from root to leaf, so that
# swa will tombstone the node closer to root first
node
.
last_access_time
=
time
.
monotonic
()
node
.
last_access_time
=
get_last_access_time
()
if
node
!=
self
.
root_node
:
self
.
full_lru_list
.
reset_node_mru
(
node
)
if
not
node
.
swa_tombstone
:
...
...
@@ -905,7 +915,7 @@ class SWARadixCache(BasePrefixCache):
total_prefix_length
=
0
while
len
(
key
)
>
0
and
child_key
in
node
.
children
.
keys
():
node
=
node
.
children
[
child_key
]
node
.
last_access_time
=
time
.
monotonic
()
node
.
last_access_time
=
get_last_access_time
()
self
.
full_lru_list
.
reset_node_mru
(
node
)
if
not
node
.
swa_tombstone
:
self
.
swa_lru_list
.
reset_node_mru
(
node
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment