Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9d33fcfb
Unverified
Commit
9d33fcfb
authored
Jul 18, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jul 18, 2025
Browse files
Hicache Storage Layer Prototype (#7704)
parent
7891bac1
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
714 additions
and
4 deletions
+714
-4
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+241
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+14
-0
python/sglang/srt/mem_cache/hicache_storage.py
python/sglang/srt/mem_cache/hicache_storage.py
+152
-0
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+179
-4
python/sglang/srt/mem_cache/memory_pool_host.py
python/sglang/srt/mem_cache/memory_pool_host.py
+38
-0
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+26
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_hicache_storage.py
test/srt/test_hicache_storage.py
+55
-0
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
9d33fcfb
...
@@ -25,6 +25,8 @@ if TYPE_CHECKING:
...
@@ -25,6 +25,8 @@ if TYPE_CHECKING:
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool_host
import
HostKVCache
from
sglang.srt.mem_cache.memory_pool_host
import
HostKVCache
from
sglang.srt.mem_cache.hicache_storage
import
HiCacheFile
,
get_hash_str
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -159,6 +161,57 @@ class TransferBuffer:
...
@@ -159,6 +161,57 @@ class TransferBuffer:
self
.
buffers
.
queue
.
clear
()
self
.
buffers
.
queue
.
clear
()
class
StorageOperation
:
counter
=
0
def
__init__
(
self
,
host_indices
:
torch
.
Tensor
,
token_ids
:
List
[
int
],
last_hash
:
Optional
[
str
]
=
None
,
):
self
.
host_indices
=
host_indices
self
.
token_ids
=
token_ids
self
.
last_hash
=
last_hash
self
.
completed_tokens
=
0
self
.
hash_value
=
[]
self
.
id
=
StorageOperation
.
counter
StorageOperation
.
counter
+=
1
def
__lt__
(
self
,
other
:
"StorageOperation"
):
return
self
.
id
<
other
.
id
class
PrefetchOperation
(
StorageOperation
):
def
__init__
(
self
,
request_id
:
str
,
host_indices
:
torch
.
Tensor
,
token_ids
:
List
[
int
],
last_hash
:
Optional
[
str
]
=
None
,
):
self
.
request_id
=
request_id
self
.
_done_flag
=
False
self
.
_lock
=
threading
.
Lock
()
super
().
__init__
(
host_indices
,
token_ids
,
last_hash
)
def
increment
(
self
,
num_tokens
:
int
):
with
self
.
_lock
:
if
self
.
_done_flag
:
return
self
.
completed_tokens
+=
num_tokens
def
mark_done
(
self
):
with
self
.
_lock
:
self
.
_done_flag
=
True
def
is_done
(
self
)
->
bool
:
return
self
.
_done_flag
class
HiCacheController
:
class
HiCacheController
:
def
__init__
(
def
__init__
(
...
@@ -169,6 +222,8 @@ class HiCacheController:
...
@@ -169,6 +222,8 @@ class HiCacheController:
load_cache_event
:
threading
.
Event
=
None
,
load_cache_event
:
threading
.
Event
=
None
,
write_policy
:
str
=
"write_through_selective"
,
write_policy
:
str
=
"write_through_selective"
,
io_backend
:
str
=
""
,
io_backend
:
str
=
""
,
storage_backend
:
Optional
[
str
]
=
None
,
prefetch_threshold
:
int
=
256
,
):
):
self
.
mem_pool_device_allocator
=
token_to_kv_pool_allocator
self
.
mem_pool_device_allocator
=
token_to_kv_pool_allocator
self
.
mem_pool_device
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
mem_pool_device
=
token_to_kv_pool_allocator
.
get_kvcache
()
...
@@ -186,6 +241,19 @@ class HiCacheController:
...
@@ -186,6 +241,19 @@ class HiCacheController:
else
:
else
:
self
.
io_backend
=
io_backend
self
.
io_backend
=
io_backend
self
.
enable_storage
=
False
# todo: move backend initialization to storage backend module
if
storage_backend
is
not
None
:
if
storage_backend
==
"file"
:
self
.
storage_backend
=
HiCacheFile
()
self
.
enable_storage
=
True
# todo: threshold policy for prefetching
self
.
prefetch_threshold
=
prefetch_threshold
else
:
raise
NotImplementedError
(
f
"Unsupported storage backend:
{
storage_backend
}
"
)
self
.
load_cache_event
=
load_cache_event
self
.
load_cache_event
=
load_cache_event
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
self
.
mem_pool_device
.
register_layer_transfer_counter
(
self
.
layer_done_counter
)
self
.
mem_pool_device
.
register_layer_transfer_counter
(
self
.
layer_done_counter
)
...
@@ -218,9 +286,26 @@ class HiCacheController:
...
@@ -218,9 +286,26 @@ class HiCacheController:
self
.
load_thread
=
threading
.
Thread
(
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
)
)
self
.
write_thread
.
start
()
self
.
write_thread
.
start
()
self
.
load_thread
.
start
()
self
.
load_thread
.
start
()
if
self
.
enable_storage
:
self
.
prefetch_thread
=
threading
.
Thread
(
target
=
self
.
prefetch_thread_func
,
daemon
=
True
)
self
.
backup_thread
=
threading
.
Thread
(
target
=
self
.
backup_thread_func
,
daemon
=
True
)
self
.
prefetch_queue
=
Queue
()
self
.
backup_queue
=
Queue
()
self
.
prefetch_revoke_queue
=
Queue
()
self
.
ack_backup_queue
=
Queue
()
self
.
prefetch_thread
.
start
()
self
.
backup_thread
.
start
()
def
reset
(
self
):
def
reset
(
self
):
self
.
stop_event
.
set
()
self
.
stop_event
.
set
()
self
.
write_thread
.
join
()
self
.
write_thread
.
join
()
...
@@ -232,6 +317,13 @@ class HiCacheController:
...
@@ -232,6 +317,13 @@ class HiCacheController:
self
.
load_buffer
.
clear
()
self
.
load_buffer
.
clear
()
self
.
ack_write_queue
.
queue
.
clear
()
self
.
ack_write_queue
.
queue
.
clear
()
self
.
ack_load_queue
.
queue
.
clear
()
self
.
ack_load_queue
.
queue
.
clear
()
if
self
.
enable_storage
:
self
.
prefetch_thread
.
join
()
self
.
backup_thread
.
join
()
self
.
prefetch_queue
.
queue
.
clear
()
self
.
backup_queue
.
queue
.
clear
()
self
.
prefetch_revoke_queue
.
queue
.
clear
()
self
.
ack_backup_queue
.
queue
.
clear
()
self
.
write_thread
=
threading
.
Thread
(
self
.
write_thread
=
threading
.
Thread
(
target
=
self
.
write_thread_func_direct
,
daemon
=
True
target
=
self
.
write_thread_func_direct
,
daemon
=
True
...
@@ -243,6 +335,16 @@ class HiCacheController:
...
@@ -243,6 +335,16 @@ class HiCacheController:
self
.
write_thread
.
start
()
self
.
write_thread
.
start
()
self
.
load_thread
.
start
()
self
.
load_thread
.
start
()
if
self
.
enable_storage
:
self
.
prefetch_thread
=
threading
.
Thread
(
target
=
self
.
prefetch_thread_func
,
daemon
=
True
)
self
.
backup_thread
=
threading
.
Thread
(
target
=
self
.
backup_thread_func
,
daemon
=
True
)
self
.
prefetch_thread
.
start
()
self
.
backup_thread
.
start
()
def
write
(
def
write
(
self
,
self
,
device_indices
:
torch
.
Tensor
,
device_indices
:
torch
.
Tensor
,
...
@@ -383,3 +485,142 @@ class HiCacheController:
...
@@ -383,3 +485,142 @@ class HiCacheController:
raise
ValueError
(
raise
ValueError
(
f
"Inconsistent states:
{
self
.
mem_pool_host
.
get_state
(
host_indices
)
}
"
f
"Inconsistent states:
{
self
.
mem_pool_host
.
get_state
(
host_indices
)
}
"
)
)
def
prefetch
(
self
,
request_id
:
str
,
host_indices
:
torch
.
Tensor
,
new_input_tokens
:
List
[
int
],
last_hash
:
Optional
[
str
]
=
None
,
)
->
int
:
"""
Prefetch KV caches from storage backend to host memory.
"""
operation
=
PrefetchOperation
(
request_id
,
host_indices
,
new_input_tokens
,
last_hash
)
self
.
prefetch_queue
.
put
(
operation
)
return
operation
def
terminate_prefetch
(
self
,
operation
):
operation
.
mark_done
()
return
operation
.
completed_tokens
,
operation
.
hash_value
def
prefetch_io_aux_func
(
self
):
"""
Auxiliary function conducting IO operations for prefetching.
"""
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
prefetch_buffer
.
get
(
block
=
True
,
timeout
=
1
)
for
h
in
operation
.
hash_value
:
page_data
=
self
.
storage_backend
.
get
(
h
)
if
page_data
is
None
:
logger
.
warning
(
f
"Prefetch operation
{
operation
.
request_id
}
failed to retrieve page
{
h
}
."
)
break
self
.
mem_pool_host
.
set_from_flat_data_page
(
operation
.
host_indices
[
operation
.
completed_tokens
],
page_data
,
)
operation
.
increment
(
self
.
page_size
)
if
operation
.
is_done
():
# operation terminated by controller, release pre-allocated memory
self
.
mem_pool_host
.
free
(
operation
.
host_indices
[
operation
.
completed_tokens
:]
)
break
except
Empty
:
continue
def
prefetch_thread_func
(
self
):
"""
Manage prefetching operations from storage backend to host memory.
"""
self
.
prefetch_buffer
=
Queue
()
aux_thread
=
threading
.
Thread
(
target
=
self
.
prefetch_io_aux_func
,
daemon
=
True
)
aux_thread
.
start
()
while
(
not
self
.
stop_event
.
is_set
())
or
not
self
.
prefetch_queue
.
empty
():
try
:
operation
=
self
.
prefetch_queue
.
get
(
block
=
True
,
timeout
=
1
)
if
operation
is
None
:
continue
last_hash
=
operation
.
last_hash
tokens_to_fetch
=
operation
.
token_ids
storage_hit_count
=
0
remaining_tokens
=
len
(
tokens_to_fetch
)
hash_value
=
[]
while
remaining_tokens
>=
self
.
page_size
:
last_hash
=
get_hash_str
(
tokens_to_fetch
[
storage_hit_count
:
storage_hit_count
+
self
.
page_size
],
last_hash
,
)
if
self
.
storage_backend
.
exists
(
last_hash
):
storage_hit_count
+=
self
.
page_size
hash_value
.
append
(
last_hash
)
remaining_tokens
-=
self
.
page_size
else
:
break
if
storage_hit_count
<
self
.
prefetch_threshold
:
# not to prefetch if not enough benefits
self
.
prefetch_revoke_queue
.
put
(
operation
.
request_id
)
else
:
operation
.
hash_value
=
hash_value
logger
.
debug
(
f
"Prefetching
{
len
(
hash_value
)
}
pages for request
{
operation
.
request_id
}
."
)
self
.
prefetch_buffer
.
put
(
operation
)
except
Empty
:
continue
def
write_storage
(
self
,
host_indices
:
torch
.
Tensor
,
token_ids
:
List
[
int
],
last_hash
:
Optional
[
str
]
=
None
,
)
->
int
:
"""
Write KV caches from host memory to storage backend.
"""
operation
=
StorageOperation
(
host_indices
,
token_ids
,
last_hash
)
self
.
backup_queue
.
put
(
operation
)
return
operation
.
id
def
backup_thread_func
(
self
):
"""
Manage backup operations from host memory to storage backend.
"""
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
backup_queue
.
get
(
block
=
True
,
timeout
=
1
)
if
operation
is
None
:
continue
last_hash
=
operation
.
last_hash
tokens_to_backup
=
operation
.
token_ids
for
i
in
range
(
0
,
len
(
tokens_to_backup
),
self
.
page_size
):
last_hash
=
get_hash_str
(
tokens_to_backup
[
i
:
i
+
self
.
page_size
],
last_hash
)
# todo, handle failures in storage backend
self
.
storage_backend
.
set
(
last_hash
,
self
.
mem_pool_host
.
get_flat_data_page
(
operation
.
host_indices
[
i
]
),
)
operation
.
completed_tokens
+=
self
.
page_size
operation
.
hash_value
.
append
(
last_hash
)
self
.
ack_backup_queue
.
put
((
operation
.
id
,
operation
.
hash_value
))
except
Empty
:
continue
python/sglang/srt/managers/scheduler.py
View file @
9d33fcfb
...
@@ -262,6 +262,7 @@ class Scheduler(
...
@@ -262,6 +262,7 @@ class Scheduler(
)
)
self
.
gpu_id
=
gpu_id
self
.
gpu_id
=
gpu_id
self
.
enable_hierarchical_cache
=
server_args
.
enable_hierarchical_cache
self
.
enable_hierarchical_cache
=
server_args
.
enable_hierarchical_cache
self
.
enable_hicache_storage
=
server_args
.
hicache_storage_backend
is
not
None
self
.
page_size
=
server_args
.
page_size
self
.
page_size
=
server_args
.
page_size
self
.
dp_size
=
server_args
.
dp_size
self
.
dp_size
=
server_args
.
dp_size
self
.
attn_tp_rank
,
self
.
attn_tp_size
,
self
.
attn_dp_rank
=
(
self
.
attn_tp_rank
,
self
.
attn_tp_size
,
self
.
attn_dp_rank
=
(
...
@@ -614,6 +615,7 @@ class Scheduler(
...
@@ -614,6 +615,7 @@ class Scheduler(
==
"fa3"
# hot fix for incompatibility
==
"fa3"
# hot fix for incompatibility
else
server_args
.
hicache_io_backend
else
server_args
.
hicache_io_backend
),
),
hicache_storage_backend
=
server_args
.
hicache_storage_backend
,
)
)
self
.
tp_worker
.
register_hicache_layer_transfer_counter
(
self
.
tp_worker
.
register_hicache_layer_transfer_counter
(
self
.
tree_cache
.
cache_controller
.
layer_done_counter
self
.
tree_cache
.
cache_controller
.
layer_done_counter
...
@@ -1258,6 +1260,15 @@ class Scheduler(
...
@@ -1258,6 +1260,15 @@ class Scheduler(
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
disagg_decode_prealloc_queue
.
add
(
req
)
self
.
disagg_decode_prealloc_queue
.
add
(
req
)
else
:
else
:
if
self
.
enable_hicache_storage
:
req
.
init_next_round_input
(
self
.
tree_cache
)
last_hash
=
req
.
last_host_node
.
get_last_hash_value
()
matched_len
=
len
(
req
.
prefix_indices
)
+
req
.
host_hit_length
if
(
matched_len
>
0
and
last_hash
is
not
None
)
or
matched_len
==
0
:
new_input_tokens
=
req
.
fill_ids
[
matched_len
:]
self
.
tree_cache
.
prefetch_from_storage
(
req
.
rid
,
req
.
last_host_node
,
new_input_tokens
,
last_hash
)
self
.
waiting_queue
.
append
(
req
)
self
.
waiting_queue
.
append
(
req
)
def
_extend_requests_to_queue
(
self
,
reqs
:
List
[
Req
],
is_retracted
:
bool
=
False
):
def
_extend_requests_to_queue
(
self
,
reqs
:
List
[
Req
],
is_retracted
:
bool
=
False
):
...
@@ -1731,6 +1742,9 @@ class Scheduler(
...
@@ -1731,6 +1742,9 @@ class Scheduler(
self
.
running_batch
.
batch_is_full
=
True
self
.
running_batch
.
batch_is_full
=
True
break
break
if
self
.
enable_hicache_storage
:
self
.
tree_cache
.
check_prefetch_progress
(
req
.
rid
)
req
.
init_next_round_input
(
self
.
tree_cache
)
req
.
init_next_round_input
(
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
,
has_chunked_req
=
(
self
.
chunked_req
is
not
None
))
res
=
adder
.
add_one_req
(
req
,
has_chunked_req
=
(
self
.
chunked_req
is
not
None
))
...
...
python/sglang/srt/mem_cache/hicache_storage.py
0 → 100644
View file @
9d33fcfb
import
hashlib
import
logging
import
os
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
import
torch
logger
=
logging
.
getLogger
(
__name__
)
def
get_hash_str
(
token_ids
:
List
[
int
],
prior_hash
:
Optional
[
str
]
=
None
)
->
str
:
hasher
=
hashlib
.
sha256
()
if
prior_hash
:
hasher
.
update
(
bytes
.
fromhex
(
prior_hash
))
for
t
in
token_ids
:
hasher
.
update
(
t
.
to_bytes
(
4
,
byteorder
=
"little"
,
signed
=
False
))
return
hasher
.
hexdigest
()
class
HiCacheStorage
(
ABC
):
"""
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
It abstracts the underlying storage mechanism, allowing different implementations to be used.
"""
# todo, translate tensor object access for different TP ranks
# potentially pass model and TP configs into storage backend
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
@
abstractmethod
def
get
(
self
,
key
:
str
,
target_location
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
|
None
:
"""
Retrieve the value associated with the given key.
Returns None if the key does not exist.
"""
pass
@
abstractmethod
def
batch_get
(
self
,
keys
:
List
[
str
],
target_locations
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
)
->
List
[
torch
.
Tensor
|
None
]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
@
abstractmethod
def
set
(
self
,
key
,
value
)
->
bool
:
"""
Store the value associated with the given key.
Returns True if the operation was successful, False otherwise.
"""
pass
@
abstractmethod
def
batch_set
(
self
,
keys
:
List
[
str
],
values
:
List
[
torch
.
Tensor
])
->
bool
:
"""
Store multiple key-value pairs.
Returns True if all operations were successful, False otherwise.
"""
pass
@
abstractmethod
def
exists
(
self
,
key
:
str
)
->
bool
:
"""
Check if the key exists in the storage.
Returns True if the key exists, False otherwise.
"""
pass
class
HiCacheFile
(
HiCacheStorage
):
def
__init__
(
self
,
file_path
:
str
=
"/tmp/hicache"
):
self
.
file_path
=
file_path
if
not
os
.
path
.
exists
(
self
.
file_path
):
os
.
makedirs
(
self
.
file_path
)
logger
.
info
(
f
"Created HiCacheFile storage directory at
{
self
.
file_path
}
"
)
def
get
(
self
,
key
:
str
,
target_location
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
|
None
:
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
try
:
# todo: fixing the target_location logic to enable in-place loading
loaded_tensor
=
torch
.
load
(
tensor_path
)
if
isinstance
(
loaded_tensor
,
torch
.
Tensor
):
return
loaded_tensor
else
:
logger
.
error
(
f
"Loaded data for key
{
key
}
is not a tensor."
)
return
None
except
FileNotFoundError
:
return
None
def
batch_get
(
self
,
keys
:
List
[
str
],
target_locations
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
List
[
torch
.
Tensor
|
None
]:
return
[
self
.
get
(
key
,
target_location
)
for
key
,
target_location
in
zip
(
keys
,
target_locations
or
[
None
]
*
len
(
keys
)
)
]
def
set
(
self
,
key
:
str
,
value
:
torch
.
Tensor
)
->
bool
:
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
if
self
.
exists
(
key
):
logger
.
debug
(
f
"Key
{
key
}
already exists. Skipped."
)
return
True
try
:
torch
.
save
(
value
,
tensor_path
)
return
True
except
Exception
as
e
:
logger
.
error
(
f
"Failed to save tensor
{
key
}
:
{
e
}
"
)
return
False
def
batch_set
(
self
,
keys
:
List
[
str
],
values
:
List
[
torch
.
Tensor
])
->
bool
:
for
key
,
value
in
zip
(
keys
,
values
):
if
not
self
.
set
(
key
,
value
):
return
False
return
True
def
exists
(
self
,
key
:
str
)
->
bool
:
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
return
os
.
path
.
exists
(
tensor_path
)
def
delete
(
self
,
key
:
str
)
->
None
:
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
try
:
os
.
remove
(
tensor_path
)
except
FileNotFoundError
:
logger
.
warning
(
f
"Key
{
key
}
does not exist. Cannot delete."
)
return
def
clear
(
self
)
->
None
:
try
:
for
filename
in
os
.
listdir
(
self
.
file_path
):
file_path
=
os
.
path
.
join
(
self
.
file_path
,
filename
)
if
os
.
path
.
isfile
(
file_path
):
os
.
remove
(
file_path
)
logger
.
info
(
"Cleared all entries in HiCacheFile storage."
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to clear HiCacheFile storage:
{
e
}
"
)
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
9d33fcfb
...
@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache):
...
@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache):
hicache_size
:
int
,
hicache_size
:
int
,
hicache_write_policy
:
str
,
hicache_write_policy
:
str
,
hicache_io_backend
:
str
,
hicache_io_backend
:
str
,
hicache_storage_backend
:
Optional
[
str
]
=
None
,
):
):
self
.
kv_cache
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
kv_cache
=
token_to_kv_pool_allocator
.
get_kvcache
()
if
isinstance
(
self
.
kv_cache
,
MHATokenToKVPool
):
if
isinstance
(
self
.
kv_cache
,
MHATokenToKVPool
):
...
@@ -49,6 +50,9 @@ class HiRadixCache(RadixCache):
...
@@ -49,6 +50,9 @@ class HiRadixCache(RadixCache):
raise
ValueError
(
f
"HiRadixCache only supports MHA and MLA yet"
)
raise
ValueError
(
f
"HiRadixCache only supports MHA and MLA yet"
)
self
.
tp_group
=
tp_cache_group
self
.
tp_group
=
tp_cache_group
self
.
enable_storage
=
hicache_storage_backend
is
not
None
# todo: customizable storage prefetch threshold
self
.
prefetch_threshold
=
256
self
.
load_cache_event
=
threading
.
Event
()
self
.
load_cache_event
=
threading
.
Event
()
self
.
cache_controller
=
HiCacheController
(
self
.
cache_controller
=
HiCacheController
(
...
@@ -58,16 +62,22 @@ class HiRadixCache(RadixCache):
...
@@ -58,16 +62,22 @@ class HiRadixCache(RadixCache):
load_cache_event
=
self
.
load_cache_event
,
load_cache_event
=
self
.
load_cache_event
,
write_policy
=
hicache_write_policy
,
write_policy
=
hicache_write_policy
,
io_backend
=
hicache_io_backend
,
io_backend
=
hicache_io_backend
,
storage_backend
=
hicache_storage_backend
,
prefetch_threshold
=
self
.
prefetch_threshold
,
)
)
# record the nodes with ongoing write through
# record the nodes with ongoing write through
self
.
ongoing_write_through
=
{}
self
.
ongoing_write_through
=
{}
# record the node segments with ongoing load back
# record the node segments with ongoing load back
self
.
ongoing_load_back
=
{}
self
.
ongoing_load_back
=
{}
# record the ongoing prefetch requests
self
.
ongoing_prefetch
=
{}
self
.
ongoing_backup
=
{}
# todo: dynamically adjust the threshold
# todo: dynamically adjust the threshold
self
.
write_through_threshold
=
(
self
.
write_through_threshold
=
(
1
if
hicache_write_policy
==
"write_through"
else
3
1
if
hicache_write_policy
==
"write_through"
else
3
)
)
self
.
write_through_threshold_storage
=
3
self
.
load_back_threshold
=
10
self
.
load_back_threshold
=
10
super
().
__init__
(
super
().
__init__
(
req_to_token_pool
,
token_to_kv_pool_allocator
,
page_size
,
disable
=
False
req_to_token_pool
,
token_to_kv_pool_allocator
,
page_size
,
disable
=
False
...
@@ -108,13 +118,30 @@ class HiRadixCache(RadixCache):
...
@@ -108,13 +118,30 @@ class HiRadixCache(RadixCache):
return
len
(
host_indices
)
return
len
(
host_indices
)
def
write_backup_storage
(
self
,
node
:
TreeNode
):
operation_id
=
self
.
cache_controller
.
write_storage
(
node
.
host_value
,
node
.
key
,
node
.
parent
.
get_last_hash_value
()
)
self
.
ongoing_backup
[
operation_id
]
=
node
node
.
protect_host
()
def
inc_hit_count
(
self
,
node
:
TreeNode
):
def
inc_hit_count
(
self
,
node
:
TreeNode
):
if
node
.
backuped
or
self
.
cache_controller
.
write_policy
==
"write_back"
:
if
self
.
cache_controller
.
write_policy
==
"write_back"
:
return
return
node
.
hit_count
+=
1
node
.
hit_count
+=
1
if
node
.
hit_count
>=
self
.
write_through_threshold
:
self
.
write_backup
(
node
)
if
not
node
.
backuped
:
node
.
hit_count
=
0
if
node
.
hit_count
>=
self
.
write_through_threshold
:
# write to host if the node is not backuped
self
.
write_backup
(
node
)
else
:
if
(
self
.
enable_storage
and
(
not
node
.
backuped_storage
)
and
node
.
hit_count
>=
self
.
write_through_threshold_storage
):
# if the node is backuped on host memory but not on storage
self
.
write_backup_storage
(
node
)
def
writing_check
(
self
,
write_back
=
False
):
def
writing_check
(
self
,
write_back
=
False
):
if
write_back
:
if
write_back
:
...
@@ -221,6 +248,10 @@ class HiRadixCache(RadixCache):
...
@@ -221,6 +248,10 @@ class HiRadixCache(RadixCache):
if
not
x
.
evicted
:
if
not
x
.
evicted
:
continue
continue
# node is protected from eviction as it has ongoing prefetch or backup to storage
if
x
.
host_ref_counter
>
0
:
continue
num_evicted
+=
self
.
cache_controller
.
evict_host
(
x
.
host_value
)
num_evicted
+=
self
.
cache_controller
.
evict_host
(
x
.
host_value
)
for
k
,
v
in
x
.
parent
.
children
.
items
():
for
k
,
v
in
x
.
parent
.
children
.
items
():
...
@@ -314,6 +345,85 @@ class HiRadixCache(RadixCache):
...
@@ -314,6 +345,85 @@ class HiRadixCache(RadixCache):
def
check_hicache_events
(
self
):
def
check_hicache_events
(
self
):
self
.
writing_check
()
self
.
writing_check
()
self
.
loading_check
()
self
.
loading_check
()
if
self
.
enable_storage
:
self
.
check_revoked_prefetch
()
self
.
check_backup_progress
()
def
check_revoked_prefetch
(
self
):
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
prefetch_revoke_queue
.
qsize
(),
dtype
=
torch
.
int
)
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
# synchrnoize TP workers to make the same update to hiradix cache
torch
.
distributed
.
all_reduce
(
queue_size
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
for
_
in
range
(
queue_size
.
item
()):
req_id
=
self
.
cache_controller
.
prefetch_revoke_queue
.
get
()
if
req_id
in
self
.
ongoing_prefetch
:
last_host_node
,
_
,
host_indices
,
_
=
self
.
ongoing_prefetch
[
req_id
]
last_host_node
.
release_host
()
self
.
cache_controller
.
mem_pool_host
.
free
(
host_indices
)
del
self
.
ongoing_prefetch
[
req_id
]
def
check_backup_progress
(
self
):
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
ack_backup_queue
.
qsize
(),
dtype
=
torch
.
int
)
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
# synchrnoize TP workers to make the same update to hiradix cache
torch
.
distributed
.
all_reduce
(
queue_size
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
for
_
in
range
(
queue_size
.
item
()):
ack_id
,
hash_value
=
self
.
cache_controller
.
ack_backup_queue
.
get
()
self
.
ongoing_backup
[
ack_id
].
hash_value
=
hash_value
self
.
ongoing_backup
[
ack_id
].
release_host
()
del
self
.
ongoing_backup
[
ack_id
]
def
check_prefetch_progress
(
self
,
req_id
:
str
):
if
req_id
not
in
self
.
ongoing_prefetch
:
# there is no ongoing prefetch for this request or it has been revoked
return
# todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over
last_host_node
,
token_ids
,
host_indices
,
operation
=
self
.
ongoing_prefetch
[
req_id
]
completed_tokens
,
hash_value
=
self
.
cache_controller
.
terminate_prefetch
(
operation
)
logger
.
debug
(
f
"Prefetch
{
req_id
}
completed with
{
completed_tokens
}
tokens"
)
min_completed_tokens
=
torch
.
tensor
(
completed_tokens
,
dtype
=
torch
.
int
)
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
# synchrnoize TP workers to make the same update to hiradix cache
torch
.
distributed
.
all_reduce
(
min_completed_tokens
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
min_completed_tokens
=
min_completed_tokens
.
item
()
fetched_token_ids
=
token_ids
[:
min_completed_tokens
]
written_indices
=
host_indices
[:
min_completed_tokens
]
matched_length
=
self
.
_insert_helper_host
(
last_host_node
,
fetched_token_ids
,
written_indices
,
hash_value
[:
min_completed_tokens
],
)
self
.
cache_controller
.
mem_pool_host
.
free
(
host_indices
[:
matched_length
])
self
.
cache_controller
.
mem_pool_host
.
free
(
host_indices
[
min_completed_tokens
:
completed_tokens
]
)
last_host_node
.
release_host
()
del
self
.
ongoing_prefetch
[
req_id
]
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
):
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
):
empty_value
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
empty_value
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
...
@@ -348,6 +458,71 @@ class HiRadixCache(RadixCache):
...
@@ -348,6 +458,71 @@ class HiRadixCache(RadixCache):
host_hit_length
=
host_hit_length
,
host_hit_length
=
host_hit_length
,
)
)
def
prefetch_from_storage
(
self
,
req_id
:
str
,
last_host_node
:
TreeNode
,
new_input_tokens
:
List
[
int
],
last_hash
:
Optional
[
str
]
=
None
,
):
if
not
self
.
enable_storage
or
len
(
new_input_tokens
)
<
self
.
prefetch_threshold
:
return
last_host_node
.
protect_host
()
host_indices
=
self
.
cache_controller
.
mem_pool_host
.
alloc
(
len
(
new_input_tokens
))
if
host_indices
is
None
:
self
.
evict_host
(
len
(
new_input_tokens
))
host_indices
=
self
.
cache_controller
.
mem_pool_host
.
alloc
(
len
(
new_input_tokens
)
)
if
host_indices
is
None
:
last_host_node
.
release_host
()
# no sufficient host memory to prefetch
return
operation
=
self
.
cache_controller
.
prefetch
(
req_id
,
host_indices
,
new_input_tokens
,
last_hash
)
self
.
ongoing_prefetch
[
req_id
]
=
(
last_host_node
,
new_input_tokens
,
host_indices
,
operation
,
)
def
_insert_helper_host
(
self
,
node
:
TreeNode
,
key
:
List
,
host_value
,
hash_value
):
node
.
last_access_time
=
time
.
monotonic
()
if
len
(
key
)
==
0
:
return
0
child_key
=
self
.
get_child_key_fn
(
key
)
matched_length
=
0
while
len
(
key
)
>
0
and
child_key
in
node
.
children
.
keys
():
node
=
node
.
children
[
child_key
]
node
.
last_access_time
=
time
.
monotonic
()
prefix_len
=
self
.
key_match_fn
(
node
.
key
,
key
)
key
=
key
[
prefix_len
:]
host_value
=
host_value
[
prefix_len
:]
hash_value
=
hash_value
[
prefix_len
:]
matched_length
+=
prefix_len
if
prefix_len
<
len
(
node
.
key
):
new_node
=
self
.
_split_node
(
node
.
key
,
node
,
prefix_len
)
node
=
new_node
if
len
(
key
):
child_key
=
self
.
get_child_key_fn
(
key
)
if
len
(
key
):
new_node
=
TreeNode
()
new_node
.
parent
=
node
new_node
.
key
=
key
new_node
.
value
=
None
new_node
.
host_value
=
host_value
new_node
.
hash_value
=
hash_value
node
.
children
[
child_key
]
=
new_node
return
matched_length
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
node
.
last_access_time
=
time
.
monotonic
()
node
.
last_access_time
=
time
.
monotonic
()
child_key
=
self
.
get_child_key_fn
(
key
)
child_key
=
self
.
get_child_key_fn
(
key
)
...
...
python/sglang/srt/mem_cache/memory_pool_host.py
View file @
9d33fcfb
...
@@ -99,6 +99,20 @@ class HostKVCache(abc.ABC):
...
@@ -99,6 +99,20 @@ class HostKVCache(abc.ABC):
def
init_kv_buffer
(
self
):
def
init_kv_buffer
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
get_flat_data_page
(
self
,
index
)
->
torch
.
Tensor
:
"""
Get a flat data page from the host memory pool.
"""
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
set_from_flat_data_page
(
self
,
index
:
int
,
data_page
:
torch
.
Tensor
)
->
None
:
"""
Set a flat data page to the host memory pool.
"""
raise
NotImplementedError
()
@
synchronized
()
@
synchronized
()
def
clear
(
self
):
def
clear
(
self
):
# Initialize memory states and tracking structures.
# Initialize memory states and tracking structures.
...
@@ -227,6 +241,19 @@ class MHATokenToKVPoolHost(HostKVCache):
...
@@ -227,6 +241,19 @@ class MHATokenToKVPoolHost(HostKVCache):
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
)
)
# todo, page first memory layout
def
get_flat_data_page
(
self
,
index
)
->
torch
.
Tensor
:
return
self
.
kv_buffer
[:,
:,
index
:
index
+
self
.
page_size
,
:,
:].
flatten
()
def
set_from_flat_data_page
(
self
,
index
:
int
,
data_page
:
torch
.
Tensor
)
->
None
:
self
.
kv_buffer
[:,
:,
index
:
index
+
self
.
page_size
,
:,
:]
=
data_page
.
reshape
(
2
,
self
.
layer_num
,
self
.
page_size
,
self
.
head_num
,
self
.
head_dim
,
)
@
property
@
property
def
k_buffer
(
self
):
def
k_buffer
(
self
):
return
self
.
kv_buffer
[
0
]
return
self
.
kv_buffer
[
0
]
...
@@ -276,3 +303,14 @@ class MLATokenToKVPoolHost(HostKVCache):
...
@@ -276,3 +303,14 @@ class MLATokenToKVPoolHost(HostKVCache):
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
)
)
def
get_flat_data_page
(
self
,
index
)
->
torch
.
Tensor
:
return
self
.
kv_buffer
[:,
index
:
index
+
self
.
page_size
,
:,
:].
flatten
()
def
set_from_flat_data_page
(
self
,
index
:
int
,
data_page
:
torch
.
Tensor
)
->
None
:
self
.
kv_buffer
[:,
index
:
index
+
self
.
page_size
,
:,
:]
=
data_page
.
reshape
(
self
.
layer_num
,
self
.
page_size
,
1
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
)
python/sglang/srt/mem_cache/radix_cache.py
View file @
9d33fcfb
...
@@ -55,8 +55,13 @@ class TreeNode:
...
@@ -55,8 +55,13 @@ class TreeNode:
self
.
hit_count
=
0
self
.
hit_count
=
0
# indicating the node is loading KV cache from host
# indicating the node is loading KV cache from host
self
.
loading
=
False
self
.
loading
=
False
# indicating the node is locked to protect from eviction
# incremented when the node is referenced by a storage operation
self
.
host_ref_counter
=
0
# store the host indices of KV cache
# store the host indices of KV cache
self
.
host_value
:
Optional
[
torch
.
Tensor
]
=
None
self
.
host_value
:
Optional
[
torch
.
Tensor
]
=
None
# store hash values of each pages
self
.
hash_value
:
Optional
[
List
[
str
]]
=
None
self
.
id
=
TreeNode
.
counter
if
id
is
None
else
id
self
.
id
=
TreeNode
.
counter
if
id
is
None
else
id
TreeNode
.
counter
+=
1
TreeNode
.
counter
+=
1
...
@@ -69,6 +74,27 @@ class TreeNode:
...
@@ -69,6 +74,27 @@ class TreeNode:
def
backuped
(
self
):
def
backuped
(
self
):
return
self
.
host_value
is
not
None
return
self
.
host_value
is
not
None
@
property
def
backuped_storage
(
self
):
return
self
.
hash_value
is
not
None
and
len
(
self
.
hash_value
)
>
0
def
protect_host
(
self
):
"""Protect the host value from eviction."""
self
.
host_ref_counter
+=
1
def
release_host
(
self
):
"""Release the host value, allowing it to be evicted."""
if
self
.
host_ref_counter
>
0
:
self
.
host_ref_counter
-=
1
else
:
raise
RuntimeError
(
"Host reference counter is already zero."
)
def
get_last_hash_value
(
self
)
->
Optional
[
str
]:
"""Returns the hash value of the last page in this node."""
if
self
.
hash_value
is
None
or
len
(
self
.
hash_value
)
==
0
:
return
None
return
self
.
hash_value
[
-
1
]
def
__lt__
(
self
,
other
:
"TreeNode"
):
def
__lt__
(
self
,
other
:
"TreeNode"
):
return
self
.
last_access_time
<
other
.
last_access_time
return
self
.
last_access_time
<
other
.
last_access_time
...
...
python/sglang/srt/server_args.py
View file @
9d33fcfb
...
@@ -222,6 +222,7 @@ class ServerArgs:
...
@@ -222,6 +222,7 @@ class ServerArgs:
hicache_size
:
int
=
0
hicache_size
:
int
=
0
hicache_write_policy
:
str
=
"write_through_selective"
hicache_write_policy
:
str
=
"write_through_selective"
hicache_io_backend
:
str
=
""
hicache_io_backend
:
str
=
""
hicache_storage_backend
:
Optional
[
str
]
=
None
flashinfer_mla_disable_ragged
:
bool
=
False
flashinfer_mla_disable_ragged
:
bool
=
False
disable_shared_experts_fusion
:
bool
=
False
disable_shared_experts_fusion
:
bool
=
False
disable_chunked_prefix_cache
:
bool
=
False
disable_chunked_prefix_cache
:
bool
=
False
...
@@ -1604,6 +1605,13 @@ class ServerArgs:
...
@@ -1604,6 +1605,13 @@ class ServerArgs:
default
=
ServerArgs
.
hicache_io_backend
,
default
=
ServerArgs
.
hicache_io_backend
,
help
=
"The IO backend for KV cache transfer between CPU and GPU"
,
help
=
"The IO backend for KV cache transfer between CPU and GPU"
,
)
)
parser
.
add_argument
(
"--hicache-storage-backend"
,
type
=
str
,
choices
=
[
"file"
],
# todo, mooncacke
default
=
ServerArgs
.
hicache_storage_backend
,
help
=
"The storage backend for hierarchical KV cache."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--flashinfer-mla-disable-ragged"
,
"--flashinfer-mla-disable-ragged"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
test/srt/run_suite.py
View file @
9d33fcfb
...
@@ -64,6 +64,7 @@ suites = {
...
@@ -64,6 +64,7 @@ suites = {
TestFile
(
"test_fused_moe.py"
,
30
),
TestFile
(
"test_fused_moe.py"
,
30
),
TestFile
(
"test_hicache.py"
,
116
),
TestFile
(
"test_hicache.py"
,
116
),
TestFile
(
"test_hicache_mla.py"
,
127
),
TestFile
(
"test_hicache_mla.py"
,
127
),
TestFile
(
"test_hicache_storage.py"
,
127
),
TestFile
(
"test_hidden_states.py"
,
55
),
TestFile
(
"test_hidden_states.py"
,
55
),
TestFile
(
"test_int8_kernel.py"
,
8
),
TestFile
(
"test_int8_kernel.py"
,
8
),
TestFile
(
"test_input_embeddings.py"
,
38
),
TestFile
(
"test_input_embeddings.py"
,
38
),
...
...
test/srt/test_hicache_storage.py
0 → 100644
View file @
9d33fcfb
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
class
TestHiCache
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-hierarchical-cache"
,
"--mem-fraction-static"
,
0.7
,
"--hicache-size"
,
100
,
"--page-size"
,
"64"
,
"--hicache-storage-backend"
,
"file"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.65
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment