Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
145482f4
Unverified
Commit
145482f4
authored
Jul 24, 2025
by
Zhiqiang Xie
Committed by
GitHub
Jul 25, 2025
Browse files
HiCache Storage TP Refinement (#8307)
Co-authored-by:
pansicheng
<
sicheng.pan.chn@gmail.com
>
parent
39fe1e88
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
102 additions
and
23 deletions
+102
-23
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+52
-6
python/sglang/srt/mem_cache/hicache_storage.py
python/sglang/srt/mem_cache/hicache_storage.py
+17
-1
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+30
-16
python/sglang/srt/mem_cache/memory_pool_host.py
python/sglang/srt/mem_cache/memory_pool_host.py
+3
-0
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
145482f4
...
...
@@ -219,6 +219,7 @@ class HiCacheController:
token_to_kv_pool_allocator
:
BaseTokenToKVPoolAllocator
,
mem_pool_host
:
HostKVCache
,
page_size
:
int
,
tp_group
:
torch
.
distributed
.
ProcessGroup
,
load_cache_event
:
threading
.
Event
=
None
,
write_policy
:
str
=
"write_through_selective"
,
io_backend
:
str
=
""
,
...
...
@@ -244,11 +245,17 @@ class HiCacheController:
self
.
enable_storage
=
False
# todo: move backend initialization to storage backend module
if
storage_backend
is
not
None
:
# create a new communication group for synchronizing storage operations across TP workers
self
.
tp_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
tp_group
)
if
self
.
tp_world_size
>
1
:
group_ranks
=
torch
.
distributed
.
get_process_group_ranks
(
tp_group
)
self
.
tp_group
=
torch
.
distributed
.
new_group
(
group_ranks
,
backend
=
"gloo"
)
if
storage_backend
==
"file"
:
self
.
storage_backend
=
HiCacheFile
()
self
.
enable_storage
=
True
# todo: threshold policy for prefetching
self
.
prefetch_threshold
=
prefetch_threshold
self
.
prefetch_threshold
=
max
(
prefetch_threshold
,
self
.
page_size
)
else
:
raise
NotImplementedError
(
f
"Unsupported storage backend:
{
storage_backend
}
"
...
...
@@ -568,13 +575,32 @@ class HiCacheController:
else
:
break
if
self
.
tp_world_size
>
1
:
storage_hit_count_tensor
=
torch
.
tensor
(
storage_hit_count
,
dtype
=
torch
.
int
)
torch
.
distributed
.
all_reduce
(
storage_hit_count_tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
storage_hit_count
=
storage_hit_count_tensor
.
item
()
if
storage_hit_count
<
self
.
prefetch_threshold
:
# not to prefetch if not enough benefits
self
.
prefetch_revoke_queue
.
put
(
operation
.
request_id
)
logger
.
debug
(
f
"Revoking prefetch for request
{
operation
.
request_id
}
due to insufficient hits (
{
storage_hit_count
}
)."
)
else
:
operation
.
hash_value
=
hash_value
operation
.
hash_value
=
hash_value
[
:
(
storage_hit_count
//
self
.
page_size
)
]
# free the pre-allocated memory for pages that are not hit
self
.
mem_pool_host
.
free
(
operation
.
host_indices
[
storage_hit_count
:])
operation
.
host_indices
=
operation
.
host_indices
[:
storage_hit_count
]
logger
.
debug
(
f
"Prefetching
{
len
(
hash_value
)
}
pages for request
{
operation
.
request_id
}
."
f
"Prefetching
{
len
(
operation
.
hash_value
)
}
pages for request
{
operation
.
request_id
}
."
)
self
.
prefetch_buffer
.
put
(
operation
)
...
...
@@ -611,17 +637,37 @@ class HiCacheController:
last_hash
=
get_hash_str
(
tokens_to_backup
[
i
:
i
+
self
.
page_size
],
last_hash
)
# todo, handle failures in storage backend
self
.
storage_backend
.
set
(
success
=
self
.
storage_backend
.
set
(
last_hash
,
self
.
mem_pool_host
.
get_flat_data_page
(
operation
.
host_indices
[
i
]
),
)
if
not
success
:
logger
.
warning
(
f
"Failed to write page
{
last_hash
}
to storage."
)
break
operation
.
completed_tokens
+=
self
.
page_size
operation
.
hash_value
.
append
(
last_hash
)
self
.
ack_backup_queue
.
put
((
operation
.
id
,
operation
.
hash_value
))
min_completed_tokens
=
operation
.
completed_tokens
if
self
.
tp_world_size
>
1
:
completed_tokens_tensor
=
torch
.
tensor
(
min_completed_tokens
,
dtype
=
torch
.
int
)
torch
.
distributed
.
all_reduce
(
completed_tokens_tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
min_completed_tokens
=
completed_tokens_tensor
.
item
()
self
.
ack_backup_queue
.
put
(
(
operation
.
id
,
operation
.
hash_value
[:
min_completed_tokens
//
self
.
page_size
],
min_completed_tokens
,
)
)
except
Empty
:
continue
python/sglang/srt/mem_cache/hicache_storage.py
View file @
145482f4
...
...
@@ -9,6 +9,12 @@ import torch
logger
=
logging
.
getLogger
(
__name__
)
from
sglang.srt.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
def
get_hash_str
(
token_ids
:
List
[
int
],
prior_hash
:
Optional
[
str
]
=
None
)
->
str
:
hasher
=
hashlib
.
sha256
()
...
...
@@ -80,13 +86,20 @@ class HiCacheFile(HiCacheStorage):
def
__init__
(
self
,
file_path
:
str
=
"/tmp/hicache"
):
self
.
file_path
=
file_path
if
not
os
.
path
.
exists
(
self
.
file_path
):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_suffix
=
f
"_
{
tp_rank
}
_
{
tp_size
}
"
if
tp_size
>
1
else
""
if
not
os
.
path
.
exists
(
self
.
file_path
)
and
tp_rank
==
0
:
os
.
makedirs
(
self
.
file_path
)
logger
.
info
(
f
"Created HiCacheFile storage directory at
{
self
.
file_path
}
"
)
def
_get_suffixed_key
(
self
,
key
:
str
)
->
str
:
return
key
+
self
.
tp_suffix
def
get
(
self
,
key
:
str
,
target_location
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
|
None
:
key
=
self
.
_get_suffixed_key
(
key
)
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
try
:
# todo: fixing the target_location logic to enable in-place loading
...
...
@@ -112,6 +125,7 @@ class HiCacheFile(HiCacheStorage):
]
def
set
(
self
,
key
:
str
,
value
:
torch
.
Tensor
)
->
bool
:
key
=
self
.
_get_suffixed_key
(
key
)
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
if
self
.
exists
(
key
):
logger
.
debug
(
f
"Key
{
key
}
already exists. Skipped."
)
...
...
@@ -130,10 +144,12 @@ class HiCacheFile(HiCacheStorage):
return
True
def
exists
(
self
,
key
:
str
)
->
bool
:
key
=
self
.
_get_suffixed_key
(
key
)
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
return
os
.
path
.
exists
(
tensor_path
)
def
delete
(
self
,
key
:
str
)
->
None
:
key
=
self
.
_get_suffixed_key
(
key
)
tensor_path
=
os
.
path
.
join
(
self
.
file_path
,
f
"
{
key
}
.bin"
)
try
:
os
.
remove
(
tensor_path
)
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
145482f4
...
...
@@ -50,6 +50,7 @@ class HiRadixCache(RadixCache):
raise
ValueError
(
f
"HiRadixCache only supports MHA and MLA yet"
)
self
.
tp_group
=
tp_cache_group
self
.
tp_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
self
.
enable_storage
=
hicache_storage_backend
is
not
None
# todo: customizable storage prefetch threshold
self
.
prefetch_threshold
=
256
...
...
@@ -59,6 +60,7 @@ class HiRadixCache(RadixCache):
token_to_kv_pool_allocator
,
self
.
token_to_kv_pool_host
,
page_size
,
self
.
tp_group
,
load_cache_event
=
self
.
load_cache_event
,
write_policy
=
hicache_write_policy
,
io_backend
=
hicache_io_backend
,
...
...
@@ -153,7 +155,7 @@ class HiRadixCache(RadixCache):
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
ack_write_queue
.
qsize
(),
dtype
=
torch
.
int
)
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
if
self
.
tp_world_size
>
1
:
# synchrnoize TP workers to make the same update to radix cache
torch
.
distributed
.
all_reduce
(
queue_size
,
...
...
@@ -353,7 +355,7 @@ class HiRadixCache(RadixCache):
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
prefetch_revoke_queue
.
qsize
(),
dtype
=
torch
.
int
)
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
if
self
.
tp_world_size
>
1
:
# synchrnoize TP workers to make the same update to hiradix cache
torch
.
distributed
.
all_reduce
(
queue_size
,
...
...
@@ -372,7 +374,7 @@ class HiRadixCache(RadixCache):
queue_size
=
torch
.
tensor
(
self
.
cache_controller
.
ack_backup_queue
.
qsize
(),
dtype
=
torch
.
int
)
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
if
self
.
tp_world_size
>
1
:
# synchrnoize TP workers to make the same update to hiradix cache
torch
.
distributed
.
all_reduce
(
queue_size
,
...
...
@@ -380,9 +382,15 @@ class HiRadixCache(RadixCache):
group
=
self
.
tp_group
,
)
for
_
in
range
(
queue_size
.
item
()):
ack_id
,
hash_value
=
self
.
cache_controller
.
ack_backup_queue
.
get
()
self
.
ongoing_backup
[
ack_id
].
hash_value
=
hash_value
self
.
ongoing_backup
[
ack_id
].
release_host
()
ack_id
,
hash_value
,
completed_tokens
=
(
self
.
cache_controller
.
ack_backup_queue
.
get
()
)
host_node
=
self
.
ongoing_backup
[
ack_id
]
if
completed_tokens
<
len
(
host_node
.
key
):
# backup is only partially successful, split the node
new_node
=
self
.
_split_node
(
host_node
.
key
,
host_node
,
completed_tokens
)
new_node
.
hash_value
=
hash_value
host_node
.
release_host
()
del
self
.
ongoing_backup
[
ack_id
]
def
check_prefetch_progress
(
self
,
req_id
:
str
):
...
...
@@ -400,15 +408,18 @@ class HiRadixCache(RadixCache):
)
logger
.
debug
(
f
"Prefetch
{
req_id
}
completed with
{
completed_tokens
}
tokens"
)
min_completed_tokens
=
torch
.
tensor
(
completed_tokens
,
dtype
=
torch
.
int
)
if
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
>
1
:
min_completed_tokens
=
completed_tokens
if
self
.
tp_world_size
>
1
:
# synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor
=
torch
.
tensor
(
min_completed_tokens
,
dtype
=
torch
.
int
)
torch
.
distributed
.
all_reduce
(
min_
completed_tokens
,
completed_tokens
_tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
min_completed_tokens
=
min_
completed_tokens
.
item
()
min_completed_tokens
=
completed_tokens
_tensor
.
item
()
fetched_token_ids
=
token_ids
[:
min_completed_tokens
]
written_indices
=
host_indices
[:
min_completed_tokens
]
matched_length
=
self
.
_insert_helper_host
(
...
...
@@ -465,16 +476,19 @@ class HiRadixCache(RadixCache):
new_input_tokens
:
List
[
int
],
last_hash
:
Optional
[
str
]
=
None
,
):
if
not
self
.
enable_storage
or
len
(
new_input_tokens
)
<
self
.
prefetch_threshold
:
# align the number of fetching tokens to the page size
prefetch_length
=
len
(
new_input_tokens
)
-
(
len
(
new_input_tokens
)
%
self
.
page_size
)
new_input_tokens
=
new_input_tokens
[:
prefetch_length
]
if
not
self
.
enable_storage
or
prefetch_length
<
self
.
prefetch_threshold
:
return
last_host_node
.
protect_host
()
host_indices
=
self
.
cache_controller
.
mem_pool_host
.
alloc
(
len
(
new_input_tokens
)
)
host_indices
=
self
.
cache_controller
.
mem_pool_host
.
alloc
(
prefetch_length
)
if
host_indices
is
None
:
self
.
evict_host
(
len
(
new_input_tokens
))
host_indices
=
self
.
cache_controller
.
mem_pool_host
.
alloc
(
len
(
new_input_tokens
)
)
self
.
evict_host
(
prefetch_length
)
host_indices
=
self
.
cache_controller
.
mem_pool_host
.
alloc
(
prefetch_length
)
if
host_indices
is
None
:
last_host_node
.
release_host
()
# no sufficient host memory to prefetch
...
...
python/sglang/srt/mem_cache/memory_pool_host.py
View file @
145482f4
...
...
@@ -126,6 +126,9 @@ class HostKVCache(abc.ABC):
@
synchronized
()
def
alloc
(
self
,
need_size
:
int
)
->
torch
.
Tensor
:
assert
(
need_size
%
self
.
page_size
==
0
),
"The requested size should be a multiple of the page size."
if
need_size
>
self
.
available_size
():
return
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment