Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
70645f4d
Unverified
Commit
70645f4d
authored
Apr 20, 2025
by
Zhiqiang Xie
Committed by
GitHub
Apr 20, 2025
Browse files
upstream hicache fixes (#5570)
parent
188f0955
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
89 additions
and
46 deletions
+89
-46
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+8
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-0
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+40
-32
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+15
-12
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+15
-1
test/srt/test_hicache.py
test/srt/test_hicache.py
+4
-0
test/srt/test_hicache_mla.py
test/srt/test_hicache_mla.py
+2
-0
test/srt/test_hicache_page.py
test/srt/test_hicache_page.py
+3
-1
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
70645f4d
...
@@ -571,6 +571,14 @@ class Req:
...
@@ -571,6 +571,14 @@ class Req:
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
self
.
prefix_indices
,
self
.
last_node
=
tree_cache
.
match_prefix
(
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
rid
=
self
.
rid
,
key
=
self
.
adjust_max_prefix_ids
()
)
)
elif
enable_hierarchical_cache
:
# in case last_node is evicted during scheduling, we need to update the prefix_indices
while
self
.
last_node
.
evicted
:
self
.
prefix_indices
=
self
.
prefix_indices
[
:
-
len
(
self
.
last_node
.
host_value
)
]
self
.
last_node
=
self
.
last_node
.
parent
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
def
adjust_max_prefix_ids
(
self
):
...
...
python/sglang/srt/managers/scheduler.py
View file @
70645f4d
...
@@ -489,6 +489,8 @@ class Scheduler(
...
@@ -489,6 +489,8 @@ class Scheduler(
tp_cache_group
=
self
.
tp_cpu_group
,
tp_cache_group
=
self
.
tp_cpu_group
,
page_size
=
self
.
page_size
,
page_size
=
self
.
page_size
,
hicache_ratio
=
server_args
.
hicache_ratio
,
hicache_ratio
=
server_args
.
hicache_ratio
,
hicache_size
=
server_args
.
hicache_size
,
hicache_write_policy
=
server_args
.
hicache_write_policy
,
)
)
else
:
else
:
self
.
tree_cache
=
RadixCache
(
self
.
tree_cache
=
RadixCache
(
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
70645f4d
...
@@ -29,15 +29,17 @@ class HiRadixCache(RadixCache):
...
@@ -29,15 +29,17 @@ class HiRadixCache(RadixCache):
tp_cache_group
:
torch
.
distributed
.
ProcessGroup
,
tp_cache_group
:
torch
.
distributed
.
ProcessGroup
,
page_size
:
int
,
page_size
:
int
,
hicache_ratio
:
float
,
hicache_ratio
:
float
,
hicache_size
:
int
,
hicache_write_policy
:
str
,
):
):
self
.
kv_cache
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
kv_cache
=
token_to_kv_pool_allocator
.
get_kvcache
()
if
isinstance
(
self
.
kv_cache
,
MHATokenToKVPool
):
if
isinstance
(
self
.
kv_cache
,
MHATokenToKVPool
):
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
self
.
kv_cache
,
hicache_ratio
,
page_size
self
.
kv_cache
,
hicache_ratio
,
hicache_size
,
page_size
)
)
elif
isinstance
(
self
.
kv_cache
,
MLATokenToKVPool
):
elif
isinstance
(
self
.
kv_cache
,
MLATokenToKVPool
):
self
.
token_to_kv_pool_host
=
MLATokenToKVPoolHost
(
self
.
token_to_kv_pool_host
=
MLATokenToKVPoolHost
(
self
.
kv_cache
,
hicache_ratio
,
page_size
self
.
kv_cache
,
hicache_ratio
,
hicache_size
,
page_size
)
)
else
:
else
:
raise
ValueError
(
f
"HiRadixCache only supports MHA and MLA yet"
)
raise
ValueError
(
f
"HiRadixCache only supports MHA and MLA yet"
)
...
@@ -50,6 +52,7 @@ class HiRadixCache(RadixCache):
...
@@ -50,6 +52,7 @@ class HiRadixCache(RadixCache):
self
.
token_to_kv_pool_host
,
self
.
token_to_kv_pool_host
,
page_size
,
page_size
,
load_cache_event
=
self
.
load_cache_event
,
load_cache_event
=
self
.
load_cache_event
,
write_policy
=
hicache_write_policy
,
)
)
# record the nodes with ongoing write through
# record the nodes with ongoing write through
...
@@ -57,7 +60,9 @@ class HiRadixCache(RadixCache):
...
@@ -57,7 +60,9 @@ class HiRadixCache(RadixCache):
# record the node segments with ongoing load back
# record the node segments with ongoing load back
self
.
ongoing_load_back
=
{}
self
.
ongoing_load_back
=
{}
# todo: dynamically adjust the threshold
# todo: dynamically adjust the threshold
self
.
write_through_threshold
=
1
self
.
write_through_threshold
=
(
1
if
hicache_write_policy
==
"write_through"
else
3
)
self
.
load_back_threshold
=
10
self
.
load_back_threshold
=
10
super
().
__init__
(
super
().
__init__
(
req_to_token_pool
,
token_to_kv_pool_allocator
,
page_size
,
disable
=
False
req_to_token_pool
,
token_to_kv_pool_allocator
,
page_size
,
disable
=
False
...
@@ -76,7 +81,7 @@ class HiRadixCache(RadixCache):
...
@@ -76,7 +81,7 @@ class HiRadixCache(RadixCache):
height
+=
1
height
+=
1
return
height
return
height
def
write_backup
(
self
,
node
:
TreeNode
):
def
write_backup
(
self
,
node
:
TreeNode
,
write_back
=
False
):
host_indices
=
self
.
cache_controller
.
write
(
host_indices
=
self
.
cache_controller
.
write
(
device_indices
=
node
.
value
,
device_indices
=
node
.
value
,
node_id
=
node
.
id
,
node_id
=
node
.
id
,
...
@@ -90,21 +95,29 @@ class HiRadixCache(RadixCache):
...
@@ -90,21 +95,29 @@ class HiRadixCache(RadixCache):
if
host_indices
is
not
None
:
if
host_indices
is
not
None
:
node
.
host_value
=
host_indices
node
.
host_value
=
host_indices
self
.
ongoing_write_through
[
node
.
id
]
=
node
self
.
ongoing_write_through
[
node
.
id
]
=
node
self
.
inc_lock_ref
(
node
)
if
not
write_back
:
# no need to lock nodes if write back
self
.
inc_lock_ref
(
node
)
else
:
else
:
return
0
return
0
return
len
(
host_indices
)
return
len
(
host_indices
)
def
inc_hit_count
(
self
,
node
:
TreeNode
):
def
inc_hit_count
(
self
,
node
:
TreeNode
):
if
self
.
cache_controller
.
write_policy
!
=
"write_
through_selective
"
:
if
node
.
backuped
or
self
.
cache_controller
.
write_policy
=
=
"write_
back
"
:
return
return
node
.
hit_count
+=
1
node
.
hit_count
+=
1
if
node
.
host_value
is
None
and
node
.
hit_count
>
self
.
write_through_threshold
:
if
node
.
hit_count
>
=
self
.
write_through_threshold
:
self
.
write_backup
(
node
)
self
.
write_backup
(
node
)
node
.
hit_count
=
0
node
.
hit_count
=
0
def
writing_check
(
self
):
def
writing_check
(
self
,
write_back
=
False
):
if
write_back
:
# blocking till all write back complete
while
len
(
self
.
ongoing_write_through
)
>
0
:
ack_id
=
self
.
cache_controller
.
ack_write_queue
.
get
()
del
self
.
ongoing_write_through
[
ack_id
]
return
queue_size
=
torch
.
tensor
(
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
ack_write_queue
.
qsize
(),
dtype
=
torch
.
int
self
.
cache_controller
.
ack_write_queue
.
qsize
(),
dtype
=
torch
.
int
)
)
...
@@ -143,29 +156,25 @@ class HiRadixCache(RadixCache):
...
@@ -143,29 +156,25 @@ class HiRadixCache(RadixCache):
heapq
.
heapify
(
leaves
)
heapq
.
heapify
(
leaves
)
num_evicted
=
0
num_evicted
=
0
pending
_nodes
=
[]
write_back
_nodes
=
[]
while
num_evicted
<
num_tokens
and
len
(
leaves
):
while
num_evicted
<
num_tokens
and
len
(
leaves
):
x
=
heapq
.
heappop
(
leaves
)
x
=
heapq
.
heappop
(
leaves
)
if
x
.
lock_ref
>
0
:
if
x
.
lock_ref
>
0
:
continue
continue
if
x
.
host_value
is
None
:
if
not
x
.
backuped
:
if
self
.
cache_controller
.
write_policy
==
"write_back"
:
if
self
.
cache_controller
.
write_policy
==
"write_back"
:
num_evicted
+=
self
.
write_backup
(
x
)
# write to host if the node is not backuped
pending_nodes
.
append
(
x
)
num_evicted
+=
self
.
write_backup
(
x
,
write_back
=
True
)
elif
self
.
cache_controller
.
write_policy
==
"write_through_selective"
:
write_back_nodes
.
append
(
x
)
num_evicted
+=
self
.
_evict_write_through_selective
(
x
)
else
:
else
:
assert
(
num_evicted
+=
self
.
_evict_regular
(
x
)
self
.
cache_controller
.
write_policy
!=
"write_through"
),
"write_through should be inclusive"
raise
NotImplementedError
else
:
else
:
num_evicted
+=
self
.
_evict_
write_through
(
x
)
num_evicted
+=
self
.
_evict_
backuped
(
x
)
for
child
in
x
.
parent
.
children
.
values
():
for
child
in
x
.
parent
.
children
.
values
():
if
child
in
pending
_nodes
:
if
child
in
write_back
_nodes
:
continue
continue
if
not
child
.
evicted
:
if
not
child
.
evicted
:
break
break
...
@@ -174,15 +183,12 @@ class HiRadixCache(RadixCache):
...
@@ -174,15 +183,12 @@ class HiRadixCache(RadixCache):
heapq
.
heappush
(
leaves
,
x
.
parent
)
heapq
.
heappush
(
leaves
,
x
.
parent
)
if
self
.
cache_controller
.
write_policy
==
"write_back"
:
if
self
.
cache_controller
.
write_policy
==
"write_back"
:
# blocking till all write back complete
self
.
writing_check
(
write_back
=
True
)
while
len
(
self
.
ongoing_write_through
)
>
0
:
for
node
in
write_back_nodes
:
self
.
writing_check
()
assert
node
.
backuped
time
.
sleep
(
0.1
)
self
.
_evict_backuped
(
node
)
for
node
in
pending_nodes
:
assert
node
.
host_value
is
not
None
self
.
_evict_write_through
(
node
)
def
_evict_
write_through
(
self
,
node
:
TreeNode
):
def
_evict_
backuped
(
self
,
node
:
TreeNode
):
# evict a node already written to host
# evict a node already written to host
num_evicted
=
self
.
cache_controller
.
evict_device
(
node
.
value
,
node
.
host_value
)
num_evicted
=
self
.
cache_controller
.
evict_device
(
node
.
value
,
node
.
host_value
)
assert
num_evicted
>
0
assert
num_evicted
>
0
...
@@ -190,7 +196,7 @@ class HiRadixCache(RadixCache):
...
@@ -190,7 +196,7 @@ class HiRadixCache(RadixCache):
node
.
value
=
None
node
.
value
=
None
return
num_evicted
return
num_evicted
def
_evict_
write_through_selective
(
self
,
node
:
TreeNode
):
def
_evict_
regular
(
self
,
node
:
TreeNode
):
# evict a node not initiated write to host
# evict a node not initiated write to host
self
.
cache_controller
.
mem_pool_device_allocator
.
free
(
node
.
value
)
self
.
cache_controller
.
mem_pool_device_allocator
.
free
(
node
.
value
)
num_evicted
=
len
(
node
.
value
)
num_evicted
=
len
(
node
.
value
)
...
@@ -339,11 +345,13 @@ class HiRadixCache(RadixCache):
...
@@ -339,11 +345,13 @@ class HiRadixCache(RadixCache):
prefix_len
=
self
.
key_match_fn
(
child
.
key
,
key
)
prefix_len
=
self
.
key_match_fn
(
child
.
key
,
key
)
if
prefix_len
<
len
(
child
.
key
):
if
prefix_len
<
len
(
child
.
key
):
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
self
.
inc_hit_count
(
new_node
)
if
not
new_node
.
evicted
:
if
not
new_node
.
evicted
:
value
.
append
(
new_node
.
value
)
value
.
append
(
new_node
.
value
)
node
=
new_node
node
=
new_node
break
break
else
:
else
:
self
.
inc_hit_count
(
child
)
if
not
child
.
evicted
:
if
not
child
.
evicted
:
value
.
append
(
child
.
value
)
value
.
append
(
child
.
value
)
node
=
child
node
=
child
...
@@ -369,7 +377,7 @@ class HiRadixCache(RadixCache):
...
@@ -369,7 +377,7 @@ class HiRadixCache(RadixCache):
else
:
else
:
new_node
.
value
=
child
.
value
[:
split_len
]
new_node
.
value
=
child
.
value
[:
split_len
]
child
.
value
=
child
.
value
[
split_len
:]
child
.
value
=
child
.
value
[
split_len
:]
if
child
.
host_value
is
not
None
:
if
child
.
backuped
:
new_node
.
host_value
=
child
.
host_value
[:
split_len
]
new_node
.
host_value
=
child
.
host_value
[:
split_len
]
child
.
host_value
=
child
.
host_value
[
split_len
:]
child
.
host_value
=
child
.
host_value
[
split_len
:]
child
.
parent
=
new_node
child
.
parent
=
new_node
...
@@ -426,8 +434,8 @@ class HiRadixCache(RadixCache):
...
@@ -426,8 +434,8 @@ class HiRadixCache(RadixCache):
node
.
children
[
child_key
]
=
new_node
node
.
children
[
child_key
]
=
new_node
self
.
evictable_size_
+=
len
(
value
)
self
.
evictable_size_
+=
len
(
value
)
if
self
.
cache_controller
.
write_policy
=
=
"write_
through
"
:
if
self
.
cache_controller
.
write_policy
!
=
"write_
back
"
:
self
.
write_backup
(
new_node
)
self
.
inc_hit_count
(
new_node
)
return
total_prefix_length
return
total_prefix_length
def
_collect_leaves_device
(
self
):
def
_collect_leaves_device
(
self
):
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
70645f4d
...
@@ -624,26 +624,27 @@ class HostKVCache(abc.ABC):
...
@@ -624,26 +624,27 @@ class HostKVCache(abc.ABC):
self
,
self
,
device_pool
:
MHATokenToKVPool
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
,
host_to_device_ratio
:
float
,
host_size
:
int
,
pin_memory
:
bool
,
pin_memory
:
bool
,
device
:
str
,
device
:
str
,
page_size
:
int
,
page_size
:
int
,
):
):
assert
(
host_to_device_ratio
>=
1
),
"The host memory should be larger than the device memory with the current protocol"
# todo, other ways of configuring the size
self
.
device_pool
=
device_pool
self
.
device_pool
=
device_pool
self
.
host_to_device_ratio
=
host_to_device_ratio
self
.
dtype
=
device_pool
.
store_dtype
self
.
pin_memory
=
pin_memory
self
.
pin_memory
=
pin_memory
self
.
device
=
device
self
.
device
=
device
self
.
page_size
=
page_size
self
.
page_size
=
page_size
self
.
size_per_token
=
self
.
get_size_per_token
()
self
.
size
=
int
(
device_pool
.
size
*
host_to_device_ratio
)
if
host_size
>
0
:
self
.
size
=
int
(
host_size
*
1e9
//
self
.
size_per_token
)
else
:
self
.
size
=
int
(
device_pool
.
size
*
host_to_device_ratio
)
# Align the host memory pool size to the page size
# Align the host memory pool size to the page size
self
.
size
=
self
.
size
-
(
self
.
size
%
self
.
page_size
)
self
.
size
=
self
.
size
-
(
self
.
size
%
self
.
page_size
)
self
.
dtype
=
device_pool
.
store_dtype
self
.
size_per_token
=
self
.
get_size_per_token
()
assert
(
self
.
size
>
device_pool
.
size
),
"The host memory should be larger than the device memory with the current protocol"
# Verify there is enough available host memory.
# Verify there is enough available host memory.
host_mem
=
psutil
.
virtual_memory
()
host_mem
=
psutil
.
virtual_memory
()
...
@@ -795,12 +796,13 @@ class MHATokenToKVPoolHost(HostKVCache):
...
@@ -795,12 +796,13 @@ class MHATokenToKVPoolHost(HostKVCache):
self
,
self
,
device_pool
:
MHATokenToKVPool
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
,
host_to_device_ratio
:
float
,
host_size
:
int
,
page_size
:
int
,
page_size
:
int
,
pin_memory
:
bool
=
True
,
pin_memory
:
bool
=
True
,
device
:
str
=
"cpu"
,
device
:
str
=
"cpu"
,
):
):
super
().
__init__
(
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
pin_memory
,
device
,
page_size
device_pool
,
host_to_device_ratio
,
host_size
,
pin_memory
,
device
,
page_size
)
)
def
get_size_per_token
(
self
):
def
get_size_per_token
(
self
):
...
@@ -869,12 +871,13 @@ class MLATokenToKVPoolHost(HostKVCache):
...
@@ -869,12 +871,13 @@ class MLATokenToKVPoolHost(HostKVCache):
self
,
self
,
device_pool
:
MLATokenToKVPool
,
device_pool
:
MLATokenToKVPool
,
host_to_device_ratio
:
float
,
host_to_device_ratio
:
float
,
host_size
:
int
,
page_size
:
int
,
page_size
:
int
,
pin_memory
:
bool
=
True
,
pin_memory
:
bool
=
True
,
device
:
str
=
"cpu"
,
device
:
str
=
"cpu"
,
):
):
super
().
__init__
(
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
pin_memory
,
device
,
page_size
device_pool
,
host_to_device_ratio
,
host_size
,
pin_memory
,
device
,
page_size
)
)
def
get_size_per_token
(
self
):
def
get_size_per_token
(
self
):
...
...
python/sglang/srt/server_args.py
View file @
70645f4d
...
@@ -180,6 +180,8 @@ class ServerArgs:
...
@@ -180,6 +180,8 @@ class ServerArgs:
tool_call_parser
:
Optional
[
str
]
=
None
tool_call_parser
:
Optional
[
str
]
=
None
enable_hierarchical_cache
:
bool
=
False
enable_hierarchical_cache
:
bool
=
False
hicache_ratio
:
float
=
2.0
hicache_ratio
:
float
=
2.0
hicache_size
:
int
=
0
hicache_write_policy
:
str
=
"write_through_selective"
flashinfer_mla_disable_ragged
:
bool
=
False
flashinfer_mla_disable_ragged
:
bool
=
False
warmups
:
Optional
[
str
]
=
None
warmups
:
Optional
[
str
]
=
None
moe_dense_tp_size
:
Optional
[
int
]
=
None
moe_dense_tp_size
:
Optional
[
int
]
=
None
...
@@ -1116,10 +1118,22 @@ class ServerArgs:
...
@@ -1116,10 +1118,22 @@ class ServerArgs:
parser
.
add_argument
(
parser
.
add_argument
(
"--hicache-ratio"
,
"--hicache-ratio"
,
type
=
float
,
type
=
float
,
required
=
False
,
default
=
ServerArgs
.
hicache_ratio
,
default
=
ServerArgs
.
hicache_ratio
,
help
=
"The ratio of the size of host KV cache memory pool to the size of device pool."
,
help
=
"The ratio of the size of host KV cache memory pool to the size of device pool."
,
)
)
parser
.
add_argument
(
"--hicache-size"
,
type
=
int
,
default
=
ServerArgs
.
hicache_size
,
help
=
"The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set."
,
)
parser
.
add_argument
(
"--hicache-write-policy"
,
type
=
str
,
choices
=
[
"write_back"
,
"write_through"
,
"write_through_selective"
],
default
=
ServerArgs
.
hicache_write_policy
,
help
=
"The write policy of hierarchical cache."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
test/srt/test_hicache.py
View file @
70645f4d
...
@@ -23,6 +23,10 @@ class TestHiCache(CustomTestCase):
...
@@ -23,6 +23,10 @@ class TestHiCache(CustomTestCase):
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
other_args
=
[
"--enable-hierarchical-cache"
,
"--enable-hierarchical-cache"
,
"--mem-fraction-static"
,
0.7
,
"--hicache-size"
,
100
,
],
],
)
)
...
...
test/srt/test_hicache_mla.py
View file @
70645f4d
...
@@ -24,6 +24,8 @@ class TestHierarchicalMLA(CustomTestCase):
...
@@ -24,6 +24,8 @@ class TestHierarchicalMLA(CustomTestCase):
other_args
=
[
other_args
=
[
"--trust-remote-code"
,
"--trust-remote-code"
,
"--enable-hierarchical-cache"
,
"--enable-hierarchical-cache"
,
"--hicache-ratio"
,
2
,
],
],
)
)
...
...
test/srt/test_hicache_page.py
View file @
70645f4d
...
@@ -24,7 +24,9 @@ class TestHiCachePage(CustomTestCase):
...
@@ -24,7 +24,9 @@ class TestHiCachePage(CustomTestCase):
other_args
=
[
other_args
=
[
"--enable-hierarchical-cache"
,
"--enable-hierarchical-cache"
,
"--page-size"
,
"--page-size"
,
"32"
,
32
,
"--hicache-write-policy"
,
"write-back"
,
],
],
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment