Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e56685ac
Unverified
Commit
e56685ac
authored
Jun 17, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jun 17, 2025
Browse files
Upstreaming hicache bug fixes (#7267)
parent
c26d7349
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
76 additions
and
24 deletions
+76
-24
benchmark/hicache/bench_multiturn.py
benchmark/hicache/bench_multiturn.py
+1
-1
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+28
-11
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-8
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+19
-4
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+9
-0
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+11
-0
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+3
-0
No files found.
benchmark/hicache/bench_multiturn.py
View file @
e56685ac
...
...
@@ -239,7 +239,7 @@ class WorkloadGenerator:
tokenizer
=
self
.
tokenizer
,
dataset_path
=
args
.
dataset_path
,
)
self
.
candidate_inputs
=
[
i
[
0
]
for
i
in
self
.
candidate_inputs
]
self
.
candidate_inputs
=
[
i
.
prompt
for
i
in
self
.
candidate_inputs
]
init_requests
=
[
(
i
,
gen_payload
(
self
.
candidate_inputs
[
i
],
args
.
output_length
))
...
...
python/sglang/srt/managers/cache_controller.py
View file @
e56685ac
...
...
@@ -30,22 +30,37 @@ logger = logging.getLogger(__name__)
class
LayerDoneCounter
:
def
__init__
(
self
,
num_layers
):
self
.
counter
=
num_layers
self
.
condition
=
threading
.
Condition
()
self
.
num_layers
=
num_layers
# extra producer and consumer counters for overlap mode
self
.
num_counters
=
3
self
.
counters
=
[
num_layers
]
*
self
.
num_counters
self
.
conditions
=
[
threading
.
Condition
()
for
_
in
range
(
self
.
num_counters
)]
self
.
producer_index
=
0
self
.
consumer_index
=
0
def
next_producer
(
self
):
return
(
self
.
producer_index
+
1
)
%
self
.
num_counters
def
update_producer
(
self
):
self
.
producer_index
=
self
.
next_producer
()
return
self
.
producer_index
def
set_consumer
(
self
,
index
):
self
.
consumer_index
=
index
def
increment
(
self
):
with
self
.
condition
:
self
.
counter
+=
1
self
.
condition
.
notify_all
()
with
self
.
condition
s
[
self
.
producer_index
]
:
self
.
counter
s
[
self
.
producer_index
]
+=
1
self
.
condition
s
[
self
.
producer_index
]
.
notify_all
()
def
wait_until
(
self
,
threshold
):
with
self
.
condition
:
while
self
.
counter
<=
threshold
:
self
.
condition
.
wait
()
with
self
.
condition
s
[
self
.
consumer_index
]
:
while
self
.
counter
s
[
self
.
consumer_index
]
<=
threshold
:
self
.
condition
s
[
self
.
consumer_index
]
.
wait
()
def
reset
(
self
):
with
self
.
condition
:
self
.
counter
=
0
with
self
.
condition
s
[
self
.
producer_index
]
:
self
.
counter
s
[
self
.
producer_index
]
=
0
class
CacheOperation
:
...
...
@@ -296,7 +311,6 @@ class HiCacheController:
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
load_queue
.
get
(
block
=
True
,
timeout
=
1
)
# time.sleep(18e-6 * len(operation.host_indices))
operation
.
data
=
self
.
mem_pool_host
.
get_flat_data
(
operation
.
host_indices
)
...
...
@@ -320,6 +334,7 @@ class HiCacheController:
if
not
self
.
load_cache_event
.
is_set
():
continue
self
.
load_cache_event
.
clear
()
self
.
layer_done_counter
.
update_producer
()
batch_operation
=
None
while
self
.
load_queue
.
qsize
()
>
0
:
...
...
@@ -331,6 +346,7 @@ class HiCacheController:
if
batch_operation
is
None
:
continue
# start layer-wise KV cache transfer from CPU to GPU
self
.
layer_done_counter
.
reset
()
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
if
self
.
page_size
==
1
:
...
...
@@ -466,6 +482,7 @@ class HiCacheController:
except
Exception
as
e
:
logger
.
error
(
e
)
# todo (zhiqiang): double buffering to be deprecated
def
write_thread_func_buffer
(
self
):
aux_thread
=
threading
.
Thread
(
target
=
self
.
write_aux_func
,
daemon
=
True
)
aux_thread
.
start
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e56685ac
...
...
@@ -659,14 +659,6 @@ class Req:
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
)
elif
enable_hierarchical_cache
:
# in case last_node is evicted during scheduling, we need to update the prefix_indices
while
self
.
last_node
.
evicted
:
self
.
prefix_indices
=
self
.
prefix_indices
[
:
-
len
(
self
.
last_node
.
host_value
)
]
self
.
last_node
=
self
.
last_node
.
parent
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
...
...
@@ -909,6 +901,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Whether to return hidden states
return_hidden_states
:
bool
=
False
# hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index
:
int
=
0
@
classmethod
def
init_new
(
cls
,
...
...
@@ -1735,6 +1730,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
token_type_ids
=
self
.
token_type_ids
,
spec_algorithm
=
self
.
spec_algorithm
,
spec_info
=
self
.
spec_info
,
hicache_consumer_index
=
self
.
hicache_consumer_index
,
capture_hidden_mode
=
(
CaptureHiddenMode
.
FULL
if
self
.
return_hidden_states
...
...
@@ -1839,6 +1835,7 @@ class ModelWorkerBatch:
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode
:
CaptureHiddenMode
=
None
spec_num_draft_tokens
:
Optional
[
int
]
=
None
hicache_consumer_index
:
int
=
0
# Overlap event
launch_done
:
Optional
[
threading
.
Event
]
=
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
e56685ac
...
...
@@ -565,6 +565,10 @@ class Scheduler(
hicache_size
=
server_args
.
hicache_size
,
hicache_write_policy
=
server_args
.
hicache_write_policy
,
)
self
.
tp_worker
.
register_hicache_layer_transfer_counter
(
self
.
tree_cache
.
cache_controller
.
layer_done_counter
)
else
:
self
.
tree_cache
=
RadixCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
...
...
@@ -1514,8 +1518,13 @@ class Scheduler(
self
.
running_batch
.
batch_is_full
=
True
break
# bypass prefix_computed if enable_hierarchical_cache
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
,
(
None
if
(
prefix_computed
and
not
self
.
enable_hierarchical_cache
)
else
self
.
tree_cache
),
self
.
enable_hierarchical_cache
,
)
...
...
@@ -1548,9 +1557,6 @@ class Scheduler(
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
]
if
self
.
enable_hierarchical_cache
:
self
.
tree_cache
.
ready_to_load_cache
()
if
adder
.
new_chunked_req
is
not
None
:
assert
self
.
chunked_req
is
None
self
.
chunked_req
=
adder
.
new_chunked_req
...
...
@@ -1574,6 +1580,10 @@ class Scheduler(
self
.
server_args
.
enable_custom_logit_processor
,
chunked_req
=
self
.
chunked_req
,
)
if
self
.
enable_hierarchical_cache
:
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
new_batch
.
hicache_consumer_index
=
self
.
tree_cache
.
ready_to_load_cache
()
new_batch
.
prepare_for_extend
()
# Mixed-style chunked prefill
...
...
@@ -1649,6 +1659,11 @@ class Scheduler(
if
self
.
is_generation
:
if
self
.
spec_algorithm
.
is_none
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
# update the consumer index of hicache to the running batch
self
.
tp_worker
.
set_hicache_consumer
(
model_worker_batch
.
hicache_consumer_index
)
if
self
.
pp_group
.
is_last_rank
:
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
e56685ac
...
...
@@ -147,6 +147,15 @@ class TpModelWorker:
# A reference make this class has the same member as TpModelWorkerClient
self
.
worker
=
self
self
.
hicache_layer_transfer_counter
=
None
def
register_hicache_layer_transfer_counter
(
self
,
counter
):
self
.
hicache_layer_transfer_counter
=
counter
def
set_hicache_consumer
(
self
,
consumer_index
):
if
self
.
hicache_layer_transfer_counter
is
not
None
:
self
.
hicache_layer_transfer_counter
.
set_consumer
(
consumer_index
)
def
get_worker_info
(
self
):
return
(
self
.
max_total_num_tokens
,
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
e56685ac
...
...
@@ -88,6 +88,15 @@ class TpModelWorkerClient:
if
self
.
device
==
"cpu"
:
self
.
scheduler_stream
.
synchronize
=
lambda
:
None
# No-op for CPU
self
.
hicache_layer_transfer_counter
=
None
def
register_hicache_layer_transfer_counter
(
self
,
counter
):
self
.
hicache_layer_transfer_counter
=
counter
def
set_hicache_consumer
(
self
,
consumer_index
):
if
self
.
hicache_layer_transfer_counter
is
not
None
:
self
.
hicache_layer_transfer_counter
.
set_consumer
(
consumer_index
)
def
get_worker_info
(
self
):
return
self
.
worker
.
get_worker_info
()
...
...
@@ -146,6 +155,8 @@ class TpModelWorkerClient:
input_ids
=
model_worker_batch
.
input_ids
resolve_future_token_ids
(
input_ids
,
self
.
future_token_ids_map
)
# update the consumer index of hicache to the running batch
self
.
set_hicache_consumer
(
model_worker_batch
.
hicache_consumer_index
)
# Run forward
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
worker
.
forward_batch_generation
(
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
e56685ac
...
...
@@ -307,7 +307,9 @@ class HiRadixCache(RadixCache):
return
last_node
,
prefix_indices
def
ready_to_load_cache
(
self
):
producer_index
=
self
.
cache_controller
.
layer_done_counter
.
next_producer
()
self
.
load_cache_event
.
set
()
return
producer_index
def
match_prefix
(
self
,
key
:
List
[
int
],
include_evicted
=
False
,
**
kwargs
):
empty_value
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
...
...
@@ -372,6 +374,7 @@ class HiRadixCache(RadixCache):
new_node
.
lock_ref
=
child
.
lock_ref
new_node
.
key
=
child
.
key
[:
split_len
]
new_node
.
loading
=
child
.
loading
new_node
.
hit_count
=
child
.
hit_count
# split value and host value if exists
if
child
.
evicted
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment