Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
60cd878a
Unverified
Commit
60cd878a
authored
Apr 25, 2026
by
Or Ozeri
Committed by
GitHub
Apr 25, 2026
Browse files
[kv_offload+HMA][11/N]: Support store with multiple KV groups (#39403)
Signed-off-by:
Or Ozeri
<
oro@il.ibm.com
>
parent
1e9f19ca
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
87 additions
and
42 deletions
+87
-42
vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
...buted/kv_transfer/kv_connector/v1/offloading/scheduler.py
+87
-42
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
View file @
60cd878a
...
@@ -112,6 +112,13 @@ class RequestOffloadState:
...
@@ -112,6 +112,13 @@ class RequestOffloadState:
for
group_state
,
new_blocks
in
zip
(
self
.
group_states
,
new_block_id_groups
):
for
group_state
,
new_blocks
in
zip
(
self
.
group_states
,
new_block_id_groups
):
group_state
.
block_ids
.
extend
(
new_blocks
)
group_state
.
block_ids
.
extend
(
new_blocks
)
def
advance_stored_idx
(
self
,
num_offloadable_tokens
:
int
)
->
None
:
for
group_config
,
group_state
in
zip
(
self
.
config
.
kv_group_configs
,
self
.
group_states
):
num_blocks
=
num_offloadable_tokens
//
group_config
.
offloaded_block_size
group_state
.
next_stored_block_idx
=
num_blocks
class
OffloadingConnectorScheduler
:
class
OffloadingConnectorScheduler
:
"""Implementation of Scheduler side methods"""
"""Implementation of Scheduler side methods"""
...
@@ -367,16 +374,16 @@ class OffloadingConnectorScheduler:
...
@@ -367,16 +374,16 @@ class OffloadingConnectorScheduler:
if
self
.
_blocks_being_loaded
is
not
None
:
if
self
.
_blocks_being_loaded
is
not
None
:
self
.
_blocks_being_loaded
.
update
(
req_blocks_being_loaded
)
self
.
_blocks_being_loaded
.
update
(
req_blocks_being_loaded
)
def
_get_reqs_to_store
(
self
,
scheduler_output
:
SchedulerOutput
):
def
_get_reqs_to_store
(
# Below assertion will be removed once this function supports HMA
self
,
scheduler_output
:
SchedulerOutput
assert
len
(
self
.
config
.
kv_group_configs
)
==
1
)
->
dict
[
ReqId
,
TransferSpec
]:
group_config
=
self
.
config
.
kv_group_configs
[
0
]
block_size_factor
=
self
.
config
.
block_size_factor
reqs_to_store
:
dict
[
ReqId
,
TransferSpec
]
=
{}
reqs_to_store
:
dict
[
ReqId
,
TransferSpec
]
=
{}
# iterate over both new and cached requests
# iterate over both new and cached requests
for
req_id
,
new_block_id_groups
,
preempted
in
yield_req_data
(
scheduler_output
):
for
req_id
,
new_block_id_groups
,
preempted
in
yield_req_data
(
scheduler_output
):
req_status
=
self
.
_req_status
[
req_id
]
req_status
=
self
.
_req_status
[
req_id
]
req_status
.
update_offload_keys
()
req_status
.
update_offload_keys
()
req
=
req_status
.
req
if
preempted
:
if
preempted
:
for
group_state
in
req_status
.
group_states
:
for
group_state
in
req_status
.
group_states
:
...
@@ -385,68 +392,106 @@ class OffloadingConnectorScheduler:
...
@@ -385,68 +392,106 @@ class OffloadingConnectorScheduler:
if
new_block_id_groups
:
if
new_block_id_groups
:
req_status
.
update_block_id_groups
(
new_block_id_groups
)
req_status
.
update_block_id_groups
(
new_block_id_groups
)
# Below assertion will be removed once this function supports HMA
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
assert
len
(
req_status
.
group_states
)
==
1
num_tokens_after_batch
=
req
.
num_computed_tokens
+
num_scheduled_tokens
group_state
=
req_status
.
group_states
[
0
]
block_ids
=
group_state
.
block_ids
req
=
req_status
.
req
new_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
expected_tokens
=
req
.
num_computed_tokens
+
new_tokens
# with async scheduling, some tokens may be missing
# with async scheduling, some tokens may be missing
total_tokens
=
min
(
expected_tokens
,
req
.
num_tokens
)
num_offloadable_tokens
=
min
(
num_tokens_after_batch
,
req
.
num_tokens
)
num_blocks
=
total_tokens
//
group_config
.
offloaded_block_size
start_block_idx
=
group_state
.
next_stored_block_idx
num_new_blocks
=
num_blocks
-
start_block_idx
if
num_new_blocks
<=
0
:
# Filter out blocks skipped due to sliding window attention / SSM
new_offload_keys
:
list
[
OffloadKey
]
=
[]
for
group_config
,
group_state
in
zip
(
self
.
config
.
kv_group_configs
,
req_status
.
group_states
):
num_blocks
=
num_offloadable_tokens
//
group_config
.
offloaded_block_size
start_block_idx
=
group_state
.
next_stored_block_idx
if
num_blocks
<=
start_block_idx
:
continue
continue
offload_keys
=
group_state
.
offload_keys
[
start_block_idx
:
num_blocks
]
# For each block to offload, take the last corresponding GPU block.
# e.g. if block size factor is 3 and GPU block IDs are
# 1 5 6 7 2 4 9 3 8 then we'll take blocks 6 4 8.
# We will use these GPU blocks to determine if the block needs
# offloading, or (if the GPU block ID is 0) this block should
# be skipped due to sliding window attention / SSM.
# We know that if a block is skipped, then all the previous blocks
# are skipped as well. This is why we take the last of each block.
offload_block_ids
=
group_state
.
block_ids
[
start_block_idx
*
block_size_factor
+
block_size_factor
-
1
:
num_blocks
*
block_size_factor
:
block_size_factor
]
assert
len
(
offload_keys
)
==
len
(
offload_block_ids
)
num_gpu_blocks
=
num_blocks
*
self
.
config
.
block_size_factor
for
offload_key
,
block_id
in
zip
(
offload_keys
,
offload_block_ids
):
assert
len
(
req
.
block_hashes
)
>=
num_gpu_blocks
if
block_id
!=
0
:
new_offload_keys
.
append
(
offload_key
)
if
not
new_offload_keys
:
req_status
.
advance_stored_idx
(
num_offloadable_tokens
)
continue
new_offload_keys
=
group_state
.
offload_keys
[
start_block_idx
:
num_blocks
]
store_output
=
self
.
manager
.
prepare_store
(
store_output
=
self
.
manager
.
prepare_store
(
new_offload_keys
,
req_status
.
req_context
new_offload_keys
,
req_status
.
req_context
)
)
if
store_output
is
None
:
if
store_output
is
None
:
logger
.
warning
(
logger
.
warning
(
"Request %s: cannot store blocks"
,
req_id
)
"Request %s: cannot store %s blocks"
,
req_id
,
num_new_blocks
)
continue
continue
group_state
.
next_stored_block_idx
=
num_blocks
if
not
store_output
.
keys_to_store
:
if
not
store_output
.
keys_to_store
:
req_status
.
advance_stored_idx
(
num_offloadable_tokens
)
continue
continue
keys_to_store
=
set
(
store_output
.
keys_to_store
)
self
.
manager
.
touch
(
group_state
.
offload_keys
[:
num_blocks
])
for
group_state
in
req_status
.
group_states
:
self
.
manager
.
touch
(
group_state
.
offload_keys
)
dst_spec
=
store_output
.
store_spec
keys_to_store
=
set
(
store_output
.
keys_to_store
)
group_sizes
:
list
[
int
]
=
[]
block_indices
:
list
[
int
]
=
[]
src_block_ids
:
list
[
int
]
=
[]
src_block_ids
:
list
[
int
]
=
[]
for
idx
,
key
in
enumerate
(
new_offload_keys
):
for
group_config
,
group_state
in
zip
(
if
key
not
in
keys_to_store
:
self
.
config
.
kv_group_configs
,
req_status
.
group_states
):
num_blocks
=
num_offloadable_tokens
//
group_config
.
offloaded_block_size
start_block_idx
=
group_state
.
next_stored_block_idx
block_ids
=
group_state
.
block_ids
num_group_blocks
=
0
start_gpu_block_idx
:
int
|
None
=
None
for
idx
,
offload_key
in
enumerate
(
group_state
.
offload_keys
[
start_block_idx
:
num_blocks
]
):
if
offload_key
not
in
keys_to_store
:
continue
continue
offloaded_block_idx
=
start_block_idx
+
idx
offloaded_block_idx
=
start_block_idx
+
idx
gpu_block_idx
=
offloaded_block_idx
*
self
.
config
.
block_size_factor
gpu_block_idx
=
offloaded_block_idx
*
block_size_factor
for
i
in
range
(
self
.
config
.
block_size_factor
):
num_group_blocks
+=
block_size_factor
src_block_ids
.
append
(
block_ids
[
gpu_block_idx
+
i
])
for
i
in
range
(
block_size_factor
):
block_id
=
block_ids
[
gpu_block_idx
+
i
]
if
block_id
==
0
:
# skipped blocks cannot appear after non-skipped blocks
assert
start_gpu_block_idx
is
None
continue
elif
start_gpu_block_idx
is
None
:
start_gpu_block_idx
=
gpu_block_idx
+
i
src_block_ids
.
append
(
block_id
)
group_sizes
.
append
(
num_group_blocks
)
block_indices
.
append
(
start_gpu_block_idx
or
0
)
group_state
.
next_stored_block_idx
=
num_blocks
src_spec
=
GPULoadStoreSpec
(
src_spec
=
GPULoadStoreSpec
(
src_block_ids
,
src_block_ids
,
group_sizes
=
group_sizes
,
block_indices
=
block_indices
group_sizes
=
(
len
(
src_block_ids
),),
block_indices
=
(
0
,),
)
)
dst_spec
=
store_output
.
store_spec
reqs_to_store
[
req_id
]
=
(
src_spec
,
dst_spec
)
reqs_to_store
[
req_id
]
=
(
src_spec
,
dst_spec
)
self
.
_reqs_being_stored
[
req_id
]
|=
keys_to_store
self
.
_reqs_being_stored
[
req_id
]
|=
keys_to_store
logger
.
debug
(
logger
.
debug
(
"Request %s offloading %s blocks
starting from block #%d
"
,
"Request %s offloading %s blocks
upto %d tokens
"
,
req_id
,
req_id
,
len
(
keys_to_store
),
len
(
keys_to_store
),
start_block_idx
,
num_offloadable_tokens
,
)
)
return
reqs_to_store
return
reqs_to_store
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment