Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5ef33ab2
Unverified
Commit
5ef33ab2
authored
Apr 23, 2026
by
Or Ozeri
Committed by
GitHub
Apr 23, 2026
Browse files
[kv_offload+HMA][10/N]: Support load with multiple KV groups (#39402)
Signed-off-by:
Or Ozeri
<
oro@il.ibm.com
>
parent
1c2c1eb8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
29 deletions
+51
-29
vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
...buted/kv_transfer/kv_connector/v1/offloading/scheduler.py
+51
-29
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
View file @
5ef33ab2
...
...
@@ -14,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
ReqId
,
)
from
vllm.logger
import
init_logger
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_offload.abstract
import
(
...
...
@@ -271,45 +272,66 @@ class OffloadingConnectorScheduler:
return
req_status
=
self
.
_req_status
[
request
.
request_id
]
block_groups
=
blocks
.
get_block_ids
()
# Below assertions will be removed once this function supports HMA
assert
len
(
self
.
config
.
kv_group_configs
)
==
1
assert
len
(
req_status
.
group_states
)
==
1
assert
len
(
block_groups
)
==
1
block_ids
=
block_groups
[
0
]
group_config
=
self
.
config
.
kv_group_configs
[
0
]
group_state
=
req_status
.
group_states
[
0
]
num_computed_gpu_blocks
=
sum
(
block
.
block_hash
is
not
None
for
block
in
blocks
.
blocks
[
0
]
)
num_computed_tokens
=
num_computed_gpu_blocks
*
group_config
.
gpu_block_size
full_block_tokens
=
num_computed_tokens
+
num_external_tokens
assert
full_block_tokens
%
group_config
.
offloaded_block_size
==
0
num_locally_computed_tokens
=
req_status
.
num_locally_computed_tokens
num_cached_tokens
=
num_locally_computed_tokens
+
num_external_tokens
keys_to_load
:
list
[
OffloadKey
]
=
[]
dst_block_ids
:
list
[
int
]
=
[]
# per group
group_sizes
:
list
[
int
]
=
[]
block_indices
:
list
[
int
]
=
[]
for
group_config
,
group_state
,
group_blocks
in
zip
(
self
.
config
.
kv_group_configs
,
req_status
.
group_states
,
blocks
.
blocks
,
):
gpu_block_size
=
group_config
.
gpu_block_size
offloaded_block_size
=
group_config
.
offloaded_block_size
offload_keys
=
group_state
.
offload_keys
num_gpu_blocks
=
cdiv
(
num_cached_tokens
,
gpu_block_size
)
assert
len
(
group_blocks
)
>=
num_gpu_blocks
num_locally_computed_gpu_blocks
=
num_gpu_blocks
# Skip null placeholder blocks (used for sliding window or mamba padding).
for
i
,
block
in
enumerate
(
group_blocks
[:
num_gpu_blocks
]):
if
not
block
.
is_null
and
block
.
block_hash
is
None
:
num_locally_computed_gpu_blocks
=
i
break
assert
(
num_locally_computed_tokens
<=
num_locally_computed_gpu_blocks
*
gpu_block_size
)
num_pending_gpu_blocks
=
num_gpu_blocks
-
num_locally_computed_gpu_blocks
num_pending_gpu_blocks
=
len
(
block_ids
)
-
num_computed_gpu_blocks
assert
(
num_external_tokens
==
num_pending_gpu_blocks
*
group_config
.
gpu_block_size
)
num_blocks
=
cdiv
(
num_cached_tokens
,
offloaded_block_size
)
assert
len
(
offload_keys
)
>=
num_blocks
if
num_pending_gpu_blocks
:
start_block_idx
=
(
num_locally_computed_gpu_blocks
//
self
.
config
.
block_size_factor
)
keys_to_load
.
extend
(
offload_keys
[
start_block_idx
:
num_blocks
])
start_block_idx
=
num_computed_tokens
//
group_config
.
offloaded_block_size
num_blocks
=
full_block_tokens
//
group_config
.
offloaded_block_size
dst_block_ids
.
extend
(
block
.
block_id
for
block
in
group_blocks
[
num_locally_computed_gpu_blocks
:
num_gpu_blocks
]
)
group_sizes
.
append
(
num_pending_gpu_blocks
)
block_indices
.
append
(
num_locally_computed_gpu_blocks
)
assert
len
(
request
.
block_hashes
)
//
self
.
config
.
block_size_factor
>=
num_blocks
offload_keys
=
group_state
.
offload_keys
[
start_block_idx
:
num_blocks
]
group_state
.
next_stored_block_idx
=
num_blocks
src_spec
=
self
.
manager
.
prepare_load
(
offload_keys
,
req_status
.
req_context
)
src_spec
=
self
.
manager
.
prepare_load
(
keys_to_load
,
req_status
.
req_context
)
dst_spec
=
GPULoadStoreSpec
(
block_ids
[
num_computed_gpu_blocks
:],
group_sizes
=
(
num_pending_gpu_blocks
,),
block_indices
=
(
num_computed_gpu_blocks
,),
dst_block_ids
,
group_sizes
=
group_sizes
,
block_indices
=
block_indices
)
self
.
_reqs_to_load
[
request
.
request_id
]
=
(
src_spec
,
dst_spec
)
req_blocks_being_loaded
=
self
.
_reqs_being_loaded
[
request
.
request_id
]
req_blocks_being_loaded
.
update
(
offload_keys
)
group_state
.
next_stored_block_idx
=
num_blocks
req_blocks_being_loaded
.
update
(
keys_to_load
)
if
self
.
_blocks_being_loaded
is
not
None
:
self
.
_blocks_being_loaded
.
update
(
req_blocks_being_loaded
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment