Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bfa27438
Unverified
Commit
bfa27438
authored
Oct 01, 2025
by
ykwd
Committed by
GitHub
Oct 01, 2025
Browse files
[HiCache] Configurable and Dynamic Prefetch Timeout (#10512)
Co-authored-by:
Zhiqiang Xie
<
xiezhq@stanford.edu
>
parent
86cb4db0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
21 deletions
+75
-21
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+3
-13
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+72
-8
No files found.
python/sglang/srt/managers/cache_controller.py
View file @
bfa27438
...
@@ -250,7 +250,7 @@ class HiCacheController:
...
@@ -250,7 +250,7 @@ class HiCacheController:
storage_backend
:
Optional
[
str
]
=
None
,
storage_backend
:
Optional
[
str
]
=
None
,
prefetch_threshold
:
int
=
256
,
prefetch_threshold
:
int
=
256
,
model_name
:
Optional
[
str
]
=
None
,
model_name
:
Optional
[
str
]
=
None
,
storage_backend_extra_config
:
Optional
[
str
]
=
None
,
storage_backend_extra_config
:
Optional
[
dict
]
=
None
,
):
):
self
.
mem_pool_device_allocator
=
token_to_kv_pool_allocator
self
.
mem_pool_device_allocator
=
token_to_kv_pool_allocator
self
.
mem_pool_device
=
token_to_kv_pool_allocator
.
get_kvcache
()
self
.
mem_pool_device
=
token_to_kv_pool_allocator
.
get_kvcache
()
...
@@ -361,7 +361,7 @@ class HiCacheController:
...
@@ -361,7 +361,7 @@ class HiCacheController:
def
_generate_storage_config
(
def
_generate_storage_config
(
self
,
self
,
model_name
:
Optional
[
str
]
=
None
,
model_name
:
Optional
[
str
]
=
None
,
storage_backend_extra_config
:
Optional
[
str
]
=
None
,
storage_backend_extra_config
:
Optional
[
dict
]
=
None
,
):
):
if
is_dp_attention_enabled
():
if
is_dp_attention_enabled
():
...
@@ -376,23 +376,13 @@ class HiCacheController:
...
@@ -376,23 +376,13 @@ class HiCacheController:
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
is_mla_backend
=
isinstance
(
self
.
mem_pool_device
,
MLATokenToKVPool
)
is_mla_backend
=
isinstance
(
self
.
mem_pool_device
,
MLATokenToKVPool
)
# Parse extra config JSON if provided
extra_config
=
None
if
storage_backend_extra_config
:
try
:
import
json
extra_config
=
json
.
loads
(
storage_backend_extra_config
)
except
Exception
as
e
:
logger
.
error
(
f
"Invalid backend extra config JSON:
{
e
}
"
)
return
HiCacheStorageConfig
(
return
HiCacheStorageConfig
(
tp_rank
=
self
.
tp_rank
,
tp_rank
=
self
.
tp_rank
,
tp_size
=
self
.
tp_size
,
tp_size
=
self
.
tp_size
,
is_mla_model
=
is_mla_backend
,
is_mla_model
=
is_mla_backend
,
is_page_first_layout
=
self
.
mem_pool_host
.
layout
==
"page_first"
,
is_page_first_layout
=
self
.
mem_pool_host
.
layout
==
"page_first"
,
model_name
=
model_name
,
model_name
=
model_name
,
extra_config
=
extra_config
,
extra_config
=
storage_backend_
extra_config
,
)
)
def
reset
(
self
):
def
reset
(
self
):
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
bfa27438
import
heapq
import
heapq
import
json
import
logging
import
logging
import
threading
import
threading
import
time
import
time
from
queue
import
Queue
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch
...
@@ -78,9 +78,19 @@ class HiRadixCache(RadixCache):
...
@@ -78,9 +78,19 @@ class HiRadixCache(RadixCache):
self
.
enable_storage
=
hicache_storage_backend
is
not
None
self
.
enable_storage
=
hicache_storage_backend
is
not
None
self
.
enable_storage_metrics
=
self
.
enable_storage
and
enable_metrics
self
.
enable_storage_metrics
=
self
.
enable_storage
and
enable_metrics
# todo: customizable storage prefetch threshold and timeout
(
self
.
prefetch_threshold
=
256
extra_config
,
self
.
prefetch_timeout
=
3
# seconds
prefetch_threshold
,
prefetch_timeout_base
,
prefetch_timeout_per_ki_token
,
)
=
self
.
_parse_storage_backend_extra_config
(
storage_backend_extra_config
)
self
.
prefetch_threshold
=
prefetch_threshold
self
.
prefetch_timeout_base
=
prefetch_timeout_base
self
.
prefetch_timeout_per_page
=
(
page_size
/
1024
*
prefetch_timeout_per_ki_token
)
# TODO: support more timeout check functions
self
.
is_prefetch_timeout
=
self
.
_prefetch_timeout_check_linear_func
self
.
prefetch_stop_policy
=
hicache_storage_prefetch_policy
self
.
prefetch_stop_policy
=
hicache_storage_prefetch_policy
self
.
load_cache_event
=
threading
.
Event
()
self
.
load_cache_event
=
threading
.
Event
()
...
@@ -95,7 +105,7 @@ class HiRadixCache(RadixCache):
...
@@ -95,7 +105,7 @@ class HiRadixCache(RadixCache):
storage_backend
=
hicache_storage_backend
,
storage_backend
=
hicache_storage_backend
,
prefetch_threshold
=
self
.
prefetch_threshold
,
prefetch_threshold
=
self
.
prefetch_threshold
,
model_name
=
model_name
,
model_name
=
model_name
,
storage_backend_extra_config
=
storage_backend_
extra_config
,
storage_backend_extra_config
=
extra_config
,
)
)
if
self
.
enable_storage_metrics
:
if
self
.
enable_storage_metrics
:
# TODO: support pp
# TODO: support pp
...
@@ -127,6 +137,53 @@ class HiRadixCache(RadixCache):
...
@@ -127,6 +137,53 @@ class HiRadixCache(RadixCache):
eviction_policy
=
eviction_policy
,
eviction_policy
=
eviction_policy
,
)
)
def
_parse_storage_backend_extra_config
(
self
,
storage_backend_extra_config
:
Optional
[
str
]
):
"""
Parse storage backend extra config JSON and extract specific parameters.
Args:
storage_backend_extra_config: JSON string containing extra configuration
Returns:
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
"""
# Parse extra config JSON if provided
extra_config
=
{}
if
storage_backend_extra_config
:
try
:
extra_config
=
json
.
loads
(
storage_backend_extra_config
)
except
Exception
as
e
:
logger
.
error
(
f
"Invalid backend extra config JSON:
{
e
}
"
)
raise
e
prefetch_threshold
=
extra_config
.
pop
(
"prefetch_threshold"
,
256
)
# tokens
prefetch_timeout_base
=
extra_config
.
pop
(
"prefetch_timeout_base"
,
1
)
# seconds
prefetch_timeout_per_ki_token
=
extra_config
.
pop
(
"prefetch_timeout_per_ki_token"
,
0.25
)
# seconds per 1024 tokens
if
not
isinstance
(
prefetch_threshold
,
int
):
raise
ValueError
(
f
"prefetch_threshold must be int, got
{
type
(
prefetch_threshold
).
__name__
}
"
)
if
not
isinstance
(
prefetch_timeout_base
,
(
int
,
float
)):
raise
ValueError
(
f
"prefetch_timeout_base must be number, got
{
type
(
prefetch_timeout_base
).
__name__
}
"
)
if
not
isinstance
(
prefetch_timeout_per_ki_token
,
(
int
,
float
)):
raise
ValueError
(
f
"prefetch_timeout_per_ki_token must be number, got
{
type
(
prefetch_timeout_per_ki_token
).
__name__
}
"
)
return
(
extra_config
,
prefetch_threshold
,
float
(
prefetch_timeout_base
),
float
(
prefetch_timeout_per_ki_token
),
)
def
reset
(
self
):
def
reset
(
self
):
TreeNode
.
counter
=
0
TreeNode
.
counter
=
0
self
.
cache_controller
.
reset
()
self
.
cache_controller
.
reset
()
...
@@ -490,6 +547,15 @@ class HiRadixCache(RadixCache):
...
@@ -490,6 +547,15 @@ class HiRadixCache(RadixCache):
host_indices
=
torch
.
cat
(
host_indices_list
,
dim
=
0
)
host_indices
=
torch
.
cat
(
host_indices_list
,
dim
=
0
)
cc
.
mem_pool_host
.
free
(
host_indices
)
cc
.
mem_pool_host
.
free
(
host_indices
)
# Timeout is linearly increasing with the number of pages
def
_prefetch_timeout_check_linear_func
(
self
,
operation
:
PrefetchOperation
):
# If hash_value has not been computed in timeout_base seconds, terminate it.
return
(
time
.
monotonic
()
-
operation
.
start_time
>
self
.
prefetch_timeout_base
+
len
(
operation
.
hash_value
)
*
self
.
prefetch_timeout_per_page
)
def
can_terminate_prefetch
(
self
,
operation
:
PrefetchOperation
):
def
can_terminate_prefetch
(
self
,
operation
:
PrefetchOperation
):
can_terminate
=
True
can_terminate
=
True
...
@@ -506,9 +572,7 @@ class HiRadixCache(RadixCache):
...
@@ -506,9 +572,7 @@ class HiRadixCache(RadixCache):
if
self
.
prefetch_stop_policy
==
"wait_complete"
:
if
self
.
prefetch_stop_policy
==
"wait_complete"
:
can_terminate
=
completed
can_terminate
=
completed
elif
self
.
prefetch_stop_policy
==
"timeout"
:
elif
self
.
prefetch_stop_policy
==
"timeout"
:
can_terminate
=
completed
or
(
can_terminate
=
completed
or
self
.
is_prefetch_timeout
(
operation
)
time
.
monotonic
()
-
operation
.
start_time
>
self
.
prefetch_timeout
)
else
:
else
:
# unknown prefetch stop policy, just return True
# unknown prefetch stop policy, just return True
return
True
return
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment