Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
54e872d3
"docs/source/ja/index.md" did not exist on "8a7306457678dad1246ff767553c6200802828d4"
Unverified
Commit
54e872d3
authored
Aug 29, 2025
by
Zhiqiang Xie
Committed by
GitHub
Aug 30, 2025
Browse files
[HiCache] resolve conflict between chunked-prefill and hicache hit count (#9776)
parent
e5b29bf1
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
20 additions
and
17 deletions
+20
-17
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+1
-1
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+1
-1
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+9
-8
python/sglang/srt/mem_cache/lora_radix_cache.py
python/sglang/srt/mem_cache/lora_radix_cache.py
+1
-1
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+5
-3
python/sglang/srt/mem_cache/radix_cache_cpp.py
python/sglang/srt/mem_cache/radix_cache_cpp.py
+1
-1
python/sglang/srt/mem_cache/swa_radix_cache.py
python/sglang/srt/mem_cache/swa_radix_cache.py
+1
-1
No files found.
python/sglang/srt/disaggregation/prefill.py
View file @
54e872d3
...
@@ -567,7 +567,7 @@ class SchedulerDisaggregationPrefillMixin:
...
@@ -567,7 +567,7 @@ class SchedulerDisaggregationPrefillMixin:
# Move the chunked request out of the batch so that we can merge
# Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch.
# only finished requests to running_batch.
self
.
last_batch
.
filter_batch
(
chunked_req_to_exclude
=
self
.
chunked_req
)
self
.
last_batch
.
filter_batch
(
chunked_req_to_exclude
=
self
.
chunked_req
)
self
.
tree_cache
.
cache_unfinished_req
(
self
.
chunked_req
)
self
.
tree_cache
.
cache_unfinished_req
(
self
.
chunked_req
,
chunked
=
True
)
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
self
.
chunked_req
.
tmp_end_idx
=
min
(
self
.
chunked_req
.
tmp_end_idx
=
min
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
54e872d3
...
@@ -1503,7 +1503,7 @@ class Scheduler(
...
@@ -1503,7 +1503,7 @@ class Scheduler(
# Move the chunked request out of the batch so that we can merge
# Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch.
# only finished requests to running_batch.
chunked_req_to_exclude
.
add
(
self
.
chunked_req
)
chunked_req_to_exclude
.
add
(
self
.
chunked_req
)
self
.
tree_cache
.
cache_unfinished_req
(
self
.
chunked_req
)
self
.
tree_cache
.
cache_unfinished_req
(
self
.
chunked_req
,
chunked
=
True
)
# chunked request keeps its rid but will get a new req_pool_idx
# chunked request keeps its rid but will get a new req_pool_idx
self
.
req_to_token_pool
.
free
(
self
.
chunked_req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
self
.
chunked_req
.
req_pool_idx
)
if
self
.
last_batch
and
self
.
last_batch
.
forward_mode
.
is_extend
():
if
self
.
last_batch
and
self
.
last_batch
.
forward_mode
.
is_extend
():
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
54e872d3
...
@@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache):
...
@@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache):
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
)
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
)
def
cache_unfinished_req
(
self
,
req
:
Req
):
def
cache_unfinished_req
(
self
,
req
:
Req
,
chunked
=
False
):
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
req
.
fill_ids
)
req
.
req_pool_idx
,
:
len
(
req
.
fill_ids
)
]
]
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
54e872d3
...
@@ -102,7 +102,7 @@ class HiRadixCache(RadixCache):
...
@@ -102,7 +102,7 @@ class HiRadixCache(RadixCache):
self
.
ongoing_backup
=
{}
self
.
ongoing_backup
=
{}
# todo: dynamically adjust the threshold
# todo: dynamically adjust the threshold
self
.
write_through_threshold
=
(
self
.
write_through_threshold
=
(
1
if
hicache_write_policy
==
"write_through"
else
3
1
if
hicache_write_policy
==
"write_through"
else
2
)
)
self
.
write_through_threshold_storage
=
(
self
.
write_through_threshold_storage
=
(
1
if
hicache_write_policy
==
"write_through"
else
3
1
if
hicache_write_policy
==
"write_through"
else
3
...
@@ -155,8 +155,9 @@ class HiRadixCache(RadixCache):
...
@@ -155,8 +155,9 @@ class HiRadixCache(RadixCache):
self
.
ongoing_backup
[
operation_id
]
=
node
self
.
ongoing_backup
[
operation_id
]
=
node
node
.
protect_host
()
node
.
protect_host
()
def
inc_hit_count
(
self
,
node
:
TreeNode
):
def
_inc_hit_count
(
self
,
node
:
TreeNode
,
chunked
=
False
):
if
self
.
cache_controller
.
write_policy
==
"write_back"
:
# skip the hit count update for chunked requests
if
self
.
cache_controller
.
write_policy
==
"write_back"
or
chunked
:
return
return
node
.
hit_count
+=
1
node
.
hit_count
+=
1
...
@@ -672,11 +673,11 @@ class HiRadixCache(RadixCache):
...
@@ -672,11 +673,11 @@ class HiRadixCache(RadixCache):
new_node
.
parent
.
children
[
self
.
get_child_key_fn
(
key
)]
=
new_node
new_node
.
parent
.
children
[
self
.
get_child_key_fn
(
key
)]
=
new_node
return
new_node
return
new_node
def
_insert_helper
(
self
,
node
:
TreeNode
,
key
:
List
,
value
):
def
insert
(
self
,
key
:
List
,
value
,
chunked
=
False
):
node
.
last_access_time
=
time
.
monotonic
()
if
len
(
key
)
==
0
:
if
len
(
key
)
==
0
:
return
0
return
0
node
=
self
.
root_node
child_key
=
self
.
get_child_key_fn
(
key
)
child_key
=
self
.
get_child_key_fn
(
key
)
total_prefix_length
=
0
total_prefix_length
=
0
...
@@ -693,7 +694,7 @@ class HiRadixCache(RadixCache):
...
@@ -693,7 +694,7 @@ class HiRadixCache(RadixCache):
self
.
token_to_kv_pool_host
.
update_synced
(
node
.
host_value
)
self
.
token_to_kv_pool_host
.
update_synced
(
node
.
host_value
)
self
.
evictable_size_
+=
len
(
node
.
value
)
self
.
evictable_size_
+=
len
(
node
.
value
)
else
:
else
:
self
.
inc_hit_count
(
node
)
self
.
_
inc_hit_count
(
node
,
chunked
)
total_prefix_length
+=
prefix_len
total_prefix_length
+=
prefix_len
else
:
else
:
# partial match, split the node
# partial match, split the node
...
@@ -703,7 +704,7 @@ class HiRadixCache(RadixCache):
...
@@ -703,7 +704,7 @@ class HiRadixCache(RadixCache):
self
.
token_to_kv_pool_host
.
update_synced
(
new_node
.
host_value
)
self
.
token_to_kv_pool_host
.
update_synced
(
new_node
.
host_value
)
self
.
evictable_size_
+=
len
(
new_node
.
value
)
self
.
evictable_size_
+=
len
(
new_node
.
value
)
else
:
else
:
self
.
inc_hit_count
(
new_node
)
self
.
_
inc_hit_count
(
new_node
,
chunked
)
total_prefix_length
+=
prefix_len
total_prefix_length
+=
prefix_len
node
=
new_node
node
=
new_node
...
@@ -737,7 +738,7 @@ class HiRadixCache(RadixCache):
...
@@ -737,7 +738,7 @@ class HiRadixCache(RadixCache):
last_hash
=
new_node
.
hash_value
[
-
1
]
last_hash
=
new_node
.
hash_value
[
-
1
]
if
self
.
cache_controller
.
write_policy
!=
"write_back"
:
if
self
.
cache_controller
.
write_policy
!=
"write_back"
:
self
.
inc_hit_count
(
new_node
)
self
.
_
inc_hit_count
(
new_node
,
chunked
)
return
total_prefix_length
return
total_prefix_length
def
_collect_leaves_device
(
self
):
def
_collect_leaves_device
(
self
):
...
...
python/sglang/srt/mem_cache/lora_radix_cache.py
View file @
54e872d3
...
@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache):
...
@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache):
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
dec_lock_ref
(
req
.
last_node
)
self
.
dec_lock_ref
(
req
.
last_node
)
def
cache_unfinished_req
(
self
,
req
:
Req
):
def
cache_unfinished_req
(
self
,
req
:
Req
,
chunked
=
False
):
"""Cache request when it is unfinished."""
"""Cache request when it is unfinished."""
if
self
.
disable
:
if
self
.
disable
:
return
return
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
54e872d3
...
@@ -195,7 +195,7 @@ class RadixCache(BasePrefixCache):
...
@@ -195,7 +195,7 @@ class RadixCache(BasePrefixCache):
last_host_node
=
last_node
,
last_host_node
=
last_node
,
)
)
def
insert
(
self
,
key
:
List
,
value
=
None
):
def
insert
(
self
,
key
:
List
,
value
=
None
,
chunked
=
False
):
if
self
.
disable
:
if
self
.
disable
:
return
0
return
0
...
@@ -240,7 +240,7 @@ class RadixCache(BasePrefixCache):
...
@@ -240,7 +240,7 @@ class RadixCache(BasePrefixCache):
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
dec_lock_ref
(
req
.
last_node
)
self
.
dec_lock_ref
(
req
.
last_node
)
def
cache_unfinished_req
(
self
,
req
:
Req
):
def
cache_unfinished_req
(
self
,
req
:
Req
,
chunked
=
False
):
"""Cache request when it is unfinished."""
"""Cache request when it is unfinished."""
if
self
.
disable
:
if
self
.
disable
:
return
return
...
@@ -261,7 +261,9 @@ class RadixCache(BasePrefixCache):
...
@@ -261,7 +261,9 @@ class RadixCache(BasePrefixCache):
page_aligned_token_ids
=
token_ids
[:
page_aligned_len
]
page_aligned_token_ids
=
token_ids
[:
page_aligned_len
]
# Radix Cache takes one ref in memory pool
# Radix Cache takes one ref in memory pool
new_prefix_len
=
self
.
insert
(
page_aligned_token_ids
,
page_aligned_kv_indices
)
new_prefix_len
=
self
.
insert
(
page_aligned_token_ids
,
page_aligned_kv_indices
,
chunked
=
chunked
)
self
.
token_to_kv_pool_allocator
.
free
(
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
]
kv_indices
[
len
(
req
.
prefix_indices
)
:
new_prefix_len
]
)
)
...
...
python/sglang/srt/mem_cache/radix_cache_cpp.py
View file @
54e872d3
...
@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache):
...
@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache):
self
.
dec_lock_ref
(
req
.
last_node
)
self
.
dec_lock_ref
(
req
.
last_node
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
def
cache_unfinished_req
(
self
,
req
:
Req
):
def
cache_unfinished_req
(
self
,
req
:
Req
,
chunked
=
False
):
"""Cache request when it is unfinished."""
"""Cache request when it is unfinished."""
assert
req
.
req_pool_idx
is
not
None
assert
req
.
req_pool_idx
is
not
None
token_ids
=
req
.
fill_ids
token_ids
=
req
.
fill_ids
...
...
python/sglang/srt/mem_cache/swa_radix_cache.py
View file @
54e872d3
...
@@ -464,7 +464,7 @@ class SWARadixCache(BasePrefixCache):
...
@@ -464,7 +464,7 @@ class SWARadixCache(BasePrefixCache):
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
self
.
dec_lock_ref
(
req
.
last_node
,
req
.
swa_uuid_for_lock
)
self
.
dec_lock_ref
(
req
.
last_node
,
req
.
swa_uuid_for_lock
)
def
cache_unfinished_req
(
self
,
req
:
Req
)
->
None
:
def
cache_unfinished_req
(
self
,
req
:
Req
,
chunked
=
False
)
->
None
:
"""Cache request when it is unfinished."""
"""Cache request when it is unfinished."""
if
self
.
disable
:
if
self
.
disable
:
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment