Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9d33fcfb
"vscode:/vscode.git/clone" did not exist on "ebd14b345dcbf1c44108fa17ac52fa90d1a70161"
Unverified
Commit
9d33fcfb
authored
Jul 18, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jul 18, 2025
Browse files
Hicache Storage Layer Prototype (#7704)
parent
7891bac1
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
714 additions
and
4 deletions
+714
-4
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+241
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+14
-0
python/sglang/srt/mem_cache/hicache_storage.py
python/sglang/srt/mem_cache/hicache_storage.py
+152
-0
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+179
-4
python/sglang/srt/mem_cache/memory_pool_host.py
python/sglang/srt/mem_cache/memory_pool_host.py
+38
-0
python/sglang/srt/mem_cache/radix_cache.py
python/sglang/srt/mem_cache/radix_cache.py
+26
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_hicache_storage.py
test/srt/test_hicache_storage.py
+55
-0
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
9d33fcfb
...
@@ -25,6 +25,8 @@ if TYPE_CHECKING:
...
@@ -25,6 +25,8 @@ if TYPE_CHECKING:
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.memory_pool_host
import
HostKVCache
from
sglang.srt.mem_cache.memory_pool_host
import
HostKVCache
from
sglang.srt.mem_cache.hicache_storage
import
HiCacheFile
,
get_hash_str
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -159,6 +161,57 @@ class TransferBuffer:
...
@@ -159,6 +161,57 @@ class TransferBuffer:
self
.
buffers
.
queue
.
clear
()
self
.
buffers
.
queue
.
clear
()
class
StorageOperation
:
counter
=
0
def
__init__
(
self
,
host_indices
:
torch
.
Tensor
,
token_ids
:
List
[
int
],
last_hash
:
Optional
[
str
]
=
None
,
):
self
.
host_indices
=
host_indices
self
.
token_ids
=
token_ids
self
.
last_hash
=
last_hash
self
.
completed_tokens
=
0
self
.
hash_value
=
[]
self
.
id
=
StorageOperation
.
counter
StorageOperation
.
counter
+=
1
def
__lt__
(
self
,
other
:
"StorageOperation"
):
return
self
.
id
<
other
.
id
class
PrefetchOperation
(
StorageOperation
):
def
__init__
(
self
,
request_id
:
str
,
host_indices
:
torch
.
Tensor
,
token_ids
:
List
[
int
],
last_hash
:
Optional
[
str
]
=
None
,
):
self
.
request_id
=
request_id
self
.
_done_flag
=
False
self
.
_lock
=
threading
.
Lock
()
super
().
__init__
(
host_indices
,
token_ids
,
last_hash
)
def
increment
(
self
,
num_tokens
:
int
):
with
self
.
_lock
:
if
self
.
_done_flag
:
return
self
.
completed_tokens
+=
num_tokens
def
mark_done
(
self
):
with
self
.
_lock
:
self
.
_done_flag
=
True
def
is_done
(
self
)
->
bool
:
return
self
.
_done_flag
class
HiCacheController
:
class
HiCacheController
:
def
__init__
(
def
__init__
(
...
@@ -169,6 +222,8 @@ class HiCacheController:
...
@@ -169,6 +222,8 @@ class HiCacheController:
load_cache_event
:
threading
.
Event
=
None
,
load_cache_event
:
threading
.
Event
=
None
,
write_policy
:
str
=
"write_through_selective"
,
write_policy
:
str
=
"write_through_selective"
,
io_backend
:
str
=
""
,
io_backend
:
str
=
""
,
storage_backend
:
Optional
[
str
]
=
None
,
prefetch_threshold
:
int
=
256
,
):
):
self
.
mem_pool_device_allocator
=
token_to_kv_pool_allocator
self
.
mem_pool_device_allocator
=
token_to_kv_pool_allocator
self
.
mem_pool_device
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
mem_pool_device
=
token_to_kv_pool_allocator
.
get_kvcache
()
...
@@ -186,6 +241,19 @@ class HiCacheController:
...
@@ -186,6 +241,19 @@ class HiCacheController:
else
:
else
:
self
.
io_backend
=
io_backend
self
.
io_backend
=
io_backend
self
.
enable_storage
=
False
# todo: move backend initialization to storage backend module
if
storage_backend
is
not
None
:
if
storage_backend
==
"file"
:
self
.
storage_backend
=
HiCacheFile
()
self
.
enable_storage
=
True
# todo: threshold policy for prefetching
self
.
prefetch_threshold
=
prefetch_threshold
else
:
raise
NotImplementedError
(
f
"Unsupported storage backend:
{
storage_backend
}
"
)
self
.
load_cache_event
=
load_cache_event
self
.
load_cache_event
=
load_cache_event
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
self
.
layer_done_counter
=
LayerDoneCounter
(
self
.
mem_pool_device
.
layer_num
)
self
.
mem_pool_device
.
register_layer_transfer_counter
(
self
.
layer_done_counter
)
self
.
mem_pool_device
.
register_layer_transfer_counter
(
self
.
layer_done_counter
)
...
@@ -218,9 +286,26 @@ class HiCacheController:
...
@@ -218,9 +286,26 @@ class HiCacheController:
self
.
load_thread
=
threading
.
Thread
(
self
.
load_thread
=
threading
.
Thread
(
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
target
=
self
.
load_thread_func_layer_by_layer
,
daemon
=
True
)
)
self
.
write_thread
.
start
()
self
.
write_thread
.
start
()
self
.
load_thread
.
start
()
self
.
load_thread
.
start
()
if
self
.
enable_storage
:
self
.
prefetch_thread
=
threading
.
Thread
(
target
=
self
.
prefetch_thread_func
,
daemon
=
True
)
self
.
backup_thread
=
threading
.
Thread
(
target
=
self
.
backup_thread_func
,
daemon
=
True
)
self
.
prefetch_queue
=
Queue
()
self
.
backup_queue
=
Queue
()
self
.
prefetch_revoke_queue
=
Queue
()
self
.
ack_backup_queue
=
Queue
()
self
.
prefetch_thread
.
start
()
self
.
backup_thread
.
start
()
def
reset
(
self
):
def
reset
(
self
):
self
.
stop_event
.
set
()
self
.
stop_event
.
set
()
self
.
write_thread
.
join
()
self
.
write_thread
.
join
()
...
@@ -232,6 +317,13 @@ class HiCacheController:
...
@@ -232,6 +317,13 @@ class HiCacheController:
self
.
load_buffer
.
clear
()
self
.
load_buffer
.
clear
()
self
.
ack_write_queue
.
queue
.
clear
()
self
.
ack_write_queue
.
queue
.
clear
()
self
.
ack_load_queue
.
queue
.
clear
()
self
.
ack_load_queue
.
queue
.
clear
()
if
self
.
enable_storage
:
self
.
prefetch_thread
.
join
()
self
.
backup_thread
.
join
()
self
.
prefetch_queue
.
queue
.
clear
()
self
.
backup_queue
.
queue
.
clear
()
self
.
prefetch_revoke_queue
.
queue
.
clear
()
self
.
ack_backup_queue
.
queue
.
clear
()
self
.
write_thread
=
threading
.
Thread
(
self
.
write_thread
=
threading
.
Thread
(
target
=
self
.
write_thread_func_direct
,
daemon
=
True
target
=
self
.
write_thread_func_direct
,
daemon
=
True
...
@@ -243,6 +335,16 @@ class HiCacheController:
...
@@ -243,6 +335,16 @@ class HiCacheController:
self
.
write_thread
.
start
()
self
.
write_thread
.
start
()
self
.
load_thread
.
start
()
self
.
load_thread
.
start
()
if
self
.
enable_storage
:
self
.
prefetch_thread
=
threading
.
Thread
(
target
=
self
.
prefetch_thread_func
,
daemon
=
True
)
self
.
backup_thread
=
threading
.
Thread
(
target
=
self
.
backup_thread_func
,
daemon
=
True
)
self
.
prefetch_thread
.
start
()
self
.
backup_thread
.
start
()
def
write
(
def
write
(
self
,
self
,
device_indices
:
torch
.
Tensor
,
device_indices
:
torch
.
Tensor
,
...
@@ -383,3 +485,142 @@ class HiCacheController:
...
@@ -383,3 +485,142 @@ class HiCacheController:
raise
ValueError
(
raise
ValueError
(
f
"Inconsistent states:
{
self
.
mem_pool_host
.
get_state
(
host_indices
)
}
"
f
"Inconsistent states:
{
self
.
mem_pool_host
.
get_state
(
host_indices
)
}
"
)
)
def
prefetch
(
self
,
request_id
:
str
,
host_indices
:
torch
.
Tensor
,
new_input_tokens
:
List
[
int
],
last_hash
:
Optional
[
str
]
=
None
,
)
->
int
:
"""
Prefetch KV caches from storage backend to host memory.
"""
operation
=
PrefetchOperation
(
request_id
,
host_indices
,
new_input_tokens
,
last_hash
)
self
.
prefetch_queue
.
put
(
operation
)
return
operation
def
terminate_prefetch
(
self
,
operation
):
operation
.
mark_done
()
return
operation
.
completed_tokens
,
operation
.
hash_value
def
prefetch_io_aux_func
(
self
):
"""
Auxiliary function conducting IO operations for prefetching.
"""
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
prefetch_buffer
.
get
(
block
=
True
,
timeout
=
1
)
for
h
in
operation
.
hash_value
:
page_data
=
self
.
storage_backend
.
get
(
h
)
if
page_data
is
None
:
logger
.
warning
(
f
"Prefetch operation
{
operation
.
request_id
}
failed to retrieve page
{
h
}
."
)
break
self
.
mem_pool_host
.
set_from_flat_data_page
(
operation
.
host_indices
[
operation
.
completed_tokens
],
page_data
,
)
operation
.
increment
(
self
.
page_size
)
if
operation
.
is_done
():
# operation terminated by controller, release pre-allocated memory
self
.
mem_pool_host
.
free
(
operation
.
host_indices
[
operation
.
completed_tokens
:]
)
break
except
Empty
:
continue
def
prefetch_thread_func
(
self
):
"""
Manage prefetching operations from storage backend to host memory.
"""
self
.
prefetch_buffer
=
Queue
()
aux_thread
=
threading
.
Thread
(
target
=
self
.
prefetch_io_aux_func
,
daemon
=
True
)
aux_thread
.
start
()
while
(
not
self
.
stop_event
.
is_set
())
or
not
self
.
prefetch_queue
.
empty
():
try
:
operation
=
self
.
prefetch_queue
.
get
(
block
=
True
,
timeout
=
1
)
if
operation
is
None
:
continue
last_hash
=
operation
.
last_hash
tokens_to_fetch
=
operation
.
token_ids
storage_hit_count
=
0
remaining_tokens
=
len
(
tokens_to_fetch
)
hash_value
=
[]
while
remaining_tokens
>=
self
.
page_size
:
last_hash
=
get_hash_str
(
tokens_to_fetch
[
storage_hit_count
:
storage_hit_count
+
self
.
page_size
],
last_hash
,
)
if
self
.
storage_backend
.
exists
(
last_hash
):
storage_hit_count
+=
self
.
page_size
hash_value
.
append
(
last_hash
)
remaining_tokens
-=
self
.
page_size
else
:
break
if
storage_hit_count
<
self
.
prefetch_threshold
:
# not to prefetch if not enough benefits
self
.
prefetch_revoke_queue
.
put
(
operation
.
request_id
)
else
:
operation
.
hash_value
=
hash_value
logger
.
debug
(
f
"Prefetching
{
len
(
hash_value
)
}
pages for request
{
operation
.
request_id
}
."
)
self
.
prefetch_buffer
.
put
(
operation
)
except
Empty
:
continue
def
write_storage
(
self
,
host_indices
:
torch
.
Tensor
,
token_ids
:
List
[
int
],
last_hash
:
Optional
[
str
]
=
None
,
)
->
int
:
"""
Write KV caches from host memory to storage backend.
"""
operation
=
StorageOperation
(
host_indices
,
token_ids
,
last_hash
)
self
.
backup_queue
.
put
(
operation
)
return
operation
.
id
def
backup_thread_func
(
self
):
"""
Manage backup operations from host memory to storage backend.
"""
while
not
self
.
stop_event
.
is_set
():
try
:
operation
=
self
.
backup_queue
.
get
(
block
=
True
,
timeout
=
1
)
if
operation
is
None
:
continue
last_hash
=
operation
.
last_hash
tokens_to_backup
=
operation
.
token_ids
for
i
in
range
(
0
,
len
(
tokens_to_backup
),
self
.
page_size
):
last_hash
=
get_hash_str
(
tokens_to_backup
[
i
:
i
+
self
.
page_size
],
last_hash
)
# todo, handle failures in storage backend
self
.
storage_backend
.
set
(
last_hash
,
self
.
mem_pool_host
.
get_flat_data_page
(
operation
.
host_indices
[
i
]
),
)
operation
.
completed_tokens
+=
self
.
page_size
operation
.
hash_value
.
append
(
last_hash
)
self
.
ack_backup_queue
.
put
((
operation
.
id
,
operation
.
hash_value
))
except
Empty
:
continue
python/sglang/srt/managers/scheduler.py
View file @
9d33fcfb
...
@@ -262,6 +262,7 @@ class Scheduler(
...
@@ -262,6 +262,7 @@ class Scheduler(
)
)
self
.
gpu_id
=
gpu_id
self
.
gpu_id
=
gpu_id
self
.
enable_hierarchical_cache
=
server_args
.
enable_hierarchical_cache
self
.
enable_hierarchical_cache
=
server_args
.
enable_hierarchical_cache
self
.
enable_hicache_storage
=
server_args
.
hicache_storage_backend
is
not
None
self
.
page_size
=
server_args
.
page_size
self
.
page_size
=
server_args
.
page_size
self
.
dp_size
=
server_args
.
dp_size
self
.
dp_size
=
server_args
.
dp_size
self
.
attn_tp_rank
,
self
.
attn_tp_size
,
self
.
attn_dp_rank
=
(
self
.
attn_tp_rank
,
self
.
attn_tp_size
,
self
.
attn_dp_rank
=
(
...
@@ -614,6 +615,7 @@ class Scheduler(
...
@@ -614,6 +615,7 @@ class Scheduler(
==
"fa3"
# hot fix for incompatibility
==
"fa3"
# hot fix for incompatibility
else
server_args
.
hicache_io_backend
else
server_args
.
hicache_io_backend
),
),
hicache_storage_backend
=
server_args
.
hicache_storage_backend
,
)
)
self
.
tp_worker
.
register_hicache_layer_transfer_counter
(
self
.
tp_worker
.
register_hicache_layer_transfer_counter
(
self
.
tree_cache
.
cache_controller
.
layer_done_counter
self
.
tree_cache
.
cache_controller
.
layer_done_counter
...
@@ -1258,6 +1260,15 @@ class Scheduler(
...
@@ -1258,6 +1260,15 @@ class Scheduler(
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
self
.
disagg_decode_prealloc_queue
.
add
(
req
)
self
.
disagg_decode_prealloc_queue
.
add
(
req
)
else
:
else
:
if
self
.
enable_hicache_storage
:
req
.
init_next_round_input
(
self
.
tree_cache
)
last_hash
=
req
.
last_host_node
.
get_last_hash_value
()
matched_len
=
len
(
req
.
prefix_indices
)
+
req
.
host_hit_length
if
(
matched_len
>
0
and
last_hash
is
not
None
)
or
matched_len
==
0
:
new_input_tokens
=
req
.
fill_ids
[
matched_len
:]
self
.
tree_cache
.
prefetch_from_storage
(
req
.
rid
,
req
.
last_host_node
,
new_input_tokens
,
last_hash
)
self
.
waiting_queue
.
append
(
req
)
self
.
waiting_queue
.
append
(
req
)
def
_extend_requests_to_queue
(
self
,
reqs
:
List
[
Req
],
is_retracted
:
bool
=
False
):
def
_extend_requests_to_queue
(
self
,
reqs
:
List
[
Req
],
is_retracted
:
bool
=
False
):
...
@@ -1731,6 +1742,9 @@ class Scheduler(
...
@@ -1731,6 +1742,9 @@ class Scheduler(
self
.
running_batch
.
batch_is_full
=
True
self
.
running_batch
.
batch_is_full
=
True
break
break
if
self
.
enable_hicache_storage
:
self
.
tree_cache
.
check_prefetch_progress
(
req
.
rid
)
req
.
init_next_round_input
(
self
.
tree_cache
)
req
.
init_next_round_input
(
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
,
has_chunked_req
=
(
self
.
chunked_req
is
not
None
))
res
=
adder
.
add_one_req
(
req
,
has_chunked_req
=
(
self
.
chunked_req
is
not
None
))
...
...
python/sglang/srt/mem_cache/hicache_storage.py
0 → 100644
View file @
9d33fcfb
import
hashlib
import
logging
import
os
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
import
torch
logger
=
logging
.
getLogger
(
__name__
)
def
get_hash_str
(
token_ids
:
List
[
int
],
prior_hash
:
Optional
[
str
]
=
None
)
->
str
:
hasher
=
hashlib
.
sha256
()
if
prior_hash
:
hasher
.
update
(
bytes
.
fromhex
(
prior_hash
))
for
t
in
token_ids
:
hasher
.
update
(
t
.
to_bytes
(
4
,
byteorder
=
"little"
,
signed
=
False
))
return
hasher
.
hexdigest
()
class
HiCacheStorage
(
ABC
):
"""
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
It abstracts the underlying storage mechanism, allowing different implementations to be used.
"""
# todo, translate tensor object access for different TP ranks
# potentially pass model and TP configs into storage backend
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
@
abstractmethod
def
get
(
self
,
key
:
str
,
target_location
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
|
None
:
"""
Retrieve the value associated with the given key.
Returns None if the key does not exist.
"""
pass
@
abstractmethod
def
batch_get
(
self
,
keys
:
List
[
str
],
target_locations
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
)
->
List
[
torch
.
Tensor
|
None
]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
@
abstractmethod
def
set
(
self
,
key
,
value
)
->
bool
:
"""
Store the value associated with the given key.
Returns True if the operation was successful, False otherwise.
"""
pass
@
abstractmethod
def
batch_set
(
self
,
keys
:
List
[
str
],
values
:
List
[
torch
.
Tensor
])
->
bool
:
"""
Store multiple key-value pairs.
Returns True if all operations were successful, False otherwise.
"""
pass
@
abstractmethod
def
exists
(
self
,
key
:
str
)
->
bool
:
"""
Check if the key exists in the storage.
Returns True if the key exists, False otherwise.
"""
pass
class
HiCacheFile
(
HiCacheStorage
):
def
__init__
(
self
,
file_path
:
str
=
"/tmp/hicache"
):
self
.
file_path
=
file_path
if
not
os
.
path
.
exists
(
self
.
file_path
):
os
.
makedirs
(
self
.
file_path
)
logger
.
info
(
f
"Created HiCacheFile storage directory at
{
self
.
file_path
}
"
)
def
get
(
self
,
key
:
str
,
target_location
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
|
None
:
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
try
:
# todo: fixing the target_location logic to enable in-place loading
loaded_tensor
=
torch
.
load
(
tensor_path
)
if
isinstance
(
loaded_tensor
,
torch
.
Tensor
):
return
loaded_tensor
else
:
logger
.
error
(
f
"Loaded data for key
{
key
}
is not a tensor."
)
return
None
except
FileNotFoundError
:
return
None
def
batch_get
(
self
,
keys
:
List
[
str
],
target_locations
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
List
[
torch
.
Tensor
|
None
]:
return
[
self
.
get
(
key
,
target_location
)
for
key
,
target_location
in
zip
(
keys
,
target_locations
or
[
None
]
*
len
(
keys
)
)
]
def
set
(
self
,
key
:
str
,
value
:
torch
.
Tensor
)
->
bool
:
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
if
self
.
exists
(
key
):
logger
.
debug
(
f
"Key
{
key
}
already exists. Skipped."
)
return
True
try
:
torch
.
save
(
value
,
tensor_path
)
return
True
except
Exception
as
e
:
logger
.
error
(
f
"Failed to save tensor
{
key
}
:
{
e
}
"
)
return
False
def
batch_set
(
self
,
keys
:
List
[
str
],
values
:
List
[
torch
.
Tensor
])
->
bool
:
for
key
,
value
in
zip
(
keys
,
values
):
if
not
self
.
set
(
key
,
value
):
return
False
return
True
def
exists
(
self
,
key
:
str
)
->
bool
:
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
return
os
.
path
.
exists
(
tensor_path
)
def
delete
(
self
,
key
:
str
)
->
None
:
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
try
:
os
.
remove
(
tensor_path
)
except
FileNotFoundError
:
logger
.
warning
(
f
"Key
{
key
}
does not exist. Cannot delete."
)
return
def
clear
(
self
)
->
None
:
try
:
for
filename
in
os
.
listdir
(
self
.
file_path
):
file_path
=
os
.
path
.
join
(
self
.
file_path
,
filename
)
if
os
.
path
.
isfile
(
file_path
):
os
.
remove
(
file_path
)
logger
.
info
(
"Cleared all entries in HiCacheFile storage."
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to clear HiCacheFile storage:
{
e
}
"
)
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
9d33fcfb
...
@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache):
...
@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache):
hicache_size
:
int
,
hicache_size
:
int
,
hicache_write_policy
:
str
,
hicache_write_policy
:
str
,
hicache_io_backend
:
str
,
hicache_io_backend
:
str
,
hicache_storage_backend
:
Optional
[
str
]
=
None
,
):
):
self
.
kv_cache
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
kv_cache
=
token_to_kv_pool_allocator
.
get_kvcache
()
if
isinstance
(
self
.
kv_cache
,
MHATokenToKVPool
):
if
isinstance
(
self
.
kv_cache
,
MHATokenToKVPool
):
...
@@ -49,6 +50,9 @@ class HiRadixCache(RadixCache):
...
@@ -49,6 +50,9 @@ class HiRadixCache(RadixCache):
raise
ValueError
(
f
"HiRadixCache only supports MHA and MLA yet"
)
raise
ValueError
(
f
"HiRadixCache only supports MHA and MLA yet"
)
self
.
tp_group
=
tp_cache_group
self
.
tp_group
=
tp_cache_group
self
.
enable_storage
=
hicache_storage_backend
is
not
None
# todo: customizable storage prefetch threshold
self
.
prefetch_threshold
=
256
self
.
load_cache_event
=
threading
.
Event
()
self
.
load_cache_event
=
threading
.
Event
()
self
.
cache_controller
=
HiCacheController
(
self
.
cache_controller
=
HiCacheController
(
...
@@ -58,16 +62,22 @@ class HiRadixCache(RadixCache):
...
@@ -58,16 +62,22 @@ class HiRadixCache(RadixCache):
load_cache_event
=
self
.
load_cache_event
,
load_cache_event
=
self
.
load_cache_event
,
write_policy
=
hicache_write_policy
,
write_policy
=
hicache_write_policy
,
io_backend
=
hicache_io_backend
,
io_backend
=
hicache_io_backend
,
storage_backend
=
hicache_storage_backend
,
prefetch_threshold
=
self
.
prefetch_threshold
,
)
)
# record the nodes with ongoing write through
# record the nodes with ongoing write through
self
.
ongoing_write_through
=
{}
self
.
ongoing_write_through
=
{}
# record the node segments with ongoing load back
# record the node segments with ongoing load back
self
.
ongoing_load_back
=
{}
self
.
ongoing_load_back
=
{}
# record the ongoing prefetch requests
self
.
ongoing_prefetch
=
{}
self
.
ongoing_backup
=
{}
# todo: dynamically adjust the threshold
# todo: dynamically adjust the threshold
self
.
write_through_threshold
=
(
self
.
write_through_threshold
=
(
1
if
hicache_write_policy
==
"write_through"
else
3
1
if
hicache_write_policy
==
"write_through"
else
3
)
)
self
.
write_through_threshold_storage
=
3
self
.
load_back_threshold
=
10
self
.
load_back_threshold
=
10
super
().
__init__
(
super
().
__init__
(
req_to_token_pool
,
token_to_kv_pool_allocator
,
page_size
,
disable
=
False
req_to_token_pool
,
token_to_kv_pool_allocator
,
page_size
,
disable
=
False
...
@@ -108,13 +118,30 @@ class HiRadixCache(RadixCache):
...
@@ -108,13 +118,30 @@ class HiRadixCache(RadixCache):
return
len
(
host_indices
)
return
len
(
host_indices
)
def
write_backup_storage
(
self
,
node
:
TreeNode
):
operation_id
=
self
.
cache_controller
.
write_storage
(
node
.
host_value
,
node
.
key
,
node
.
parent
.
get_last_hash_value
()
)
self
.
ongoing_backup
[
operation_id
]
=
node
node
.
protect_host
()
def
inc_hit_count
(
self
,
node
:
TreeNode
):
def
inc_hit_count
(
self
,
node
:
TreeNode
):
if
node
.
backuped
or
self
.
cache_controller
.
write_policy
==
"write_back"
:
if
self
.
cache_controller
.
write_policy
==
"write_back"
:
return
return
node
.
hit_count
+=
1
node
.
hit_count
+=
1
if
not
node
.
backuped
:
if
node
.
hit_count
>=
self
.
write_through_threshold
:
if
node
.
hit_count
>=
self
.
write_through_threshold
:
# write to host if the node is not backuped
self
.
write_backup
(
node
)
self
.
write_backup
(
node
)
node
.
hit_count
=
0
else
:
if
(
self
.
enable_storage
and
(
not
node
.
backuped_storage
)
and
node
.
hit_count
>=
self
.
write_through_threshold_storage
):
# if the node is backuped on host memory but not on storage
self
.
write_backup_storage
(
node
)
def
writing_check
(
self
,
write_back
=
False
):
def
writing_check
(
self
,
write_back
=
False
):
if
write_back
:
if
write_back
:
...
@@ -221,6 +248,10 @@ class HiRadixCache(RadixCache):
...
@@ -221,6 +248,10 @@ class HiRadixCache(RadixCache):
if
not
x
.
evicted
:
if
not
x
.
evicted
:
continue
continue
# node is protected from eviction as it has ongoing prefetch or backup to storage
if
x
.
host_ref_counter
>
0
:
continue
num_evicted
+=
self
.
cache_controller
.
evict_host
(
x
.
host_value
)
num_evicted
+=
self
.
cache_controller
.
evict_host
(
x
.
host_value
)
for
k
,
v
in
x
.
parent
.
children
.
items
():
for
k
,
v
in
x
.
parent
.
children
.
items
():
...
@@ -314,6 +345,85 @@ class HiRadixCache(RadixCache):
...
@@ -314,6 +345,85 @@ class HiRadixCache(RadixCache):
def
check_hicache_events
(
self
):
def
check_hicache_events
(
self
):
self
.
writing_check
()
self
.
writing_check
()
self
.
loading_check
()
self
.
loading_check
()
if
self
.
enable_storage
:
self
.
check_revoked_prefetch
()
self
.
check_backup_progress
()
def
check_revoked_prefetch
(
self
):
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
prefetch_revoke_queue
.
qsize
(),
dtype
=
torch
.
int
)
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
# synchrnoize TP workers to make the same update to hiradix cache
torch
.
distributed
.
all_reduce
(
queue_size
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
for
_
in
range
(
queue_size
.
item
()):
req_id
=
self
.
cache_controller
.
prefetch_revoke_queue
.
get
()
if
req_id
in
self
.
ongoing_prefetch
:
last_host_node
,
_
,
host_indices
,
_
=
self
.
ongoing_prefetch
[
req_id
]
last_host_node
.
release_host
()
self
.
cache_controller
.
mem_pool_host
.
free
(
host_indices
)
del
self
.
ongoing_prefetch
[
req_id
]
def
check_backup_progress
(
self
):
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
ack_backup_queue
.
qsize
(),
dtype
=
torch
.
int
)
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
# synchrnoize TP workers to make the same update to hiradix cache
torch
.
distributed
.
all_reduce
(
queue_size
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
for
_
in
range
(
queue_size
.
item
()):
ack_id
,
hash_value
=
self
.
cache_controller
.
ack_backup_queue
.
get
()
self
.
ongoing_backup
[
ack_id
].
hash_value
=
hash_value
self
.
ongoing_backup
[
ack_id
].
release_host
()
del
self
.
ongoing_backup
[
ack_id
]
def
check_prefetch_progress
(
self
,
req_id
:
str
):
if
req_id
not
in
self
.
ongoing_prefetch
:
# there is no ongoing prefetch for this request or it has been revoked
return
# todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over
last_host_node
,
token_ids
,
host_indices
,
operation
=
self
.
ongoing_prefetch
[
req_id
]
completed_tokens
,
hash_value
=
self
.
cache_controller
.
terminate_prefetch
(
operation
)
logger
.
debug
(
f
"Prefetch
{
req_id
}
completed with
{
completed_tokens
}
tokens"
)
min_completed_tokens
=
torch
.
tensor
(
completed_tokens
,
dtype
=
torch
.
int
)
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
# synchrnoize TP workers to make the same update to hiradix cache
torch
.
distributed
.
all_reduce
(
min_completed_tokens
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
min_completed_tokens
=
min_completed_tokens
.
item
()
fetched_token_ids
=
token_ids
[:
min_completed_tokens
]
written_indices
=
host_indices
[:
min_completed_tokens
]
matched_length
=
self
.
_insert_helper_host
(
last_host_node
,
fetched_token_ids
,
written_indices
,
hash_value
[:
min_completed_tokens
],
)
self
.
cache_controller
.
mem_pool_host
.
free
(
host_indices
[:
matched_length
])
self
.
cache_controller
.
mem_pool_host
.
free
(
host_indices
[
min_completed_tokens
:
completed_tokens
]
)
last_host_node
.
release_host
()
del
self
.
ongoing_prefetch
[
req_id
]
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
):
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
):
empty_value
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
empty_value
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
...
@@ -348,6 +458,71 @@ class HiRadixCache(RadixCache):
...
@@ -348,6 +458,71 @@ class HiRadixCache(RadixCache):
host_hit_length
=
host_hit_length
,
host_hit_length
=
host_hit_length
,
)
)
def
prefetch_from_storage
(
self
,
req_id
:
str
,
last_host_node
:
TreeNode
,
new_input_tokens
:
List
[
int
],
last_hash
:
Optional
[
str
]
=
None
,
):
if
not
self
.
enable_storage
or
len
(
new_input_tokens
)
<
self
.
prefetch_threshold
:
return
last_host_node
.
protect_host
()
host_indices
=
self
.
cache_controller
.
mem_pool_host
.
alloc
(
len
(
new_input_tokens
))
if
host_indices
is
None
:
self
.
evict_host
(
len
(
new_input_tokens
))
host_indices
=
self
.
cache_controller
.
mem_pool_host
.
alloc
(
len
(
new_input_tokens
)
)
if
host_indices
is
None
:
last_host_node
.
release_host
()
# no sufficient host memory to prefetch
return
operation
=
self
.
cache_controller
.
prefetch
(
req_id
,
host_indices
,
new_input_tokens
,
last_hash
)
self
.
ongoing_prefetch
[
req_id
]
=
(
last_host_node
,
new_input_tokens
,
host_indices
,
operation
,
)
def
_insert_helper_host
(
self
,
node
:
TreeNode
,
key
:
List
,
host_value
,
hash_value
):
node
.
last_access_time
=
time
.
monotonic
()
if
len
(
key
)
==
0
:
return
0
child_key
=
self
.
get_child_key_fn
(
key
)
matched_length
=
0
while
len
(
key
)
>
0
and
child_key
in
node
.
children
.
keys
():
node
=
node
.
children
[
child_key
]
node
.
last_access_time
=
time
.
monotonic
()
prefix_len
=
self
.
key_match_fn
(
node
.
key
,
key
)
key
=
key
[
prefix_len
:]
host_value
=
host_value
[
prefix_len
:]
hash_value
=
hash_value
[
prefix_len
:]
matched_length
+=
prefix_len
if
prefix_len
<
len
(
node
.
key
):
new_node
=
self
.
_split_node
(
node
.
key
,
node
,
prefix_len
)
node
=
new_node
if
len
(
key
):
child_key
=
self
.
get_child_key_fn
(
key
)
if
len
(
key
):
new_node
=
TreeNode
()
new_node
.
parent
=
node
new_node
.
key
=
key
new_node
.
value
=
None
new_node
.
host_value
=
host_value
new_node
.
hash_value
=
hash_value
node
.
children
[
child_key
]
=
new_node
return
matched_length
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
def
_match_prefix_helper
(
self
,
node
:
TreeNode
,
key
:
List
):
node
.
last_access_time
=
time
.
monotonic
()
node
.
last_access_time
=
time
.
monotonic
()
child_key
=
self
.
get_child_key_fn
(
key
)
child_key
=
self
.
get_child_key_fn
(
key
)
...
...
python/sglang/srt/mem_cache/memory_pool_host.py
View file @
9d33fcfb
...
@@ -99,6 +99,20 @@ class HostKVCache(abc.ABC):
...
@@ -99,6 +99,20 @@ class HostKVCache(abc.ABC):
def
init_kv_buffer
(
self
):
def
init_kv_buffer
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
get_flat_data_page
(
self
,
index
)
->
torch
.
Tensor
:
"""
Get a flat data page from the host memory pool.
"""
raise
NotImplementedError
()
@
abc
.
abstractmethod
def
set_from_flat_data_page
(
self
,
index
:
int
,
data_page
:
torch
.
Tensor
)
->
None
:
"""
Set a flat data page to the host memory pool.
"""
raise
NotImplementedError
()
@
synchronized
()
@
synchronized
()
def
clear
(
self
):
def
clear
(
self
):
# Initialize memory states and tracking structures.
# Initialize memory states and tracking structures.
...
@@ -227,6 +241,19 @@ class MHATokenToKVPoolHost(HostKVCache):
...
@@ -227,6 +241,19 @@ class MHATokenToKVPoolHost(HostKVCache):
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
)
)
# todo, page first memory layout
def
get_flat_data_page
(
self
,
index
)
->
torch
.
Tensor
:
return
self
.
kv_buffer
[:,
:,
index
:
index
+
self
.
page_size
,
:,
:].
flatten
()
def
set_from_flat_data_page
(
self
,
index
:
int
,
data_page
:
torch
.
Tensor
)
->
None
:
self
.
kv_buffer
[:,
:,
index
:
index
+
self
.
page_size
,
:,
:]
=
data_page
.
reshape
(
2
,
self
.
layer_num
,
self
.
page_size
,
self
.
head_num
,
self
.
head_dim
,
)
@
property
@
property
def
k_buffer
(
self
):
def
k_buffer
(
self
):
return
self
.
kv_buffer
[
0
]
return
self
.
kv_buffer
[
0
]
...
@@ -276,3 +303,14 @@ class MLATokenToKVPoolHost(HostKVCache):
...
@@ -276,3 +303,14 @@ class MLATokenToKVPoolHost(HostKVCache):
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
)
)
def
get_flat_data_page
(
self
,
index
)
->
torch
.
Tensor
:
return
self
.
kv_buffer
[:,
index
:
index
+
self
.
page_size
,
:,
:].
flatten
()
def
set_from_flat_data_page
(
self
,
index
:
int
,
data_page
:
torch
.
Tensor
)
->
None
:
self
.
kv_buffer
[:,
index
:
index
+
self
.
page_size
,
:,
:]
=
data_page
.
reshape
(
self
.
layer_num
,
self
.
page_size
,
1
,
self
.
kv_lora_rank
+
self
.
qk_rope_head_dim
,
)
python/sglang/srt/mem_cache/radix_cache.py
View file @
9d33fcfb
...
@@ -55,8 +55,13 @@ class TreeNode:
...
@@ -55,8 +55,13 @@ class TreeNode:
self
.
hit_count
=
0
self
.
hit_count
=
0
# indicating the node is loading KV cache from host
# indicating the node is loading KV cache from host
self
.
loading
=
False
self
.
loading
=
False
# indicating the node is locked to protect from eviction
# incremented when the node is referenced by a storage operation
self
.
host_ref_counter
=
0
# store the host indices of KV cache
# store the host indices of KV cache
self
.
host_value
:
Optional
[
torch
.
Tensor
]
=
None
self
.
host_value
:
Optional
[
torch
.
Tensor
]
=
None
# store hash values of each pages
self
.
hash_value
:
Optional
[
List
[
str
]]
=
None
self
.
id
=
TreeNode
.
counter
if
id
is
None
else
id
self
.
id
=
TreeNode
.
counter
if
id
is
None
else
id
TreeNode
.
counter
+=
1
TreeNode
.
counter
+=
1
...
@@ -69,6 +74,27 @@ class TreeNode:
...
@@ -69,6 +74,27 @@ class TreeNode:
def
backuped
(
self
):
def
backuped
(
self
):
return
self
.
host_value
is
not
None
return
self
.
host_value
is
not
None
@
property
def
backuped_storage
(
self
):
return
self
.
hash_value
is
not
None
and
len
(
self
.
hash_value
)
>
0
def
protect_host
(
self
):
"""Protect the host value from eviction."""
self
.
host_ref_counter
+=
1
def
release_host
(
self
):
"""Release the host value, allowing it to be evicted."""
if
self
.
host_ref_counter
>
0
:
self
.
host_ref_counter
-=
1
else
:
raise
RuntimeError
(
"Host reference counter is already zero."
)
def
get_last_hash_value
(
self
)
->
Optional
[
str
]:
"""Returns the hash value of the last page in this node."""
if
self
.
hash_value
is
None
or
len
(
self
.
hash_value
)
==
0
:
return
None
return
self
.
hash_value
[
-
1
]
def
__lt__
(
self
,
other
:
"TreeNode"
):
def
__lt__
(
self
,
other
:
"TreeNode"
):
return
self
.
last_access_time
<
other
.
last_access_time
return
self
.
last_access_time
<
other
.
last_access_time
...
...
python/sglang/srt/server_args.py
View file @
9d33fcfb
...
@@ -222,6 +222,7 @@ class ServerArgs:
...
@@ -222,6 +222,7 @@ class ServerArgs:
hicache_size
:
int
=
0
hicache_size
:
int
=
0
hicache_write_policy
:
str
=
"write_through_selective"
hicache_write_policy
:
str
=
"write_through_selective"
hicache_io_backend
:
str
=
""
hicache_io_backend
:
str
=
""
hicache_storage_backend
:
Optional
[
str
]
=
None
flashinfer_mla_disable_ragged
:
bool
=
False
flashinfer_mla_disable_ragged
:
bool
=
False
disable_shared_experts_fusion
:
bool
=
False
disable_shared_experts_fusion
:
bool
=
False
disable_chunked_prefix_cache
:
bool
=
False
disable_chunked_prefix_cache
:
bool
=
False
...
@@ -1604,6 +1605,13 @@ class ServerArgs:
...
@@ -1604,6 +1605,13 @@ class ServerArgs:
default
=
ServerArgs
.
hicache_io_backend
,
default
=
ServerArgs
.
hicache_io_backend
,
help
=
"The IO backend for KV cache transfer between CPU and GPU"
,
help
=
"The IO backend for KV cache transfer between CPU and GPU"
,
)
)
parser
.
add_argument
(
"--hicache-storage-backend"
,
type
=
str
,
choices
=
[
"file"
],
# todo, mooncacke
default
=
ServerArgs
.
hicache_storage_backend
,
help
=
"The storage backend for hierarchical KV cache."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--flashinfer-mla-disable-ragged"
,
"--flashinfer-mla-disable-ragged"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
test/srt/run_suite.py
View file @
9d33fcfb
...
@@ -64,6 +64,7 @@ suites = {
...
@@ -64,6 +64,7 @@ suites = {
TestFile
(
"test_fused_moe.py"
,
30
),
TestFile
(
"test_fused_moe.py"
,
30
),
TestFile
(
"test_hicache.py"
,
116
),
TestFile
(
"test_hicache.py"
,
116
),
TestFile
(
"test_hicache_mla.py"
,
127
),
TestFile
(
"test_hicache_mla.py"
,
127
),
TestFile
(
"test_hicache_storage.py"
,
127
),
TestFile
(
"test_hidden_states.py"
,
55
),
TestFile
(
"test_hidden_states.py"
,
55
),
TestFile
(
"test_int8_kernel.py"
,
8
),
TestFile
(
"test_int8_kernel.py"
,
8
),
TestFile
(
"test_input_embeddings.py"
,
38
),
TestFile
(
"test_input_embeddings.py"
,
38
),
...
...
test/srt/test_hicache_storage.py
0 → 100644
View file @
9d33fcfb
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
class
TestHiCache
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-hierarchical-cache"
,
"--mem-fraction-static"
,
0.7
,
"--hicache-size"
,
100
,
"--page-size"
,
"64"
,
"--hicache-storage-backend"
,
"file"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreaterEqual
(
metrics
[
"score"
],
0.65
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment