Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
948b01a0
Unverified
Commit
948b01a0
authored
Sep 08, 2025
by
DarkSharpness
Committed by
GitHub
Sep 08, 2025
Browse files
[Refactor] Remove Hicache Load & Write threads (#10127)
Co-authored-by:
Zhiqiang Xie
<
xiezhq@stanford.edu
>
parent
cdc56ef6
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
217 additions
and
206 deletions
+217
-206
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+140
-148
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-2
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+0
-4
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+10
-3
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+8
-10
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+48
-32
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+7
-2
python/sglang/srt/mem_cache/memory_pool_host.py
python/sglang/srt/mem_cache/memory_pool_host.py
+2
-1
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+0
-2
python/sglang/srt/mem_cache/swa_radix_cache.py
python/sglang/srt/mem_cache/swa_radix_cache.py
+0
-2
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
948b01a0
...
...
@@ -18,7 +18,7 @@ import math
import
threading
import
time
from
queue
import
Empty
,
Full
,
PriorityQueue
,
Queue
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
import
torch
...
...
@@ -43,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
logger
=
logging
.
getLogger
(
__name__
)
class
LayerLoadingEvent
:
def
__init__
(
self
,
num_layers
:
int
):
self
.
_num_layers
=
num_layers
self
.
load_events
=
[
torch
.
cuda
.
Event
()
for
_
in
range
(
num_layers
)]
self
.
start_event
=
torch
.
cuda
.
Event
()
# start event on controller stream
def
complete
(
self
,
layer_index
:
int
):
assert
0
<=
layer_index
<
self
.
_num_layers
self
.
load_events
[
layer_index
].
record
()
def
wait
(
self
,
layer_index
:
int
):
torch
.
cuda
.
current_stream
().
wait_event
(
self
.
load_events
[
layer_index
])
@
property
def
finish_event
(
self
):
return
self
.
load_events
[
-
1
]
class
LayerDoneCounter
:
def
__init__
(
self
,
num_layers
):
def
__init__
(
self
,
num_layers
:
int
):
self
.
num_layers
=
num_layers
# extra producer and consumer counters for overlap mode
self
.
num_counters
=
3
self
.
counters
=
[
num_layers
]
*
self
.
num_counters
self
.
conditions
=
[
threading
.
Condition
()
for
_
in
range
(
self
.
num_counters
)]
self
.
producer_index
=
0
self
.
consumer_index
=
0
def
next_producer
(
self
):
return
(
self
.
producer_index
+
1
)
%
self
.
num_counters
self
.
events
=
[
LayerLoadingEvent
(
num_layers
)
for
_
in
range
(
self
.
num_counters
)]
self
.
producer_index
=
-
1
self
.
consumer_index
=
-
1
def
update_producer
(
self
):
self
.
producer_index
=
self
.
next_producer
()
self
.
producer_index
=
(
self
.
producer_index
+
1
)
%
self
.
num_counters
assert
self
.
events
[
self
.
producer_index
].
finish_event
.
query
(),
(
"Producer finish event should be ready before being reused."
)
return
self
.
producer_index
def
set_consumer
(
self
,
index
):
def
set_consumer
(
self
,
index
:
int
):
self
.
consumer_index
=
index
def
increment
(
self
):
with
self
.
conditions
[
self
.
producer_index
]:
self
.
counters
[
self
.
producer_index
]
+=
1
self
.
conditions
[
self
.
producer_index
].
notify_all
()
def
wait_until
(
self
,
threshold
):
with
self
.
conditions
[
self
.
consumer_index
]:
while
self
.
counters
[
self
.
consumer_index
]
<=
threshold
:
self
.
conditions
[
self
.
consumer_index
].
wait
()
def
wait_until
(
self
,
threshold
:
int
):
if
self
.
consumer_index
<
0
:
return
self
.
events
[
self
.
consumer_index
].
wait
(
threshold
)
def
reset
(
self
):
with
self
.
conditions
[
self
.
producer_index
]:
self
.
co
unters
[
self
.
produc
er_index
]
=
0
self
.
producer_index
=
-
1
self
.
co
nsum
er_index
=
-
1
class
CacheOperation
:
...
...
@@ -99,36 +113,30 @@ class CacheOperation:
# default priority is the order of creation
self
.
priority
=
priority
if
priority
is
not
None
else
self
.
id
def
merge
(
self
,
other
:
"CacheOperation"
)
->
None
:
# multiple operations can be merged into a single operation for batch processing
self
.
host_indices
=
torch
.
cat
([
self
.
host_indices
,
other
.
host_indices
])
self
.
device_indices
=
torch
.
cat
([
self
.
device_indices
,
other
.
device_indices
])
self
.
priority
=
min
(
self
.
priority
,
other
.
priority
)
self
.
node_ids
.
extend
(
other
.
node_ids
)
def
split
(
self
,
factor
)
->
List
[
"CacheOperation"
]:
# split an operation into smaller operations to reduce the size of intermediate buffers
if
factor
<=
1
:
return
[
self
]
chunk_size
=
math
.
ceil
(
len
(
self
.
host_indices
)
/
factor
)
split_ops
=
[]
for
i
in
range
(
0
,
len
(
self
.
host_indices
),
chunk_size
):
split_ops
.
append
(
CacheOperation
(
host_indices
=
self
.
host_indices
[
i
:
i
+
chunk_size
],
device_indices
=
self
.
device_indices
[
i
:
i
+
chunk_size
],
node_id
=
0
,
)
)
# Inherit the node_ids on the final chunk
if
split_ops
:
split_ops
[
-
1
].
node_ids
=
self
.
node_ids
@
staticmethod
def
merge_ops
(
ops
:
List
[
CacheOperation
])
->
CacheOperation
:
assert
len
(
ops
)
>
0
if
len
(
ops
)
==
1
:
return
ops
[
0
]
host_indices
=
torch
.
cat
([
op
.
host_indices
for
op
in
ops
])
device_indices
=
torch
.
cat
([
op
.
device_indices
for
op
in
ops
])
node_ids
=
[]
priority
=
min
(
op
.
priority
for
op
in
ops
)
for
op
in
ops
:
node_ids
.
extend
(
op
.
node_ids
)
merged_op
=
CacheOperation
(
host_indices
,
device_indices
,
-
1
,
priority
)
merged_op
.
node_ids
=
node_ids
return
merged_op
def
__lt__
(
self
,
other
:
CacheOperation
):
return
self
.
priority
<
other
.
priority
return
split_ops
def
__lt__
(
self
,
other
:
"CacheOperation"
):
return
self
.
priority
<
other
.
priority
class
HiCacheAck
(
NamedTuple
):
start_event
:
torch
.
cuda
.
Event
finish_event
:
torch
.
cuda
.
Event
node_ids
:
List
[
int
]
class
TransferBuffer
:
...
...
@@ -236,7 +244,7 @@ class HiCacheController:
mem_pool_host
:
HostKVCache
,
page_size
:
int
,
tp_group
:
torch
.
distributed
.
ProcessGroup
,
load_cache_event
:
threading
.
Event
=
None
,
load_cache_event
:
threading
.
Event
,
write_policy
:
str
=
"write_through_selective"
,
io_backend
:
str
=
""
,
storage_backend
:
Optional
[
str
]
=
None
,
...
...
@@ -340,8 +348,9 @@ class HiCacheController:
self
.
page_set_func
=
self
.
_3fs_zero_copy_page_set
self
.
batch_exists_func
=
self
.
_3fs_zero_copy_batch_exists
self
.
load_cache_event
=
load_cache_event
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
self
.
device
=
self
.
mem_pool_device
.
device
self
.
layer_num
=
self
.
mem_pool_device
.
layer_num
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
layer_num
)
self
.
mem_pool_device
.
register_layer_transfer_counter
(
self
.
layer_done_counter
)
if
write_policy
not
in
[
...
...
@@ -351,11 +360,11 @@ class HiCacheController:
]:
raise
ValueError
(
f
"Invalid write policy:
{
write_policy
}
"
)
self
.
write_queue
=
PriorityQueue
()
self
.
load_queue
=
PriorityQueue
()
self
.
ack_
write
_queue
=
Queue
()
self
.
ack_
load
_queue
=
Queue
()
#
self.write_queue = PriorityQueue
[CacheOperation]
()
self
.
load_queue
:
List
[
CacheOperation
]
=
[]
self
.
write_queue
:
List
[
CacheOperation
]
=
[]
self
.
ack_
load
_queue
:
List
[
HiCacheAck
]
=
[]
self
.
ack_
write
_queue
:
List
[
HiCacheAck
]
=
[]
self
.
stop_event
=
threading
.
Event
()
self
.
write_buffer
=
TransferBuffer
(
self
.
stop_event
)
...
...
@@ -366,16 +375,6 @@ class HiCacheController:
self
.
write_stream
=
torch
.
cuda
.
Stream
()
self
.
load_stream
=
torch
.
cuda
.
Stream
()
self
.
write_thread
=
threading
.
Thread
(
target
=
self
.
write_thread_func_direct
,
daemon
=
True
)
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
)
self
.
write_thread
.
start
()
self
.
load_thread
.
start
()
if
self
.
enable_storage
:
self
.
prefetch_thread
=
threading
.
Thread
(
target
=
self
.
prefetch_thread_func
,
daemon
=
True
...
...
@@ -432,15 +431,13 @@ class HiCacheController:
def
reset
(
self
):
self
.
stop_event
.
set
()
self
.
write_thread
.
join
()
self
.
load_thread
.
join
()
self
.
write_queue
.
queue
.
clear
()
self
.
load_queue
.
queue
.
clear
()
self
.
write_queue
.
clear
()
self
.
load_queue
.
clear
()
self
.
write_buffer
.
clear
()
self
.
load_buffer
.
clear
()
self
.
ack_write_queue
.
queue
.
clear
()
self
.
ack_load_queue
.
queue
.
clear
()
self
.
ack_write_queue
.
clear
()
self
.
ack_load_queue
.
clear
()
if
self
.
enable_storage
:
self
.
prefetch_thread
.
join
()
self
.
backup_thread
.
join
()
...
...
@@ -449,15 +446,7 @@ class HiCacheController:
self
.
prefetch_revoke_queue
.
queue
.
clear
()
self
.
ack_backup_queue
.
queue
.
clear
()
self
.
write_thread
=
threading
.
Thread
(
target
=
self
.
write_thread_func_direct
,
daemon
=
True
)
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
)
self
.
stop_event
.
clear
()
self
.
write_thread
.
start
()
self
.
load_thread
.
start
()
if
self
.
enable_storage
:
self
.
prefetch_thread
=
threading
.
Thread
(
...
...
@@ -473,7 +462,7 @@ class HiCacheController:
self
,
device_indices
:
torch
.
Tensor
,
priority
:
Optional
[
int
]
=
None
,
node_id
:
int
=
0
,
node_id
:
int
=
-
1
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Back up KV caches from device memory to host memory.
...
...
@@ -482,17 +471,46 @@ class HiCacheController:
if
host_indices
is
None
:
return
None
self
.
mem_pool_host
.
protect_write
(
host_indices
)
torch
.
cuda
.
current_stream
().
synchronize
()
self
.
write_queue
.
put
(
self
.
write_queue
.
append
(
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
)
self
.
start_writing
()
return
host_indices
def
start_writing
(
self
)
->
None
:
if
len
(
self
.
write_queue
)
==
0
:
return
op
=
CacheOperation
.
merge_ops
(
self
.
write_queue
)
host_indices
,
device_indices
=
self
.
move_indices
(
op
)
self
.
write_queue
.
clear
()
start_event
=
torch
.
cuda
.
Event
()
finish_event
=
torch
.
cuda
.
Event
()
start_event
.
record
()
with
torch
.
cuda
.
stream
(
self
.
write_stream
):
start_event
.
wait
(
self
.
write_stream
)
self
.
mem_pool_host
.
backup_from_device_all_layer
(
self
.
mem_pool_device
,
host_indices
,
device_indices
,
self
.
io_backend
)
self
.
mem_pool_host
.
complete_io
(
op
.
host_indices
)
finish_event
.
record
()
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# still alive when the write stream is executing.
if
host_indices
.
is_cuda
:
host_indices
.
record_stream
(
self
.
write_stream
)
if
device_indices
.
is_cuda
:
device_indices
.
record_stream
(
self
.
write_stream
)
self
.
ack_write_queue
.
append
(
HiCacheAck
(
start_event
,
finish_event
,
op
.
node_ids
))
def
load
(
self
,
host_indices
:
torch
.
Tensor
,
priority
:
Optional
[
int
]
=
None
,
node_id
:
int
=
0
,
node_id
:
int
=
-
1
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Load KV caches from host memory to device memory.
...
...
@@ -501,17 +519,18 @@ class HiCacheController:
if
device_indices
is
None
:
return
None
self
.
mem_pool_host
.
protect_load
(
host_indices
)
# to ensure the device indices are ready before accessed by another CUDA stream
torch
.
cuda
.
current_stream
().
synchronize
()
self
.
load_queue
.
put
(
self
.
load_queue
.
append
(
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
)
return
device_indices
def
move_indices
(
self
,
host_indices
,
device_indices
):
def
move_indices
(
self
,
op
:
CacheOperation
):
host_indices
,
device_indices
=
op
.
host_indices
,
op
.
device_indices
# move indices to GPU if using kernels, to host if using direct indexing
if
self
.
io_backend
==
"kernel"
:
return
host_indices
.
to
(
self
.
mem_pool_device
.
device
),
device_indices
if
not
host_indices
.
is_cuda
:
host_indices
=
host_indices
.
to
(
self
.
device
,
non_blocking
=
True
)
return
host_indices
,
device_indices
elif
self
.
io_backend
==
"direct"
:
device_indices
=
device_indices
.
cpu
()
host_indices
,
idx
=
host_indices
.
sort
()
...
...
@@ -519,58 +538,20 @@ class HiCacheController:
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
def
write_thread_func_direct
(
self
):
"""
Directly write through KV caches to host memory without buffering.
"""
torch
.
cuda
.
set_stream
(
self
.
write_stream
)
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
host_indices
,
device_indices
=
self
.
move_indices
(
operation
.
host_indices
,
operation
.
device_indices
)
self
.
mem_pool_host
.
backup_from_device_all_layer
(
self
.
mem_pool_device
,
host_indices
,
device_indices
,
self
.
io_backend
)
self
.
write_stream
.
synchronize
()
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_write_queue
.
put
(
node_id
)
except
Empty
:
continue
except
Exception
as
e
:
logger
.
error
(
e
)
def
start_loading
(
self
)
->
int
:
if
len
(
self
.
load_queue
)
==
0
:
return
-
1
def
load_thread_func_layer_by_layer
(
self
):
"""
Load KV caches from host memory to device memory layer by layer.
"""
torch
.
cuda
.
set_stream
(
self
.
load_stream
)
while
not
self
.
stop_event
.
is_set
():
self
.
load_cache_event
.
wait
(
timeout
=
1
)
if
not
self
.
load_cache_event
.
is_set
():
continue
self
.
load_cache_event
.
clear
()
self
.
layer_done_counter
.
update_producer
()
batch_operation
=
None
while
self
.
load_queue
.
qsize
()
>
0
:
op
=
self
.
load_queue
.
get
(
block
=
True
)
if
batch_operation
is
None
:
batch_operation
=
op
else
:
batch_operation
.
merge
(
op
)
if
batch_operation
is
None
:
continue
producer_id
=
self
.
layer_done_counter
.
update_producer
()
op
=
CacheOperation
.
merge_ops
(
self
.
load_queue
)
host_indices
,
device_indices
=
self
.
move_indices
(
op
)
self
.
load_queue
.
clear
()
producer_event
=
self
.
layer_done_counter
.
events
[
producer_id
]
producer_event
.
start_event
.
record
()
# start layer-wise KV cache transfer from CPU to GPU
self
.
layer_done_counter
.
reset
()
host_indices
,
device_indices
=
self
.
move_indices
(
batch_operation
.
host_indices
,
batch_operation
.
device_indices
)
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
with
torch
.
cuda
.
stream
(
self
.
load_stream
):
producer_event
.
start_event
.
wait
(
self
.
load_stream
)
for
i
in
range
(
self
.
layer_num
):
self
.
mem_pool_host
.
load_to_device_per_layer
(
self
.
mem_pool_device
,
host_indices
,
...
...
@@ -578,13 +559,24 @@ class HiCacheController:
i
,
self
.
io_backend
,
)
self
.
load_stream
.
synchronize
()
self
.
layer_done_counter
.
increment
()
self
.
mem_pool_host
.
complete_io
(
batch_operation
.
host_indices
)
for
node_id
in
batch_operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_load_queue
.
put
(
node_id
)
producer_event
.
complete
(
i
)
self
.
mem_pool_host
.
complete_io
(
op
.
host_indices
)
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# still alive when the load stream is executing.
if
host_indices
.
is_cuda
:
host_indices
.
record_stream
(
self
.
load_stream
)
if
device_indices
.
is_cuda
:
device_indices
.
record_stream
(
self
.
load_stream
)
self
.
ack_load_queue
.
append
(
HiCacheAck
(
start_event
=
producer_event
.
start_event
,
finish_event
=
producer_event
.
finish_event
,
node_ids
=
op
.
node_ids
,
)
)
return
producer_id
def
evict_device
(
self
,
device_indices
:
torch
.
Tensor
,
host_indices
:
torch
.
Tensor
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
948b01a0
...
...
@@ -911,7 +911,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
is_prefill_only
:
bool
=
False
# hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index
:
int
=
0
hicache_consumer_index
:
int
=
-
1
@
classmethod
def
init_new
(
...
...
@@ -1897,7 +1897,7 @@ class ModelWorkerBatch:
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
]]
=
None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode
:
CaptureHiddenMode
=
None
hicache_consumer_index
:
int
=
0
hicache_consumer_index
:
int
=
-
1
# Overlap event
launch_done
:
Optional
[
threading
.
Event
]
=
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
948b01a0
...
...
@@ -1807,10 +1807,6 @@ class Scheduler(
if
self
.
spec_algorithm
.
is_none
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
# update the consumer index of hicache to the running batch
self
.
tp_worker
.
set_hicache_consumer
(
model_worker_batch
.
hicache_consumer_index
)
if
self
.
pp_group
.
is_last_rank
:
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
948b01a0
...
...
@@ -12,10 +12,11 @@
# limitations under the License.
# ==============================================================================
"""A tensor parallel worker."""
from
__future__
import
annotations
import
logging
import
threading
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -45,6 +46,9 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
MultiprocessingSerializer
,
broadcast_pyobj
,
set_random_seed
if
TYPE_CHECKING
:
from
sglang.srt.managers.cache_controller
import
LayerDoneCounter
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -167,10 +171,10 @@ class TpModelWorker:
self
.
hicache_layer_transfer_counter
=
None
def
register_hicache_layer_transfer_counter
(
self
,
counter
):
def
register_hicache_layer_transfer_counter
(
self
,
counter
:
LayerDoneCounter
):
self
.
hicache_layer_transfer_counter
=
counter
def
set_hicache_consumer
(
self
,
consumer_index
):
def
set_hicache_consumer
(
self
,
consumer_index
:
int
):
if
self
.
hicache_layer_transfer_counter
is
not
None
:
self
.
hicache_layer_transfer_counter
.
set_consumer
(
consumer_index
)
...
...
@@ -230,6 +234,9 @@ class TpModelWorker:
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
bool
]:
# update the consumer index of hicache to the running batch
self
.
set_hicache_consumer
(
model_worker_batch
.
hicache_consumer_index
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
pp_proxy_tensors
=
None
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
948b01a0
...
...
@@ -12,13 +12,14 @@
# limitations under the License.
# ==============================================================================
"""A tensor parallel worker."""
from
__future__
import
annotations
import
dataclasses
import
logging
import
signal
import
threading
from
queue
import
Queue
from
typing
import
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
import
psutil
import
torch
...
...
@@ -38,6 +39,9 @@ from sglang.srt.server_args import ServerArgs
from
sglang.srt.utils
import
DynamicGradMode
,
get_compiler_backend
from
sglang.utils
import
get_exception_traceback
if
TYPE_CHECKING
:
from
sglang.srt.managers.cache_controller
import
LayerDoneCounter
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -79,7 +83,7 @@ class TpModelWorkerClient:
)
# Launch threads
self
.
input_queue
=
Queue
()
self
.
input_queue
=
Queue
[
Tuple
[
ModelWorkerBatch
,
int
,
torch
.
Event
]]
()
self
.
output_queue
=
Queue
()
self
.
forward_stream
=
torch
.
get_device_module
(
self
.
device
).
Stream
()
self
.
forward_thread
=
threading
.
Thread
(
...
...
@@ -93,13 +97,9 @@ class TpModelWorkerClient:
self
.
hicache_layer_transfer_counter
=
None
def
register_hicache_layer_transfer_counter
(
self
,
counter
):
def
register_hicache_layer_transfer_counter
(
self
,
counter
:
LayerDoneCounter
):
self
.
hicache_layer_transfer_counter
=
counter
def
set_hicache_consumer
(
self
,
consumer_index
):
if
self
.
hicache_layer_transfer_counter
is
not
None
:
self
.
hicache_layer_transfer_counter
.
set_consumer
(
consumer_index
)
def
get_worker_info
(
self
):
return
self
.
worker
.
get_worker_info
()
...
...
@@ -147,7 +147,7 @@ class TpModelWorkerClient:
@
DynamicGradMode
()
def
forward_thread_func_
(
self
):
batch_pt
=
0
batch_lists
=
[
None
]
*
2
batch_lists
:
List
=
[
None
]
*
2
while
True
:
model_worker_batch
,
future_token_ids_ct
,
sync_event
=
self
.
input_queue
.
get
()
...
...
@@ -169,8 +169,6 @@ class TpModelWorkerClient:
input_ids
=
model_worker_batch
.
input_ids
resolve_future_token_ids
(
input_ids
,
self
.
future_token_ids_map
)
# update the consumer index of hicache to the running batch
self
.
set_hicache_consumer
(
model_worker_batch
.
hicache_consumer_index
)
# Run forward
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
worker
.
forward_batch_generation
(
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
948b01a0
...
...
@@ -201,41 +201,57 @@ class HiRadixCache(RadixCache):
if
write_back
:
# blocking till all write back complete
while
len
(
self
.
ongoing_write_through
)
>
0
:
ack_id
=
self
.
cache_controller
.
ack_write_queue
.
get
()
del
self
.
ongoing_write_through
[
ack_id
]
for
_
,
finish_event
,
ack_list
in
self
.
cache_controller
.
ack_write_queue
:
finish_event
.
synchronize
()
for
ack_id
in
ack_list
:
del
self
.
ongoing_write_through
[
ack_id
]
self
.
cache_controller
.
ack_write_queue
.
clear
()
assert
len
(
self
.
ongoing_write_through
)
==
0
return
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
ack_write_queue
.
qsize
(),
dtype
=
torch
.
int
)
# NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
if
len
(
self
.
ongoing_write_through
)
==
0
:
return
finish_count
=
0
for
_
,
finish_event
,
ack_list
in
self
.
cache_controller
.
ack_write_queue
:
if
not
finish_event
.
query
():
break
finish_count
+=
1
queue_size
=
torch
.
tensor
(
finish_count
,
dtype
=
torch
.
int
,
device
=
"cpu"
)
if
self
.
tp_world_size
>
1
:
# synchr
n
oize TP workers to make the same update to radix cache
# synchro
n
ize TP workers to make the same update to radix cache
torch
.
distributed
.
all_reduce
(
queue_size
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
for
_
in
range
(
queue_size
.
item
()):
ack_id
=
self
.
cache_controller
.
ack_write_queue
.
get
()
backuped_node
=
self
.
ongoing_write_through
[
ack_id
]
self
.
dec_lock_ref
(
backuped_node
)
del
self
.
ongoing_write_through
[
ack_id
]
if
self
.
enable_storage
:
self
.
write_backup_storage
(
backuped_node
)
finish_count
=
int
(
queue_size
.
item
())
while
finish_count
>
0
:
_
,
finish_event
,
ack_list
=
self
.
cache_controller
.
ack_write_queue
.
pop
(
0
)
finish_event
.
synchronize
()
for
ack_id
in
ack_list
:
backuped_node
=
self
.
ongoing_write_through
.
pop
(
ack_id
)
self
.
dec_lock_ref
(
backuped_node
)
if
self
.
enable_storage
:
self
.
write_backup_storage
(
backuped_node
)
finish_count
-=
1
def
loading_check
(
self
):
while
not
self
.
cache_controller
.
ack_load_queue
.
empty
():
try
:
ack_id
=
self
.
cache_controller
.
ack_load_queue
.
get_nowait
()
start_node
,
end_node
=
self
.
ongoing_load_back
[
ack_id
]
self
.
dec_lock_ref
(
end_node
)
while
end_node
!=
start_node
:
assert
end_node
.
loading
end_node
.
loading
=
False
end_node
=
end_node
.
parent
# clear the reference
del
self
.
ongoing_load_back
[
ack_id
]
except
Exception
:
finish_count
=
0
for
_
,
finish_event
,
ack_list
in
self
.
cache_controller
.
ack_load_queue
:
if
not
finish_event
.
query
():
# the KV cache loading is still ongoing
break
finish_count
+=
1
# no need to sync across TP workers as batch forwarding is synced
for
ack_id
in
ack_list
:
end_node
=
self
.
ongoing_load_back
.
pop
(
ack_id
)
self
.
dec_lock_ref
(
end_node
)
# ACK until all events are processed
del
self
.
cache_controller
.
ack_load_queue
[:
finish_count
]
def
evictable_size
(
self
):
return
self
.
evictable_size_
...
...
@@ -360,12 +376,11 @@ class HiRadixCache(RadixCache):
# no sufficient GPU memory to load back KV caches
return
None
self
.
ongoing_load_back
[
last_hit_node
.
id
]
=
(
ancester_node
,
last_hit_node
)
self
.
ongoing_load_back
[
last_hit_node
.
id
]
=
last_hit_node
offset
=
0
for
node
in
nodes_to_load
:
node
.
value
=
device_indices
[
offset
:
offset
+
len
(
node
.
host_value
)]
offset
+=
len
(
node
.
host_value
)
node
.
loading
=
True
self
.
evictable_size_
+=
len
(
device_indices
)
self
.
inc_lock_ref
(
last_hit_node
)
...
...
@@ -394,10 +409,12 @@ class HiRadixCache(RadixCache):
last_node
,
)
def
ready_to_load_host_cache
(
self
):
producer_index
=
self
.
cache_controller
.
layer_done_counter
.
next_producer
()
self
.
load_cache_event
.
set
()
return
producer_index
def
ready_to_load_host_cache
(
self
)
->
int
:
"""
Notify the cache controller to start the KV cache loading.
Return the consumer index for the schedule batch manager to track.
"""
return
self
.
cache_controller
.
start_loading
()
def
check_hicache_events
(
self
):
self
.
writing_check
()
...
...
@@ -702,7 +719,6 @@ class HiRadixCache(RadixCache):
new_node
.
parent
=
child
.
parent
new_node
.
lock_ref
=
child
.
lock_ref
new_node
.
key
=
child
.
key
[:
split_len
]
new_node
.
loading
=
child
.
loading
new_node
.
hit_count
=
child
.
hit_count
# split value and host value if exists
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
948b01a0
...
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from
__future__
import
annotations
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
"""
...
...
@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache.
import
abc
import
logging
from
contextlib
import
nullcontext
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
get_bool_env_var
,
is_cuda
,
is_npu
,
next_power_of_2
if
TYPE_CHECKING
:
from
sglang.srt.managers.cache_controller
import
LayerDoneCounter
logger
=
logging
.
getLogger
(
__name__
)
GB
=
1024
*
1024
*
1024
...
...
@@ -175,7 +180,7 @@ class KVCache(abc.ABC):
)
->
None
:
raise
NotImplementedError
()
def
register_layer_transfer_counter
(
self
,
layer_transfer_counter
):
def
register_layer_transfer_counter
(
self
,
layer_transfer_counter
:
LayerDoneCounter
):
self
.
layer_transfer_counter
=
layer_transfer_counter
def
get_cpu_copy
(
self
,
indices
):
...
...
python/sglang/srt/mem_cache/memory_pool_host.py
View file @
948b01a0
...
...
@@ -3,6 +3,7 @@ import logging
import
threading
from
enum
import
IntEnum
from
functools
import
wraps
from
typing
import
Optional
import
psutil
import
torch
...
...
@@ -169,7 +170,7 @@ class HostKVCache(abc.ABC):
return
len
(
self
.
free_slots
)
@
synchronized
()
def
alloc
(
self
,
need_size
:
int
)
->
torch
.
Tensor
:
def
alloc
(
self
,
need_size
:
int
)
->
Optional
[
torch
.
Tensor
]
:
assert
(
need_size
%
self
.
page_size
==
0
),
"The requested size should be a multiple of the page size."
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
948b01a0
...
...
@@ -53,8 +53,6 @@ class TreeNode:
self
.
last_access_time
=
time
.
monotonic
()
self
.
hit_count
=
0
# indicating the node is loading KV cache from host
self
.
loading
=
False
# indicating the node is locked to protect from eviction
# incremented when the node is referenced by a storage operation
self
.
host_ref_counter
=
0
...
...
python/sglang/srt/mem_cache/swa_radix_cache.py
View file @
948b01a0
...
...
@@ -60,8 +60,6 @@ class TreeNode:
self
.
last_access_time
=
time
.
monotonic
()
self
.
hit_count
=
0
# indicating the node is loading KV cache from host
self
.
loading
=
False
# store the host indices of KV cache
self
.
host_value
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment