Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
145482f4
Unverified
Commit
145482f4
authored
Jul 24, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jul 25, 2025
Browse files
HiCache Storage TP Refinement (#8307)
Co-authored-by:
pansicheng
<
sicheng.pan.chn@gmail.com
>
parent
39fe1e88
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
102 additions
and
23 deletions
+102
-23
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+52
-6
python/sglang/srt/mem_cache/hicache_storage.py
python/sglang/srt/mem_cache/hicache_storage.py
+17
-1
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+30
-16
python/sglang/srt/mem_cache/memory_pool_host.py
python/sglang/srt/mem_cache/memory_pool_host.py
+3
-0
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
145482f4
...
@@ -219,6 +219,7 @@ class HiCacheController:
...
@@ -219,6 +219,7 @@ class HiCacheController:
token_to_kv_pool_allocator
:
BaseTokenToKVPoolAllocator
,
token_to_kv_pool_allocator
:
BaseTokenToKVPoolAllocator
,
mem_pool_host
:
HostKVCache
,
mem_pool_host
:
HostKVCache
,
page_size
:
int
,
page_size
:
int
,
tp_group
:
torch
.
distributed
.
ProcessGroup
,
load_cache_event
:
threading
.
Event
=
None
,
load_cache_event
:
threading
.
Event
=
None
,
write_policy
:
str
=
"write_through_selective"
,
write_policy
:
str
=
"write_through_selective"
,
io_backend
:
str
=
""
,
io_backend
:
str
=
""
,
...
@@ -244,11 +245,17 @@ class HiCacheController:
...
@@ -244,11 +245,17 @@ class HiCacheController:
self
.
enable_storage
=
False
self
.
enable_storage
=
False
# todo: move backend initialization to storage backend module
# todo: move backend initialization to storage backend module
if
storage_backend
is
not
None
:
if
storage_backend
is
not
None
:
# create a new communication group for synchronizing storage operations across TP workers
self
.
tp_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
tp_group
)
if
self
.
tp_world_size
>
1
:
group_ranks
=
torch
.
distributed
.
get_process_group_ranks
(
tp_group
)
self
.
tp_group
=
torch
.
distributed
.
new_group
(
group_ranks
,
backend
=
"gloo"
)
if
storage_backend
==
"file"
:
if
storage_backend
==
"file"
:
self
.
storage_backend
=
HiCacheFile
()
self
.
storage_backend
=
HiCacheFile
()
self
.
enable_storage
=
True
self
.
enable_storage
=
True
# todo: threshold policy for prefetching
# todo: threshold policy for prefetching
self
.
prefetch_threshold
=
prefetch_threshold
self
.
prefetch_threshold
=
max
(
prefetch_threshold
,
self
.
page_size
)
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"Unsupported storage backend:
{
storage_backend
}
"
f
"Unsupported storage backend:
{
storage_backend
}
"
...
@@ -568,13 +575,32 @@ class HiCacheController:
...
@@ -568,13 +575,32 @@ class HiCacheController:
else
:
else
:
break
break
if
self
.
tp_world_size
>
1
:
storage_hit_count_tensor
=
torch
.
tensor
(
storage_hit_count
,
dtype
=
torch
.
int
)
torch
.
distributed
.
all_reduce
(
storage_hit_count_tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
storage_hit_count
=
storage_hit_count_tensor
.
item
()
if
storage_hit_count
<
self
.
prefetch_threshold
:
if
storage_hit_count
<
self
.
prefetch_threshold
:
# not to prefetch if not enough benefits
# not to prefetch if not enough benefits
self
.
prefetch_revoke_queue
.
put
(
operation
.
request_id
)
self
.
prefetch_revoke_queue
.
put
(
operation
.
request_id
)
logger
.
debug
(
f
"Revoking prefetch for request
{
operation
.
request_id
}
due to insufficient hits (
{
storage_hit_count
}
)."
)
else
:
else
:
operation
.
hash_value
=
hash_value
operation
.
hash_value
=
hash_value
[
:
(
storage_hit_count
//
self
.
page_size
)
]
# free the pre-allocated memory for pages that are not hit
self
.
mem_pool_host
.
free
(
operation
.
host_indices
[
storage_hit_count
:])
operation
.
host_indices
=
operation
.
host_indices
[:
storage_hit_count
]
logger
.
debug
(
logger
.
debug
(
f
"Prefetching
{
len
(
hash_value
)
}
pages for request
{
operation
.
request_id
}
."
f
"Prefetching
{
len
(
operation
.
hash_value
)
}
pages for request
{
operation
.
request_id
}
."
)
)
self
.
prefetch_buffer
.
put
(
operation
)
self
.
prefetch_buffer
.
put
(
operation
)
...
@@ -611,17 +637,37 @@ class HiCacheController:
...
@@ -611,17 +637,37 @@ class HiCacheController:
last_hash
=
get_hash_str
(
last_hash
=
get_hash_str
(
tokens_to_backup
[
i
:
i
+
self
.
page_size
],
last_hash
tokens_to_backup
[
i
:
i
+
self
.
page_size
],
last_hash
)
)
# todo, handle failures in storage backend
success
=
self
.
storage_backend
.
set
(
self
.
storage_backend
.
set
(
last_hash
,
last_hash
,
self
.
mem_pool_host
.
get_flat_data_page
(
self
.
mem_pool_host
.
get_flat_data_page
(
operation
.
host_indices
[
i
]
operation
.
host_indices
[
i
]
),
),
)
)
if
not
success
:
logger
.
warning
(
f
"Failed to write page
{
last_hash
}
to storage."
)
break
operation
.
completed_tokens
+=
self
.
page_size
operation
.
completed_tokens
+=
self
.
page_size
operation
.
hash_value
.
append
(
last_hash
)
operation
.
hash_value
.
append
(
last_hash
)
self
.
ack_backup_queue
.
put
((
operation
.
id
,
operation
.
hash_value
))
min_completed_tokens
=
operation
.
completed_tokens
if
self
.
tp_world_size
>
1
:
completed_tokens_tensor
=
torch
.
tensor
(
min_completed_tokens
,
dtype
=
torch
.
int
)
torch
.
distributed
.
all_reduce
(
completed_tokens_tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
min_completed_tokens
=
completed_tokens_tensor
.
item
()
self
.
ack_backup_queue
.
put
(
(
operation
.
id
,
operation
.
hash_value
[:
min_completed_tokens
//
self
.
page_size
],
min_completed_tokens
,
)
)
except
Empty
:
except
Empty
:
continue
continue
python/sglang/srt/mem_cache/hicache_storage.py
View file @
145482f4
...
@@ -9,6 +9,12 @@ import torch
...
@@ -9,6 +9,12 @@ import torch
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
def
get_hash_str
(
token_ids
:
List
[
int
],
prior_hash
:
Optional
[
str
]
=
None
)
->
str
:
def
get_hash_str
(
token_ids
:
List
[
int
],
prior_hash
:
Optional
[
str
]
=
None
)
->
str
:
hasher
=
hashlib
.
sha256
()
hasher
=
hashlib
.
sha256
()
...
@@ -80,13 +86,20 @@ class HiCacheFile(HiCacheStorage):
...
@@ -80,13 +86,20 @@ class HiCacheFile(HiCacheStorage):
def
__init__
(
self
,
file_path
:
str
=
"/tmp/hicache"
):
def
__init__
(
self
,
file_path
:
str
=
"/tmp/hicache"
):
self
.
file_path
=
file_path
self
.
file_path
=
file_path
if
not
os
.
path
.
exists
(
self
.
file_path
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_suffix
=
f
"_
{
tp_rank
}
_
{
tp_size
}
"
if
tp_size
>
1
else
""
if
not
os
.
path
.
exists
(
self
.
file_path
)
and
tp_rank
==
0
:
os
.
makedirs
(
self
.
file_path
)
os
.
makedirs
(
self
.
file_path
)
logger
.
info
(
f
"Created HiCacheFile storage directory at
{
self
.
file_path
}
"
)
logger
.
info
(
f
"Created HiCacheFile storage directory at
{
self
.
file_path
}
"
)
def
_get_suffixed_key
(
self
,
key
:
str
)
->
str
:
return
key
+
self
.
tp_suffix
def
get
(
def
get
(
self
,
key
:
str
,
target_location
:
Optional
[
torch
.
Tensor
]
=
None
self
,
key
:
str
,
target_location
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
|
None
:
)
->
torch
.
Tensor
|
None
:
key
=
self
.
_get_suffixed_key
(
key
)
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
try
:
try
:
# todo: fixing the target_location logic to enable in-place loading
# todo: fixing the target_location logic to enable in-place loading
...
@@ -112,6 +125,7 @@ class HiCacheFile(HiCacheStorage):
...
@@ -112,6 +125,7 @@ class HiCacheFile(HiCacheStorage):
]
]
def
set
(
self
,
key
:
str
,
value
:
torch
.
Tensor
)
->
bool
:
def
set
(
self
,
key
:
str
,
value
:
torch
.
Tensor
)
->
bool
:
key
=
self
.
_get_suffixed_key
(
key
)
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
if
self
.
exists
(
key
):
if
self
.
exists
(
key
):
logger
.
debug
(
f
"Key
{
key
}
already exists. Skipped."
)
logger
.
debug
(
f
"Key
{
key
}
already exists. Skipped."
)
...
@@ -130,10 +144,12 @@ class HiCacheFile(HiCacheStorage):
...
@@ -130,10 +144,12 @@ class HiCacheFile(HiCacheStorage):
return
True
return
True
def
exists
(
self
,
key
:
str
)
->
bool
:
def
exists
(
self
,
key
:
str
)
->
bool
:
key
=
self
.
_get_suffixed_key
(
key
)
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
return
os
.
path
.
exists
(
tensor_path
)
return
os
.
path
.
exists
(
tensor_path
)
def
delete
(
self
,
key
:
str
)
->
None
:
def
delete
(
self
,
key
:
str
)
->
None
:
key
=
self
.
_get_suffixed_key
(
key
)
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
try
:
try
:
os
.
remove
(
tensor_path
)
os
.
remove
(
tensor_path
)
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
145482f4
...
@@ -50,6 +50,7 @@ class HiRadixCache(RadixCache):
...
@@ -50,6 +50,7 @@ class HiRadixCache(RadixCache):
raise
ValueError
(
f
"HiRadixCache only supports MHA and MLA yet"
)
raise
ValueError
(
f
"HiRadixCache only supports MHA and MLA yet"
)
self
.
tp_group
=
tp_cache_group
self
.
tp_group
=
tp_cache_group
self
.
tp_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
self
.
enable_storage
=
hicache_storage_backend
is
not
None
self
.
enable_storage
=
hicache_storage_backend
is
not
None
# todo: customizable storage prefetch threshold
# todo: customizable storage prefetch threshold
self
.
prefetch_threshold
=
256
self
.
prefetch_threshold
=
256
...
@@ -59,6 +60,7 @@ class HiRadixCache(RadixCache):
...
@@ -59,6 +60,7 @@ class HiRadixCache(RadixCache):
token_to_kv_pool_allocator
,
token_to_kv_pool_allocator
,
self
.
token_to_kv_pool_host
,
self
.
token_to_kv_pool_host
,
page_size
,
page_size
,
self
.
tp_group
,
load_cache_event
=
self
.
load_cache_event
,
load_cache_event
=
self
.
load_cache_event
,
write_policy
=
hicache_write_policy
,
write_policy
=
hicache_write_policy
,
io_backend
=
hicache_io_backend
,
io_backend
=
hicache_io_backend
,
...
@@ -153,7 +155,7 @@ class HiRadixCache(RadixCache):
...
@@ -153,7 +155,7 @@ class HiRadixCache(RadixCache):
queue_size
=
torch
.
tensor
(
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
ack_write_queue
.
qsize
(),
dtype
=
torch
.
int
self
.
cache_controller
.
ack_write_queue
.
qsize
(),
dtype
=
torch
.
int
)
)
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
if
self
.
tp_world_size
>
1
:
# synchrnoize TP workers to make the same update to radix cache
# synchrnoize TP workers to make the same update to radix cache
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
queue_size
,
queue_size
,
...
@@ -353,7 +355,7 @@ class HiRadixCache(RadixCache):
...
@@ -353,7 +355,7 @@ class HiRadixCache(RadixCache):
queue_size
=
torch
.
tensor
(
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
prefetch_revoke_queue
.
qsize
(),
dtype
=
torch
.
int
self
.
cache_controller
.
prefetch_revoke_queue
.
qsize
(),
dtype
=
torch
.
int
)
)
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
if
self
.
tp_world_size
>
1
:
# synchrnoize TP workers to make the same update to hiradix cache
# synchrnoize TP workers to make the same update to hiradix cache
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
queue_size
,
queue_size
,
...
@@ -372,7 +374,7 @@ class HiRadixCache(RadixCache):
...
@@ -372,7 +374,7 @@ class HiRadixCache(RadixCache):
queue_size
=
torch
.
tensor
(
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
ack_backup_queue
.
qsize
(),
dtype
=
torch
.
int
self
.
cache_controller
.
ack_backup_queue
.
qsize
(),
dtype
=
torch
.
int
)
)
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
if
self
.
tp_world_size
>
1
:
# synchrnoize TP workers to make the same update to hiradix cache
# synchrnoize TP workers to make the same update to hiradix cache
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
queue_size
,
queue_size
,
...
@@ -380,9 +382,15 @@ class HiRadixCache(RadixCache):
...
@@ -380,9 +382,15 @@ class HiRadixCache(RadixCache):
group
=
self
.
tp_group
,
group
=
self
.
tp_group
,
)
)
for
_
in
range
(
queue_size
.
item
()):
for
_
in
range
(
queue_size
.
item
()):
ack_id
,
hash_value
=
self
.
cache_controller
.
ack_backup_queue
.
get
()
ack_id
,
hash_value
,
completed_tokens
=
(
self
.
ongoing_backup
[
ack_id
].
hash_value
=
hash_value
self
.
cache_controller
.
ack_backup_queue
.
get
()
self
.
ongoing_backup
[
ack_id
].
release_host
()
)
host_node
=
self
.
ongoing_backup
[
ack_id
]
if
completed_tokens
<
len
(
host_node
.
key
):
# backup is only partially successful, split the node
new_node
=
self
.
_split_node
(
host_node
.
key
,
host_node
,
completed_tokens
)
new_node
.
hash_value
=
hash_value
host_node
.
release_host
()
del
self
.
ongoing_backup
[
ack_id
]
del
self
.
ongoing_backup
[
ack_id
]
def
check_prefetch_progress
(
self
,
req_id
:
str
):
def
check_prefetch_progress
(
self
,
req_id
:
str
):
...
@@ -400,15 +408,18 @@ class HiRadixCache(RadixCache):
...
@@ -400,15 +408,18 @@ class HiRadixCache(RadixCache):
)
)
logger
.
debug
(
f
"Prefetch
{
req_id
}
completed with
{
completed_tokens
}
tokens"
)
logger
.
debug
(
f
"Prefetch
{
req_id
}
completed with
{
completed_tokens
}
tokens"
)
min_completed_tokens
=
torch
.
tensor
(
completed_tokens
,
dtype
=
torch
.
int
)
min_completed_tokens
=
completed_tokens
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
if
self
.
tp_world_size
>
1
:
# synchrnoize TP workers to make the same update to hiradix cache
# synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor
=
torch
.
tensor
(
min_completed_tokens
,
dtype
=
torch
.
int
)
torch
.
distributed
.
all_reduce
(
torch
.
distributed
.
all_reduce
(
min_
completed_tokens
,
completed_tokens
_tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
group
=
self
.
tp_group
,
)
)
min_completed_tokens
=
min_
completed_tokens
.
item
()
min_completed_tokens
=
completed_tokens
_tensor
.
item
()
fetched_token_ids
=
token_ids
[:
min_completed_tokens
]
fetched_token_ids
=
token_ids
[:
min_completed_tokens
]
written_indices
=
host_indices
[:
min_completed_tokens
]
written_indices
=
host_indices
[:
min_completed_tokens
]
matched_length
=
self
.
_insert_helper_host
(
matched_length
=
self
.
_insert_helper_host
(
...
@@ -465,16 +476,19 @@ class HiRadixCache(RadixCache):
...
@@ -465,16 +476,19 @@ class HiRadixCache(RadixCache):
new_input_tokens
:
List
[
int
],
new_input_tokens
:
List
[
int
],
last_hash
:
Optional
[
str
]
=
None
,
last_hash
:
Optional
[
str
]
=
None
,
):
):
if
not
self
.
enable_storage
or
len
(
new_input_tokens
)
<
self
.
prefetch_threshold
:
# align the number of fetching tokens to the page size
prefetch_length
=
len
(
new_input_tokens
)
-
(
len
(
new_input_tokens
)
%
self
.
page_size
)
new_input_tokens
=
new_input_tokens
[:
prefetch_length
]
if
not
self
.
enable_storage
or
prefetch_length
<
self
.
prefetch_threshold
:
return
return
last_host_node
.
protect_host
()
last_host_node
.
protect_host
()
host_indices
=
self
.
cache_controller
.
mem_pool_host
.
alloc
(
len
(
new_input_tokens
)
)
host_indices
=
self
.
cache_controller
.
mem_pool_host
.
alloc
(
prefetch_length
)
if
host_indices
is
None
:
if
host_indices
is
None
:
self
.
evict_host
(
len
(
new_input_tokens
))
self
.
evict_host
(
prefetch_length
)
host_indices
=
self
.
cache_controller
.
mem_pool_host
.
alloc
(
host_indices
=
self
.
cache_controller
.
mem_pool_host
.
alloc
(
prefetch_length
)
len
(
new_input_tokens
)
)
if
host_indices
is
None
:
if
host_indices
is
None
:
last_host_node
.
release_host
()
last_host_node
.
release_host
()
# no sufficient host memory to prefetch
# no sufficient host memory to prefetch
...
...
python/sglang/srt/mem_cache/memory_pool_host.py
View file @
145482f4
...
@@ -126,6 +126,9 @@ class HostKVCache(abc.ABC):
...
@@ -126,6 +126,9 @@ class HostKVCache(abc.ABC):
@
synchronized
()
@
synchronized
()
def
alloc
(
self
,
need_size
:
int
)
->
torch
.
Tensor
:
def
alloc
(
self
,
need_size
:
int
)
->
torch
.
Tensor
:
assert
(
need_size
%
self
.
page_size
==
0
),
"The requested size should be a multiple of the page size."
if
need_size
>
self
.
available_size
():
if
need_size
>
self
.
available_size
():
return
None
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment