Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
51adca74
Unverified
Commit
51adca74
authored
Apr 24, 2026
by
Or Ozeri
Committed by
GitHub
Apr 24, 2026
Browse files
[kv_offload+HMA][9/N]: Support lookup with multiple KV groups (#39401)
Signed-off-by:
Or Ozeri
<
oro@il.ibm.com
>
parent
e8eb0490
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
41 deletions
+72
-41
vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
...buted/kv_transfer/kv_connector/v1/offloading/scheduler.py
+72
-41
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
View file @
51adca74
...
@@ -120,6 +120,13 @@ class OffloadingConnectorScheduler:
...
@@ -120,6 +120,13 @@ class OffloadingConnectorScheduler:
self
.
config
=
SchedulerOffloadConfig
.
from_spec
(
spec
)
self
.
config
=
SchedulerOffloadConfig
.
from_spec
(
spec
)
self
.
manager
:
OffloadingManager
=
spec
.
get_manager
()
self
.
manager
:
OffloadingManager
=
spec
.
get_manager
()
attention_groups
:
list
[
int
]
=
[]
for
idx
,
_
in
enumerate
(
spec
.
kv_cache_config
.
kv_cache_groups
):
# currently treat all groups as full attention
attention_groups
.
append
(
idx
)
self
.
lookup_groups
=
attention_groups
self
.
_req_status
:
dict
[
ReqId
,
RequestOffloadState
]
=
{}
self
.
_req_status
:
dict
[
ReqId
,
RequestOffloadState
]
=
{}
# requests to load for the current scheduler step
# requests to load for the current scheduler step
self
.
_reqs_to_load
:
dict
[
ReqId
,
TransferSpec
]
=
{}
self
.
_reqs_to_load
:
dict
[
ReqId
,
TransferSpec
]
=
{}
...
@@ -204,64 +211,88 @@ class OffloadingConnectorScheduler:
...
@@ -204,64 +211,88 @@ class OffloadingConnectorScheduler:
group_state
.
block_ids
.
clear
()
group_state
.
block_ids
.
clear
()
else
:
else
:
req_status
=
RequestOffloadState
(
config
=
self
.
config
,
req
=
request
)
req_status
=
RequestOffloadState
(
config
=
self
.
config
,
req
=
request
)
req_status
.
update_offload_keys
()
self
.
_req_status
[
request
.
request_id
]
=
req_status
self
.
_req_status
[
request
.
request_id
]
=
req_status
req_status
.
update_offload_keys
()
req_status
.
num_locally_computed_tokens
=
num_computed_tokens
req_status
.
num_locally_computed_tokens
=
num_computed_tokens
# Below assertions will be removed once this function supports HMA
for
gs
in
req_status
.
group_states
:
assert
len
(
self
.
config
.
kv_group_configs
)
==
1
self
.
manager
.
touch
(
gs
.
offload_keys
)
assert
len
(
req_status
.
group_states
)
==
1
group_config
=
self
.
config
.
kv_group_configs
[
0
]
group_state
=
req_status
.
group_states
[
0
]
num_blocks
=
request
.
num_tokens
//
group_config
.
offloaded_block_size
# Start with the full request size as the maximum loadable
max_hit_size_tokens
:
int
=
req_status
.
req
.
num_tokens
num_hit_tokens
:
int
=
0
defer_lookup
=
False
delay_request
=
False
for
group_idx
in
self
.
lookup_groups
:
group_config
:
GroupOffloadConfig
=
self
.
config
.
kv_group_configs
[
group_idx
]
offloaded_block_size
=
group_config
.
offloaded_block_size
offload_keys
=
req_status
.
group_states
[
group_idx
].
offload_keys
assert
len
(
request
.
block_hashes
)
//
self
.
config
.
block_size_factor
==
num
_block
s
num_blocks
=
max_hit_size_tokens
//
offloaded
_block
_size
offload_keys
=
group_state
.
offload_key
s
assert
len
(
offload_keys
)
>=
num_block
s
self
.
manager
.
touch
(
offload_keys
)
# Constrain to block-aligned boundary for this group
max_hit_size_tokens
=
num_blocks
*
offloaded_block_size
num_hit_tokens
=
max_hit_size_tokens
-
num_computed_tokens
if
num_hit_tokens
<
offloaded_block_size
:
# we can only load less than a block, better skip
return
0
,
False
start_block_idx
=
num_computed_tokens
//
offloaded_block_size
offload_keys
=
offload_keys
[
start_block_idx
:
num_blocks
]
# Full attention relies on all previous KV cache blocks.
# Thus, we search for a maximal prefix of KV cache which are all cached.
block_hits
=
self
.
_maximal_prefix_lookup
(
offload_keys
,
req_status
.
req_context
)
if
block_hits
==
0
:
return
0
,
False
full_block_tokens
=
group_config
.
offloaded_block_size
*
num_blocks
if
block_hits
is
None
:
if
full_block_tokens
-
num_computed_tokens
<
group_config
.
offloaded_block_size
:
defer_lookup
=
True
# we can load less than a block, skip
else
:
return
0
,
False
# Further constrain based on what's actually available by backend
max_hit_size_tokens
=
offloaded_block_size
*
(
start_block_idx
+
block_hits
)
start_block_idx
=
num_computed_tokens
//
group_config
.
offloaded_block_size
num_hit_tokens
=
max_hit_size_tokens
-
num_computed_tokens
# Full attention relays on all previous KV cache blocks.
if
num_hit_tokens
<
offloaded_block_size
:
# Thus, we search for a maximal prefix of KV cache which are all cached.
# we can only load less than a block, better skip
hits
=
self
.
_maximal_prefix_lookup
(
return
0
,
False
offload_keys
[
start_block_idx
:],
req_status
.
req_context
)
if
(
if
hits
is
None
:
block_hits
# indicates a lookup that should be tried later
and
self
.
_blocks_being_loaded
and
any
(
key
in
self
.
_blocks_being_loaded
for
key
in
offload_keys
[:
block_hits
]
)
):
# hit blocks are being loaded, delay request
delay_request
=
True
if
defer_lookup
:
logger
.
debug
(
"Offloading manager delayed request %s as backend requested"
,
req_status
.
req
.
request_id
,
)
return
None
,
False
if
delay_request
:
logger
.
debug
(
"Delaying request %s since some of its blocks are already being loaded"
,
req_status
.
req
.
request_id
,
)
return
None
,
False
return
None
,
False
if
hits
==
0
:
return
0
,
False
num_hit_tokens
=
(
group_config
.
offloaded_block_size
*
(
start_block_idx
+
hits
)
-
num_computed_tokens
)
logger
.
debug
(
logger
.
debug
(
"Request %s hit %s offloaded tokens after %s GPU hit tokens"
,
"Request %s hit %s offloaded tokens after %s GPU hit tokens"
,
request
.
request_id
,
request
.
request_id
,
num_hit_tokens
,
num_hit_tokens
,
num_computed_tokens
,
num_computed_tokens
,
)
)
if
num_hit_tokens
<
group_config
.
offloaded_block_size
:
return
0
,
False
if
self
.
_blocks_being_loaded
and
any
(
key
in
self
.
_blocks_being_loaded
for
key
in
offload_keys
[
start_block_idx
:
start_block_idx
+
hits
]
):
# hit blocks are being loaded, delay request
logger
.
debug
(
"Delaying request %s since some of its blocks are already being loaded"
,
request
.
request_id
,
)
return
None
,
False
return
num_hit_tokens
,
True
return
num_hit_tokens
,
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment