Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4353c9cb
Unverified
Commit
4353c9cb
authored
Apr 19, 2026
by
omerpaz95
Committed by
GitHub
Apr 19, 2026
Browse files
[KV Offload] Pass request context (#39185)
Signed-off-by:
omerpaz95
<
omerpaz95@gmail.com
>
parent
4b7f5ea1
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
170 additions
and
81 deletions
+170
-81
tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py
.../kv_connector/unit/offloading_connector/test_scheduler.py
+39
-15
tests/v1/kv_connector/unit/offloading_connector/utils.py
tests/v1/kv_connector/unit/offloading_connector/utils.py
+1
-1
tests/v1/kv_offload/test_cpu_manager.py
tests/v1/kv_offload/test_cpu_manager.py
+68
-49
vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
...buted/kv_transfer/kv_connector/v1/offloading/scheduler.py
+11
-3
vllm/v1/kv_offload/abstract.py
vllm/v1/kv_offload/abstract.py
+24
-4
vllm/v1/kv_offload/cpu/manager.py
vllm/v1/kv_offload/cpu/manager.py
+16
-3
vllm/v1/kv_offload/reuse_manager.py
vllm/v1/kv_offload/reuse_manager.py
+11
-6
No files found.
tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py
View file @
4353c9cb
...
...
@@ -32,8 +32,8 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
# 3 blocks, store just the middle block (skip first and last)
# blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8]
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
*
3
)
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
:
generate_store_output
(
list
(
keys
)[
1
:
2
]
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
keys
,
req_context
:
generate_store_output
(
list
(
keys
)[
1
:
2
]
)
)
runner
.
run
(
decoded_tokens
=
[
0
])
...
...
@@ -45,18 +45,22 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
runner
.
manager
.
prepare_store
.
assert_not_called
()
# +1 token -> single block, fail prepare_store
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
:
None
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
,
req_context
:
None
runner
.
run
(
decoded_tokens
=
[
0
])
runner
.
manager
.
prepare_store
.
assert_called
()
# 1 more block (+ token for async scheduling)
# now set block_hashes_to_store = []
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
:
generate_store_output
([])
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
keys
,
req_context
:
generate_store_output
([])
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
offloaded_block_size
+
1
))
# 1 more block (+ token for kicking off offloading)
# now check touch was called with all 6 blocks
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
:
generate_store_output
(
keys
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
keys
,
req_context
:
generate_store_output
(
keys
)
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
offloaded_block_size
+
1
),
expected_stored_gpu_block_indexes
=
(
15
,
16
,
17
),
...
...
@@ -89,13 +93,17 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
runner
.
new_request
(
token_ids
=
[
0
]
*
gpu_block_size
+
[
1
]
*
(
offloaded_block_size
-
gpu_block_size
)
)
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
:
generate_store_output
([])
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
keys
,
req_context
:
generate_store_output
([])
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
])
runner
.
manager
.
lookup
.
assert_not_called
()
# single block lookup with no hits
runner
.
new_request
(
token_ids
=
[
1
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
:
generate_store_output
([])
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
keys
,
req_context
:
generate_store_output
([])
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
])
runner
.
manager
.
lookup
.
assert_called
()
assert
len
(
list
(
runner
.
manager
.
lookup
.
call_args
.
args
[
0
]))
==
1
...
...
@@ -103,7 +111,9 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
# single block lookup with a hit
runner
.
scheduler
.
reset_prefix_cache
()
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
:
generate_store_output
([])
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
keys
,
req_context
:
generate_store_output
([])
)
runner
.
manager
.
lookup
.
return_value
=
1
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_loaded_gpu_block_indexes
=
(
0
,
1
,
2
)
...
...
@@ -113,7 +123,9 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
*
2
+
[
1
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
:
generate_store_output
([])
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
keys
,
req_context
:
generate_store_output
([])
)
runner
.
manager
.
lookup
.
return_value
=
1
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_loaded_gpu_block_indexes
=
(
3
,
4
,
5
)
...
...
@@ -164,14 +176,18 @@ def test_request_preemption(request_runner, async_scheduling: bool):
# 2 blocks, store all, without flushing
# blocks = [0, 1, 2], [3, 4, 5]
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
*
2
)
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
:
generate_store_output
(
keys
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
keys
,
req_context
:
generate_store_output
(
keys
)
)
runner
.
run
(
decoded_tokens
=
[
0
],
complete_transfers
=
False
,
)
# decode 2 more blocks - 1 gpu block, storing [6, 7, 8] (no flush)
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
:
generate_store_output
(
keys
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
keys
,
req_context
:
generate_store_output
(
keys
)
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
(
2
*
offloaded_block_size
-
gpu_block_size
),
complete_transfers
=
False
,
...
...
@@ -195,7 +211,9 @@ def test_request_preemption(request_runner, async_scheduling: bool):
# request should now return from preemption
# re-load [0, ..., 8] from the CPU and store [9, 10, 11]
runner
.
manager
.
lookup
.
return_value
=
3
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
:
generate_store_output
(
keys
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
keys
,
req_context
:
generate_store_output
(
keys
)
)
runner
.
run
(
decoded_tokens
=
[
0
]
*
gpu_block_size
,
expected_loaded_gpu_block_indexes
=
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
),
...
...
@@ -222,7 +240,9 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling:
# store 1 blocks
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
:
generate_store_output
(
keys
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
keys
,
req_context
:
generate_store_output
(
keys
)
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
(
0
,
1
,
2
),
...
...
@@ -253,7 +273,9 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling:
assert
transfer_jobs
==
list
(
runner
.
offloading_spec
.
handler
.
transfer_specs
)
# complete transfers
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
:
generate_store_output
([])
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
keys
,
req_context
:
generate_store_output
([])
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_loaded_gpu_block_indexes
=
(
0
,
1
,
2
),
...
...
@@ -278,7 +300,9 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool):
# store 1 blocks
runner
.
new_request
(
token_ids
=
[
0
]
*
offloaded_block_size
)
runner
.
manager
.
prepare_store
.
side_effect
=
lambda
keys
:
generate_store_output
(
keys
)
runner
.
manager
.
prepare_store
.
side_effect
=
(
lambda
keys
,
req_context
:
generate_store_output
(
keys
)
)
runner
.
run
(
decoded_tokens
=
[
EOS_TOKEN_ID
],
expected_stored_gpu_block_indexes
=
(
0
,
1
,
2
),
...
...
tests/v1/kv_connector/unit/offloading_connector/utils.py
View file @
4353c9cb
...
...
@@ -115,7 +115,7 @@ class MockOffloadingSpec(OffloadingSpec):
self
.
manager
=
MagicMock
(
spec
=
OffloadingManager
)
self
.
manager
.
lookup
.
return_value
=
0
self
.
manager
.
prepare_load
=
lambda
keys
:
MockLoadStoreSpec
(
keys
)
self
.
manager
.
prepare_load
=
lambda
keys
,
req_context
:
MockLoadStoreSpec
(
keys
)
self
.
handler
=
MockOffloadingHandler
()
def
get_manager
(
self
)
->
OffloadingManager
:
...
...
tests/v1/kv_offload/test_cpu_manager.py
View file @
4353c9cb
...
...
@@ -11,6 +11,7 @@ from vllm.v1.kv_offload.abstract import (
OffloadingEvent
,
OffloadKey
,
PrepareStoreOutput
,
ReqContext
,
make_offload_key
,
)
from
vllm.v1.kv_offload.cpu.manager
import
CPUOffloadingManager
...
...
@@ -19,6 +20,14 @@ from vllm.v1.kv_offload.mediums import CPULoadStoreSpec
from
vllm.v1.kv_offload.reuse_manager
import
FilterReusedOffloadingManager
def
make_req_context
(
kv_transfer_params
:
dict
|
None
=
None
)
->
ReqContext
:
"""Create a ReqContext as production code would, from a request's params."""
return
ReqContext
(
kv_transfer_params
=
kv_transfer_params
)
_EMPTY_REQ_CTX
=
make_req_context
()
@
dataclass
class
ExpectedPrepareStoreOutput
:
keys_to_store
:
list
[
int
]
...
...
@@ -103,7 +112,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy):
)
# store [1, 2] and complete
manager
.
prepare_store
(
to_keys
([
1
,
2
]))
manager
.
prepare_store
(
to_keys
([
1
,
2
])
,
_EMPTY_REQ_CTX
)
manager
.
complete_store
(
to_keys
([
1
,
2
]))
# touch [1] to make block 2 the LRU candidate
...
...
@@ -113,7 +122,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy):
# - block 2 is already stored -> filtered out of keys_to_store
# - block 2 must NOT be evicted even though it is the LRU candidate
# - block 1 (ID 0) is evicted instead; new blocks [3,4,5] get IDs 2,3,0
prepare_store_output
=
manager
.
prepare_store
(
to_keys
([
2
,
3
,
4
,
5
]))
prepare_store_output
=
manager
.
prepare_store
(
to_keys
([
2
,
3
,
4
,
5
])
,
_EMPTY_REQ_CTX
)
verify_store_output
(
prepare_store_output
,
ExpectedPrepareStoreOutput
(
...
...
@@ -127,7 +136,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy):
manager
.
complete_store
(
to_keys
([
2
,
3
,
4
,
5
]))
# block 2 must still be present in the cache
assert
manager
.
lookup
(
to_keys
([
2
]))
==
1
assert
manager
.
lookup
(
to_keys
([
2
])
,
_EMPTY_REQ_CTX
)
==
1
def
test_cpu_manager
():
...
...
@@ -140,7 +149,7 @@ def test_cpu_manager():
)
# prepare store [1, 2]
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
]))
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
])
,
_EMPTY_REQ_CTX
)
verify_store_output
(
prepare_store_output
,
ExpectedPrepareStoreOutput
(
...
...
@@ -151,7 +160,7 @@ def test_cpu_manager():
)
# lookup [1, 2] -> not ready
assert
cpu_manager
.
lookup
(
to_keys
([
1
,
2
]))
==
0
assert
cpu_manager
.
lookup
(
to_keys
([
1
,
2
])
,
_EMPTY_REQ_CTX
)
==
0
# no events so far
assert
list
(
cpu_manager
.
take_events
())
==
[]
...
...
@@ -161,12 +170,14 @@ def test_cpu_manager():
verify_events
(
cpu_manager
.
take_events
(),
expected_stores
=
({
1
,
2
},))
# lookup [1, 2]
assert
cpu_manager
.
lookup
(
to_keys
([
1
]))
==
1
assert
cpu_manager
.
lookup
(
to_keys
([
1
,
2
]))
==
2
assert
cpu_manager
.
lookup
(
to_keys
([
1
,
2
,
3
]))
==
2
assert
cpu_manager
.
lookup
(
to_keys
([
1
])
,
_EMPTY_REQ_CTX
)
==
1
assert
cpu_manager
.
lookup
(
to_keys
([
1
,
2
])
,
_EMPTY_REQ_CTX
)
==
2
assert
cpu_manager
.
lookup
(
to_keys
([
1
,
2
,
3
])
,
_EMPTY_REQ_CTX
)
==
2
# prepare store [2, 3, 4, 5] -> evicts [1]
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
2
,
3
,
4
,
5
]))
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
2
,
3
,
4
,
5
]),
_EMPTY_REQ_CTX
)
verify_store_output
(
prepare_store_output
,
ExpectedPrepareStoreOutput
(
...
...
@@ -180,23 +191,23 @@ def test_cpu_manager():
verify_events
(
cpu_manager
.
take_events
(),
expected_evictions
=
({
1
},))
# prepare store with no space
assert
cpu_manager
.
prepare_store
(
to_keys
([
1
,
6
]))
is
None
assert
cpu_manager
.
prepare_store
(
to_keys
([
1
,
6
])
,
_EMPTY_REQ_CTX
)
is
None
# complete store [2, 3, 4, 5]
cpu_manager
.
complete_store
(
to_keys
([
2
,
3
,
4
,
5
]))
# prepare load [2, 3]
prepare_load_output
=
cpu_manager
.
prepare_load
(
to_keys
([
2
,
3
]))
prepare_load_output
=
cpu_manager
.
prepare_load
(
to_keys
([
2
,
3
])
,
_EMPTY_REQ_CTX
)
verify_load_output
(
prepare_load_output
,
[
1
,
2
])
# prepare store with no space ([2, 3] is being loaded)
assert
cpu_manager
.
prepare_store
(
to_keys
([
6
,
7
,
8
]))
is
None
assert
cpu_manager
.
prepare_store
(
to_keys
([
6
,
7
,
8
])
,
_EMPTY_REQ_CTX
)
is
None
# complete load [2, 3]
cpu_manager
.
complete_load
(
to_keys
([
2
,
3
]))
# prepare store [6, 7, 8] -> evicts [2, 3, 4] (oldest)
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
6
,
7
,
8
]))
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
6
,
7
,
8
])
,
_EMPTY_REQ_CTX
)
verify_store_output
(
prepare_store_output
,
ExpectedPrepareStoreOutput
(
...
...
@@ -213,7 +224,7 @@ def test_cpu_manager():
cpu_manager
.
touch
(
to_keys
([
5
,
6
,
7
]))
# prepare store [7, 9] -> evicts [8] (oldest following previous touch)
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
9
]))
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
9
])
,
_EMPTY_REQ_CTX
)
verify_store_output
(
prepare_store_output
,
ExpectedPrepareStoreOutput
(
...
...
@@ -227,8 +238,8 @@ def test_cpu_manager():
cpu_manager
.
complete_store
(
to_keys
([
7
,
9
]),
success
=
False
)
# assert [7] is still stored, but [9] is not
assert
cpu_manager
.
lookup
(
to_keys
([
7
]))
==
1
assert
cpu_manager
.
lookup
(
to_keys
([
9
]))
==
0
assert
cpu_manager
.
lookup
(
to_keys
([
7
])
,
_EMPTY_REQ_CTX
)
==
1
assert
cpu_manager
.
lookup
(
to_keys
([
9
])
,
_EMPTY_REQ_CTX
)
==
0
verify_events
(
cpu_manager
.
take_events
(),
...
...
@@ -260,7 +271,9 @@ class TestARCPolicy:
cpu_manager
,
arc_policy
=
self
.
_make_manager
()
# prepare store [1, 2]
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
]))
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
]),
_EMPTY_REQ_CTX
)
verify_store_output
(
prepare_store_output
,
ExpectedPrepareStoreOutput
(
...
...
@@ -271,7 +284,7 @@ class TestARCPolicy:
)
# lookup [1, 2] -> not ready
assert
cpu_manager
.
lookup
(
to_keys
([
1
,
2
]))
==
0
assert
cpu_manager
.
lookup
(
to_keys
([
1
,
2
])
,
_EMPTY_REQ_CTX
)
==
0
# no events so far
assert
list
(
cpu_manager
.
take_events
())
==
[]
...
...
@@ -281,9 +294,9 @@ class TestARCPolicy:
verify_events
(
cpu_manager
.
take_events
(),
expected_stores
=
({
1
,
2
},))
# lookup [1, 2]
assert
cpu_manager
.
lookup
(
to_keys
([
1
]))
==
1
assert
cpu_manager
.
lookup
(
to_keys
([
1
,
2
]))
==
2
assert
cpu_manager
.
lookup
(
to_keys
([
1
,
2
,
3
]))
==
2
assert
cpu_manager
.
lookup
(
to_keys
([
1
])
,
_EMPTY_REQ_CTX
)
==
1
assert
cpu_manager
.
lookup
(
to_keys
([
1
,
2
])
,
_EMPTY_REQ_CTX
)
==
2
assert
cpu_manager
.
lookup
(
to_keys
([
1
,
2
,
3
])
,
_EMPTY_REQ_CTX
)
==
2
# blocks should be in T1 (recent)
assert
len
(
arc_policy
.
t1
)
==
2
...
...
@@ -297,7 +310,7 @@ class TestARCPolicy:
cpu_manager
,
arc_policy
=
self
.
_make_manager
(
enable_events
=
False
)
# store and complete block 1
cpu_manager
.
prepare_store
(
to_keys
([
1
]))
cpu_manager
.
prepare_store
(
to_keys
([
1
])
,
_EMPTY_REQ_CTX
)
cpu_manager
.
complete_store
(
to_keys
([
1
]))
# block 1 starts in T1 (recent)
...
...
@@ -319,7 +332,9 @@ class TestARCPolicy:
cpu_manager
,
_
=
self
.
_make_manager
()
# prepare and complete store [1, 2, 3, 4]
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
,
3
,
4
]))
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
,
3
,
4
]),
_EMPTY_REQ_CTX
)
verify_store_output
(
prepare_store_output
,
ExpectedPrepareStoreOutput
(
...
...
@@ -331,19 +346,21 @@ class TestARCPolicy:
cpu_manager
.
complete_store
(
to_keys
([
1
,
2
,
3
,
4
]))
# prepare load [2, 3] (increases ref_cnt)
prepare_load_output
=
cpu_manager
.
prepare_load
(
to_keys
([
2
,
3
]))
prepare_load_output
=
cpu_manager
.
prepare_load
(
to_keys
([
2
,
3
])
,
_EMPTY_REQ_CTX
)
verify_load_output
(
prepare_load_output
,
[
1
,
2
])
# prepare store [5, 6, 7] with [2, 3] being loaded
# should fail because [2, 3] have ref_cnt > 0
assert
cpu_manager
.
prepare_store
(
to_keys
([
5
,
6
,
7
]))
is
None
assert
cpu_manager
.
prepare_store
(
to_keys
([
5
,
6
,
7
])
,
_EMPTY_REQ_CTX
)
is
None
# complete load [2, 3]
cpu_manager
.
complete_load
(
to_keys
([
2
,
3
]))
# now prepare store [5, 6, 7] should succeed
# ARC will evict blocks one at a time from T1 as needed
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
5
,
6
,
7
]))
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
5
,
6
,
7
]),
_EMPTY_REQ_CTX
)
assert
prepare_store_output
is
not
None
# Should successfully evict enough blocks to make room (at least 1)
assert
len
(
prepare_store_output
.
evicted_keys
)
>=
1
...
...
@@ -357,13 +374,13 @@ class TestARCPolicy:
cpu_manager
,
arc_policy
=
self
.
_make_manager
(
num_blocks
=
2
,
enable_events
=
False
)
# store blocks 1, 2 (fills cache)
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
]))
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
])
,
_EMPTY_REQ_CTX
)
cpu_manager
.
complete_store
(
to_keys
([
1
,
2
]))
initial_target
=
arc_policy
.
target_t1_size
# store block 3, evicting block 1 (moves to B1 ghost list)
cpu_manager
.
prepare_store
(
to_keys
([
3
]))
cpu_manager
.
prepare_store
(
to_keys
([
3
])
,
_EMPTY_REQ_CTX
)
cpu_manager
.
complete_store
(
to_keys
([
3
]))
# block 1 should be in B1 (ghost list)
...
...
@@ -384,7 +401,7 @@ class TestARCPolicy:
cpu_manager
,
arc_policy
=
self
.
_make_manager
(
enable_events
=
False
)
# store blocks 1, 2, 3, 4
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
,
3
,
4
]))
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
,
3
,
4
])
,
_EMPTY_REQ_CTX
)
cpu_manager
.
complete_store
(
to_keys
([
1
,
2
,
3
,
4
]))
# promote blocks 3, 4 to T2 by touching them
...
...
@@ -399,7 +416,7 @@ class TestARCPolicy:
arc_policy
.
target_t1_size
=
1
# store block 5, should evict from T1 (block 1, LRU in T1)
output
=
cpu_manager
.
prepare_store
(
to_keys
([
5
]))
output
=
cpu_manager
.
prepare_store
(
to_keys
([
5
])
,
_EMPTY_REQ_CTX
)
assert
output
is
not
None
assert
to_keys
([
1
])
==
output
.
evicted_keys
...
...
@@ -418,12 +435,12 @@ class TestARCPolicy:
cpu_manager
,
arc_policy
=
self
.
_make_manager
(
num_blocks
=
2
,
enable_events
=
False
)
# fill cache with blocks 1, 2
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
]))
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
])
,
_EMPTY_REQ_CTX
)
cpu_manager
.
complete_store
(
to_keys
([
1
,
2
]))
# store many blocks to fill ghost lists
for
i
in
range
(
3
,
20
):
cpu_manager
.
prepare_store
(
to_keys
([
i
]))
cpu_manager
.
prepare_store
(
to_keys
([
i
])
,
_EMPTY_REQ_CTX
)
cpu_manager
.
complete_store
(
to_keys
([
i
]))
# ghost lists should not exceed cache_capacity
...
...
@@ -438,7 +455,7 @@ class TestARCPolicy:
cpu_manager
,
arc_policy
=
self
.
_make_manager
()
# store blocks 1, 2, 3, 4
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
,
3
,
4
]))
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
,
3
,
4
])
,
_EMPTY_REQ_CTX
)
cpu_manager
.
complete_store
(
to_keys
([
1
,
2
,
3
,
4
]))
# promote 3, 4 to T2
...
...
@@ -453,7 +470,7 @@ class TestARCPolicy:
assert
len
(
arc_policy
.
t2
)
==
3
# store block 5, should evict from T1 (block 2, only one in T1)
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
5
]))
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
5
])
,
_EMPTY_REQ_CTX
)
verify_store_output
(
prepare_store_output
,
ExpectedPrepareStoreOutput
(
...
...
@@ -471,11 +488,11 @@ class TestARCPolicy:
cpu_manager
,
arc_policy
=
self
.
_make_manager
()
# store blocks 1, 2, 3, 4
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
,
3
,
4
]))
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
,
3
,
4
])
,
_EMPTY_REQ_CTX
)
cpu_manager
.
complete_store
(
to_keys
([
1
,
2
,
3
,
4
]))
# prepare store block 5 (will evict block 1)
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
5
]))
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
5
])
,
_EMPTY_REQ_CTX
)
assert
prepare_store_output
is
not
None
assert
len
(
prepare_store_output
.
evicted_keys
)
==
1
...
...
@@ -483,7 +500,7 @@ class TestARCPolicy:
cpu_manager
.
complete_store
(
to_keys
([
5
]),
success
=
False
)
# block 5 should not be in cache
assert
cpu_manager
.
lookup
(
to_keys
([
5
]))
==
0
assert
cpu_manager
.
lookup
(
to_keys
([
5
])
,
_EMPTY_REQ_CTX
)
==
0
# block 5 should not be in T1 or T2
assert
to_keys
([
5
])[
0
]
not
in
arc_policy
.
t1
assert
to_keys
([
5
])[
0
]
not
in
arc_policy
.
t2
...
...
@@ -500,11 +517,13 @@ class TestARCPolicy:
cpu_manager
,
arc_policy
=
self
.
_make_manager
()
# store [1, 2]
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
]))
cpu_manager
.
prepare_store
(
to_keys
([
1
,
2
])
,
_EMPTY_REQ_CTX
)
cpu_manager
.
complete_store
(
to_keys
([
1
,
2
]))
# store [3, 4, 5] -> evicts [1]
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
3
,
4
,
5
]))
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
3
,
4
,
5
]),
_EMPTY_REQ_CTX
)
assert
prepare_store_output
is
not
None
assert
len
(
prepare_store_output
.
evicted_keys
)
==
1
cpu_manager
.
complete_store
(
to_keys
([
3
,
4
,
5
]))
...
...
@@ -517,13 +536,13 @@ class TestARCPolicy:
assert
len
(
arc_policy
.
t2
)
==
2
# store [6] -> should evict from T1 (4 is oldest in T1)
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
6
]))
prepare_store_output
=
cpu_manager
.
prepare_store
(
to_keys
([
6
])
,
_EMPTY_REQ_CTX
)
assert
prepare_store_output
is
not
None
cpu_manager
.
complete_store
(
to_keys
([
6
]))
# verify blocks 2, 3 (in T2) are still present
assert
cpu_manager
.
lookup
(
to_keys
([
2
]))
==
1
assert
cpu_manager
.
lookup
(
to_keys
([
3
]))
==
1
assert
cpu_manager
.
lookup
(
to_keys
([
2
])
,
_EMPTY_REQ_CTX
)
==
1
assert
cpu_manager
.
lookup
(
to_keys
([
3
])
,
_EMPTY_REQ_CTX
)
==
1
# verify events
events
=
list
(
cpu_manager
.
take_events
())
...
...
@@ -543,34 +562,34 @@ def test_filter_reused_manager():
)
# Lookup [1, 2] -> 1st time, added to tracker but not eligible for store yet
assert
manager
.
lookup
(
to_keys
([
1
,
2
]))
==
0
assert
manager
.
lookup
(
to_keys
([
1
,
2
])
,
_EMPTY_REQ_CTX
)
==
0
# prepare store [1, 2] -> should be filtered
prepare_store_output
=
manager
.
prepare_store
(
to_keys
([
1
,
2
]))
prepare_store_output
=
manager
.
prepare_store
(
to_keys
([
1
,
2
])
,
_EMPTY_REQ_CTX
)
assert
prepare_store_output
is
not
None
assert
prepare_store_output
.
keys_to_store
==
[]
# Lookup [1] -> 2nd time, eligible now
assert
manager
.
lookup
(
to_keys
([
1
]))
==
0
assert
manager
.
lookup
(
to_keys
([
1
])
,
_EMPTY_REQ_CTX
)
==
0
# prepare store [1, 2] -> [1] should be eligible, [2] should be filtered
prepare_store_output
=
manager
.
prepare_store
(
to_keys
([
1
,
2
]))
prepare_store_output
=
manager
.
prepare_store
(
to_keys
([
1
,
2
])
,
_EMPTY_REQ_CTX
)
assert
prepare_store_output
is
not
None
assert
prepare_store_output
.
keys_to_store
==
to_keys
([
1
])
# Lookup [3, 4] -> 1st time
# (evicts [2] from tracker since max_size is 3 and tracker has [1])
assert
manager
.
lookup
(
to_keys
([
3
,
4
]))
==
0
assert
manager
.
lookup
(
to_keys
([
3
,
4
])
,
_EMPTY_REQ_CTX
)
==
0
# Verify [2] was evicted from the tracker (tracker now has: [1], [3], [4])
assert
to_keys
([
2
])[
0
]
not
in
manager
.
counts
# Lookup [2] again -> (this adds [2] back to the tracker as 1st time)
assert
manager
.
lookup
(
to_keys
([
2
]))
==
0
assert
manager
.
lookup
(
to_keys
([
2
])
,
_EMPTY_REQ_CTX
)
==
0
# Verify [2] was re-added with count=1 (not eligible yet)
assert
manager
.
counts
.
get
(
to_keys
([
2
])[
0
])
==
1
# prepare store [2] -> should still be filtered out since count was reset
prepare_store_output
=
manager
.
prepare_store
(
to_keys
([
2
]))
prepare_store_output
=
manager
.
prepare_store
(
to_keys
([
2
])
,
_EMPTY_REQ_CTX
)
assert
prepare_store_output
is
not
None
assert
prepare_store_output
.
keys_to_store
==
[]
...
...
vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
View file @
4353c9cb
...
...
@@ -19,6 +19,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from
vllm.v1.kv_offload.abstract
import
(
OffloadingManager
,
OffloadKey
,
ReqContext
,
get_offload_block_hash
,
make_offload_key
,
)
...
...
@@ -74,6 +75,7 @@ class RequestOffloadState:
config
:
SchedulerOffloadConfig
req
:
Request
group_states
:
tuple
[
RequestGroupState
,
...]
=
field
(
init
=
False
)
req_context
:
ReqContext
=
field
(
init
=
False
)
# number of hits in the GPU cache
num_locally_computed_tokens
:
int
=
0
...
...
@@ -81,6 +83,7 @@ class RequestOffloadState:
self
.
group_states
=
tuple
(
RequestGroupState
()
for
_
in
self
.
config
.
kv_group_configs
)
self
.
req_context
=
ReqContext
(
kv_transfer_params
=
self
.
req
.
kv_transfer_params
)
def
update_offload_keys
(
self
)
->
None
:
for
group_config
,
group_state
in
zip
(
...
...
@@ -181,7 +184,10 @@ class OffloadingConnectorScheduler:
return
0
,
False
start_block_idx
=
num_computed_tokens
//
group_config
.
offloaded_block_size
hits
=
self
.
manager
.
lookup
(
offload_keys
[
start_block_idx
:])
hits
=
self
.
manager
.
lookup
(
offload_keys
[
start_block_idx
:],
req_status
.
req_context
,
)
if
hits
is
None
:
# indicates a lookup that should be tried later
return
None
,
False
...
...
@@ -249,7 +255,7 @@ class OffloadingConnectorScheduler:
assert
len
(
request
.
block_hashes
)
//
self
.
config
.
block_size_factor
>=
num_blocks
offload_keys
=
group_state
.
offload_keys
[
start_block_idx
:
num_blocks
]
src_spec
=
self
.
manager
.
prepare_load
(
offload_keys
)
src_spec
=
self
.
manager
.
prepare_load
(
offload_keys
,
req_status
.
req_context
)
dst_spec
=
GPULoadStoreSpec
(
block_ids
[
num_computed_gpu_blocks
:],
group_sizes
=
(
num_pending_gpu_blocks
,),
...
...
@@ -304,7 +310,9 @@ class OffloadingConnectorScheduler:
assert
len
(
req
.
block_hashes
)
>=
num_gpu_blocks
new_offload_keys
=
group_state
.
offload_keys
[
start_block_idx
:
num_blocks
]
store_output
=
self
.
manager
.
prepare_store
(
new_offload_keys
)
store_output
=
self
.
manager
.
prepare_store
(
new_offload_keys
,
req_status
.
req_context
)
if
store_output
is
None
:
logger
.
warning
(
"Request %s: cannot store %s blocks"
,
req_id
,
num_new_blocks
...
...
vllm/v1/kv_offload/abstract.py
View file @
4353c9cb
...
...
@@ -30,7 +30,7 @@ The class provides the following primitives:
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
typing
import
NewType
from
typing
import
Any
,
NewType
# `OffloadKey` identifies an offloaded block. It combines a block hash with
# its KV cache group index, encoded as raw bytes to avoid tuple GC overhead.
...
...
@@ -53,6 +53,11 @@ def get_offload_group_idx(key: OffloadKey) -> int:
return
int
.
from_bytes
(
key
[
-
4
:],
"big"
,
signed
=
False
)
@
dataclass
class
ReqContext
:
kv_transfer_params
:
dict
[
str
,
Any
]
|
None
=
None
class
LoadStoreSpec
(
ABC
):
"""
Abstract metadata that encapsulates information allowing a worker
...
...
@@ -86,13 +91,18 @@ class OffloadingEvent:
class
OffloadingManager
(
ABC
):
@
abstractmethod
def
lookup
(
self
,
keys
:
Iterable
[
OffloadKey
])
->
int
|
None
:
def
lookup
(
self
,
keys
:
Iterable
[
OffloadKey
],
req_context
:
ReqContext
,
)
->
int
|
None
:
"""
Finds the length of the maximal series of blocks, starting from the
first one, that are all offloaded.
Args:
keys: the keys identifying the blocks to lookup.
req_context: per-request context (e.g. kv_transfer_params).
Returns:
An integer representing the maximal number of blocks that
...
...
@@ -103,7 +113,11 @@ class OffloadingManager(ABC):
pass
@
abstractmethod
def
prepare_load
(
self
,
keys
:
Iterable
[
OffloadKey
])
->
LoadStoreSpec
:
def
prepare_load
(
self
,
keys
:
Iterable
[
OffloadKey
],
req_context
:
ReqContext
,
)
->
LoadStoreSpec
:
"""
Prepare the given blocks to be read.
The given blocks will be protected from eviction until
...
...
@@ -112,6 +126,7 @@ class OffloadingManager(ABC):
Args:
keys: the keys identifying the blocks.
req_context: per-request context (e.g. kv_transfer_params).
Returns:
A LoadStoreSpec that can be used by a worker to locate and load
...
...
@@ -139,7 +154,11 @@ class OffloadingManager(ABC):
return
@
abstractmethod
def
prepare_store
(
self
,
keys
:
Iterable
[
OffloadKey
])
->
PrepareStoreOutput
|
None
:
def
prepare_store
(
self
,
keys
:
Iterable
[
OffloadKey
],
req_context
:
ReqContext
,
)
->
PrepareStoreOutput
|
None
:
"""
Prepare the given blocks to be offloaded.
The given blocks will be protected from eviction until
...
...
@@ -147,6 +166,7 @@ class OffloadingManager(ABC):
Args:
keys: the keys identifying the blocks.
req_context: per-request context (e.g. kv_transfer_params).
Returns:
A PrepareStoreOutput indicating which blocks need storing,
...
...
vllm/v1/kv_offload/cpu/manager.py
View file @
4353c9cb
...
...
@@ -9,6 +9,7 @@ from vllm.v1.kv_offload.abstract import (
OffloadingManager
,
OffloadKey
,
PrepareStoreOutput
,
ReqContext
,
)
from
vllm.v1.kv_offload.cpu.policies.abstract
import
BlockStatus
,
CachePolicy
from
vllm.v1.kv_offload.cpu.policies.arc
import
ARCCachePolicy
...
...
@@ -83,7 +84,11 @@ class CPUOffloadingManager(OffloadingManager):
# --- OffloadingManager interface ---
def
lookup
(
self
,
keys
:
Iterable
[
OffloadKey
])
->
int
|
None
:
def
lookup
(
self
,
keys
:
Iterable
[
OffloadKey
],
req_context
:
ReqContext
,
)
->
int
|
None
:
hit_count
=
0
for
key
in
keys
:
block
=
self
.
_policy
.
get
(
key
)
...
...
@@ -92,7 +97,11 @@ class CPUOffloadingManager(OffloadingManager):
hit_count
+=
1
return
hit_count
def
prepare_load
(
self
,
keys
:
Iterable
[
OffloadKey
])
->
LoadStoreSpec
:
def
prepare_load
(
self
,
keys
:
Iterable
[
OffloadKey
],
req_context
:
ReqContext
,
)
->
LoadStoreSpec
:
blocks
=
[]
for
key
in
keys
:
block
=
self
.
_policy
.
get
(
key
)
...
...
@@ -112,7 +121,11 @@ class CPUOffloadingManager(OffloadingManager):
assert
block
.
ref_cnt
>
0
,
f
"Block
{
key
!
r
}
ref_cnt is already 0"
block
.
ref_cnt
-=
1
def
prepare_store
(
self
,
keys
:
Iterable
[
OffloadKey
])
->
PrepareStoreOutput
|
None
:
def
prepare_store
(
self
,
keys
:
Iterable
[
OffloadKey
],
req_context
:
ReqContext
,
)
->
PrepareStoreOutput
|
None
:
keys_list
=
list
(
keys
)
# filter out blocks that are already stored
...
...
vllm/v1/kv_offload/reuse_manager.py
View file @
4353c9cb
...
...
@@ -16,6 +16,7 @@ from vllm.v1.kv_offload.abstract import (
OffloadingManager
,
OffloadKey
,
PrepareStoreOutput
,
ReqContext
,
)
...
...
@@ -65,7 +66,7 @@ class FilterReusedOffloadingManager(OffloadingManager):
# Intercepted methods
# ------------------------------------------------------------------
def
lookup
(
self
,
keys
:
Iterable
[
OffloadKey
])
->
int
|
None
:
def
lookup
(
self
,
keys
:
Iterable
[
OffloadKey
]
,
req_context
:
ReqContext
)
->
int
|
None
:
"""Record each key, then delegate lookup to backing manager."""
keys
=
list
(
keys
)
for
key
in
keys
:
...
...
@@ -76,9 +77,11 @@ class FilterReusedOffloadingManager(OffloadingManager):
if
len
(
self
.
counts
)
>=
self
.
max_tracker_size
:
self
.
counts
.
popitem
(
last
=
False
)
# evict LRU
self
.
counts
[
key
]
=
1
return
self
.
_backing
.
lookup
(
keys
)
return
self
.
_backing
.
lookup
(
keys
,
req_context
)
def
prepare_store
(
self
,
keys
:
Iterable
[
OffloadKey
])
->
PrepareStoreOutput
|
None
:
def
prepare_store
(
self
,
keys
:
Iterable
[
OffloadKey
],
req_context
:
ReqContext
)
->
PrepareStoreOutput
|
None
:
"""Filter out blocks below threshold, then delegate to backing.
Filtering is evaluated *before* calling the backing manager's
...
...
@@ -93,14 +96,16 @@ class FilterReusedOffloadingManager(OffloadingManager):
# Passing an empty list is intentional and safe — CPUOffloadingManager
# handles it correctly, returning a PrepareStoreOutput with empty lists.
# Delegate to the backing manager with only the eligible keys.
return
self
.
_backing
.
prepare_store
(
eligible
)
return
self
.
_backing
.
prepare_store
(
eligible
,
req_context
)
# ------------------------------------------------------------------
# Delegated methods
# ------------------------------------------------------------------
def
prepare_load
(
self
,
keys
:
Iterable
[
OffloadKey
])
->
LoadStoreSpec
:
return
self
.
_backing
.
prepare_load
(
keys
)
def
prepare_load
(
self
,
keys
:
Iterable
[
OffloadKey
],
req_context
:
ReqContext
)
->
LoadStoreSpec
:
return
self
.
_backing
.
prepare_load
(
keys
,
req_context
)
def
touch
(
self
,
keys
:
Iterable
[
OffloadKey
])
->
None
:
return
self
.
_backing
.
touch
(
keys
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment