Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6b634493
Unverified
Commit
6b634493
authored
Nov 02, 2025
by
hzh0425
Committed by
GitHub
Nov 01, 2025
Browse files
[HICache / PD]: Support offloading incremental KV cache in decode side. (#11966)
parent
756ad9ce
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
71 additions
and
28 deletions
+71
-28
python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py
...lang/srt/disaggregation/decode_kvcache_offload_manager.py
+71
-28
No files found.
python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py
View file @
6b634493
...
@@ -60,6 +60,7 @@ class DecodeKVCacheOffloadManager:
...
@@ -60,6 +60,7 @@ class DecodeKVCacheOffloadManager:
self
.
tp_group
=
tp_group
self
.
tp_group
=
tp_group
self
.
tp_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
self
.
tp_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
self
.
tp_group
)
self
.
cache_controller
=
HiCacheController
(
self
.
cache_controller
=
HiCacheController
(
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
token_to_kv_pool_allocator
=
self
.
token_to_kv_pool_allocator
,
mem_pool_host
=
self
.
decode_host_mem_pool
,
mem_pool_host
=
self
.
decode_host_mem_pool
,
...
@@ -77,41 +78,59 @@ class DecodeKVCacheOffloadManager:
...
@@ -77,41 +78,59 @@ class DecodeKVCacheOffloadManager:
logger
.
info
(
"Enable offload kv cache for decode side"
)
logger
.
info
(
"Enable offload kv cache for decode side"
)
def
offload_kv_cache
(
self
,
req
)
->
bool
:
def
offload_kv_cache
(
self
,
req
)
->
bool
:
"""Offload
a finished request's KV cache to storag
e."""
"""Offload
incremental KV cache for decode sid
e."""
if
self
.
cache_controller
is
None
or
self
.
decode_host_mem_pool
is
None
:
if
self
.
cache_controller
is
None
or
self
.
decode_host_mem_pool
is
None
:
return
False
return
False
if
req
.
req_pool_idx
==
-
1
:
if
req
.
req_pool_idx
==
-
1
or
len
(
req
.
output_ids
)
==
0
:
return
False
return
False
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
]
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
]
if
token_indices
.
dim
()
==
0
or
token_indices
.
numel
()
==
0
:
if
token_indices
.
dim
()
==
0
or
token_indices
.
numel
()
==
0
:
logger
.
debug
(
f
"Request
{
req
.
rid
}
has invalid token_indices:
{
token_indices
}
"
)
return
False
return
False
tokens
=
req
.
origin_input_ids
+
req
.
output_ids
# Prefill side offloads page-aligned origin_input_ids, decode side offloads the incremental part
aligned_len
=
(
len
(
tokens
)
//
self
.
page_size
)
*
self
.
page_size
all_tokens
=
req
.
origin_input_ids
+
req
.
output_ids
[:
-
1
]
if
aligned_len
==
0
:
prefill_offloaded_len
=
(
len
(
req
.
origin_input_ids
)
//
self
.
page_size
*
self
.
page_size
)
incremental_len
=
len
(
all_tokens
)
-
prefill_offloaded_len
incremental_aligned_len
=
incremental_len
//
self
.
page_size
*
self
.
page_size
if
incremental_aligned_len
==
0
:
return
False
return
False
token_indices
=
token_indices
[:
aligned_len
]
# Extract incremental tokens and indices
tokens
=
tokens
[:
aligned_len
]
start
,
end
=
(
prefill_offloaded_len
,
prefill_offloaded_len
+
incremental_aligned_len
,
)
incremental_tokens
=
all_tokens
[
start
:
end
]
incremental_indices
=
token_indices
[
start
:
end
]
# Early free prefill-offloaded GPU memory
if
prefill_offloaded_len
>
0
:
self
.
token_to_kv_pool_allocator
.
free
(
token_indices
[:
prefill_offloaded_len
])
# Asynchronously offload KV cache from device to host
by cache controller
# Asynchronously offload
incremental
KV cache from device to host
self
.
request_counter
+=
1
self
.
request_counter
+=
1
ack_id
=
self
.
request_counter
ack_id
=
self
.
request_counter
host_indices
=
self
.
cache_controller
.
write
(
host_indices
=
self
.
cache_controller
.
write
(
device_indices
=
token
_indices
.
long
(),
device_indices
=
incremental
_indices
.
long
(),
node_id
=
ack_id
,
node_id
=
ack_id
,
)
)
if
host_indices
is
None
:
if
host_indices
is
None
:
logger
.
error
(
f
"Not enough host memory for request
{
req
.
rid
}
"
)
logger
.
error
(
f
"Not enough host memory for request
{
req
.
rid
}
"
)
return
False
return
False
self
.
ongoing_offload
[
ack_id
]
=
(
req
,
host_indices
,
tokens
,
time
.
time
())
self
.
ongoing_offload
[
ack_id
]
=
(
req
,
host_indices
,
incremental_tokens
,
time
.
time
(),
prefill_offloaded_len
,
)
return
True
return
True
def
check_offload_progress
(
self
):
def
check_offload_progress
(
self
):
...
@@ -140,14 +159,33 @@ class DecodeKVCacheOffloadManager:
...
@@ -140,14 +159,33 @@ class DecodeKVCacheOffloadManager:
_
,
finish_event
,
ack_list
=
self
.
cache_controller
.
ack_write_queue
.
pop
(
0
)
_
,
finish_event
,
ack_list
=
self
.
cache_controller
.
ack_write_queue
.
pop
(
0
)
finish_event
.
synchronize
()
finish_event
.
synchronize
()
for
ack_id
in
ack_list
:
for
ack_id
in
ack_list
:
req
,
host_indices
,
tokens
,
start_time
=
self
.
ongoing_offload
.
pop
(
ack_id
)
(
req
,
host_indices
,
incremental_tokens
,
start_time
,
prefill_offloaded_len
,
)
=
self
.
ongoing_offload
.
pop
(
ack_id
)
self
.
_release_finished_req
(
req
,
prefill_offloaded_len
)
self
.
_trigger_backup
(
req
,
host_indices
,
incremental_tokens
,
start_time
,
prefill_offloaded_len
,
)
finish_count
-=
1
# Release device
def
_release_finished_req
(
self
,
req
,
prefill_offloaded_len
):
self
.
tree_cache
.
cache_finished_req
(
req
)
kv_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req
.
req_pool_idx
,
:
len
(
req
.
origin_input_ids
)
+
max
(
len
(
req
.
output_ids
)
-
1
,
0
),
]
#
Trigger async backup from host to storage by cache controller
#
Free the incremental part of the request
self
.
_trigger_backup
(
req
.
rid
,
host_indices
,
tokens
,
start_time
)
self
.
token_to_kv_pool_allocator
.
free
(
kv_indices
[
prefill_offloaded_len
:]
)
finish_count
-=
1
self
.
req_to_token_pool
.
free
(
req
.
req_pool_idx
)
def
_check_backup_progress
(
self
,
finish_count
):
def
_check_backup_progress
(
self
,
finish_count
):
"""Check the progress of backup from host to storage."""
"""Check the progress of backup from host to storage."""
...
@@ -159,25 +197,30 @@ class DecodeKVCacheOffloadManager:
...
@@ -159,25 +197,30 @@ class DecodeKVCacheOffloadManager:
# Release host memory
# Release host memory
self
.
decode_host_mem_pool
.
free
(
host_indices
)
self
.
decode_host_mem_pool
.
free
(
host_indices
)
logger
.
debug
(
logger
.
info
(
f
"Finished backup request
{
req_id
}
, free host memory, len:
{
len
(
host_indices
)
}
, cost time:
{
time
.
time
()
-
start_time
:.
2
f
}
seconds."
f
"Finished backup request
{
req_id
}
, free host memory, len:
{
len
(
host_indices
)
}
, cost time:
{
time
.
time
()
-
start_time
:.
2
f
}
seconds."
)
)
def
_trigger_backup
(
self
,
req_id
,
host_indices
,
tokens
,
start_time
):
def
_trigger_backup
(
"""Trigger async backup from host to storage by cache controller."""
self
,
req
,
host_indices
,
incremental_tokens
,
start_time
,
prefill_offloaded_len
):
"""Trigger async backup from host to storage."""
prefill_hashes
=
self
.
_compute_prefix_hash
(
req
.
origin_input_ids
[:
prefill_offloaded_len
]
)
last_prefill_hash
=
prefill_hashes
[
-
1
]
if
prefill_offloaded_len
>
0
else
""
# Generate page hashes and write to storage
page_hashes
=
self
.
_compute_prefix_hash
(
incremental_tokens
,
last_prefill_hash
)
page_hashes
=
self
.
_compute_prefix_hash
(
tokens
)
ack_id
=
self
.
cache_controller
.
write_storage
(
ack_id
=
self
.
cache_controller
.
write_storage
(
host_indices
,
host_indices
,
tokens
,
incremental_
tokens
,
hash_value
=
page_hashes
,
hash_value
=
page_hashes
,
)
)
self
.
ongoing_backup
[
ack_id
]
=
(
req
_
id
,
host_indices
,
start_time
)
self
.
ongoing_backup
[
ack_id
]
=
(
req
.
r
id
,
host_indices
,
start_time
)
def
_compute_prefix_hash
(
self
,
tokens
):
def
_compute_prefix_hash
(
self
,
tokens
,
prior_hash
=
""
):
last_hash
=
""
page_hashes
=
[]
page_hashes
=
[]
last_hash
=
prior_hash
for
offset
in
range
(
0
,
len
(
tokens
),
self
.
page_size
):
for
offset
in
range
(
0
,
len
(
tokens
),
self
.
page_size
):
page_tokens
=
tokens
[
offset
:
offset
+
self
.
page_size
]
page_tokens
=
tokens
[
offset
:
offset
+
self
.
page_size
]
last_hash
=
self
.
cache_controller
.
get_hash_str
(
page_tokens
,
last_hash
)
last_hash
=
self
.
cache_controller
.
get_hash_str
(
page_tokens
,
last_hash
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment