Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
70645f4d
Unverified
Commit
70645f4d
authored
Apr 20, 2025
by
Zhiqiang Xie
Committed by
GitHub
Apr 20, 2025
Browse files
upstream hicache fixes (#5570)
parent
188f0955
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
89 additions
and
46 deletions
+89
-46
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+8
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-0
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+40
-32
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+15
-12
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+15
-1
test/srt/test_hicache.py
test/srt/test_hicache.py
+4
-0
test/srt/test_hicache_mla.py
test/srt/test_hicache_mla.py
+2
-0
test/srt/test_hicache_page.py
test/srt/test_hicache_page.py
+3
-1
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
70645f4d
...
...
@@ -571,6 +571,14 @@ class Req:
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
)
elif
enable_hierarchical_cache
:
# in case last_node is evicted during scheduling, we need to update the prefix_indices
while
self
.
last_node
.
evicted
:
self
.
prefix_indices
=
self
.
prefix_indices
[
:
-
len
(
self
.
last_node
.
host_value
)
]
self
.
last_node
=
self
.
last_node
.
parent
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
70645f4d
...
...
@@ -489,6 +489,8 @@ class Scheduler(
tp_cache_group
=
self
.
tp_cpu_group
,
page_size
=
self
.
page_size
,
hicache_ratio
=
server_args
.
hicache_ratio
,
hicache_size
=
server_args
.
hicache_size
,
hicache_write_policy
=
server_args
.
hicache_write_policy
,
)
else
:
self
.
tree_cache
=
RadixCache
(
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
70645f4d
...
...
@@ -29,15 +29,17 @@ class HiRadixCache(RadixCache):
tp_cache_group
:
torch
.
distributed
.
ProcessGroup
,
page_size
:
int
,
hicache_ratio
:
float
,
hicache_size
:
int
,
hicache_write_policy
:
str
,
):
self
.
kv_cache
=
token_to_kv_pool_allocator
.
get_kvcache
()
if
isinstance
(
self
.
kv_cache
,
MHATokenToKVPool
):
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
self
.
kv_cache
,
hicache_ratio
,
page_size
self
.
kv_cache
,
hicache_ratio
,
hicache_size
,
page_size
)
elif
isinstance
(
self
.
kv_cache
,
MLATokenToKVPool
):
self
.
token_to_kv_pool_host
=
MLATokenToKVPoolHost
(
self
.
kv_cache
,
hicache_ratio
,
page_size
self
.
kv_cache
,
hicache_ratio
,
hicache_size
,
page_size
)
else
:
raise
ValueError
(
f
"HiRadixCache only supports MHA and MLA yet"
)
...
...
@@ -50,6 +52,7 @@ class HiRadixCache(RadixCache):
self
.
token_to_kv_pool_host
,
page_size
,
load_cache_event
=
self
.
load_cache_event
,
write_policy
=
hicache_write_policy
,
)
# record the nodes with ongoing write through
...
...
@@ -57,7 +60,9 @@ class HiRadixCache(RadixCache):
# record the node segments with ongoing load back
self
.
ongoing_load_back
=
{}
# todo: dynamically adjust the threshold
self
.
write_through_threshold
=
1
self
.
write_through_threshold
=
(
1
if
hicache_write_policy
==
"write_through"
else
3
)
self
.
load_back_threshold
=
10
super
().
__init__
(
req_to_token_pool
,
token_to_kv_pool_allocator
,
page_size
,
disable
=
False
...
...
@@ -76,7 +81,7 @@ class HiRadixCache(RadixCache):
height
+=
1
return
height
def
write_backup
(
self
,
node
:
TreeNode
):
def
write_backup
(
self
,
node
:
TreeNode
,
write_back
=
False
):
host_indices
=
self
.
cache_controller
.
write
(
device_indices
=
node
.
value
,
node_id
=
node
.
id
,
...
...
@@ -90,21 +95,29 @@ class HiRadixCache(RadixCache):
if
host_indices
is
not
None
:
node
.
host_value
=
host_indices
self
.
ongoing_write_through
[
node
.
id
]
=
node
self
.
inc_lock_ref
(
node
)
if
not
write_back
:
# no need to lock nodes if write back
self
.
inc_lock_ref
(
node
)
else
:
return
0
return
len
(
host_indices
)
def
inc_hit_count
(
self
,
node
:
TreeNode
):
if
self
.
cache_controller
.
write_policy
!
=
"write_
through_selective
"
:
if
node
.
backuped
or
self
.
cache_controller
.
write_policy
=
=
"write_
back
"
:
return
node
.
hit_count
+=
1
if
node
.
host_value
is
None
and
node
.
hit_count
>
self
.
write_through_threshold
:
if
node
.
hit_count
>
=
self
.
write_through_threshold
:
self
.
write_backup
(
node
)
node
.
hit_count
=
0
def
writing_check
(
self
):
def
writing_check
(
self
,
write_back
=
False
):
if
write_back
:
# blocking till all write back complete
while
len
(
self
.
ongoing_write_through
)
>
0
:
ack_id
=
self
.
cache_controller
.
ack_write_queue
.
get
()
del
self
.
ongoing_write_through
[
ack_id
]
return
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
ack_write_queue
.
qsize
(),
dtype
=
torch
.
int
)
...
...
@@ -143,29 +156,25 @@ class HiRadixCache(RadixCache):
heapq
.
heapify
(
leaves
)
num_evicted
=
0
pending
_nodes
=
[]
write_back
_nodes
=
[]
while
num_evicted
<
num_tokens
and
len
(
leaves
):
x
=
heapq
.
heappop
(
leaves
)
if
x
.
lock_ref
>
0
:
continue
if
x
.
host_value
is
None
:
if
not
x
.
backuped
:
if
self
.
cache_controller
.
write_policy
==
"write_back"
:
num_evicted
+=
self
.
write_backup
(
x
)
pending_nodes
.
append
(
x
)
elif
self
.
cache_controller
.
write_policy
==
"write_through_selective"
:
num_evicted
+=
self
.
_evict_write_through_selective
(
x
)
# write to host if the node is not backuped
num_evicted
+=
self
.
write_backup
(
x
,
write_back
=
True
)
write_back_nodes
.
append
(
x
)
else
:
assert
(
self
.
cache_controller
.
write_policy
!=
"write_through"
),
"write_through should be inclusive"
raise
NotImplementedError
num_evicted
+=
self
.
_evict_regular
(
x
)
else
:
num_evicted
+=
self
.
_evict_
write_through
(
x
)
num_evicted
+=
self
.
_evict_
backuped
(
x
)
for
child
in
x
.
parent
.
children
.
values
():
if
child
in
pending
_nodes
:
if
child
in
write_back
_nodes
:
continue
if
not
child
.
evicted
:
break
...
...
@@ -174,15 +183,12 @@ class HiRadixCache(RadixCache):
heapq
.
heappush
(
leaves
,
x
.
parent
)
if
self
.
cache_controller
.
write_policy
==
"write_back"
:
# blocking till all write back complete
while
len
(
self
.
ongoing_write_through
)
>
0
:
self
.
writing_check
()
time
.
sleep
(
0.1
)
for
node
in
pending_nodes
:
assert
node
.
host_value
is
not
None
self
.
_evict_write_through
(
node
)
self
.
writing_check
(
write_back
=
True
)
for
node
in
write_back_nodes
:
assert
node
.
backuped
self
.
_evict_backuped
(
node
)
def
_evict_
write_through
(
self
,
node
:
TreeNode
):
def
_evict_
backuped
(
self
,
node
:
TreeNode
):
# evict a node already written to host
num_evicted
=
self
.
cache_controller
.
evict_device
(
node
.
value
,
node
.
host_value
)
assert
num_evicted
>
0
...
...
@@ -190,7 +196,7 @@ class HiRadixCache(RadixCache):
node
.
value
=
None
return
num_evicted
def
_evict_
write_through_selective
(
self
,
node
:
TreeNode
):
def
_evict_
regular
(
self
,
node
:
TreeNode
):
# evict a node not initiated write to host
self
.
cache_controller
.
mem_pool_device_allocator
.
free
(
node
.
value
)
num_evicted
=
len
(
node
.
value
)
...
...
@@ -339,11 +345,13 @@ class HiRadixCache(RadixCache):
prefix_len
=
self
.
key_match_fn
(
child
.
key
,
key
)
if
prefix_len
<
len
(
child
.
key
):
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
self
.
inc_hit_count
(
new_node
)
if
not
new_node
.
evicted
:
value
.
append
(
new_node
.
value
)
node
=
new_node
break
else
:
self
.
inc_hit_count
(
child
)
if
not
child
.
evicted
:
value
.
append
(
child
.
value
)
node
=
child
...
...
@@ -369,7 +377,7 @@ class HiRadixCache(RadixCache):
else
:
new_node
.
value
=
child
.
value
[:
split_len
]
child
.
value
=
child
.
value
[
split_len
:]
if
child
.
host_value
is
not
None
:
if
child
.
backuped
:
new_node
.
host_value
=
child
.
host_value
[:
split_len
]
child
.
host_value
=
child
.
host_value
[
split_len
:]
child
.
parent
=
new_node
...
...
@@ -426,8 +434,8 @@ class HiRadixCache(RadixCache):
node
.
children
[
child_key
]
=
new_node
self
.
evictable_size_
+=
len
(
value
)
if
self
.
cache_controller
.
write_policy
=
=
"write_
through
"
:
self
.
write_backup
(
new_node
)
if
self
.
cache_controller
.
write_policy
!
=
"write_
back
"
:
self
.
inc_hit_count
(
new_node
)
return
total_prefix_length
def
_collect_leaves_device
(
self
):
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
70645f4d
...
...
@@ -624,26 +624,27 @@ class HostKVCache(abc.ABC):
self
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
,
host_size
:
int
,
pin_memory
:
bool
,
device
:
str
,
page_size
:
int
,
):
assert
(
host_to_device_ratio
>=
1
),
"The host memory should be larger than the device memory with the current protocol"
# todo, other ways of configuring the size
self
.
device_pool
=
device_pool
self
.
host_to_device_ratio
=
host_to_device_ratio
self
.
dtype
=
device_pool
.
store_dtype
self
.
pin_memory
=
pin_memory
self
.
device
=
device
self
.
page_size
=
page_size
self
.
size
=
int
(
device_pool
.
size
*
host_to_device_ratio
)
self
.
size_per_token
=
self
.
get_size_per_token
()
if
host_size
>
0
:
self
.
size
=
int
(
host_size
*
1e9
//
self
.
size_per_token
)
else
:
self
.
size
=
int
(
device_pool
.
size
*
host_to_device_ratio
)
# Align the host memory pool size to the page size
self
.
size
=
self
.
size
-
(
self
.
size
%
self
.
page_size
)
self
.
dtype
=
device_pool
.
store_dtype
self
.
size_per_token
=
self
.
get_size_per_token
()
assert
(
self
.
size
>
device_pool
.
size
),
"The host memory should be larger than the device memory with the current protocol"
# Verify there is enough available host memory.
host_mem
=
psutil
.
virtual_memory
()
...
...
@@ -795,12 +796,13 @@ class MHATokenToKVPoolHost(HostKVCache):
self
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
,
host_size
:
int
,
page_size
:
int
,
pin_memory
:
bool
=
True
,
device
:
str
=
"cpu"
,
):
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
pin_memory
,
device
,
page_size
device_pool
,
host_to_device_ratio
,
host_size
,
pin_memory
,
device
,
page_size
)
def
get_size_per_token
(
self
):
...
...
@@ -869,12 +871,13 @@ class MLATokenToKVPoolHost(HostKVCache):
self
,
device_pool
:
MLATokenToKVPool
,
host_to_device_ratio
:
float
,
host_size
:
int
,
page_size
:
int
,
pin_memory
:
bool
=
True
,
device
:
str
=
"cpu"
,
):
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
pin_memory
,
device
,
page_size
device_pool
,
host_to_device_ratio
,
host_size
,
pin_memory
,
device
,
page_size
)
def
get_size_per_token
(
self
):
...
...
python/sglang/srt/server_args.py
View file @
70645f4d
...
...
@@ -180,6 +180,8 @@ class ServerArgs:
tool_call_parser
:
Optional
[
str
]
=
None
enable_hierarchical_cache
:
bool
=
False
hicache_ratio
:
float
=
2.0
hicache_size
:
int
=
0
hicache_write_policy
:
str
=
"write_through_selective"
flashinfer_mla_disable_ragged
:
bool
=
False
warmups
:
Optional
[
str
]
=
None
moe_dense_tp_size
:
Optional
[
int
]
=
None
...
...
@@ -1116,10 +1118,22 @@ class ServerArgs:
parser
.
add_argument
(
"--hicache-ratio"
,
type
=
float
,
required
=
False
,
default
=
ServerArgs
.
hicache_ratio
,
help
=
"The ratio of the size of host KV cache memory pool to the size of device pool."
,
)
parser
.
add_argument
(
"--hicache-size"
,
type
=
int
,
default
=
ServerArgs
.
hicache_size
,
help
=
"The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set."
,
)
parser
.
add_argument
(
"--hicache-write-policy"
,
type
=
str
,
choices
=
[
"write_back"
,
"write_through"
,
"write_through_selective"
],
default
=
ServerArgs
.
hicache_write_policy
,
help
=
"The write policy of hierarchical cache."
,
)
parser
.
add_argument
(
"--enable-deepep-moe"
,
action
=
"store_true"
,
...
...
test/srt/test_hicache.py
View file @
70645f4d
...
...
@@ -23,6 +23,10 @@ class TestHiCache(CustomTestCase):
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-hierarchical-cache"
,
"--mem-fraction-static"
,
0.7
,
"--hicache-size"
,
100
,
],
)
...
...
test/srt/test_hicache_mla.py
View file @
70645f4d
...
...
@@ -24,6 +24,8 @@ class TestHierarchicalMLA(CustomTestCase):
other_args
=
[
"--trust-remote-code"
,
"--enable-hierarchical-cache"
,
"--hicache-ratio"
,
2
,
],
)
...
...
test/srt/test_hicache_page.py
View file @
70645f4d
...
...
@@ -24,7 +24,9 @@ class TestHiCachePage(CustomTestCase):
other_args
=
[
"--enable-hierarchical-cache"
,
"--page-size"
,
"32"
,
32
,
"--hicache-write-policy"
,
"write-back"
,
],
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment