Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e119f042
Unverified
Commit
e119f042
authored
Apr 01, 2025
by
Zhiqiang Xie
Committed by
GitHub
Apr 01, 2025
Browse files
Large page size aligned hierarchical caching (#4581)
parent
9eb49e87
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
242 additions
and
71 deletions
+242
-71
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+34
-11
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+60
-50
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+66
-6
python/sglang/srt/mem_cache/paged_allocator.py
python/sglang/srt/mem_cache/paged_allocator.py
+24
-0
test/srt/test_hicache.py
test/srt/test_hicache.py
+4
-2
test/srt/test_hicache_mla.py
test/srt/test_hicache_mla.py
+4
-1
test/srt/test_hicache_page.py
test/srt/test_hicache_page.py
+49
-0
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
e119f042
...
@@ -149,6 +149,7 @@ class HiCacheController:
...
@@ -149,6 +149,7 @@ class HiCacheController:
self
,
self
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
TokenToKVPoolAllocator
,
mem_pool_host
:
HostKVCache
,
mem_pool_host
:
HostKVCache
,
page_size
:
int
,
load_cache_event
:
threading
.
Event
=
None
,
load_cache_event
:
threading
.
Event
=
None
,
write_policy
:
str
=
"write_through_selective"
,
write_policy
:
str
=
"write_through_selective"
,
):
):
...
@@ -156,6 +157,7 @@ class HiCacheController:
...
@@ -156,6 +157,7 @@ class HiCacheController:
self
.
mem_pool_device
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
mem_pool_device
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
mem_pool_host
=
mem_pool_host
self
.
mem_pool_host
=
mem_pool_host
self
.
write_policy
=
write_policy
self
.
write_policy
=
write_policy
self
.
page_size
=
page_size
self
.
load_cache_event
=
load_cache_event
self
.
load_cache_event
=
load_cache_event
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
...
@@ -184,7 +186,12 @@ class HiCacheController:
...
@@ -184,7 +186,12 @@ class HiCacheController:
self
.
load_stream
=
torch
.
cuda
.
Stream
()
self
.
load_stream
=
torch
.
cuda
.
Stream
()
self
.
write_thread
=
threading
.
Thread
(
self
.
write_thread
=
threading
.
Thread
(
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
target
=
(
self
.
write_thread_func_buffer
if
self
.
page_size
==
1
else
self
.
write_thread_func_direct
),
daemon
=
True
,
)
)
self
.
load_thread
=
threading
.
Thread
(
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
...
@@ -205,7 +212,12 @@ class HiCacheController:
...
@@ -205,7 +212,12 @@ class HiCacheController:
self
.
ack_load_queue
.
queue
.
clear
()
self
.
ack_load_queue
.
queue
.
clear
()
self
.
write_thread
=
threading
.
Thread
(
self
.
write_thread
=
threading
.
Thread
(
target
=
self
.
write_thread_func_buffer
,
daemon
=
True
target
=
(
self
.
write_thread_func_buffer
if
self
.
page_size
==
1
else
self
.
write_thread_func_direct
),
daemon
=
True
,
)
)
self
.
load_thread
=
threading
.
Thread
(
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
...
@@ -260,10 +272,12 @@ class HiCacheController:
...
@@ -260,10 +272,12 @@ class HiCacheController:
while
not
self
.
stop_event
.
is_set
():
while
not
self
.
stop_event
.
is_set
():
try
:
try
:
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
operation
=
self
.
write_queue
.
get
(
block
=
True
,
timeout
=
1
)
operation
.
data
=
self
.
mem_pool_device
.
get_flat_data
(
self
.
mem_pool_host
.
write_page_all_layers
(
operation
.
device_indices
operation
.
host_indices
,
operation
.
device_indices
,
self
.
mem_pool_device
,
)
)
self
.
mem_pool_host
.
transfer
(
operation
.
host_indices
,
operation
.
data
)
self
.
write_stream
.
synchronize
(
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
self
.
mem_pool_host
.
complete_io
(
operation
.
host_indices
)
for
node_id
in
operation
.
node_ids
:
for
node_id
in
operation
.
node_ids
:
if
node_id
!=
0
:
if
node_id
!=
0
:
...
@@ -320,12 +334,21 @@ class HiCacheController:
...
@@ -320,12 +334,21 @@ class HiCacheController:
self
.
layer_done_counter
.
reset
()
self
.
layer_done_counter
.
reset
()
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
for
i
in
range
(
self
.
mem_pool_host
.
layer_num
):
flat_data
=
self
.
mem_pool_host
.
get_flat_data_by_layer
(
if
self
.
page_size
==
1
:
batch_operation
.
host_indices
,
i
flat_data
=
self
.
mem_pool_host
.
get_flat_data_by_layer
(
)
batch_operation
.
host_indices
,
i
self
.
mem_pool_device
.
transfer_per_layer
(
)
batch_operation
.
device_indices
,
flat_data
,
i
self
.
mem_pool_device
.
transfer_per_layer
(
)
batch_operation
.
device_indices
,
flat_data
,
i
)
else
:
self
.
mem_pool_host
.
load_page_per_layer
(
batch_operation
.
host_indices
,
batch_operation
.
device_indices
,
self
.
mem_pool_device
,
i
,
)
self
.
load_stream
.
synchronize
()
self
.
layer_done_counter
.
increment
()
self
.
layer_done_counter
.
increment
()
self
.
mem_pool_host
.
complete_io
(
batch_operation
.
host_indices
)
self
.
mem_pool_host
.
complete_io
(
batch_operation
.
host_indices
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
e119f042
...
@@ -1282,7 +1282,7 @@ class Scheduler(
...
@@ -1282,7 +1282,7 @@ class Scheduler(
]
]
if
self
.
enable_hierarchical_cache
:
if
self
.
enable_hierarchical_cache
:
self
.
tree_cache
.
read_to_load_cache
()
self
.
tree_cache
.
read
y
_to_load_cache
()
if
adder
.
new_chunked_req
is
not
None
:
if
adder
.
new_chunked_req
is
not
None
:
assert
self
.
chunked_req
is
None
assert
self
.
chunked_req
is
None
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
e119f042
...
@@ -16,7 +16,6 @@ from sglang.srt.mem_cache.memory_pool import (
...
@@ -16,7 +16,6 @@ from sglang.srt.mem_cache.memory_pool import (
TokenToKVPoolAllocator
,
TokenToKVPoolAllocator
,
)
)
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
,
TreeNode
from
sglang.srt.mem_cache.radix_cache
import
_key_match_page_size1
as
_key_match
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -31,29 +30,25 @@ class HiRadixCache(RadixCache):
...
@@ -31,29 +30,25 @@ class HiRadixCache(RadixCache):
page_size
:
int
,
page_size
:
int
,
hicache_ratio
:
float
,
hicache_ratio
:
float
,
):
):
if
page_size
!=
1
:
raise
ValueError
(
"Page size larger than 1 is not yet supported in HiRadixCache."
)
self
.
kv_cache
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
kv_cache
=
token_to_kv_pool_allocator
.
get_kvcache
()
if
isinstance
(
self
.
kv_cache
,
MHATokenToKVPool
):
if
isinstance
(
self
.
kv_cache
,
MHATokenToKVPool
):
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
self
.
token_to_kv_pool_host
=
MHATokenToKVPoolHost
(
self
.
kv_cache
,
hicache_ratio
self
.
kv_cache
,
hicache_ratio
,
page_size
)
)
elif
isinstance
(
self
.
kv_cache
,
MLATokenToKVPool
):
elif
isinstance
(
self
.
kv_cache
,
MLATokenToKVPool
):
self
.
token_to_kv_pool_host
=
MLATokenToKVPoolHost
(
self
.
token_to_kv_pool_host
=
MLATokenToKVPoolHost
(
self
.
kv_cache
,
hicache_ratio
self
.
kv_cache
,
hicache_ratio
,
page_size
)
)
else
:
else
:
raise
ValueError
(
f
"
Only MHA and MLA supports swap kv_cache to host.
"
)
raise
ValueError
(
f
"
HiRadixCache only supports MHA and MLA yet
"
)
self
.
tp_group
=
tp_cache_group
self
.
tp_group
=
tp_cache_group
self
.
page_size
=
page_size
self
.
load_cache_event
=
threading
.
Event
()
self
.
load_cache_event
=
threading
.
Event
()
self
.
cache_controller
=
HiCacheController
(
self
.
cache_controller
=
HiCacheController
(
token_to_kv_pool_allocator
,
token_to_kv_pool_allocator
,
self
.
token_to_kv_pool_host
,
self
.
token_to_kv_pool_host
,
page_size
,
load_cache_event
=
self
.
load_cache_event
,
load_cache_event
=
self
.
load_cache_event
,
)
)
...
@@ -65,7 +60,7 @@ class HiRadixCache(RadixCache):
...
@@ -65,7 +60,7 @@ class HiRadixCache(RadixCache):
self
.
write_through_threshold
=
1
self
.
write_through_threshold
=
1
self
.
load_back_threshold
=
10
self
.
load_back_threshold
=
10
super
().
__init__
(
super
().
__init__
(
req_to_token_pool
,
token_to_kv_pool_allocator
,
self
.
page_size
,
disable
=
False
req_to_token_pool
,
token_to_kv_pool_allocator
,
page_size
,
disable
=
False
)
)
def
reset
(
self
):
def
reset
(
self
):
...
@@ -299,18 +294,26 @@ class HiRadixCache(RadixCache):
...
@@ -299,18 +294,26 @@ class HiRadixCache(RadixCache):
return
last_node
,
prefix_indices
return
last_node
,
prefix_indices
def
read_to_load_cache
(
self
):
def
read
y
_to_load_cache
(
self
):
self
.
load_cache_event
.
set
()
self
.
load_cache_event
.
set
()
def
match_prefix
(
self
,
key
:
List
[
int
],
include_evicted
=
False
,
**
kwargs
):
def
match_prefix
(
self
,
key
:
List
[
int
],
include_evicted
=
False
,
**
kwargs
):
if
self
.
disable
:
empty_value
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
return
[],
self
.
root_node
if
self
.
disable
or
len
(
key
)
==
0
:
if
include_evicted
:
return
empty_value
,
self
.
root_node
,
self
.
root_node
else
:
return
empty_value
,
self
.
root_node
if
self
.
page_size
!=
1
:
page_aligned_len
=
len
(
key
)
//
self
.
page_size
*
self
.
page_size
key
=
key
[:
page_aligned_len
]
value
,
last_node
=
self
.
_match_prefix_helper
(
self
.
root_node
,
key
)
value
,
last_node
=
self
.
_match_prefix_helper
(
self
.
root_node
,
key
)
if
value
:
if
value
:
value
=
torch
.
cat
(
value
)
value
=
torch
.
cat
(
value
)
else
:
else
:
value
=
torch
.
tensor
([],
dtype
=
torch
.
int64
)
value
=
empty_value
last_node_global
=
last_node
last_node_global
=
last_node
while
last_node
.
evicted
:
while
last_node
.
evicted
:
...
@@ -323,11 +326,13 @@ class HiRadixCache(RadixCache):
...
@@ -323,11 +326,13 @@ class HiRadixCache(RadixCache):
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
node
.
last_access_time
=
time
.
time
()
node
.
last_access_time
=
time
.
time
()
child_key
=
self
.
get_child_key_fn
(
key
)
value
=
[]
value
=
[]
while
len
(
key
)
>
0
and
key
[
0
]
in
node
.
children
.
keys
():
child
=
node
.
children
[
key
[
0
]]
while
len
(
key
)
>
0
and
child_key
in
node
.
children
.
keys
():
child
=
node
.
children
[
child_key
]
child
.
last_access_time
=
time
.
time
()
child
.
last_access_time
=
time
.
time
()
prefix_len
=
_
key_match
(
child
.
key
,
key
)
prefix_len
=
self
.
key_match
_fn
(
child
.
key
,
key
)
if
prefix_len
<
len
(
child
.
key
):
if
prefix_len
<
len
(
child
.
key
):
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
if
not
new_node
.
evicted
:
if
not
new_node
.
evicted
:
...
@@ -339,12 +344,16 @@ class HiRadixCache(RadixCache):
...
@@ -339,12 +344,16 @@ class HiRadixCache(RadixCache):
value
.
append
(
child
.
value
)
value
.
append
(
child
.
value
)
node
=
child
node
=
child
key
=
key
[
prefix_len
:]
key
=
key
[
prefix_len
:]
if
len
(
key
):
child_key
=
self
.
get_child_key_fn
(
key
)
return
value
,
node
return
value
,
node
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
:
int
):
def
_split_node
(
self
,
key
,
child
:
TreeNode
,
split_len
:
int
):
# child node split into new_node -> child
# child node split into new_node -> child
new_node
=
TreeNode
()
new_node
=
TreeNode
()
new_node
.
children
=
{
key
[
split_len
]
:
child
}
new_node
.
children
=
{
self
.
get_child_key_fn
(
key
[
split_len
:])
:
child
}
new_node
.
parent
=
child
.
parent
new_node
.
parent
=
child
.
parent
new_node
.
lock_ref
=
child
.
lock_ref
new_node
.
lock_ref
=
child
.
lock_ref
new_node
.
key
=
child
.
key
[:
split_len
]
new_node
.
key
=
child
.
key
[:
split_len
]
...
@@ -361,7 +370,7 @@ class HiRadixCache(RadixCache):
...
@@ -361,7 +370,7 @@ class HiRadixCache(RadixCache):
child
.
host_value
=
child
.
host_value
[
split_len
:]
child
.
host_value
=
child
.
host_value
[
split_len
:]
child
.
parent
=
new_node
child
.
parent
=
new_node
child
.
key
=
child
.
key
[
split_len
:]
child
.
key
=
child
.
key
[
split_len
:]
new_node
.
parent
.
children
[
key
[
0
]
]
=
new_node
new_node
.
parent
.
children
[
self
.
get_child_key_fn
(
key
)
]
=
new_node
return
new_node
return
new_node
def
_insert_helper
(
self
,
node
:
TreeNode
,
key
:
List
,
value
):
def
_insert_helper
(
self
,
node
:
TreeNode
,
key
:
List
,
value
):
...
@@ -369,52 +378,53 @@ class HiRadixCache(RadixCache):
...
@@ -369,52 +378,53 @@ class HiRadixCache(RadixCache):
if
len
(
key
)
==
0
:
if
len
(
key
)
==
0
:
return
0
return
0
if
key
[
0
]
in
node
.
children
.
keys
():
child_key
=
self
.
get_child_key_fn
(
key
)
child
=
node
.
children
[
key
[
0
]]
total_prefix_length
=
0
prefix_len
=
_key_match
(
child
.
key
,
key
)
while
len
(
key
)
>
0
and
child_key
in
node
.
children
.
keys
():
node
=
node
.
children
[
child_key
]
node
.
last_access_time
=
time
.
time
()
prefix_len
=
self
.
key_match_fn
(
node
.
key
,
key
)
if
prefix_len
==
len
(
child
.
key
):
if
prefix_len
==
len
(
node
.
key
):
if
child
.
evicted
:
if
node
.
evicted
:
# change the reference if the node is evicted
# change the reference if the node is evicted
# this often happens in the case of KV cache recomputation
# this often happens in the case of KV cache recomputation
child
.
value
=
value
[:
prefix_len
]
node
.
value
=
value
[:
prefix_len
]
self
.
token_to_kv_pool_host
.
update_synced
(
child
.
host_value
)
self
.
token_to_kv_pool_host
.
update_synced
(
node
.
host_value
)
self
.
evictable_size_
+=
len
(
value
[:
prefix_len
])
self
.
evictable_size_
+=
len
(
node
.
value
)
return
self
.
_insert_helper
(
child
,
key
[
prefix_len
:],
value
[
prefix_len
:]
)
else
:
else
:
self
.
inc_hit_count
(
child
)
self
.
inc_hit_count
(
node
)
return
prefix_len
+
self
.
_insert_helper
(
total_prefix_length
+=
prefix_len
child
,
key
[
prefix_len
:],
value
[
prefix_len
:]
)
# partial match, split the node
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
if
new_node
.
evicted
:
new_node
.
value
=
value
[:
prefix_len
]
self
.
token_to_kv_pool_host
.
update_synced
(
new_node
.
host_value
)
self
.
evictable_size_
+=
len
(
new_node
.
value
)
return
self
.
_insert_helper
(
new_node
,
key
[
prefix_len
:],
value
[
prefix_len
:]
)
else
:
else
:
self
.
inc_hit_count
(
new_node
)
# partial match, split the node
return
prefix_len
+
self
.
_insert_helper
(
new_node
=
self
.
_split_node
(
node
.
key
,
node
,
prefix_len
)
new_node
,
key
[
prefix_len
:],
value
[
prefix_len
:]
if
new_node
.
evicted
:
)
new_node
.
value
=
value
[:
prefix_len
]
self
.
token_to_kv_pool_host
.
update_synced
(
new_node
.
host_value
)
self
.
evictable_size_
+=
len
(
new_node
.
value
)
else
:
self
.
inc_hit_count
(
new_node
)
total_prefix_length
+=
prefix_len
node
=
new_node
key
=
key
[
prefix_len
:]
value
=
value
[
prefix_len
:]
if
len
(
key
):
child_key
=
self
.
get_child_key_fn
(
key
)
if
len
(
key
):
if
len
(
key
):
new_node
=
TreeNode
()
new_node
=
TreeNode
()
new_node
.
parent
=
node
new_node
.
parent
=
node
new_node
.
key
=
key
new_node
.
key
=
key
new_node
.
value
=
value
new_node
.
value
=
value
node
.
children
[
key
[
0
]
]
=
new_node
node
.
children
[
child_
key
]
=
new_node
self
.
evictable_size_
+=
len
(
value
)
self
.
evictable_size_
+=
len
(
value
)
if
self
.
cache_controller
.
write_policy
==
"write_through"
:
if
self
.
cache_controller
.
write_policy
==
"write_through"
:
self
.
write_backup
(
new_node
)
self
.
write_backup
(
new_node
)
return
0
return
total_prefix_length
def
_collect_leaves_device
(
self
):
def
_collect_leaves_device
(
self
):
def
is_leaf
(
node
):
def
is_leaf
(
node
):
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
e119f042
...
@@ -608,8 +608,9 @@ class HostKVCache(abc.ABC):
...
@@ -608,8 +608,9 @@ class HostKVCache(abc.ABC):
self
,
self
,
device_pool
:
MHATokenToKVPool
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
,
host_to_device_ratio
:
float
,
pin_memory
:
bool
=
False
,
# no need to use pin memory with the double buffering
pin_memory
:
bool
,
device
:
str
=
"cpu"
,
device
:
str
,
page_size
:
int
,
):
):
assert
(
assert
(
host_to_device_ratio
>=
1
host_to_device_ratio
>=
1
...
@@ -620,8 +621,11 @@ class HostKVCache(abc.ABC):
...
@@ -620,8 +621,11 @@ class HostKVCache(abc.ABC):
self
.
host_to_device_ratio
=
host_to_device_ratio
self
.
host_to_device_ratio
=
host_to_device_ratio
self
.
pin_memory
=
pin_memory
self
.
pin_memory
=
pin_memory
self
.
device
=
device
self
.
device
=
device
self
.
page_size
=
page_size
self
.
size
=
int
(
device_pool
.
size
*
host_to_device_ratio
)
self
.
size
=
int
(
device_pool
.
size
*
host_to_device_ratio
)
# Align the host memory pool size to the page size
self
.
size
=
self
.
size
-
(
self
.
size
%
self
.
page_size
)
self
.
dtype
=
device_pool
.
store_dtype
self
.
dtype
=
device_pool
.
store_dtype
self
.
size_per_token
=
self
.
get_size_per_token
()
self
.
size_per_token
=
self
.
get_size_per_token
()
...
@@ -775,10 +779,13 @@ class MHATokenToKVPoolHost(HostKVCache):
...
@@ -775,10 +779,13 @@ class MHATokenToKVPoolHost(HostKVCache):
self
,
self
,
device_pool
:
MHATokenToKVPool
,
device_pool
:
MHATokenToKVPool
,
host_to_device_ratio
:
float
,
host_to_device_ratio
:
float
,
pin_memory
:
bool
=
False
,
# no need to use pin memory with the double buffering
page_size
:
int
,
pin_memory
:
bool
=
True
,
device
:
str
=
"cpu"
,
device
:
str
=
"cpu"
,
):
):
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
pin_memory
,
device
)
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
pin_memory
,
device
,
page_size
)
def
get_size_per_token
(
self
):
def
get_size_per_token
(
self
):
self
.
head_num
=
self
.
device_pool
.
head_num
self
.
head_num
=
self
.
device_pool
.
head_num
...
@@ -811,16 +818,48 @@ class MHATokenToKVPoolHost(HostKVCache):
...
@@ -811,16 +818,48 @@ class MHATokenToKVPoolHost(HostKVCache):
def
assign_flat_data
(
self
,
indices
,
flat_data
):
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
self
.
kv_buffer
[:,
:,
indices
]
=
flat_data
def
write_page_all_layers
(
self
,
host_indices
,
device_indices
,
device_pool
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
for
j
in
range
(
self
.
layer_num
):
self
.
kv_buffer
[
0
,
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
k_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
self
.
kv_buffer
[
1
,
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
v_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
def
load_page_per_layer
(
self
,
host_indices
,
device_indices
,
device_pool
,
layer_id
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
device_pool
.
k_buffer
[
layer_id
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
0
,
layer_id
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
device_pool
.
v_buffer
[
layer_id
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
1
,
layer_id
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
class
MLATokenToKVPoolHost
(
HostKVCache
):
class
MLATokenToKVPoolHost
(
HostKVCache
):
def
__init__
(
def
__init__
(
self
,
self
,
device_pool
:
MLATokenToKVPool
,
device_pool
:
MLATokenToKVPool
,
host_to_device_ratio
:
float
,
host_to_device_ratio
:
float
,
pin_memory
:
bool
=
False
,
# no need to use pin memory with the double buffering
page_size
:
int
,
pin_memory
:
bool
=
True
,
device
:
str
=
"cpu"
,
device
:
str
=
"cpu"
,
):
):
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
pin_memory
,
device
)
super
().
__init__
(
device_pool
,
host_to_device_ratio
,
pin_memory
,
device
,
page_size
)
def
get_size_per_token
(
self
):
def
get_size_per_token
(
self
):
self
.
kv_lora_rank
=
self
.
device_pool
.
kv_lora_rank
self
.
kv_lora_rank
=
self
.
device_pool
.
kv_lora_rank
...
@@ -857,3 +896,24 @@ class MLATokenToKVPoolHost(HostKVCache):
...
@@ -857,3 +896,24 @@ class MLATokenToKVPoolHost(HostKVCache):
def
assign_flat_data
(
self
,
indices
,
flat_data
):
def
assign_flat_data
(
self
,
indices
,
flat_data
):
self
.
kv_buffer
[:,
indices
]
=
flat_data
self
.
kv_buffer
[:,
indices
]
=
flat_data
def
write_page_all_layers
(
self
,
host_indices
,
device_indices
,
device_pool
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
for
j
in
range
(
self
.
layer_num
):
self
.
kv_buffer
[
j
,
h_index
:
h_index
+
self
.
page_size
].
copy_
(
device_pool
.
kv_buffer
[
j
][
d_index
:
d_index
+
self
.
page_size
],
non_blocking
=
True
,
)
def
load_page_per_layer
(
self
,
host_indices
,
device_indices
,
device_pool
,
layer_id
):
device_indices_cpu
=
device_indices
[::
self
.
page_size
].
cpu
()
for
i
in
range
(
len
(
device_indices_cpu
)):
h_index
=
host_indices
[
i
*
self
.
page_size
]
d_index
=
device_indices_cpu
[
i
]
device_pool
.
kv_buffer
[
layer_id
][
d_index
:
d_index
+
self
.
page_size
].
copy_
(
self
.
kv_buffer
[
layer_id
,
h_index
:
h_index
+
self
.
page_size
],
non_blocking
=
True
,
)
python/sglang/srt/mem_cache/paged_allocator.py
View file @
e119f042
...
@@ -190,6 +190,30 @@ class PagedTokenToKVPoolAllocator:
...
@@ -190,6 +190,30 @@ class PagedTokenToKVPoolAllocator:
def
available_size
(
self
):
def
available_size
(
self
):
return
len
(
self
.
free_pages
)
*
self
.
page_size
return
len
(
self
.
free_pages
)
*
self
.
page_size
def
get_kvcache
(
self
):
return
self
.
_kvcache
def
alloc
(
self
,
need_size
:
int
):
# page-aligned allocation, returning contiguous indices of pages
if
self
.
debug_mode
:
assert
(
need_size
%
self
.
page_size
==
0
),
"The allocation size should be page-aligned"
num_pages
=
need_size
//
self
.
page_size
if
num_pages
>
len
(
self
.
free_pages
):
return
None
out_pages
=
self
.
free_pages
[:
num_pages
]
self
.
free_pages
=
self
.
free_pages
[
num_pages
:]
out_indices
=
(
out_pages
[:,
None
]
*
self
.
page_size
+
torch
.
arange
(
self
.
page_size
,
device
=
self
.
device
)
).
reshape
(
-
1
)
return
out_indices
def
alloc_extend
(
def
alloc_extend
(
self
,
self
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
...
...
test/srt/test_hicache.py
View file @
e119f042
...
@@ -12,7 +12,7 @@ from sglang.test.test_utils import (
...
@@ -12,7 +12,7 @@ from sglang.test.test_utils import (
)
)
class
Test
PageSiz
e
(
CustomTestCase
):
class
Test
HiCach
e
(
CustomTestCase
):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
...
@@ -21,7 +21,9 @@ class TestPageSize(CustomTestCase):
...
@@ -21,7 +21,9 @@ class TestPageSize(CustomTestCase):
cls
.
model
,
cls
.
model
,
cls
.
base_url
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-hierarchical-cache"
],
other_args
=
[
"--enable-hierarchical-cache"
,
],
)
)
@
classmethod
@
classmethod
...
...
test/srt/test_hicache_mla.py
View file @
e119f042
...
@@ -21,7 +21,10 @@ class TestHierarchicalMLA(CustomTestCase):
...
@@ -21,7 +21,10 @@ class TestHierarchicalMLA(CustomTestCase):
cls
.
model
,
cls
.
model
,
cls
.
base_url
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--enable-hierarchical-cache"
],
other_args
=
[
"--trust-remote-code"
,
"--enable-hierarchical-cache"
,
],
)
)
@
classmethod
@
classmethod
...
...
test/srt/test_hicache_page.py
0 → 100644
View file @
e119f042
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
class
TestHiCachePage
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-hierarchical-cache"
,
"--page-size"
,
"32"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.65
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment