Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e2fd2b9c
Unverified
Commit
e2fd2b9c
authored
Aug 08, 2025
by
pansicheng
Committed by
GitHub
Aug 08, 2025
Browse files
Simple prefetch policy (#8692)
parent
7490e3f6
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
150 additions
and
38 deletions
+150
-38
benchmark/hicache/bench_multiturn.py
benchmark/hicache/bench_multiturn.py
+29
-2
python/sglang/srt/managers/cache_controller.py
python/sglang/srt/managers/cache_controller.py
+53
-28
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+5
-1
python/sglang/srt/mem_cache/hiradix_cache.py
python/sglang/srt/mem_cache/hiradix_cache.py
+53
-5
python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py
python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py
+2
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
No files found.
benchmark/hicache/bench_multiturn.py
View file @
e2fd2b9c
...
...
@@ -20,6 +20,8 @@ from sglang.bench_serving import (
sample_random_requests
,
)
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
20
*
60
*
60
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
...
...
@@ -139,7 +141,7 @@ async def async_request_sglang_generate(
"""
Sends a streaming request to the server. Gathers text token-by-token.
"""
async
with
aiohttp
.
ClientSession
()
as
session
:
async
with
aiohttp
.
ClientSession
(
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
headers
=
{}
generated_text
=
""
ttft
=
0.0
...
...
@@ -150,6 +152,8 @@ async def async_request_sglang_generate(
try
:
async
with
session
.
post
(
url
=
url
,
json
=
payload
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
prompt_tokens
=
0
cached_tokens
=
0
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
...
...
@@ -168,6 +172,12 @@ async def async_request_sglang_generate(
if
ttft
==
0.0
:
ttft
=
time
.
perf_counter
()
-
st
output
.
ttft
=
ttft
prompt_tokens
=
(
data
.
get
(
"meta_info"
)
or
{}).
get
(
"prompt_tokens"
,
0
)
cached_tokens
=
(
data
.
get
(
"meta_info"
)
or
{}).
get
(
"cached_tokens"
,
0
)
# Decoding phase
else
:
...
...
@@ -179,6 +189,8 @@ async def async_request_sglang_generate(
output
.
generated_text
=
generated_text
output
.
success
=
True
output
.
latency
=
latency
output
.
prompt_len
=
prompt_tokens
output
.
cached_tokens
=
cached_tokens
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
...
...
@@ -201,6 +213,7 @@ def gen_payload(prompt, output_len):
"ignore_eos"
:
True
,
},
"stream"
:
True
,
"stream_options"
:
{
"include_usage"
:
True
},
"lora_path"
:
""
,
"return_logprob"
:
False
,
"logprob_start_len"
:
-
1
,
...
...
@@ -303,7 +316,12 @@ class WorkloadGenerator:
self
.
response_queue
=
queue
.
Queue
()
self
.
pbar
=
tqdm
(
total
=
args
.
num_clients
*
args
.
num_rounds
)
self
.
performance_metrics
=
{
"ttft"
:
[],
"latency"
:
[]}
self
.
performance_metrics
=
{
"ttft"
:
[],
"latency"
:
[],
"prompt_len"
:
[],
"cached_tokens"
:
[],
}
async
def
handle_request
(
self
,
item
):
try
:
...
...
@@ -360,6 +378,8 @@ class WorkloadGenerator:
self
.
client_records
[
client_id
][
"round"
]
+=
1
self
.
performance_metrics
[
"ttft"
].
append
(
response
.
ttft
)
self
.
performance_metrics
[
"latency"
].
append
(
response
.
latency
)
self
.
performance_metrics
[
"prompt_len"
].
append
(
response
.
prompt_len
)
self
.
performance_metrics
[
"cached_tokens"
].
append
(
response
.
cached_tokens
)
self
.
completed_requests
+=
1
if
self
.
client_records
[
client_id
][
"round"
]
<
args
.
num_rounds
:
...
...
@@ -416,6 +436,12 @@ class WorkloadGenerator:
len
(
self
.
performance_metrics
[
"latency"
])
//
2
],
"throughput"
:
self
.
pbar
.
total
/
(
self
.
finished_time
-
self
.
start_time
),
"cache_hit_rate"
:
(
0
if
sum
(
self
.
performance_metrics
[
"prompt_len"
])
==
0
else
sum
(
self
.
performance_metrics
[
"cached_tokens"
])
/
sum
(
self
.
performance_metrics
[
"prompt_len"
])
),
},
}
print
(
"All requests completed"
)
...
...
@@ -434,6 +460,7 @@ class WorkloadGenerator:
print
(
f
" Throughput:
{
performance_data
[
'summary'
][
'throughput'
]:.
2
f
}
requests per second"
)
print
(
f
" Cache Hit Rate:
{
performance_data
[
'summary'
][
'cache_hit_rate'
]:.
6
f
}
"
)
log_to_jsonl_file
(
performance_data
,
args
.
log_file
,
tag
=
args
.
tag
)
...
...
python/sglang/srt/managers/cache_controller.py
View file @
e2fd2b9c
...
...
@@ -16,6 +16,7 @@ limitations under the License.
import
logging
import
math
import
threading
import
time
from
queue
import
Empty
,
Full
,
PriorityQueue
,
Queue
from
typing
import
TYPE_CHECKING
,
List
,
Optional
...
...
@@ -195,6 +196,8 @@ class PrefetchOperation(StorageOperation):
self
.
_done_flag
=
False
self
.
_lock
=
threading
.
Lock
()
self
.
start_time
=
time
.
monotonic
()
super
().
__init__
(
host_indices
,
token_ids
,
last_hash
)
def
increment
(
self
,
num_tokens
:
int
):
...
...
@@ -278,6 +281,12 @@ class HiCacheController:
self
.
enable_storage
=
True
# todo: threshold policy for prefetching
self
.
prefetch_threshold
=
max
(
prefetch_threshold
,
self
.
page_size
)
self
.
prefetch_capacity_limit
=
int
(
0.8
*
(
self
.
mem_pool_host
.
size
-
self
.
mem_pool_device
.
size
)
)
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
self
.
prefetch_tokens_occupied
=
0
# create a new communication group for synchronizing storage operations across TP workers
self
.
tp_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
tp_group
)
if
self
.
tp_world_size
>
1
:
...
...
@@ -525,7 +534,7 @@ class HiCacheController:
host_indices
:
torch
.
Tensor
,
new_input_tokens
:
List
[
int
],
last_hash
:
Optional
[
str
]
=
None
,
)
->
int
:
)
->
PrefetchOperation
:
"""
Prefetch KV caches from storage backend to host memory.
"""
...
...
@@ -586,11 +595,23 @@ class HiCacheController:
operation
=
self
.
prefetch_buffer
.
get
(
block
=
True
,
timeout
=
1
)
if
self
.
is_mooncake_backend
():
self
.
mooncake_page_transfer
(
operation
)
elif
self
.
storage_backend_type
==
"hf3fs"
:
self
.
generic_page_transfer
(
operation
,
batch_size
=
128
)
else
:
self
.
generic_page_transfer
(
operation
)
except
Empty
:
continue
def
prefetch_rate_limit_check
(
self
)
->
bool
:
"""
Rate limit the prefetching operations to avoid overwhelming the storage backend.
"""
# cancel prefetch if too much memory is occupied
if
self
.
prefetch_tokens_occupied
>=
self
.
prefetch_capacity_limit
:
return
False
# todo: more sophisticated rate limiting based on storage backend performance
return
True
def
prefetch_thread_func
(
self
):
"""
Manage prefetching operations from storage backend to host memory.
...
...
@@ -604,34 +625,36 @@ class HiCacheController:
if
operation
is
None
:
continue
last_hash
=
operation
.
last_hash
tokens_to_fetch
=
operation
.
token_ids
storage_hit_count
=
0
remaining_tokens
=
len
(
tokens_to_fetch
)
hash_value
=
[]
while
remaining_tokens
>=
self
.
page_size
:
last_hash
=
self
.
get_hash_str
(
tokens_to_fetch
[
storage_hit_count
:
storage_hit_count
+
self
.
page_size
],
last_hash
,
)
# todo, more unified interface
if
not
self
.
is_mooncake_backend
():
if
not
self
.
storage_backend
.
exists
(
last_hash
):
break
hash_value
.
append
(
last_hash
)
storage_hit_count
+=
self
.
page_size
remaining_tokens
-=
self
.
page_size
if
self
.
is_mooncake_backend
():
# deferring to batch exists for mooncake store
exist_result
=
self
.
storage_backend
.
exists
(
hash_value
)
storage_hit_count
=
(
sum
(
1
for
v
in
exist_result
.
values
()
if
v
!=
0
)
*
self
.
page_size
)
if
self
.
prefetch_rate_limit_check
():
last_hash
=
operation
.
last_hash
tokens_to_fetch
=
operation
.
token_ids
remaining_tokens
=
len
(
tokens_to_fetch
)
hash_value
=
[]
while
remaining_tokens
>=
self
.
page_size
:
last_hash
=
self
.
get_hash_str
(
tokens_to_fetch
[
storage_hit_count
:
storage_hit_count
+
self
.
page_size
],
last_hash
,
)
# todo, more unified interface
if
not
self
.
is_mooncake_backend
():
if
not
self
.
storage_backend
.
exists
(
last_hash
):
break
hash_value
.
append
(
last_hash
)
storage_hit_count
+=
self
.
page_size
remaining_tokens
-=
self
.
page_size
if
self
.
is_mooncake_backend
():
# deferring to batch exists for mooncake store
exist_result
=
self
.
storage_backend
.
exists
(
hash_value
)
storage_hit_count
=
(
sum
(
1
for
v
in
exist_result
.
values
()
if
v
!=
0
)
*
self
.
page_size
)
if
self
.
tp_world_size
>
1
:
storage_hit_count_tensor
=
torch
.
tensor
(
...
...
@@ -750,6 +773,8 @@ class HiCacheController:
if
self
.
is_mooncake_backend
():
self
.
mooncake_page_backup
(
operation
)
elif
self
.
storage_backend_type
==
"hf3fs"
:
self
.
generic_page_backup
(
operation
,
batch_size
=
128
)
else
:
self
.
generic_page_backup
(
operation
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
e2fd2b9c
...
...
@@ -619,6 +619,7 @@ class Scheduler(
),
hicache_mem_layout
=
server_args
.
hicache_mem_layout
,
hicache_storage_backend
=
server_args
.
hicache_storage_backend
,
hicache_storage_prefetch_policy
=
server_args
.
hicache_storage_prefetch_policy
,
)
self
.
tp_worker
.
register_hicache_layer_transfer_counter
(
self
.
tree_cache
.
cache_controller
.
layer_done_counter
...
...
@@ -1572,7 +1573,10 @@ class Scheduler(
break
if
self
.
enable_hicache_storage
:
self
.
tree_cache
.
check_prefetch_progress
(
req
.
rid
)
prefetch_done
=
self
.
tree_cache
.
check_prefetch_progress
(
req
.
rid
)
if
not
prefetch_done
:
# skip staging requests that are ongoing prefetch
continue
req
.
init_next_round_input
(
self
.
tree_cache
)
res
=
adder
.
add_one_req
(
req
,
has_chunked_req
=
(
self
.
chunked_req
is
not
None
))
...
...
python/sglang/srt/mem_cache/hiradix_cache.py
View file @
e2fd2b9c
...
...
@@ -2,11 +2,12 @@ import heapq
import
logging
import
threading
import
time
from
queue
import
Queue
from
typing
import
List
,
Optional
import
torch
from
sglang.srt.managers.cache_controller
import
HiCacheController
from
sglang.srt.managers.cache_controller
import
HiCacheController
,
PrefetchOperation
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
MatchResult
from
sglang.srt.mem_cache.memory_pool
import
(
...
...
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
hicache_io_backend
:
str
,
hicache_mem_layout
:
str
,
hicache_storage_backend
:
Optional
[
str
]
=
None
,
hicache_storage_prefetch_policy
:
Optional
[
str
]
=
"best_effort"
,
):
if
hicache_io_backend
==
"direct"
:
...
...
@@ -85,6 +87,13 @@ class HiRadixCache(RadixCache):
prefetch_threshold
=
self
.
prefetch_threshold
,
)
self
.
prefetch_stop_policy
=
hicache_storage_prefetch_policy
# todo: customizable storage prefetch timeout
self
.
prefetch_timeout
=
3
# seconds
logger
.
info
(
f
"HiCache storage prefetch policy:
{
hicache_storage_prefetch_policy
}
"
)
# record the nodes with ongoing write through
self
.
ongoing_write_through
=
{}
# record the node segments with ongoing load back
...
...
@@ -385,9 +394,10 @@ class HiRadixCache(RadixCache):
for
_
in
range
(
queue_size
.
item
()):
req_id
=
self
.
cache_controller
.
prefetch_revoke_queue
.
get
()
if
req_id
in
self
.
ongoing_prefetch
:
last_host_node
,
_
,
_
,
_
=
self
.
ongoing_prefetch
[
req_id
]
last_host_node
,
token_ids
,
_
,
_
=
self
.
ongoing_prefetch
[
req_id
]
last_host_node
.
release_host
()
del
self
.
ongoing_prefetch
[
req_id
]
self
.
cache_controller
.
prefetch_tokens_occupied
-=
len
(
token_ids
)
else
:
# the revoked operation already got terminated
pass
...
...
@@ -419,10 +429,41 @@ class HiRadixCache(RadixCache):
host_node
.
release_host
()
del
self
.
ongoing_backup
[
ack_id
]
def
check_prefetch_progress
(
self
,
req_id
:
str
):
def
can_terminate_prefetch
(
self
,
operation
:
PrefetchOperation
):
can_terminate
=
True
if
self
.
prefetch_stop_policy
==
"best_effort"
:
return
can_terminate
completed
=
(
operation
.
completed_tokens
==
len
(
operation
.
hash_value
)
*
self
.
page_size
)
if
self
.
prefetch_stop_policy
==
"wait_complete"
:
can_terminate
=
completed
elif
self
.
prefetch_stop_policy
==
"timeout"
:
can_terminate
=
completed
or
(
time
.
monotonic
()
-
operation
.
start_time
>
self
.
prefetch_timeout
)
else
:
# unknown prefetch stop policy, just return True
return
True
if
self
.
tp_world_size
>
1
:
can_terminate
=
torch
.
tensor
(
can_terminate
,
dtype
=
torch
.
int
)
torch
.
distributed
.
all_reduce
(
can_terminate
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
tp_group
,
)
can_terminate
=
bool
(
can_terminate
.
item
())
return
can_terminate
def
check_prefetch_progress
(
self
,
req_id
:
str
)
->
bool
:
if
req_id
not
in
self
.
ongoing_prefetch
:
# there is no ongoing prefetch for this request or it has been revoked
return
return
True
# todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over
...
...
@@ -430,13 +471,16 @@ class HiRadixCache(RadixCache):
req_id
]
if
not
self
.
can_terminate_prefetch
(
operation
):
return
False
completed_tokens
,
hash_value
=
self
.
cache_controller
.
terminate_prefetch
(
operation
)
logger
.
debug
(
f
"Prefetch
{
req_id
}
completed with
{
completed_tokens
}
tokens"
)
min_completed_tokens
=
completed_tokens
if
self
.
tp_world_size
>
1
:
if
self
.
tp_world_size
>
1
and
self
.
prefetch_stop_policy
!=
"wait_complete"
:
# synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor
=
torch
.
tensor
(
min_completed_tokens
,
dtype
=
torch
.
int
...
...
@@ -464,6 +508,9 @@ class HiRadixCache(RadixCache):
)
last_host_node
.
release_host
()
del
self
.
ongoing_prefetch
[
req_id
]
self
.
cache_controller
.
prefetch_tokens_occupied
-=
len
(
token_ids
)
return
True
def
match_prefix
(
self
,
key
:
List
[
int
],
**
kwargs
):
empty_value
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
...
...
@@ -531,6 +578,7 @@ class HiRadixCache(RadixCache):
host_indices
,
operation
,
)
self
.
cache_controller
.
prefetch_tokens_occupied
+=
len
(
new_input_tokens
)
def
_insert_helper_host
(
self
,
node
:
TreeNode
,
key
:
List
,
host_value
,
hash_value
):
node
.
last_access_time
=
time
.
monotonic
()
...
...
python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py
View file @
e2fd2b9c
...
...
@@ -96,6 +96,8 @@ class Hf3fsClient:
)
self
.
iov_r
=
make_iovec
(
self
.
shm_r
,
self
.
hf3fs_mount_point
)
self
.
iov_w
=
make_iovec
(
self
.
shm_w
,
self
.
hf3fs_mount_point
)
self
.
shm_r
.
unlink
()
self
.
shm_w
.
unlink
()
self
.
rlock
=
threading
.
RLock
()
self
.
wlock
=
threading
.
RLock
()
...
...
@@ -176,8 +178,6 @@ class Hf3fsClient:
del
self
.
iov_w
self
.
shm_r
.
close
()
self
.
shm_w
.
close
()
self
.
shm_r
.
unlink
()
self
.
shm_w
.
unlink
()
def
flush
(
self
)
->
None
:
os
.
fsync
(
self
.
file
)
python/sglang/srt/server_args.py
View file @
e2fd2b9c
...
...
@@ -203,6 +203,7 @@ class ServerArgs:
hicache_io_backend
:
str
=
"kernel"
hicache_mem_layout
:
str
=
"layer_first"
hicache_storage_backend
:
Optional
[
str
]
=
None
hicache_storage_prefetch_policy
:
str
=
"best_effort"
# Double Sparsity
enable_double_sparsity
:
bool
=
False
...
...
@@ -1626,6 +1627,13 @@ class ServerArgs:
default
=
ServerArgs
.
hicache_storage_backend
,
help
=
"The storage backend for hierarchical KV cache."
,
)
parser
.
add_argument
(
"--hicache-storage-prefetch-policy"
,
type
=
str
,
choices
=
[
"best_effort"
,
"wait_complete"
,
"timeout"
],
default
=
ServerArgs
.
hicache_storage_prefetch_policy
,
help
=
"Control when prefetching from the storage backend should stop."
,
)
# Double Sparsity
parser
.
add_argument
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment