Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7013e9ac
Unverified
Commit
7013e9ac
authored
Jan 21, 2026
by
Or Ozeri
Committed by
GitHub
Jan 21, 2026
Browse files
OffloadingConnector: Prevent redundant loads (#29087)
Signed-off-by:
Or Ozeri
<
oro@il.ibm.com
>
parent
c78ee240
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
109 additions
and
8 deletions
+109
-8
tests/v1/kv_connector/unit/test_offloading_connector.py
tests/v1/kv_connector/unit/test_offloading_connector.py
+67
-1
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
...buted/kv_transfer/kv_connector/v1/offloading_connector.py
+36
-3
vllm/v1/kv_offload/abstract.py
vllm/v1/kv_offload/abstract.py
+4
-2
vllm/v1/kv_offload/arc_manager.py
vllm/v1/kv_offload/arc_manager.py
+1
-1
vllm/v1/kv_offload/lru_manager.py
vllm/v1/kv_offload/lru_manager.py
+1
-1
No files found.
tests/v1/kv_connector/unit/test_offloading_connector.py
View file @
7013e9ac
...
@@ -213,7 +213,6 @@ class RequestRunner:
...
@@ -213,7 +213,6 @@ class RequestRunner:
)
)
def
new_request
(
self
,
token_ids
:
list
[
int
]):
def
new_request
(
self
,
token_ids
:
list
[
int
]):
assert
not
self
.
scheduler
.
requests
self
.
req_id
+=
1
self
.
req_id
+=
1
req
=
Request
(
req
=
Request
(
...
@@ -338,11 +337,20 @@ class RequestRunner:
...
@@ -338,11 +337,20 @@ class RequestRunner:
token_id
=
token_id
or
0
,
token_id
=
token_id
or
0
,
)
)
prev_token_id
=
token_id
if
self
.
scheduler
.
running
:
if
self
.
scheduler
.
running
:
token_id
=
next
(
tokens_iter
,
None
)
token_id
=
next
(
tokens_iter
,
None
)
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_runner_output
)
if
(
prev_token_id
is
EOS_TOKEN_ID
and
prev_token_id
!=
token_id
and
self
.
scheduler
.
requests
):
# continue for one more step to allow offloading to kick off
continue
if
token_id
is
None
:
if
token_id
is
None
:
break
break
...
@@ -651,3 +659,61 @@ def test_request_preemption(request_runner):
...
@@ -651,3 +659,61 @@ def test_request_preemption(request_runner):
decoded_tokens
=
[
EOS_TOKEN_ID
],
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
(
9
,
10
,
11
),
expected_stored_gpu_block_indexes
=
(
9
,
10
,
11
),
)
)
def
test_concurrent_lookups_of_the_same_prefix
(
request_runner
):
offloaded_block_size
=
12
gpu_block_size
=
4
num_gpu_blocks
=
100
runner
=
request_runner
(
offloaded_block_size
=
offloaded_block_size
,
gpu_block_size
=
gpu_block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
)
# store 1 blocks
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
(
block_hashes
)
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
(
0
,
1
,
2
),
)
# start a request to load the first block, but don't complete
runner
.
scheduler
.
reset_prefix_cache
()
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
lookup
.
return_value
=
1
runner
.
run
(
decoded_tokens
=
[],
complete_transfers
=
False
,
)
# request triggered a load
transfer_jobs
=
list
(
runner
.
offloading_spec
.
handler
.
transfer_specs
)
assert
transfer_jobs
# start a new request to load the same first block
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
lookup
.
return_value
=
1
runner
.
run
(
decoded_tokens
=
[],
complete_transfers
=
False
,
)
# request did not trigger a load
assert
transfer_jobs
==
list
(
runner
.
offloading_spec
.
handler
.
transfer_specs
)
# complete transfers
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
block_hashes
:
generate_store_output
([])
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_loaded_gpu_block_indexes
=
(
0
,
1
,
2
),
)
# second request will use the GPU prefix cache
assert
transfer_jobs
==
list
(
runner
.
offloading_spec
.
handler
.
transfer_specs
)
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
View file @
7013e9ac
...
@@ -107,7 +107,7 @@ class OffloadingConnector(KVConnectorBase_V1):
...
@@ -107,7 +107,7 @@ class OffloadingConnector(KVConnectorBase_V1):
def
get_num_new_matched_tokens
(
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
num_computed_tokens
:
int
self
,
request
:
"Request"
,
num_computed_tokens
:
int
)
->
tuple
[
int
,
bool
]:
)
->
tuple
[
int
|
None
,
bool
]:
assert
self
.
connector_scheduler
is
not
None
assert
self
.
connector_scheduler
is
not
None
return
self
.
connector_scheduler
.
get_num_new_matched_tokens
(
return
self
.
connector_scheduler
.
get_num_new_matched_tokens
(
request
,
num_computed_tokens
request
,
num_computed_tokens
...
@@ -161,6 +161,11 @@ class OffloadingConnectorScheduler:
...
@@ -161,6 +161,11 @@ class OffloadingConnectorScheduler:
# request blocks are stored in order
# request blocks are stored in order
# index of next block (of size offloaded_block_size) to offload
# index of next block (of size offloaded_block_size) to offload
self
.
_next_stored_block_idx
:
dict
[
ReqId
,
int
]
=
{}
self
.
_next_stored_block_idx
:
dict
[
ReqId
,
int
]
=
{}
# if GPU prefix caching is enabled,
# track loaded blocks to avoid redundant loads
self
.
_blocks_being_loaded
:
set
[
BlockHash
]
|
None
=
(
set
()
if
spec
.
vllm_config
.
cache_config
.
enable_prefix_caching
else
None
)
# request ID -> set(block hashes being stored/load)
# request ID -> set(block hashes being stored/load)
self
.
_reqs_being_stored
=
defaultdict
[
ReqId
,
set
[
BlockHash
]](
set
)
self
.
_reqs_being_stored
=
defaultdict
[
ReqId
,
set
[
BlockHash
]](
set
)
...
@@ -181,7 +186,7 @@ class OffloadingConnectorScheduler:
...
@@ -181,7 +186,7 @@ class OffloadingConnectorScheduler:
def
get_num_new_matched_tokens
(
def
get_num_new_matched_tokens
(
self
,
request
:
Request
,
num_computed_tokens
:
int
self
,
request
:
Request
,
num_computed_tokens
:
int
)
->
tuple
[
int
,
bool
]:
)
->
tuple
[
int
|
None
,
bool
]:
"""
"""
Get number of new tokens that can be loaded beyond the
Get number of new tokens that can be loaded beyond the
num_computed_tokens.
num_computed_tokens.
...
@@ -195,6 +200,9 @@ class OffloadingConnectorScheduler:
...
@@ -195,6 +200,9 @@ class OffloadingConnectorScheduler:
A tuple with the following elements:
A tuple with the following elements:
- The number of tokens that can be loaded beyond what is
- The number of tokens that can be loaded beyond what is
already computed.
already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if tokens will be loaded asynchronously
- `True` if tokens will be loaded asynchronously
(between scheduler steps).
(between scheduler steps).
"""
"""
...
@@ -214,6 +222,9 @@ class OffloadingConnectorScheduler:
...
@@ -214,6 +222,9 @@ class OffloadingConnectorScheduler:
hits
=
self
.
manager
.
lookup
(
hits
=
self
.
manager
.
lookup
(
self
.
_get_block_hashes
(
request
,
start_idx
=
start_block_idx
)
self
.
_get_block_hashes
(
request
,
start_idx
=
start_block_idx
)
)
)
if
hits
is
None
:
# indicates a lookup that should be tried later
return
None
,
False
if
hits
==
0
:
if
hits
==
0
:
return
0
,
False
return
0
,
False
...
@@ -229,6 +240,22 @@ class OffloadingConnectorScheduler:
...
@@ -229,6 +240,22 @@ class OffloadingConnectorScheduler:
if
num_hit_tokens
<
self
.
offloaded_block_size
:
if
num_hit_tokens
<
self
.
offloaded_block_size
:
return
0
,
False
return
0
,
False
if
self
.
_blocks_being_loaded
:
block_hashes
=
self
.
_get_block_hashes
(
request
,
start_idx
=
start_block_idx
,
end_idx
=
start_block_idx
+
hits
)
if
any
(
block_hash
in
self
.
_blocks_being_loaded
for
block_hash
in
block_hashes
):
# hit blocks are being loaded, delay request
logger
.
debug
(
"Delaying request %s since some of its blocks are already"
" being loaded"
,
request
.
request_id
,
)
return
None
,
False
return
num_hit_tokens
,
True
return
num_hit_tokens
,
True
def
update_state_after_alloc
(
def
update_state_after_alloc
(
...
@@ -270,9 +297,13 @@ class OffloadingConnectorScheduler:
...
@@ -270,9 +297,13 @@ class OffloadingConnectorScheduler:
)
)
self
.
_reqs_to_load
[
request
.
request_id
]
=
(
src_spec
,
dst_spec
)
self
.
_reqs_to_load
[
request
.
request_id
]
=
(
src_spec
,
dst_spec
)
self
.
_reqs_being_loaded
[
request
.
request_id
].
update
(
block_hashes
)
req_blocks_being_loaded
=
self
.
_reqs_being_loaded
[
request
.
request_id
]
req_blocks_being_loaded
.
update
(
block_hashes
)
self
.
_next_stored_block_idx
[
request
.
request_id
]
=
num_blocks
self
.
_next_stored_block_idx
[
request
.
request_id
]
=
num_blocks
if
self
.
_blocks_being_loaded
is
not
None
:
self
.
_blocks_being_loaded
.
update
(
req_blocks_being_loaded
)
def
_get_reqs_to_store
(
self
,
scheduler_output
:
SchedulerOutput
):
def
_get_reqs_to_store
(
self
,
scheduler_output
:
SchedulerOutput
):
reqs_to_store
:
dict
[
ReqId
,
TransferSpec
]
=
{}
reqs_to_store
:
dict
[
ReqId
,
TransferSpec
]
=
{}
# iterate over both new and cached requests
# iterate over both new and cached requests
...
@@ -379,6 +410,8 @@ class OffloadingConnectorScheduler:
...
@@ -379,6 +410,8 @@ class OffloadingConnectorScheduler:
for
req_id
in
connector_output
.
finished_recving
or
[]:
for
req_id
in
connector_output
.
finished_recving
or
[]:
block_hashes
=
self
.
_reqs_being_loaded
.
pop
(
req_id
,
None
)
block_hashes
=
self
.
_reqs_being_loaded
.
pop
(
req_id
,
None
)
if
block_hashes
:
if
block_hashes
:
if
self
.
_blocks_being_loaded
:
self
.
_blocks_being_loaded
.
difference_update
(
block_hashes
)
self
.
manager
.
complete_load
(
block_hashes
)
self
.
manager
.
complete_load
(
block_hashes
)
def
request_finished
(
def
request_finished
(
...
...
vllm/v1/kv_offload/abstract.py
View file @
7013e9ac
...
@@ -68,7 +68,7 @@ class OffloadingEvent:
...
@@ -68,7 +68,7 @@ class OffloadingEvent:
class
OffloadingManager
(
ABC
):
class
OffloadingManager
(
ABC
):
@
abstractmethod
@
abstractmethod
def
lookup
(
self
,
block_hashes
:
Iterable
[
BlockHash
])
->
int
:
def
lookup
(
self
,
block_hashes
:
Iterable
[
BlockHash
])
->
int
|
None
:
"""
"""
Finds the length of the maximal series of blocks, starting from the
Finds the length of the maximal series of blocks, starting from the
first one, that are all offloaded.
first one, that are all offloaded.
...
@@ -78,7 +78,9 @@ class OffloadingManager(ABC):
...
@@ -78,7 +78,9 @@ class OffloadingManager(ABC):
Returns:
Returns:
An integer representing the maximal number of blocks that
An integer representing the maximal number of blocks that
are currently offloaded.
are currently offloaded, or None if the lookup should be retried
later. Returning None will delay the request handling by the vLLM
scheduler.
"""
"""
pass
pass
...
...
vllm/v1/kv_offload/arc_manager.py
View file @
7013e9ac
...
@@ -63,7 +63,7 @@ class ARCOffloadingManager(OffloadingManager):
...
@@ -63,7 +63,7 @@ class ARCOffloadingManager(OffloadingManager):
self
.
events
:
list
[
OffloadingEvent
]
|
None
=
[]
if
enable_events
else
None
self
.
events
:
list
[
OffloadingEvent
]
|
None
=
[]
if
enable_events
else
None
self
.
cache_capacity
:
int
=
self
.
backend
.
get_num_free_blocks
()
self
.
cache_capacity
:
int
=
self
.
backend
.
get_num_free_blocks
()
def
lookup
(
self
,
block_hashes
:
Iterable
[
BlockHash
])
->
int
:
def
lookup
(
self
,
block_hashes
:
Iterable
[
BlockHash
])
->
int
|
None
:
hit_count
=
0
hit_count
=
0
for
block_hash
in
block_hashes
:
for
block_hash
in
block_hashes
:
block
=
self
.
t1
.
get
(
block_hash
)
or
self
.
t2
.
get
(
block_hash
)
block
=
self
.
t1
.
get
(
block_hash
)
or
self
.
t2
.
get
(
block_hash
)
...
...
vllm/v1/kv_offload/lru_manager.py
View file @
7013e9ac
...
@@ -24,7 +24,7 @@ class LRUOffloadingManager(OffloadingManager):
...
@@ -24,7 +24,7 @@ class LRUOffloadingManager(OffloadingManager):
self
.
blocks
:
OrderedDict
[
BlockHash
,
BlockStatus
]
=
OrderedDict
()
self
.
blocks
:
OrderedDict
[
BlockHash
,
BlockStatus
]
=
OrderedDict
()
self
.
events
:
list
[
OffloadingEvent
]
|
None
=
[]
if
enable_events
else
None
self
.
events
:
list
[
OffloadingEvent
]
|
None
=
[]
if
enable_events
else
None
def
lookup
(
self
,
block_hashes
:
Iterable
[
BlockHash
])
->
int
:
def
lookup
(
self
,
block_hashes
:
Iterable
[
BlockHash
])
->
int
|
None
:
hit_count
=
0
hit_count
=
0
for
block_hash
in
block_hashes
:
for
block_hash
in
block_hashes
:
block
=
self
.
blocks
.
get
(
block_hash
)
block
=
self
.
blocks
.
get
(
block_hash
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment