Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e56685ac
Unverified
Commit
e56685ac
authored
Jun 17, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jun 17, 2025
Browse files
Upstreaming hicache bug fixes (#7267)
parent
c26d7349
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
76 additions
and
24 deletions
+76
-24
benchmark/hicache/bench_multiturn.py
benchmark/hicache/bench_multiturn.py
+1
-1
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+28
-11
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+5
-8
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+19
-4
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+9
-0
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+11
-0
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+3
-0
No files found.
benchmark/hicache/bench_multiturn.py
View file @
e56685ac
...
@@ -239,7 +239,7 @@ class WorkloadGenerator:
...
@@ -239,7 +239,7 @@ class WorkloadGenerator:
tokenizer
=
self
.
tokenizer
,
tokenizer
=
self
.
tokenizer
,
dataset_path
=
args
.
dataset_path
,
dataset_path
=
args
.
dataset_path
,
)
)
self
.
candidate_inputs
=
[
i
[
0
]
for
i
in
self
.
candidate_inputs
]
self
.
candidate_inputs
=
[
i
.
prompt
for
i
in
self
.
candidate_inputs
]
init_requests
=
[
init_requests
=
[
(
i
,
gen_payload
(
self
.
candidate_inputs
[
i
],
args
.
output_length
))
(
i
,
gen_payload
(
self
.
candidate_inputs
[
i
],
args
.
output_length
))
...
...
python/sglang/srt/managers/cache_controller.py
View file @
e56685ac
...
@@ -30,22 +30,37 @@ logger = logging.getLogger(__name__)
...
@@ -30,22 +30,37 @@ logger = logging.getLogger(__name__)
class
LayerDoneCounter
:
class
LayerDoneCounter
:
def
__init__
(
self
,
num_layers
):
def
__init__
(
self
,
num_layers
):
self
.
counter
=
num_layers
self
.
num_layers
=
num_layers
self
.
condition
=
threading
.
Condition
()
# extra producer and consumer counters for overlap mode
self
.
num_counters
=
3
self
.
counters
=
[
num_layers
]
*
self
.
num_counters
self
.
conditions
=
[
threading
.
Condition
()
for
_
in
range
(
self
.
num_counters
)]
self
.
producer_index
=
0
self
.
consumer_index
=
0
def
next_producer
(
self
):
return
(
self
.
producer_index
+
1
)
%
self
.
num_counters
def
update_producer
(
self
):
self
.
producer_index
=
self
.
next_producer
()
return
self
.
producer_index
def
set_consumer
(
self
,
index
):
self
.
consumer_index
=
index
def
increment
(
self
):
def
increment
(
self
):
with
self
.
condition
:
with
self
.
condition
s
[
self
.
producer_index
]
:
self
.
counter
+=
1
self
.
counter
s
[
self
.
producer_index
]
+=
1
self
.
condition
.
notify_all
()
self
.
condition
s
[
self
.
producer_index
]
.
notify_all
()
def
wait_until
(
self
,
threshold
):
def
wait_until
(
self
,
threshold
):
with
self
.
condition
:
with
self
.
condition
s
[
self
.
consumer_index
]
:
while
self
.
counter
<=
threshold
:
while
self
.
counter
s
[
self
.
consumer_index
]
<=
threshold
:
self
.
condition
.
wait
()
self
.
condition
s
[
self
.
consumer_index
]
.
wait
()
def
reset
(
self
):
def
reset
(
self
):
with
self
.
condition
:
with
self
.
condition
s
[
self
.
producer_index
]
:
self
.
counter
=
0
self
.
counter
s
[
self
.
producer_index
]
=
0
class
CacheOperation
:
class
CacheOperation
:
...
@@ -296,7 +311,6 @@ class HiCacheController:
...
@@ -296,7 +311,6 @@ class HiCacheController:
while
not
self
.
stop_event
.
is_set
():
while
not
self
.
stop_event
.
is_set
():
try
:
try
:
operation
=
self
.
load_queue
.
get
(
block
=
True
,
timeout
=
1
)
operation
=
self
.
load_queue
.
get
(
block
=
True
,
timeout
=
1
)
# time.sleep(18e-6 * len(operation.host_indices))
operation
.
data
=
self
.
mem_pool_host
.
get_flat_data
(
operation
.
data
=
self
.
mem_pool_host
.
get_flat_data
(
operation
.
host_indices
operation
.
host_indices
)
)
...
@@ -320,6 +334,7 @@ class HiCacheController:
...
@@ -320,6 +334,7 @@ class HiCacheController:
if
not
self
.
load_cache_event
.
is_set
():
if
not
self
.
load_cache_event
.
is_set
():
continue
continue
self
.
load_cache_event
.
clear
()
self
.
load_cache_event
.
clear
()
self
.
layer_done_counter
.
update_producer
()
batch_operation
=
None
batch_operation
=
None
while
self
.
load_queue
.
qsize
()
>
0
:
while
self
.
load_queue
.
qsize
()
>
0
:
...
@@ -331,6 +346,7 @@ class HiCacheController:
...
@@ -331,6 +346,7 @@ class HiCacheController:
if
batch_operation
is
None
:
if
batch_operation
is
None
:
continue
continue
# start layer-wise KV cache transfer from CPU to GPU
self
.
layer_done_counter
.
reset
()
self
.
layer_done_counter
.
reset
()
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
if
self
.
page_size
==
1
:
if
self
.
page_size
==
1
:
...
@@ -466,6 +482,7 @@ class HiCacheController:
...
@@ -466,6 +482,7 @@ class HiCacheController:
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
e
)
logger
.
error
(
e
)
# todo (zhiqiang): double buffering to be deprecated
def
write_thread_func_buffer
(
self
):
def
write_thread_func_buffer
(
self
):
aux_thread
=
threading
.
Thread
(
target
=
self
.
write_aux_func
,
daemon
=
True
)
aux_thread
=
threading
.
Thread
(
target
=
self
.
write_aux_func
,
daemon
=
True
)
aux_thread
.
start
()
aux_thread
.
start
()
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
e56685ac
...
@@ -659,14 +659,6 @@ class Req:
...
@@ -659,14 +659,6 @@ class Req:
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
)
)
elif
enable_hierarchical_cache
:
# in case last_node is evicted during scheduling, we need to update the prefix_indices
while
self
.
last_node
.
evicted
:
self
.
prefix_indices
=
self
.
prefix_indices
[
:
-
len
(
self
.
last_node
.
host_value
)
]
self
.
last_node
=
self
.
last_node
.
parent
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
def
adjust_max_prefix_ids
(
self
):
...
@@ -909,6 +901,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -909,6 +901,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Whether to return hidden states
# Whether to return hidden states
return_hidden_states
:
bool
=
False
return_hidden_states
:
bool
=
False
# hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index
:
int
=
0
@
classmethod
@
classmethod
def
init_new
(
def
init_new
(
cls
,
cls
,
...
@@ -1735,6 +1730,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1735,6 +1730,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
token_type_ids
=
self
.
token_type_ids
,
token_type_ids
=
self
.
token_type_ids
,
spec_algorithm
=
self
.
spec_algorithm
,
spec_algorithm
=
self
.
spec_algorithm
,
spec_info
=
self
.
spec_info
,
spec_info
=
self
.
spec_info
,
hicache_consumer_index
=
self
.
hicache_consumer_index
,
capture_hidden_mode
=
(
capture_hidden_mode
=
(
CaptureHiddenMode
.
FULL
CaptureHiddenMode
.
FULL
if
self
.
return_hidden_states
if
self
.
return_hidden_states
...
@@ -1839,6 +1835,7 @@ class ModelWorkerBatch:
...
@@ -1839,6 +1835,7 @@ class ModelWorkerBatch:
# If set, the output of the batch contains the hidden states of the run.
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode
:
CaptureHiddenMode
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
spec_num_draft_tokens
:
Optional
[
int
]
=
None
spec_num_draft_tokens
:
Optional
[
int
]
=
None
hicache_consumer_index
:
int
=
0
# Overlap event
# Overlap event
launch_done
:
Optional
[
threading
.
Event
]
=
None
launch_done
:
Optional
[
threading
.
Event
]
=
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
e56685ac
...
@@ -565,6 +565,10 @@ class Scheduler(
...
@@ -565,6 +565,10 @@ class Scheduler(
hicache_size
=
server_args
.
hicache_size
,
hicache_size
=
server_args
.
hicache_size
,
hicache_write_policy
=
server_args
.
hicache_write_policy
,
hicache_write_policy
=
server_args
.
hicache_write_policy
,
)
)
self
.
tp_worker
.
register_hicache_layer_transfer_counter
(
self
.
tree_cache
.
cache_controller
.
layer_done_counter
)
else
:
else
:
self
.
tree_cache
=
RadixCache
(
self
.
tree_cache
=
RadixCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
req_to_token_pool
=
self
.
req_to_token_pool
,
...
@@ -1514,8 +1518,13 @@ class Scheduler(
...
@@ -1514,8 +1518,13 @@ class Scheduler(
self
.
running_batch
.
batch_is_full
=
True
self
.
running_batch
.
batch_is_full
=
True
break
break
# bypass prefix_computed if enable_hierarchical_cache
req
.
init_next_round_input
(
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
,
(
None
if
(
prefix_computed
and
not
self
.
enable_hierarchical_cache
)
else
self
.
tree_cache
),
self
.
enable_hierarchical_cache
,
self
.
enable_hierarchical_cache
,
)
)
...
@@ -1548,9 +1557,6 @@ class Scheduler(
...
@@ -1548,9 +1557,6 @@ class Scheduler(
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
]
]
if
self
.
enable_hierarchical_cache
:
self
.
tree_cache
.
ready_to_load_cache
()
if
adder
.
new_chunked_req
is
not
None
:
if
adder
.
new_chunked_req
is
not
None
:
assert
self
.
chunked_req
is
None
assert
self
.
chunked_req
is
None
self
.
chunked_req
=
adder
.
new_chunked_req
self
.
chunked_req
=
adder
.
new_chunked_req
...
@@ -1574,6 +1580,10 @@ class Scheduler(
...
@@ -1574,6 +1580,10 @@ class Scheduler(
self
.
server_args
.
enable_custom_logit_processor
,
self
.
server_args
.
enable_custom_logit_processor
,
chunked_req
=
self
.
chunked_req
,
chunked_req
=
self
.
chunked_req
,
)
)
if
self
.
enable_hierarchical_cache
:
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
new_batch
.
hicache_consumer_index
=
self
.
tree_cache
.
ready_to_load_cache
()
new_batch
.
prepare_for_extend
()
new_batch
.
prepare_for_extend
()
# Mixed-style chunked prefill
# Mixed-style chunked prefill
...
@@ -1649,6 +1659,11 @@ class Scheduler(
...
@@ -1649,6 +1659,11 @@ class Scheduler(
if
self
.
is_generation
:
if
self
.
is_generation
:
if
self
.
spec_algorithm
.
is_none
():
if
self
.
spec_algorithm
.
is_none
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
# update the consumer index of hicache to the running batch
self
.
tp_worker
.
set_hicache_consumer
(
model_worker_batch
.
hicache_consumer_index
)
if
self
.
pp_group
.
is_last_rank
:
if
self
.
pp_group
.
is_last_rank
:
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
e56685ac
...
@@ -147,6 +147,15 @@ class TpModelWorker:
...
@@ -147,6 +147,15 @@ class TpModelWorker:
# A reference make this class has the same member as TpModelWorkerClient
# A reference make this class has the same member as TpModelWorkerClient
self
.
worker
=
self
self
.
worker
=
self
self
.
hicache_layer_transfer_counter
=
None
def
register_hicache_layer_transfer_counter
(
self
,
counter
):
self
.
hicache_layer_transfer_counter
=
counter
def
set_hicache_consumer
(
self
,
consumer_index
):
if
self
.
hicache_layer_transfer_counter
is
not
None
:
self
.
hicache_layer_transfer_counter
.
set_consumer
(
consumer_index
)
def
get_worker_info
(
self
):
def
get_worker_info
(
self
):
return
(
return
(
self
.
max_total_num_tokens
,
self
.
max_total_num_tokens
,
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
e56685ac
...
@@ -88,6 +88,15 @@ class TpModelWorkerClient:
...
@@ -88,6 +88,15 @@ class TpModelWorkerClient:
if
self
.
device
==
"cpu"
:
if
self
.
device
==
"cpu"
:
self
.
scheduler_stream
.
synchronize
=
lambda
:
None
# No-op for CPU
self
.
scheduler_stream
.
synchronize
=
lambda
:
None
# No-op for CPU
self
.
hicache_layer_transfer_counter
=
None
def
register_hicache_layer_transfer_counter
(
self
,
counter
):
self
.
hicache_layer_transfer_counter
=
counter
def
set_hicache_consumer
(
self
,
consumer_index
):
if
self
.
hicache_layer_transfer_counter
is
not
None
:
self
.
hicache_layer_transfer_counter
.
set_consumer
(
consumer_index
)
def
get_worker_info
(
self
):
def
get_worker_info
(
self
):
return
self
.
worker
.
get_worker_info
()
return
self
.
worker
.
get_worker_info
()
...
@@ -146,6 +155,8 @@ class TpModelWorkerClient:
...
@@ -146,6 +155,8 @@ class TpModelWorkerClient:
input_ids
=
model_worker_batch
.
input_ids
input_ids
=
model_worker_batch
.
input_ids
resolve_future_token_ids
(
input_ids
,
self
.
future_token_ids_map
)
resolve_future_token_ids
(
input_ids
,
self
.
future_token_ids_map
)
# update the consumer index of hicache to the running batch
self
.
set_hicache_consumer
(
model_worker_batch
.
hicache_consumer_index
)
# Run forward
# Run forward
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
worker
.
forward_batch_generation
(
self
.
worker
.
forward_batch_generation
(
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
e56685ac
...
@@ -307,7 +307,9 @@ class HiRadixCache(RadixCache):
...
@@ -307,7 +307,9 @@ class HiRadixCache(RadixCache):
return
last_node
,
prefix_indices
return
last_node
,
prefix_indices
def
ready_to_load_cache
(
self
):
def
ready_to_load_cache
(
self
):
producer_index
=
self
.
cache_controller
.
layer_done_counter
.
next_producer
()
self
.
load_cache_event
.
set
()
self
.
load_cache_event
.
set
()
return
producer_index
def
match_prefix
(
self
,
key
:
List
[
int
],
include_evicted
=
False
,
**
kwargs
):
def
match_prefix
(
self
,
key
:
List
[
int
],
include_evicted
=
False
,
**
kwargs
):
empty_value
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
empty_value
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
...
@@ -372,6 +374,7 @@ class HiRadixCache(RadixCache):
...
@@ -372,6 +374,7 @@ class HiRadixCache(RadixCache):
new_node
.
lock_ref
=
child
.
lock_ref
new_node
.
lock_ref
=
child
.
lock_ref
new_node
.
key
=
child
.
key
[:
split_len
]
new_node
.
key
=
child
.
key
[:
split_len
]
new_node
.
loading
=
child
.
loading
new_node
.
loading
=
child
.
loading
new_node
.
hit_count
=
child
.
hit_count
# split value and host value if exists
# split value and host value if exists
if
child
.
evicted
:
if
child
.
evicted
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment