Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
10b544ae
Unverified
Commit
10b544ae
authored
Mar 12, 2025
by
Zhiqiang Xie
Committed by
GitHub
Mar 12, 2025
Browse files
Hierarchical Caching Refactoring and Fixing TP issue (#4082)
parent
01090e8a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
194 additions
and
56 deletions
+194
-56
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+63
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+16
-4
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+30
-5
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+19
-28
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+45
-17
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+21
-0
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
10b544ae
...
...
@@ -30,6 +30,26 @@ from sglang.srt.mem_cache.memory_pool import (
logger
=
logging
.
getLogger
(
__name__
)
class
LayerDoneCounter
:
def
__init__
(
self
,
num_layers
):
self
.
counter
=
num_layers
self
.
condition
=
threading
.
Condition
()
def
increment
(
self
):
with
self
.
condition
:
self
.
counter
+=
1
self
.
condition
.
notify_all
()
def
wait_until
(
self
,
threshold
):
with
self
.
condition
:
while
self
.
counter
<=
threshold
:
self
.
condition
.
wait
()
def
reset
(
self
):
with
self
.
condition
:
self
.
counter
=
0
class
CacheOperation
:
counter
=
0
...
...
@@ -132,6 +152,7 @@ class HiCacheController:
self
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
mem_pool_host
:
MHATokenToKVPoolHost
,
load_cache_event
:
threading
.
Event
=
None
,
write_policy
:
str
=
"write_through_selective"
,
):
self
.
mem_pool_device_allocator
=
token_to_kv_pool_allocator
...
...
@@ -139,6 +160,10 @@ class HiCacheController:
self
.
mem_pool_host
=
mem_pool_host
self
.
write_policy
=
write_policy
self
.
load_cache_event
=
load_cache_event
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
self
.
mem_pool_device
.
register_layer_transfer_counter
(
self
.
layer_done_counter
)
if
write_policy
not
in
[
"write_through"
,
"write_through_selective"
,
...
...
@@ -165,7 +190,7 @@ class HiCacheController:
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
)
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_
buff
er
,
daemon
=
True
target
=
self
.
load_thread_func_
layer_by_lay
er
,
daemon
=
True
)
self
.
write_thread
.
start
()
self
.
load_thread
.
start
()
...
...
@@ -186,7 +211,7 @@ class HiCacheController:
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
)
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_
buff
er
,
daemon
=
True
target
=
self
.
load_thread_func_
layer_by_lay
er
,
daemon
=
True
)
self
.
stop_event
.
clear
()
self
.
write_thread
.
start
()
...
...
@@ -273,6 +298,42 @@ class HiCacheController:
except
Exception
as
e
:
logger
.
error
(
e
)
def
load_thread_func_layer_by_layer
(
self
):
"""
Load KV caches from host memory to device memory layer by layer.
"""
with
torch
.
cuda
.
stream
(
self
.
load_stream
):
while
not
self
.
stop_event
.
is_set
():
self
.
load_cache_event
.
wait
(
timeout
=
1
)
if
not
self
.
load_cache_event
.
is_set
():
continue
self
.
load_cache_event
.
clear
()
batch_operation
=
None
while
self
.
load_queue
.
qsize
()
>
0
:
op
=
self
.
load_queue
.
get
(
block
=
True
)
if
batch_operation
is
None
:
batch_operation
=
op
else
:
batch_operation
.
merge
(
op
)
if
batch_operation
is
None
:
continue
self
.
layer_done_counter
.
reset
()
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
flat_data
=
self
.
mem_pool_host
.
get_flat_data_by_layer
(
batch_operation
.
host_indices
,
i
)
self
.
mem_pool_device
.
transfer_per_layer
(
batch_operation
.
device_indices
,
flat_data
,
i
)
self
.
layer_done_counter
.
increment
()
self
.
mem_pool_host
.
complete_io
(
batch_operation
.
host_indices
)
for
node_id
in
batch_operation
.
node_ids
:
if
node_id
!=
0
:
self
.
ack_load_queue
.
put
(
node_id
)
def
write_aux_func
(
self
,
no_wait
=
False
):
"""
Auxiliary function to prepare the buffer for write operations.
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
10b544ae
...
...
@@ -315,6 +315,7 @@ class Req:
# The relative logprob_start_len in an extend batch
self
.
extend_logprob_start_len
=
0
self
.
last_node
=
None
self
.
last_node_global
=
None
# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
...
...
@@ -389,13 +390,24 @@ class Req:
# Whether request reached finished condition
return
self
.
finished_reason
is
not
None
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
):
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
,
enable_hierarchical_cache
=
False
,
):
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
if
tree_cache
is
not
None
:
# tree cache is None if the prefix is not computed with tree cache.
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
)
if
enable_hierarchical_cache
:
self
.
prefix_indices
,
self
.
last_node
,
self
.
last_node_global
=
(
tree_cache
.
match_prefix
(
key
=
self
.
adjust_max_prefix_ids
(),
include_evicted
=
True
)
)
else
:
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
10b544ae
...
...
@@ -73,9 +73,15 @@ class CacheAgnosticPolicy(Enum):
class
SchedulePolicy
:
Policy
=
Union
[
CacheAwarePolicy
,
CacheAgnosticPolicy
]
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
):
def
__init__
(
self
,
policy
:
str
,
tree_cache
:
BasePrefixCache
,
enable_hierarchical_cache
:
bool
=
False
,
):
self
.
policy
=
self
.
_validate_and_adjust_policy
(
policy
,
tree_cache
)
self
.
tree_cache
=
tree_cache
self
.
enable_hierarchical_cache
=
enable_hierarchical_cache
# It is used to find the matching prefix for in-batch prefix caching.
self
.
waiting_queue_radix_tree
=
RadixCache
(
...
...
@@ -149,9 +155,14 @@ class SchedulePolicy:
prefix_ids
=
r
.
adjust_max_prefix_ids
()
# NOTE: the prefix_indices must always be aligned with last_node
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
if
self
.
enable_hierarchical_cache
:
r
.
prefix_indices
,
r
.
last_node
,
r
.
last_node_global
=
(
self
.
tree_cache
.
match_prefix
(
key
=
prefix_ids
,
include_evicted
=
True
)
)
else
:
r
.
prefix_indices
,
r
.
last_node
=
self
.
tree_cache
.
match_prefix
(
rid
=
r
.
rid
,
key
=
prefix_ids
)
# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
...
...
@@ -428,7 +439,9 @@ class PrefillAdder:
return
self
.
budget_state
()
def
add_one_req
(
self
,
req
:
Req
,
has_chunked_req
:
bool
):
def
add_one_req
(
self
,
req
:
Req
,
has_chunked_req
:
bool
,
enable_hierarchical_cache
:
bool
=
False
):
if
req
.
sampling_params
.
ignore_eos
and
self
.
tree_cache
.
disable
:
return
self
.
add_one_req_ignore_eos
(
req
,
has_chunked_req
)
...
...
@@ -448,6 +461,18 @@ class PrefillAdder:
if
total_tokens
>
self
.
rem_total_tokens
:
return
AddReqResult
.
NO_TOKEN
if
(
enable_hierarchical_cache
and
req
.
last_node_global
is
not
None
and
req
.
last_node_global
.
evicted
):
req
.
last_node
,
req
.
prefix_indices
=
self
.
tree_cache
.
init_load_back
(
req
.
last_node_global
,
req
.
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
input_tokens
=
req
.
extend_input_len
prefix_len
=
len
(
req
.
prefix_indices
)
if
self
.
rem_chunk_tokens
is
None
or
input_tokens
<=
self
.
rem_chunk_tokens
:
# Non-chunked prefill
self
.
can_run_list
.
append
(
req
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
10b544ae
...
...
@@ -265,12 +265,10 @@ class Scheduler:
f
"context_len=
{
self
.
model_config
.
context_len
}
"
)
# Init memory pool and cache
self
.
init_memory_pool_and_cache
()
# Init running status
self
.
waiting_queue
:
List
[
Req
]
=
[]
self
.
staging_reqs
=
{}
# The running decoding batch for continuous batching
self
.
running_batch
:
Optional
[
ScheduleBatch
]
=
None
# The current forward batch
...
...
@@ -308,7 +306,9 @@ class Scheduler:
self
.
grammar_backend
=
None
# Init schedule policy and new token estimation
self
.
policy
=
SchedulePolicy
(
self
.
schedule_policy
,
self
.
tree_cache
)
self
.
policy
=
SchedulePolicy
(
self
.
schedule_policy
,
self
.
tree_cache
,
self
.
enable_hierarchical_cache
)
assert
(
server_args
.
schedule_conservativeness
>=
0
),
"Invalid schedule_conservativeness"
...
...
@@ -431,6 +431,7 @@ class Scheduler:
self
.
tree_cache
=
HiRadixCache
(
req_to_token_pool
=
self
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
tp_cache_group
=
self
.
tp_worker
.
get_tp_cpu_group
(),
)
else
:
self
.
tree_cache
=
RadixCache
(
...
...
@@ -1005,6 +1006,11 @@ class Scheduler:
self
.
batch_is_full
=
True
return
None
if
self
.
enable_hierarchical_cache
:
# check for completion of hierarchical cache activities to release memory
self
.
tree_cache
.
writing_check
()
self
.
tree_cache
.
loading_check
()
# Get priority queue
prefix_computed
=
self
.
policy
.
calc_priority
(
self
.
waiting_queue
)
...
...
@@ -1048,32 +1054,14 @@ class Scheduler:
self
.
batch_is_full
=
True
break
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
)
req
.
init_next_round_input
(
None
if
prefix_computed
else
self
.
tree_cache
,
self
.
enable_hierarchical_cache
,
)
if
self
.
enable_hierarchical_cache
and
req
.
last_node
is
not
None
:
if
req
.
last_node
.
evicted
:
# loading KV cache for the request
req
.
last_node
,
req
.
prefix_indices
=
self
.
tree_cache
.
init_load_back
(
req
.
last_node
,
req
.
prefix_indices
,
adder
.
rem_total_tokens
,
)
if
req
.
last_node
.
loading
:
# to prevent frequent cache invalidation
if
req
.
rid
in
self
.
staging_reqs
:
self
.
tree_cache
.
dec_lock_ref
(
self
.
staging_reqs
[
req
.
rid
])
self
.
tree_cache
.
inc_lock_ref
(
req
.
last_node
)
self
.
staging_reqs
[
req
.
rid
]
=
req
.
last_node
continue
elif
req
.
last_node
.
loading
:
if
not
self
.
tree_cache
.
loading_complete
(
req
.
last_node
):
continue
if
req
.
rid
in
self
.
staging_reqs
:
self
.
tree_cache
.
dec_lock_ref
(
self
.
staging_reqs
[
req
.
rid
])
del
self
.
staging_reqs
[
req
.
rid
]
res
=
adder
.
add_one_req
(
req
,
self
.
chunked_req
)
res
=
adder
.
add_one_req
(
req
,
self
.
chunked_req
,
self
.
enable_hierarchical_cache
)
if
res
!=
AddReqResult
.
CONTINUE
:
if
res
==
AddReqResult
.
NO_TOKEN
:
if
self
.
enable_hierarchical_cache
:
...
...
@@ -1094,6 +1082,9 @@ class Scheduler:
x
for
x
in
self
.
waiting_queue
if
x
not
in
set
(
can_run_list
)
]
if
self
.
enable_hierarchical_cache
:
self
.
tree_cache
.
read_to_load_cache
()
if
adder
.
new_chunked_req
is
not
None
:
assert
self
.
chunked_req
is
None
self
.
chunked_req
=
adder
.
new_chunked_req
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
10b544ae
import
heapq
import
logging
import
threading
import
time
from
typing
import
List
,
Optional
import
torch
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.mem_cache.memory_pool
import
(
MHATokenToKVPoolHost
,
ReqToTokenPool
,
...
...
@@ -22,12 +24,18 @@ class HiRadixCache(RadixCache):
self
,
req_to_token_pool
:
ReqToTokenPool
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
tp_cache_group
:
torch
.
distributed
.
ProcessGroup
,
):
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
token_to_kv_pool_allocator
.
get_kvcache
()
)
self
.
tp_group
=
tp_cache_group
self
.
load_cache_event
=
threading
.
Event
()
self
.
cache_controller
=
HiCacheController
(
token_to_kv_pool_allocator
,
self
.
token_to_kv_pool_host
token_to_kv_pool_allocator
,
self
.
token_to_kv_pool_host
,
load_cache_event
=
self
.
load_cache_event
,
)
# record the nodes with ongoing write through
...
...
@@ -55,14 +63,12 @@ class HiRadixCache(RadixCache):
def
write_backup
(
self
,
node
:
TreeNode
):
host_indices
=
self
.
cache_controller
.
write
(
device_indices
=
node
.
value
,
priority
=-
self
.
get_height
(
node
),
node_id
=
node
.
id
,
)
if
host_indices
is
None
:
self
.
evict_host
(
len
(
node
.
value
))
host_indices
=
self
.
cache_controller
.
write
(
device_indices
=
node
.
value
,
priority
=-
self
.
get_height
(
node
),
node_id
=
node
.
id
,
)
if
host_indices
is
not
None
:
...
...
@@ -83,14 +89,20 @@ class HiRadixCache(RadixCache):
node
.
hit_count
=
0
def
writing_check
(
self
):
while
not
self
.
cache_controller
.
ack_write_queue
.
empty
():
try
:
ack_id
=
self
.
cache_controller
.
ack_write_queue
.
get_nowait
()
self
.
dec_lock_ref
(
self
.
ongoing_write_through
[
ack_id
])
# clear the reference
del
self
.
ongoing_write_through
[
ack_id
]
except
Exception
:
break
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
ack_write_queue
.
qsize
(),
dtype
=
torch
.
int
)
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
# synchrnoize TP workers to make the same update to radix cache
torch
.
distributed
.
all_reduce
(
queue_size
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
for
_
in
range
(
queue_size
.
item
()):
ack_id
=
self
.
cache_controller
.
ack_write_queue
.
get
()
self
.
dec_lock_ref
(
self
.
ongoing_write_through
[
ack_id
])
del
self
.
ongoing_write_through
[
ack_id
]
def
loading_check
(
self
):
while
not
self
.
cache_controller
.
ack_load_queue
.
empty
():
...
...
@@ -108,8 +120,6 @@ class HiRadixCache(RadixCache):
break
def
evictable_size
(
self
):
self
.
writing_check
()
self
.
loading_check
()
return
self
.
evictable_size_
def
evict
(
self
,
num_tokens
:
int
,
evict_callback
=
None
):
...
...
@@ -242,10 +252,6 @@ class HiRadixCache(RadixCache):
return
device_indices
def
loading_complete
(
self
,
node
:
TreeNode
):
self
.
loading_check
()
return
node
.
loading
==
False
def
init_load_back
(
self
,
last_node
:
TreeNode
,
...
...
@@ -272,6 +278,28 @@ class HiRadixCache(RadixCache):
return
last_node
,
prefix_indices
def
read_to_load_cache
(
self
):
self
.
load_cache_event
.
set
()
def
match_prefix
(
self
,
key
:
List
[
int
],
include_evicted
=
False
,
**
kwargs
):
if
self
.
disable
:
return
[],
self
.
root_node
value
,
last_node
=
self
.
_match_prefix_helper
(
self
.
root_node
,
key
)
if
value
:
value
=
torch
.
concat
(
value
)
else
:
value
=
torch
.
tensor
([],
dtype
=
torch
.
int32
)
last_node_global
=
last_node
while
last_node
.
evicted
:
last_node
=
last_node
.
parent
if
include_evicted
:
return
value
,
last_node
,
last_node_global
else
:
return
value
,
last_node
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
node
.
last_access_time
=
time
.
time
()
value
=
[]
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
10b544ae
...
...
@@ -206,6 +206,8 @@ class MHATokenToKVPool(KVCache):
self
.
layer_num
=
layer_num
self
.
_create_buffers
()
self
.
layer_transfer_counter
=
None
k_size
,
v_size
=
self
.
get_kv_size_bytes
()
logger
.
info
(
f
"KV Cache is allocated. #tokens:
{
size
}
, K size:
{
k_size
/
GB
:.
2
f
}
GB, V size:
{
v_size
/
GB
:.
2
f
}
GB"
...
...
@@ -267,12 +269,28 @@ class MHATokenToKVPool(KVCache):
self
.
k_buffer
[
i
][
indices
]
=
k_data
[
i
]
self
.
v_buffer
[
i
][
indices
]
=
v_data
[
i
]
def
register_layer_transfer_counter
(
self
,
layer_transfer_counter
):
self
.
layer_transfer_counter
=
layer_transfer_counter
def
transfer_per_layer
(
self
,
indices
,
flat_data
,
layer_id
):
# transfer prepared data from host to device
flat_data
=
flat_data
.
to
(
device
=
self
.
device
,
non_blocking
=
False
)
k_data
,
v_data
=
flat_data
[
0
],
flat_data
[
1
]
self
.
k_buffer
[
layer_id
][
indices
]
=
k_data
self
.
v_buffer
[
layer_id
][
indices
]
=
v_data
def
get_key_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
)
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
k_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
k_buffer
[
layer_id
]
def
get_value_buffer
(
self
,
layer_id
:
int
):
if
self
.
layer_transfer_counter
is
not
None
:
self
.
layer_transfer_counter
.
wait_until
(
layer_id
)
if
self
.
store_dtype
!=
self
.
dtype
:
return
self
.
v_buffer
[
layer_id
].
view
(
self
.
dtype
)
return
self
.
v_buffer
[
layer_id
]
...
...
@@ -530,6 +548,9 @@ class MHATokenToKVPoolHost:
def
get_flat_data
(
self
,
indices
):
return
self
.
kv_buffer
[:,
:,
indices
]
def
get_flat_data_by_layer
(
self
,
indices
,
layer_id
):
return
self
.
kv_buffer
[:,
layer_id
,
indices
]
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment