Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
948b01a0
Unverified
Commit
948b01a0
authored
Sep 08, 2025
by
DarkSharpness
Committed by
GitHub
Sep 08, 2025
Browse files
[Refactor] Remove Hicache Load & Write threads (#10127)
Co-authored-by:
Zhiqiang Xie
<
xiezhq@stanford.edu
>
parent
cdc56ef6
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
217 additions
and
206 deletions
+217
-206
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+140
-148
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-2
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+0
-4
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+10
-3
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+8
-10
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+48
-32
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+7
-2
python/sglang/srt/mem_cache/memory_pool_host.py
python/sglang/srt/mem_cache/memory_pool_host.py
+2
-1
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+0
-2
python/sglang/srt/mem_cache/swa_radix_cache.py
python/sglang/srt/mem_cache/swa_radix_cache.py
+0
-2
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
948b01a0
...
@@ -18,7 +18,7 @@ import math
...
@@ -18,7 +18,7 @@ import math
import
threading
import
threading
import
time
import
time
from
queue
import
Empty
,
Full
,
PriorityQueue
,
Queue
from
queue
import
Empty
,
Full
,
PriorityQueue
,
Queue
from
typing
import
TYPE_CHECKING
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
import
torch
import
torch
...
@@ -43,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
...
@@ -43,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
LayerLoadingEvent
:
def
__init__
(
self
,
num_layers
:
int
):
self
.
_num_layers
=
num_layers
self
.
load_events
=
[
torch
.
cuda
.
Event
()
for
_
in
range
(
num_layers
)]
self
.
start_event
=
torch
.
cuda
.
Event
()
# start event on controller stream
def
complete
(
self
,
layer_index
:
int
):
assert
0
<=
layer_index
<
self
.
_num_layers
self
.
load_events
[
layer_index
].
record
()
def
wait
(
self
,
layer_index
:
int
):
torch
.
cuda
.
current_stream
().
wait_event
(
self
.
load_events
[
layer_index
])
@
property
def
finish_event
(
self
):
return
self
.
load_events
[
-
1
]
class
LayerDoneCounter
:
class
LayerDoneCounter
:
def
__init__
(
self
,
num_layers
):
def
__init__
(
self
,
num_layers
:
int
):
self
.
num_layers
=
num_layers
self
.
num_layers
=
num_layers
# extra producer and consumer counters for overlap mode
# extra producer and consumer counters for overlap mode
self
.
num_counters
=
3
self
.
num_counters
=
3
self
.
counters
=
[
num_layers
]
*
self
.
num_counters
self
.
events
=
[
LayerLoadingEvent
(
num_layers
)
for
_
in
range
(
self
.
num_counters
)]
self
.
conditions
=
[
threading
.
Condition
()
for
_
in
range
(
self
.
num_counters
)]
self
.
producer_index
=
-
1
self
.
producer_index
=
0
self
.
consumer_index
=
-
1
self
.
consumer_index
=
0
def
next_producer
(
self
):
return
(
self
.
producer_index
+
1
)
%
self
.
num_counters
def
update_producer
(
self
):
def
update_producer
(
self
):
self
.
producer_index
=
self
.
next_producer
()
self
.
producer_index
=
(
self
.
producer_index
+
1
)
%
self
.
num_counters
assert
self
.
events
[
self
.
producer_index
].
finish_event
.
query
(),
(
"Producer finish event should be ready before being reused."
)
return
self
.
producer_index
return
self
.
producer_index
def
set_consumer
(
self
,
index
):
def
set_consumer
(
self
,
index
:
int
):
self
.
consumer_index
=
index
self
.
consumer_index
=
index
def
increment
(
self
):
def
wait_until
(
self
,
threshold
:
int
):
with
self
.
conditions
[
self
.
producer_index
]:
if
self
.
consumer_index
<
0
:
self
.
counters
[
self
.
producer_index
]
+=
1
return
self
.
conditions
[
self
.
producer_index
].
notify_all
()
self
.
events
[
self
.
consumer_index
].
wait
(
threshold
)
def
wait_until
(
self
,
threshold
):
with
self
.
conditions
[
self
.
consumer_index
]:
while
self
.
counters
[
self
.
consumer_index
]
<=
threshold
:
self
.
conditions
[
self
.
consumer_index
].
wait
()
def
reset
(
self
):
def
reset
(
self
):
with
self
.
conditions
[
self
.
producer_index
]:
self
.
producer_index
=
-
1
self
.
co
unters
[
self
.
produc
er_index
]
=
0
self
.
co
nsum
er_index
=
-
1
class
CacheOperation
:
class
CacheOperation
:
...
@@ -99,36 +113,30 @@ class CacheOperation:
...
@@ -99,36 +113,30 @@ class CacheOperation:
# default priority is the order of creation
# default priority is the order of creation
self
.
priority
=
priority
if
priority
is
not
None
else
self
.
id
self
.
priority
=
priority
if
priority
is
not
None
else
self
.
id
def
merge
(
self
,
other
:
"CacheOperation"
)
->
None
:
@
staticmethod
# multiple operations can be merged into a single operation for batch processing
def
merge_ops
(
ops
:
List
[
CacheOperation
])
->
CacheOperation
:
self
.
host_indices
=
torch
.
cat
([
self
.
host_indices
,
other
.
host_indices
])
assert
len
(
ops
)
>
0
self
.
device_indices
=
torch
.
cat
([
self
.
device_indices
,
other
.
device_indices
])
if
len
(
ops
)
==
1
:
self
.
priority
=
min
(
self
.
priority
,
other
.
priority
)
return
ops
[
0
]
self
.
node_ids
.
extend
(
other
.
node_ids
)
host_indices
=
torch
.
cat
([
op
.
host_indices
for
op
in
ops
])
def
split
(
self
,
factor
)
->
List
[
"CacheOperation"
]:
device_indices
=
torch
.
cat
([
op
.
device_indices
for
op
in
ops
])
# split an operation into smaller operations to reduce the size of intermediate buffers
node_ids
=
[]
if
factor
<=
1
:
priority
=
min
(
op
.
priority
for
op
in
ops
)
return
[
self
]
for
op
in
ops
:
node_ids
.
extend
(
op
.
node_ids
)
chunk_size
=
math
.
ceil
(
len
(
self
.
host_indices
)
/
factor
)
merged_op
=
CacheOperation
(
host_indices
,
device_indices
,
-
1
,
priority
)
split_ops
=
[]
merged_op
.
node_ids
=
node_ids
for
i
in
range
(
0
,
len
(
self
.
host_indices
),
chunk_size
):
return
merged_op
split_ops
.
append
(
CacheOperation
(
def
__lt__
(
self
,
other
:
CacheOperation
):
host_indices
=
self
.
host_indices
[
i
:
i
+
chunk_size
],
return
self
.
priority
<
other
.
priority
device_indices
=
self
.
device_indices
[
i
:
i
+
chunk_size
],
node_id
=
0
,
)
)
# Inherit the node_ids on the final chunk
if
split_ops
:
split_ops
[
-
1
].
node_ids
=
self
.
node_ids
return
split_ops
def
__lt__
(
self
,
other
:
"CacheOperation"
):
class
HiCacheAck
(
NamedTuple
):
return
self
.
priority
<
other
.
priority
start_event
:
torch
.
cuda
.
Event
finish_event
:
torch
.
cuda
.
Event
node_ids
:
List
[
int
]
class
TransferBuffer
:
class
TransferBuffer
:
...
@@ -236,7 +244,7 @@ class HiCacheController:
...
@@ -236,7 +244,7 @@ class HiCacheController:
mem_pool_host
:
HostKVCache
,
mem_pool_host
:
HostKVCache
,
page_size
:
int
,
page_size
:
int
,
tp_group
:
torch
.
distributed
.
ProcessGroup
,
tp_group
:
torch
.
distributed
.
ProcessGroup
,
load_cache_event
:
threading
.
Event
=
None
,
load_cache_event
:
threading
.
Event
,
write_policy
:
str
=
"write_through_selective"
,
write_policy
:
str
=
"write_through_selective"
,
io_backend
:
str
=
""
,
io_backend
:
str
=
""
,
storage_backend
:
Optional
[
str
]
=
None
,
storage_backend
:
Optional
[
str
]
=
None
,
...
@@ -340,8 +348,9 @@ class HiCacheController:
...
@@ -340,8 +348,9 @@ class HiCacheController:
self
.
page_set_func
=
self
.
_3fs_zero_copy_page_set
self
.
page_set_func
=
self
.
_3fs_zero_copy_page_set
self
.
batch_exists_func
=
self
.
_3fs_zero_copy_batch_exists
self
.
batch_exists_func
=
self
.
_3fs_zero_copy_batch_exists
self
.
load_cache_event
=
load_cache_event
self
.
device
=
self
.
mem_pool_device
.
device
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
self
.
layer_num
=
self
.
mem_pool_device
.
layer_num
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
layer_num
)
self
.
mem_pool_device
.
register_layer_transfer_counter
(
self
.
layer_done_counter
)
self
.
mem_pool_device
.
register_layer_transfer_counter
(
self
.
layer_done_counter
)
if
write_policy
not
in
[
if
write_policy
not
in
[
...
@@ -351,11 +360,11 @@ class HiCacheController:
...
@@ -351,11 +360,11 @@ class HiCacheController:
]:
]:
raise
ValueError
(
f
"Invalid write policy:
{
write_policy
}
"
)
raise
ValueError
(
f
"Invalid write policy:
{
write_policy
}
"
)
self
.
write_queue
=
PriorityQueue
()
#
self.write_queue = PriorityQueue
[CacheOperation]
()
self
.
load_queue
=
PriorityQueue
()
self
.
load_queue
:
List
[
CacheOperation
]
=
[]
self
.
write_queue
:
List
[
CacheOperation
]
=
[]
self
.
ack_
write
_queue
=
Queue
()
self
.
ack_
load
_queue
:
List
[
HiCacheAck
]
=
[]
self
.
ack_
load
_queue
=
Queue
()
self
.
ack_
write
_queue
:
List
[
HiCacheAck
]
=
[]
self
.
stop_event
=
threading
.
Event
()
self
.
stop_event
=
threading
.
Event
()
self
.
write_buffer
=
TransferBuffer
(
self
.
stop_event
)
self
.
write_buffer
=
TransferBuffer
(
self
.
stop_event
)
...
@@ -366,16 +375,6 @@ class HiCacheController:
...
@@ -366,16 +375,6 @@ class HiCacheController:
self
.
write_stream
=
torch
.
cuda
.
Stream
()
self
.
write_stream
=
torch
.
cuda
.
Stream
()
self
.
load_stream
=
torch
.
cuda
.
Stream
()
self
.
load_stream
=
torch
.
cuda
.
Stream
()
self
.
write_thread
=
threading
.
Thread
(
target
=
self
.
write_thread_func_direct
,
daemon
=
True
)
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
)
self
.
write_thread
.
start
()
self
.
load_thread
.
start
()
if
self
.
enable_storage
:
if
self
.
enable_storage
:
self
.
prefetch_thread
=
threading
.
Thread
(
self
.
prefetch_thread
=
threading
.
Thread
(
target
=
self
.
prefetch_thread_func
,
daemon
=
True
target
=
self
.
prefetch_thread_func
,
daemon
=
True
...
@@ -432,15 +431,13 @@ class HiCacheController:
...
@@ -432,15 +431,13 @@ class HiCacheController:
def
reset
(
self
):
def
reset
(
self
):
self
.
stop_event
.
set
()
self
.
stop_event
.
set
()
self
.
write_thread
.
join
()
self
.
load_thread
.
join
()
self
.
write_queue
.
queue
.
clear
()
self
.
write_queue
.
clear
()
self
.
load_queue
.
queue
.
clear
()
self
.
load_queue
.
clear
()
self
.
write_buffer
.
clear
()
self
.
write_buffer
.
clear
()
self
.
load_buffer
.
clear
()
self
.
load_buffer
.
clear
()
self
.
ack_write_queue
.
queue
.
clear
()
self
.
ack_write_queue
.
clear
()
self
.
ack_load_queue
.
queue
.
clear
()
self
.
ack_load_queue
.
clear
()
if
self
.
enable_storage
:
if
self
.
enable_storage
:
self
.
prefetch_thread
.
join
()
self
.
prefetch_thread
.
join
()
self
.
backup_thread
.
join
()
self
.
backup_thread
.
join
()
...
@@ -449,15 +446,7 @@ class HiCacheController:
...
@@ -449,15 +446,7 @@ class HiCacheController:
self
.
prefetch_revoke_queue
.
queue
.
clear
()
self
.
prefetch_revoke_queue
.
queue
.
clear
()
self
.
ack_backup_queue
.
queue
.
clear
()
self
.
ack_backup_queue
.
queue
.
clear
()
self
.
write_thread
=
threading
.
Thread
(
target
=
self
.
write_thread_func_direct
,
daemon
=
True
)
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
)
self
.
stop_event
.
clear
()
self
.
stop_event
.
clear
()
self
.
write_thread
.
start
()
self
.
load_thread
.
start
()
if
self
.
enable_storage
:
if
self
.
enable_storage
:
self
.
prefetch_thread
=
threading
.
Thread
(
self
.
prefetch_thread
=
threading
.
Thread
(
...
@@ -473,7 +462,7 @@ class HiCacheController:
...
@@ -473,7 +462,7 @@ class HiCacheController:
self
,
self
,
device_indices
:
torch
.
Tensor
,
device_indices
:
torch
.
Tensor
,
priority
:
Optional
[
int
]
=
None
,
priority
:
Optional
[
int
]
=
None
,
node_id
:
int
=
0
,
node_id
:
int
=
-
1
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
"""
"""
Back up KV caches from device memory to host memory.
Back up KV caches from device memory to host memory.
...
@@ -482,17 +471,46 @@ class HiCacheController:
...
@@ -482,17 +471,46 @@ class HiCacheController:
if
host_indices
is
None
:
if
host_indices
is
None
:
return
None
return
None
self
.
mem_pool_host
.
protect_write
(
host_indices
)
self
.
mem_pool_host
.
protect_write
(
host_indices
)
torch
.
cuda
.
current_stream
().
synchronize
()
self
.
write_queue
.
append
(
self
.
write_queue
.
put
(
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
)
)
self
.
start_writing
()
return
host_indices
return
host_indices
def
start_writing
(
self
)
->
None
:
if
len
(
self
.
write_queue
)
==
0
:
return
op
=
CacheOperation
.
merge_ops
(
self
.
write_queue
)
host_indices
,
device_indices
=
self
.
move_indices
(
op
)
self
.
write_queue
.
clear
()
start_event
=
torch
.
cuda
.
Event
()
finish_event
=
torch
.
cuda
.
Event
()
start_event
.
record
()
with
torch
.
cuda
.
stream
(
self
.
write_stream
):
start_event
.
wait
(
self
.
write_stream
)
self
.
mem_pool_host
.
backup_from_device_all_layer
(
self
.
mem_pool_device
,
host_indices
,
device_indices
,
self
.
io_backend
)
self
.
mem_pool_host
.
complete_io
(
op
.
host_indices
)
finish_event
.
record
()
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# still alive when the write stream is executing.
if
host_indices
.
is_cuda
:
host_indices
.
record_stream
(
self
.
write_stream
)
if
device_indices
.
is_cuda
:
device_indices
.
record_stream
(
self
.
write_stream
)
self
.
ack_write_queue
.
append
(
HiCacheAck
(
start_event
,
finish_event
,
op
.
node_ids
))
def
load
(
def
load
(
self
,
self
,
host_indices
:
torch
.
Tensor
,
host_indices
:
torch
.
Tensor
,
priority
:
Optional
[
int
]
=
None
,
priority
:
Optional
[
int
]
=
None
,
node_id
:
int
=
0
,
node_id
:
int
=
-
1
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
"""
"""
Load KV caches from host memory to device memory.
Load KV caches from host memory to device memory.
...
@@ -501,17 +519,18 @@ class HiCacheController:
...
@@ -501,17 +519,18 @@ class HiCacheController:
if
device_indices
is
None
:
if
device_indices
is
None
:
return
None
return
None
self
.
mem_pool_host
.
protect_load
(
host_indices
)
self
.
mem_pool_host
.
protect_load
(
host_indices
)
# to ensure the device indices are ready before accessed by another CUDA stream
self
.
load_queue
.
append
(
torch
.
cuda
.
current_stream
().
synchronize
()
self
.
load_queue
.
put
(
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
CacheOperation
(
host_indices
,
device_indices
,
node_id
,
priority
)
)
)
return
device_indices
return
device_indices
def
move_indices
(
self
,
host_indices
,
device_indices
):
def
move_indices
(
self
,
op
:
CacheOperation
):
host_indices
,
device_indices
=
op
.
host_indices
,
op
.
device_indices
# move indices to GPU if using kernels, to host if using direct indexing
# move indices to GPU if using kernels, to host if using direct indexing
if
self
.
io_backend
==
"kernel"
:
if
self
.
io_backend
==
"kernel"
:
return
host_indices
.
to
(
self
.
mem_pool_device
.
device
),
device_indices
if
not
host_indices
.
is_cuda
:
host_indices
=
host_indices
.
to
(
self
.
device
,
non_blocking
=
True
)
return
host_indices
,
device_indices
elif
self
.
io_backend
==
"direct"
:
elif
self
.
io_backend
==
"direct"
:
device_indices
=
device_indices
.
cpu
()
device_indices
=
device_indices
.
cpu
()
host_indices
,
idx
=
host_indices
.
sort
()
host_indices
,
idx
=
host_indices
.
sort
()
...
@@ -519,58 +538,20 @@ class HiCacheController:
...
@@ -519,58 +538,20 @@ class HiCacheController:
else
:
else
:
raise
ValueError
(
f
"Unsupported io backend"
)
raise
ValueError
(
f
"Unsupported io backend"
)
def
write_thread_func_direct
(
self
):
def
start_loading
(
self
)
->
int
:
"""
if
len
(
self
.
load_queue
)
==
0
:
Directly write through KV caches to host memory without buffering.
return
-
1
"""
torch
.
cuda
.
set_stream
(
self
.
write_stream
)
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
host_indices
,
device_indices
=
self
.
move_indices
(
operation
.
host_indices
,
operation
.
device_indices
)
self
.
mem_pool_host
.
backup_from_device_all_layer
(
self
.
mem_pool_device
,
host_indices
,
device_indices
,
self
.
io_backend
)
self
.
write_stream
.
synchronize
()
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_write_queue
.
put
(
node_id
)
except
Empty
:
continue
except
Exception
as
e
:
logger
.
error
(
e
)
def
load_thread_func_layer_by_layer
(
self
):
producer_id
=
self
.
layer_done_counter
.
update_producer
()
"""
op
=
CacheOperation
.
merge_ops
(
self
.
load_queue
)
Load KV caches from host memory to device memory layer by layer.
host_indices
,
device_indices
=
self
.
move_indices
(
op
)
"""
self
.
load_queue
.
clear
()
torch
.
cuda
.
set_stream
(
self
.
load_stream
)
producer_event
=
self
.
layer_done_counter
.
events
[
producer_id
]
while
not
self
.
stop_event
.
is_set
():
producer_event
.
start_event
.
record
()
self
.
load_cache_event
.
wait
(
timeout
=
1
)
if
not
self
.
load_cache_event
.
is_set
():
continue
self
.
load_cache_event
.
clear
()
self
.
layer_done_counter
.
update_producer
()
batch_operation
=
None
while
self
.
load_queue
.
qsize
()
>
0
:
op
=
self
.
load_queue
.
get
(
block
=
True
)
if
batch_operation
is
None
:
batch_operation
=
op
else
:
batch_operation
.
merge
(
op
)
if
batch_operation
is
None
:
continue
# start layer-wise KV cache transfer from CPU to GPU
with
torch
.
cuda
.
stream
(
self
.
load_stream
):
self
.
layer_done_counter
.
reset
()
producer_event
.
start_event
.
wait
(
self
.
load_stream
)
host_indices
,
device_indices
=
self
.
move_indices
(
for
i
in
range
(
self
.
layer_num
):
batch_operation
.
host_indices
,
batch_operation
.
device_indices
)
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
self
.
mem_pool_host
.
load_to_device_per_layer
(
self
.
mem_pool_host
.
load_to_device_per_layer
(
self
.
mem_pool_device
,
self
.
mem_pool_device
,
host_indices
,
host_indices
,
...
@@ -578,13 +559,24 @@ class HiCacheController:
...
@@ -578,13 +559,24 @@ class HiCacheController:
i
,
i
,
self
.
io_backend
,
self
.
io_backend
,
)
)
self
.
load_stream
.
synchronize
()
producer_event
.
complete
(
i
)
self
.
layer_done_counter
.
increment
()
self
.
mem_pool_host
.
complete_io
(
op
.
host_indices
)
# NOTE: We must save the host indices and device indices here,
self
.
mem_pool_host
.
complete_io
(
batch_operation
.
host_indices
)
# this is because we need to guarantee that these tensors are
for
node_id
in
batch_operation
.
node_ids
:
# still alive when the load stream is executing.
if
node_id
!=
0
:
if
host_indices
.
is_cuda
:
self
.
ack_load_queue
.
put
(
node_id
)
host_indices
.
record_stream
(
self
.
load_stream
)
if
device_indices
.
is_cuda
:
device_indices
.
record_stream
(
self
.
load_stream
)
self
.
ack_load_queue
.
append
(
HiCacheAck
(
start_event
=
producer_event
.
start_event
,
finish_event
=
producer_event
.
finish_event
,
node_ids
=
op
.
node_ids
,
)
)
return
producer_id
def
evict_device
(
def
evict_device
(
self
,
device_indices
:
torch
.
Tensor
,
host_indices
:
torch
.
Tensor
self
,
device_indices
:
torch
.
Tensor
,
host_indices
:
torch
.
Tensor
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
948b01a0
...
@@ -911,7 +911,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -911,7 +911,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
is_prefill_only
:
bool
=
False
is_prefill_only
:
bool
=
False
# hicache pointer for synchronizing data loading from CPU to GPU
# hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index
:
int
=
0
hicache_consumer_index
:
int
=
-
1
@
classmethod
@
classmethod
def
init_new
(
def
init_new
(
...
@@ -1897,7 +1897,7 @@ class ModelWorkerBatch:
...
@@ -1897,7 +1897,7 @@ class ModelWorkerBatch:
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
]]
=
None
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
]]
=
None
# If set, the output of the batch contains the hidden states of the run.
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode
:
CaptureHiddenMode
=
None
capture_hidden_mode
:
CaptureHiddenMode
=
None
hicache_consumer_index
:
int
=
0
hicache_consumer_index
:
int
=
-
1
# Overlap event
# Overlap event
launch_done
:
Optional
[
threading
.
Event
]
=
None
launch_done
:
Optional
[
threading
.
Event
]
=
None
...
...
python/sglang/srt/managers/scheduler.py
View file @
948b01a0
...
@@ -1807,10 +1807,6 @@ class Scheduler(
...
@@ -1807,10 +1807,6 @@ class Scheduler(
if
self
.
spec_algorithm
.
is_none
():
if
self
.
spec_algorithm
.
is_none
():
model_worker_batch
=
batch
.
get_model_worker_batch
()
model_worker_batch
=
batch
.
get_model_worker_batch
()
# update the consumer index of hicache to the running batch
self
.
tp_worker
.
set_hicache_consumer
(
model_worker_batch
.
hicache_consumer_index
)
if
self
.
pp_group
.
is_last_rank
:
if
self
.
pp_group
.
is_last_rank
:
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
self
.
tp_worker
.
forward_batch_generation
(
model_worker_batch
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
948b01a0
...
@@ -12,10 +12,11 @@
...
@@ -12,10 +12,11 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""A tensor parallel worker."""
"""A tensor parallel worker."""
from
__future__
import
annotations
import
logging
import
logging
import
threading
import
threading
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -45,6 +46,9 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions
...
@@ -45,6 +46,9 @@ from sglang.srt.patch_torch import monkey_patch_torch_reductions
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
MultiprocessingSerializer
,
broadcast_pyobj
,
set_random_seed
from
sglang.srt.utils
import
MultiprocessingSerializer
,
broadcast_pyobj
,
set_random_seed
if
TYPE_CHECKING
:
from
sglang.srt.managers.cache_controller
import
LayerDoneCounter
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -167,10 +171,10 @@ class TpModelWorker:
...
@@ -167,10 +171,10 @@ class TpModelWorker:
self
.
hicache_layer_transfer_counter
=
None
self
.
hicache_layer_transfer_counter
=
None
def
register_hicache_layer_transfer_counter
(
self
,
counter
):
def
register_hicache_layer_transfer_counter
(
self
,
counter
:
LayerDoneCounter
):
self
.
hicache_layer_transfer_counter
=
counter
self
.
hicache_layer_transfer_counter
=
counter
def
set_hicache_consumer
(
self
,
consumer_index
):
def
set_hicache_consumer
(
self
,
consumer_index
:
int
):
if
self
.
hicache_layer_transfer_counter
is
not
None
:
if
self
.
hicache_layer_transfer_counter
is
not
None
:
self
.
hicache_layer_transfer_counter
.
set_consumer
(
consumer_index
)
self
.
hicache_layer_transfer_counter
.
set_consumer
(
consumer_index
)
...
@@ -230,6 +234,9 @@ class TpModelWorker:
...
@@ -230,6 +234,9 @@ class TpModelWorker:
)
->
Tuple
[
)
->
Tuple
[
Union
[
LogitsProcessorOutput
,
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
bool
Union
[
LogitsProcessorOutput
,
torch
.
Tensor
],
Optional
[
torch
.
Tensor
],
bool
]:
]:
# update the consumer index of hicache to the running batch
self
.
set_hicache_consumer
(
model_worker_batch
.
hicache_consumer_index
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
forward_batch
=
ForwardBatch
.
init_new
(
model_worker_batch
,
self
.
model_runner
)
pp_proxy_tensors
=
None
pp_proxy_tensors
=
None
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
948b01a0
...
@@ -12,13 +12,14 @@
...
@@ -12,13 +12,14 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""A tensor parallel worker."""
"""A tensor parallel worker."""
from
__future__
import
annotations
import
dataclasses
import
dataclasses
import
logging
import
logging
import
signal
import
signal
import
threading
import
threading
from
queue
import
Queue
from
queue
import
Queue
from
typing
import
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
import
psutil
import
psutil
import
torch
import
torch
...
@@ -38,6 +39,9 @@ from sglang.srt.server_args import ServerArgs
...
@@ -38,6 +39,9 @@ from sglang.srt.server_args import ServerArgs
from
sglang.srt.utils
import
DynamicGradMode
,
get_compiler_backend
from
sglang.srt.utils
import
DynamicGradMode
,
get_compiler_backend
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
if
TYPE_CHECKING
:
from
sglang.srt.managers.cache_controller
import
LayerDoneCounter
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -79,7 +83,7 @@ class TpModelWorkerClient:
...
@@ -79,7 +83,7 @@ class TpModelWorkerClient:
)
)
# Launch threads
# Launch threads
self
.
input_queue
=
Queue
()
self
.
input_queue
=
Queue
[
Tuple
[
ModelWorkerBatch
,
int
,
torch
.
Event
]]
()
self
.
output_queue
=
Queue
()
self
.
output_queue
=
Queue
()
self
.
forward_stream
=
torch
.
get_device_module
(
self
.
device
).
Stream
()
self
.
forward_stream
=
torch
.
get_device_module
(
self
.
device
).
Stream
()
self
.
forward_thread
=
threading
.
Thread
(
self
.
forward_thread
=
threading
.
Thread
(
...
@@ -93,13 +97,9 @@ class TpModelWorkerClient:
...
@@ -93,13 +97,9 @@ class TpModelWorkerClient:
self
.
hicache_layer_transfer_counter
=
None
self
.
hicache_layer_transfer_counter
=
None
def
register_hicache_layer_transfer_counter
(
self
,
counter
):
def
register_hicache_layer_transfer_counter
(
self
,
counter
:
LayerDoneCounter
):
self
.
hicache_layer_transfer_counter
=
counter
self
.
hicache_layer_transfer_counter
=
counter
def
set_hicache_consumer
(
self
,
consumer_index
):
if
self
.
hicache_layer_transfer_counter
is
not
None
:
self
.
hicache_layer_transfer_counter
.
set_consumer
(
consumer_index
)
def
get_worker_info
(
self
):
def
get_worker_info
(
self
):
return
self
.
worker
.
get_worker_info
()
return
self
.
worker
.
get_worker_info
()
...
@@ -147,7 +147,7 @@ class TpModelWorkerClient:
...
@@ -147,7 +147,7 @@ class TpModelWorkerClient:
@
DynamicGradMode
()
@
DynamicGradMode
()
def
forward_thread_func_
(
self
):
def
forward_thread_func_
(
self
):
batch_pt
=
0
batch_pt
=
0
batch_lists
=
[
None
]
*
2
batch_lists
:
List
=
[
None
]
*
2
while
True
:
while
True
:
model_worker_batch
,
future_token_ids_ct
,
sync_event
=
self
.
input_queue
.
get
()
model_worker_batch
,
future_token_ids_ct
,
sync_event
=
self
.
input_queue
.
get
()
...
@@ -169,8 +169,6 @@ class TpModelWorkerClient:
...
@@ -169,8 +169,6 @@ class TpModelWorkerClient:
input_ids
=
model_worker_batch
.
input_ids
input_ids
=
model_worker_batch
.
input_ids
resolve_future_token_ids
(
input_ids
,
self
.
future_token_ids_map
)
resolve_future_token_ids
(
input_ids
,
self
.
future_token_ids_map
)
# update the consumer index of hicache to the running batch
self
.
set_hicache_consumer
(
model_worker_batch
.
hicache_consumer_index
)
# Run forward
# Run forward
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
logits_output
,
next_token_ids
,
can_run_cuda_graph
=
(
self
.
worker
.
forward_batch_generation
(
self
.
worker
.
forward_batch_generation
(
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
948b01a0
...
@@ -201,41 +201,57 @@ class HiRadixCache(RadixCache):
...
@@ -201,41 +201,57 @@ class HiRadixCache(RadixCache):
if
write_back
:
if
write_back
:
# blocking till all write back complete
# blocking till all write back complete
while
len
(
self
.
ongoing_write_through
)
>
0
:
while
len
(
self
.
ongoing_write_through
)
>
0
:
ack_id
=
self
.
cache_controller
.
ack_write_queue
.
get
()
for
_
,
finish_event
,
ack_list
in
self
.
cache_controller
.
ack_write_queue
:
del
self
.
ongoing_write_through
[
ack_id
]
finish_event
.
synchronize
()
for
ack_id
in
ack_list
:
del
self
.
ongoing_write_through
[
ack_id
]
self
.
cache_controller
.
ack_write_queue
.
clear
()
assert
len
(
self
.
ongoing_write_through
)
==
0
return
return
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
ack_write_queue
.
qsize
(),
dtype
=
torch
.
int
# NOTE: all ranks has the same ongoing_write_through, can skip sync if empty
)
if
len
(
self
.
ongoing_write_through
)
==
0
:
return
finish_count
=
0
for
_
,
finish_event
,
ack_list
in
self
.
cache_controller
.
ack_write_queue
:
if
not
finish_event
.
query
():
break
finish_count
+=
1
queue_size
=
torch
.
tensor
(
finish_count
,
dtype
=
torch
.
int
,
device
=
"cpu"
)
if
self
.
tp_world_size
>
1
:
if
self
.
tp_world_size
>
1
:
# synchr
n
oize TP workers to make the same update to radix cache
# synchro
n
ize TP workers to make the same update to radix cache
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
queue_size
,
queue_size
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
group
=
self
.
tp_group
,
)
)
for
_
in
range
(
queue_size
.
item
()):
ack_id
=
self
.
cache_controller
.
ack_write_queue
.
get
()
finish_count
=
int
(
queue_size
.
item
())
backuped_node
=
self
.
ongoing_write_through
[
ack_id
]
while
finish_count
>
0
:
self
.
dec_lock_ref
(
backuped_node
)
_
,
finish_event
,
ack_list
=
self
.
cache_controller
.
ack_write_queue
.
pop
(
0
)
del
self
.
ongoing_write_through
[
ack_id
]
finish_event
.
synchronize
()
if
self
.
enable_storage
:
for
ack_id
in
ack_list
:
self
.
write_backup_storage
(
backuped_node
)
backuped_node
=
self
.
ongoing_write_through
.
pop
(
ack_id
)
self
.
dec_lock_ref
(
backuped_node
)
if
self
.
enable_storage
:
self
.
write_backup_storage
(
backuped_node
)
finish_count
-=
1
def
loading_check
(
self
):
def
loading_check
(
self
):
while
not
self
.
cache_controller
.
ack_load_queue
.
empty
():
finish_count
=
0
try
:
for
_
,
finish_event
,
ack_list
in
self
.
cache_controller
.
ack_load_queue
:
ack_id
=
self
.
cache_controller
.
ack_load_queue
.
get_nowait
()
if
not
finish_event
.
query
():
start_node
,
end_node
=
self
.
ongoing_load_back
[
ack_id
]
# the KV cache loading is still ongoing
self
.
dec_lock_ref
(
end_node
)
while
end_node
!=
start_node
:
assert
end_node
.
loading
end_node
.
loading
=
False
end_node
=
end_node
.
parent
# clear the reference
del
self
.
ongoing_load_back
[
ack_id
]
except
Exception
:
break
break
finish_count
+=
1
# no need to sync across TP workers as batch forwarding is synced
for
ack_id
in
ack_list
:
end_node
=
self
.
ongoing_load_back
.
pop
(
ack_id
)
self
.
dec_lock_ref
(
end_node
)
# ACK until all events are processed
del
self
.
cache_controller
.
ack_load_queue
[:
finish_count
]
def
evictable_size
(
self
):
def
evictable_size
(
self
):
return
self
.
evictable_size_
return
self
.
evictable_size_
...
@@ -360,12 +376,11 @@ class HiRadixCache(RadixCache):
...
@@ -360,12 +376,11 @@ class HiRadixCache(RadixCache):
# no sufficient GPU memory to load back KV caches
# no sufficient GPU memory to load back KV caches
return
None
return
None
self
.
ongoing_load_back
[
last_hit_node
.
id
]
=
(
ancester_node
,
last_hit_node
)
self
.
ongoing_load_back
[
last_hit_node
.
id
]
=
last_hit_node
offset
=
0
offset
=
0
for
node
in
nodes_to_load
:
for
node
in
nodes_to_load
:
node
.
value
=
device_indices
[
offset
:
offset
+
len
(
node
.
host_value
)]
node
.
value
=
device_indices
[
offset
:
offset
+
len
(
node
.
host_value
)]
offset
+=
len
(
node
.
host_value
)
offset
+=
len
(
node
.
host_value
)
node
.
loading
=
True
self
.
evictable_size_
+=
len
(
device_indices
)
self
.
evictable_size_
+=
len
(
device_indices
)
self
.
inc_lock_ref
(
last_hit_node
)
self
.
inc_lock_ref
(
last_hit_node
)
...
@@ -394,10 +409,12 @@ class HiRadixCache(RadixCache):
...
@@ -394,10 +409,12 @@ class HiRadixCache(RadixCache):
last_node
,
last_node
,
)
)
def
ready_to_load_host_cache
(
self
):
def
ready_to_load_host_cache
(
self
)
->
int
:
producer_index
=
self
.
cache_controller
.
layer_done_counter
.
next_producer
()
"""
self
.
load_cache_event
.
set
()
Notify the cache controller to start the KV cache loading.
return
producer_index
Return the consumer index for the schedule batch manager to track.
"""
return
self
.
cache_controller
.
start_loading
()
def
check_hicache_events
(
self
):
def
check_hicache_events
(
self
):
self
.
writing_check
()
self
.
writing_check
()
...
@@ -702,7 +719,6 @@ class HiRadixCache(RadixCache):
...
@@ -702,7 +719,6 @@ class HiRadixCache(RadixCache):
new_node
.
parent
=
child
.
parent
new_node
.
parent
=
child
.
parent
new_node
.
lock_ref
=
child
.
lock_ref
new_node
.
lock_ref
=
child
.
lock_ref
new_node
.
key
=
child
.
key
[:
split_len
]
new_node
.
key
=
child
.
key
[:
split_len
]
new_node
.
loading
=
child
.
loading
new_node
.
hit_count
=
child
.
hit_count
new_node
.
hit_count
=
child
.
hit_count
# split value and host value if exists
# split value and host value if exists
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
948b01a0
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
"""
"""
from
__future__
import
annotations
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
"""
"""
...
@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache.
...
@@ -27,7 +29,7 @@ KVCache actually holds the physical kv cache.
import
abc
import
abc
import
logging
import
logging
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
...
@@ -38,6 +40,9 @@ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
get_bool_env_var
,
is_cuda
,
is_npu
,
next_power_of_2
from
sglang.srt.utils
import
get_bool_env_var
,
is_cuda
,
is_npu
,
next_power_of_2
if
TYPE_CHECKING
:
from
sglang.srt.managers.cache_controller
import
LayerDoneCounter
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
GB
=
1024
*
1024
*
1024
GB
=
1024
*
1024
*
1024
...
@@ -175,7 +180,7 @@ class KVCache(abc.ABC):
...
@@ -175,7 +180,7 @@ class KVCache(abc.ABC):
)
->
None
:
)
->
None
:
raise
NotImplementedError
()
raise
NotImplementedError
()
def
register_layer_transfer_counter
(
self
,
layer_transfer_counter
):
def
register_layer_transfer_counter
(
self
,
layer_transfer_counter
:
LayerDoneCounter
):
self
.
layer_transfer_counter
=
layer_transfer_counter
self
.
layer_transfer_counter
=
layer_transfer_counter
def
get_cpu_copy
(
self
,
indices
):
def
get_cpu_copy
(
self
,
indices
):
...
...
python/sglang/srt/mem_cache/memory_pool_host.py
View file @
948b01a0
...
@@ -3,6 +3,7 @@ import logging
...
@@ -3,6 +3,7 @@ import logging
import
threading
import
threading
from
enum
import
IntEnum
from
enum
import
IntEnum
from
functools
import
wraps
from
functools
import
wraps
from
typing
import
Optional
import
psutil
import
psutil
import
torch
import
torch
...
@@ -169,7 +170,7 @@ class HostKVCache(abc.ABC):
...
@@ -169,7 +170,7 @@ class HostKVCache(abc.ABC):
return
len
(
self
.
free_slots
)
return
len
(
self
.
free_slots
)
@
synchronized
()
@
synchronized
()
def
alloc
(
self
,
need_size
:
int
)
->
torch
.
Tensor
:
def
alloc
(
self
,
need_size
:
int
)
->
Optional
[
torch
.
Tensor
]
:
assert
(
assert
(
need_size
%
self
.
page_size
==
0
need_size
%
self
.
page_size
==
0
),
"The requested size should be a multiple of the page size."
),
"The requested size should be a multiple of the page size."
...
...
python/sglang/srt/mem_cache/radix_cache.py
View file @
948b01a0
...
@@ -53,8 +53,6 @@ class TreeNode:
...
@@ -53,8 +53,6 @@ class TreeNode:
self
.
last_access_time
=
time
.
monotonic
()
self
.
last_access_time
=
time
.
monotonic
()
self
.
hit_count
=
0
self
.
hit_count
=
0
# indicating the node is loading KV cache from host
self
.
loading
=
False
# indicating the node is locked to protect from eviction
# indicating the node is locked to protect from eviction
# incremented when the node is referenced by a storage operation
# incremented when the node is referenced by a storage operation
self
.
host_ref_counter
=
0
self
.
host_ref_counter
=
0
...
...
python/sglang/srt/mem_cache/swa_radix_cache.py
View file @
948b01a0
...
@@ -60,8 +60,6 @@ class TreeNode:
...
@@ -60,8 +60,6 @@ class TreeNode:
self
.
last_access_time
=
time
.
monotonic
()
self
.
last_access_time
=
time
.
monotonic
()
self
.
hit_count
=
0
self
.
hit_count
=
0
# indicating the node is loading KV cache from host
self
.
loading
=
False
# store the host indices of KV cache
# store the host indices of KV cache
self
.
host_value
=
None
self
.
host_value
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment