Unverified Commit 4353c9cb authored by omerpaz95's avatar omerpaz95 Committed by GitHub
Browse files

[KV Offload] Pass request context (#39185)


Signed-off-by: default avataromerpaz95 <omerpaz95@gmail.com>
parent 4b7f5ea1
...@@ -32,8 +32,8 @@ def test_offloading_connector(request_runner, async_scheduling: bool): ...@@ -32,8 +32,8 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
# 3 blocks, store just the middle block (skip first and last) # 3 blocks, store just the middle block (skip first and last)
# blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8] # blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8]
runner.new_request(token_ids=[0] * offloaded_block_size * 3) runner.new_request(token_ids=[0] * offloaded_block_size * 3)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output( runner.manager.prepare_store.side_effect = (
list(keys)[1:2] lambda keys, req_context: generate_store_output(list(keys)[1:2])
) )
runner.run(decoded_tokens=[0]) runner.run(decoded_tokens=[0])
...@@ -45,18 +45,22 @@ def test_offloading_connector(request_runner, async_scheduling: bool): ...@@ -45,18 +45,22 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
runner.manager.prepare_store.assert_not_called() runner.manager.prepare_store.assert_not_called()
# +1 token -> single block, fail prepare_store # +1 token -> single block, fail prepare_store
runner.manager.prepare_store.side_effect = lambda keys: None runner.manager.prepare_store.side_effect = lambda keys, req_context: None
runner.run(decoded_tokens=[0]) runner.run(decoded_tokens=[0])
runner.manager.prepare_store.assert_called() runner.manager.prepare_store.assert_called()
# 1 more block (+ token for async scheduling) # 1 more block (+ token for async scheduling)
# now set block_hashes_to_store = [] # now set block_hashes_to_store = []
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output([]) runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([])
)
runner.run(decoded_tokens=[0] * (offloaded_block_size + 1)) runner.run(decoded_tokens=[0] * (offloaded_block_size + 1))
# 1 more block (+ token for kicking off offloading) # 1 more block (+ token for kicking off offloading)
# now check touch was called with all 6 blocks # now check touch was called with all 6 blocks
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output(keys) runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
runner.run( runner.run(
decoded_tokens=[0] * (offloaded_block_size + 1), decoded_tokens=[0] * (offloaded_block_size + 1),
expected_stored_gpu_block_indexes=(15, 16, 17), expected_stored_gpu_block_indexes=(15, 16, 17),
...@@ -89,13 +93,17 @@ def test_offloading_connector(request_runner, async_scheduling: bool): ...@@ -89,13 +93,17 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
runner.new_request( runner.new_request(
token_ids=[0] * gpu_block_size + [1] * (offloaded_block_size - gpu_block_size) token_ids=[0] * gpu_block_size + [1] * (offloaded_block_size - gpu_block_size)
) )
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output([]) runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([])
)
runner.run(decoded_tokens=[EOS_TOKEN_ID]) runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.manager.lookup.assert_not_called() runner.manager.lookup.assert_not_called()
# single block lookup with no hits # single block lookup with no hits
runner.new_request(token_ids=[1] * offloaded_block_size) runner.new_request(token_ids=[1] * offloaded_block_size)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output([]) runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([])
)
runner.run(decoded_tokens=[EOS_TOKEN_ID]) runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.manager.lookup.assert_called() runner.manager.lookup.assert_called()
assert len(list(runner.manager.lookup.call_args.args[0])) == 1 assert len(list(runner.manager.lookup.call_args.args[0])) == 1
...@@ -103,7 +111,9 @@ def test_offloading_connector(request_runner, async_scheduling: bool): ...@@ -103,7 +111,9 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
# single block lookup with a hit # single block lookup with a hit
runner.scheduler.reset_prefix_cache() runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size) runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output([]) runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([])
)
runner.manager.lookup.return_value = 1 runner.manager.lookup.return_value = 1
runner.run( runner.run(
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2) decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2)
...@@ -113,7 +123,9 @@ def test_offloading_connector(request_runner, async_scheduling: bool): ...@@ -113,7 +123,9 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
runner.new_request( runner.new_request(
token_ids=[0] * offloaded_block_size * 2 + [1] * offloaded_block_size token_ids=[0] * offloaded_block_size * 2 + [1] * offloaded_block_size
) )
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output([]) runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([])
)
runner.manager.lookup.return_value = 1 runner.manager.lookup.return_value = 1
runner.run( runner.run(
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5) decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5)
...@@ -164,14 +176,18 @@ def test_request_preemption(request_runner, async_scheduling: bool): ...@@ -164,14 +176,18 @@ def test_request_preemption(request_runner, async_scheduling: bool):
# 2 blocks, store all, without flushing # 2 blocks, store all, without flushing
# blocks = [0, 1, 2], [3, 4, 5] # blocks = [0, 1, 2], [3, 4, 5]
runner.new_request(token_ids=[0] * offloaded_block_size * 2) runner.new_request(token_ids=[0] * offloaded_block_size * 2)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output(keys) runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
runner.run( runner.run(
decoded_tokens=[0], decoded_tokens=[0],
complete_transfers=False, complete_transfers=False,
) )
# decode 2 more blocks - 1 gpu block, storing [6, 7, 8] (no flush) # decode 2 more blocks - 1 gpu block, storing [6, 7, 8] (no flush)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output(keys) runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
runner.run( runner.run(
decoded_tokens=[0] * (2 * offloaded_block_size - gpu_block_size), decoded_tokens=[0] * (2 * offloaded_block_size - gpu_block_size),
complete_transfers=False, complete_transfers=False,
...@@ -195,7 +211,9 @@ def test_request_preemption(request_runner, async_scheduling: bool): ...@@ -195,7 +211,9 @@ def test_request_preemption(request_runner, async_scheduling: bool):
# request should now return from preemption # request should now return from preemption
# re-load [0, ..., 8] from the CPU and store [9, 10, 11] # re-load [0, ..., 8] from the CPU and store [9, 10, 11]
runner.manager.lookup.return_value = 3 runner.manager.lookup.return_value = 3
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output(keys) runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
runner.run( runner.run(
decoded_tokens=[0] * gpu_block_size, decoded_tokens=[0] * gpu_block_size,
expected_loaded_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8), expected_loaded_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
...@@ -222,7 +240,9 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling: ...@@ -222,7 +240,9 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling:
# store 1 blocks # store 1 blocks
runner.new_request(token_ids=[0] * offloaded_block_size) runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output(keys) runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
runner.run( runner.run(
decoded_tokens=[EOS_TOKEN_ID], decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(0, 1, 2), expected_stored_gpu_block_indexes=(0, 1, 2),
...@@ -253,7 +273,9 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling: ...@@ -253,7 +273,9 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling:
assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs) assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs)
# complete transfers # complete transfers
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output([]) runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([])
)
runner.run( runner.run(
decoded_tokens=[EOS_TOKEN_ID], decoded_tokens=[EOS_TOKEN_ID],
expected_loaded_gpu_block_indexes=(0, 1, 2), expected_loaded_gpu_block_indexes=(0, 1, 2),
...@@ -278,7 +300,9 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool): ...@@ -278,7 +300,9 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool):
# store 1 blocks # store 1 blocks
runner.new_request(token_ids=[0] * offloaded_block_size) runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output(keys) runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
runner.run( runner.run(
decoded_tokens=[EOS_TOKEN_ID], decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(0, 1, 2), expected_stored_gpu_block_indexes=(0, 1, 2),
......
...@@ -115,7 +115,7 @@ class MockOffloadingSpec(OffloadingSpec): ...@@ -115,7 +115,7 @@ class MockOffloadingSpec(OffloadingSpec):
self.manager = MagicMock(spec=OffloadingManager) self.manager = MagicMock(spec=OffloadingManager)
self.manager.lookup.return_value = 0 self.manager.lookup.return_value = 0
self.manager.prepare_load = lambda keys: MockLoadStoreSpec(keys) self.manager.prepare_load = lambda keys, req_context: MockLoadStoreSpec(keys)
self.handler = MockOffloadingHandler() self.handler = MockOffloadingHandler()
def get_manager(self) -> OffloadingManager: def get_manager(self) -> OffloadingManager:
......
...@@ -11,6 +11,7 @@ from vllm.v1.kv_offload.abstract import ( ...@@ -11,6 +11,7 @@ from vllm.v1.kv_offload.abstract import (
OffloadingEvent, OffloadingEvent,
OffloadKey, OffloadKey,
PrepareStoreOutput, PrepareStoreOutput,
ReqContext,
make_offload_key, make_offload_key,
) )
from vllm.v1.kv_offload.cpu.manager import CPUOffloadingManager from vllm.v1.kv_offload.cpu.manager import CPUOffloadingManager
...@@ -19,6 +20,14 @@ from vllm.v1.kv_offload.mediums import CPULoadStoreSpec ...@@ -19,6 +20,14 @@ from vllm.v1.kv_offload.mediums import CPULoadStoreSpec
from vllm.v1.kv_offload.reuse_manager import FilterReusedOffloadingManager from vllm.v1.kv_offload.reuse_manager import FilterReusedOffloadingManager
def make_req_context(kv_transfer_params: dict | None = None) -> ReqContext:
"""Create a ReqContext as production code would, from a request's params."""
return ReqContext(kv_transfer_params=kv_transfer_params)
_EMPTY_REQ_CTX = make_req_context()
@dataclass @dataclass
class ExpectedPrepareStoreOutput: class ExpectedPrepareStoreOutput:
keys_to_store: list[int] keys_to_store: list[int]
...@@ -103,7 +112,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy): ...@@ -103,7 +112,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy):
) )
# store [1, 2] and complete # store [1, 2] and complete
manager.prepare_store(to_keys([1, 2])) manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
manager.complete_store(to_keys([1, 2])) manager.complete_store(to_keys([1, 2]))
# touch [1] to make block 2 the LRU candidate # touch [1] to make block 2 the LRU candidate
...@@ -113,7 +122,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy): ...@@ -113,7 +122,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy):
# - block 2 is already stored -> filtered out of keys_to_store # - block 2 is already stored -> filtered out of keys_to_store
# - block 2 must NOT be evicted even though it is the LRU candidate # - block 2 must NOT be evicted even though it is the LRU candidate
# - block 1 (ID 0) is evicted instead; new blocks [3,4,5] get IDs 2,3,0 # - block 1 (ID 0) is evicted instead; new blocks [3,4,5] get IDs 2,3,0
prepare_store_output = manager.prepare_store(to_keys([2, 3, 4, 5])) prepare_store_output = manager.prepare_store(to_keys([2, 3, 4, 5]), _EMPTY_REQ_CTX)
verify_store_output( verify_store_output(
prepare_store_output, prepare_store_output,
ExpectedPrepareStoreOutput( ExpectedPrepareStoreOutput(
...@@ -127,7 +136,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy): ...@@ -127,7 +136,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy):
manager.complete_store(to_keys([2, 3, 4, 5])) manager.complete_store(to_keys([2, 3, 4, 5]))
# block 2 must still be present in the cache # block 2 must still be present in the cache
assert manager.lookup(to_keys([2])) == 1 assert manager.lookup(to_keys([2]), _EMPTY_REQ_CTX) == 1
def test_cpu_manager(): def test_cpu_manager():
...@@ -140,7 +149,7 @@ def test_cpu_manager(): ...@@ -140,7 +149,7 @@ def test_cpu_manager():
) )
# prepare store [1, 2] # prepare store [1, 2]
prepare_store_output = cpu_manager.prepare_store(to_keys([1, 2])) prepare_store_output = cpu_manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
verify_store_output( verify_store_output(
prepare_store_output, prepare_store_output,
ExpectedPrepareStoreOutput( ExpectedPrepareStoreOutput(
...@@ -151,7 +160,7 @@ def test_cpu_manager(): ...@@ -151,7 +160,7 @@ def test_cpu_manager():
) )
# lookup [1, 2] -> not ready # lookup [1, 2] -> not ready
assert cpu_manager.lookup(to_keys([1, 2])) == 0 assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0
# no events so far # no events so far
assert list(cpu_manager.take_events()) == [] assert list(cpu_manager.take_events()) == []
...@@ -161,12 +170,14 @@ def test_cpu_manager(): ...@@ -161,12 +170,14 @@ def test_cpu_manager():
verify_events(cpu_manager.take_events(), expected_stores=({1, 2},)) verify_events(cpu_manager.take_events(), expected_stores=({1, 2},))
# lookup [1, 2] # lookup [1, 2]
assert cpu_manager.lookup(to_keys([1])) == 1 assert cpu_manager.lookup(to_keys([1]), _EMPTY_REQ_CTX) == 1
assert cpu_manager.lookup(to_keys([1, 2])) == 2 assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 2
assert cpu_manager.lookup(to_keys([1, 2, 3])) == 2 assert cpu_manager.lookup(to_keys([1, 2, 3]), _EMPTY_REQ_CTX) == 2
# prepare store [2, 3, 4, 5] -> evicts [1] # prepare store [2, 3, 4, 5] -> evicts [1]
prepare_store_output = cpu_manager.prepare_store(to_keys([2, 3, 4, 5])) prepare_store_output = cpu_manager.prepare_store(
to_keys([2, 3, 4, 5]), _EMPTY_REQ_CTX
)
verify_store_output( verify_store_output(
prepare_store_output, prepare_store_output,
ExpectedPrepareStoreOutput( ExpectedPrepareStoreOutput(
...@@ -180,23 +191,23 @@ def test_cpu_manager(): ...@@ -180,23 +191,23 @@ def test_cpu_manager():
verify_events(cpu_manager.take_events(), expected_evictions=({1},)) verify_events(cpu_manager.take_events(), expected_evictions=({1},))
# prepare store with no space # prepare store with no space
assert cpu_manager.prepare_store(to_keys([1, 6])) is None assert cpu_manager.prepare_store(to_keys([1, 6]), _EMPTY_REQ_CTX) is None
# complete store [2, 3, 4, 5] # complete store [2, 3, 4, 5]
cpu_manager.complete_store(to_keys([2, 3, 4, 5])) cpu_manager.complete_store(to_keys([2, 3, 4, 5]))
# prepare load [2, 3] # prepare load [2, 3]
prepare_load_output = cpu_manager.prepare_load(to_keys([2, 3])) prepare_load_output = cpu_manager.prepare_load(to_keys([2, 3]), _EMPTY_REQ_CTX)
verify_load_output(prepare_load_output, [1, 2]) verify_load_output(prepare_load_output, [1, 2])
# prepare store with no space ([2, 3] is being loaded) # prepare store with no space ([2, 3] is being loaded)
assert cpu_manager.prepare_store(to_keys([6, 7, 8])) is None assert cpu_manager.prepare_store(to_keys([6, 7, 8]), _EMPTY_REQ_CTX) is None
# complete load [2, 3] # complete load [2, 3]
cpu_manager.complete_load(to_keys([2, 3])) cpu_manager.complete_load(to_keys([2, 3]))
# prepare store [6, 7, 8] -> evicts [2, 3, 4] (oldest) # prepare store [6, 7, 8] -> evicts [2, 3, 4] (oldest)
prepare_store_output = cpu_manager.prepare_store(to_keys([6, 7, 8])) prepare_store_output = cpu_manager.prepare_store(to_keys([6, 7, 8]), _EMPTY_REQ_CTX)
verify_store_output( verify_store_output(
prepare_store_output, prepare_store_output,
ExpectedPrepareStoreOutput( ExpectedPrepareStoreOutput(
...@@ -213,7 +224,7 @@ def test_cpu_manager(): ...@@ -213,7 +224,7 @@ def test_cpu_manager():
cpu_manager.touch(to_keys([5, 6, 7])) cpu_manager.touch(to_keys([5, 6, 7]))
# prepare store [7, 9] -> evicts [8] (oldest following previous touch) # prepare store [7, 9] -> evicts [8] (oldest following previous touch)
prepare_store_output = cpu_manager.prepare_store(to_keys([9])) prepare_store_output = cpu_manager.prepare_store(to_keys([9]), _EMPTY_REQ_CTX)
verify_store_output( verify_store_output(
prepare_store_output, prepare_store_output,
ExpectedPrepareStoreOutput( ExpectedPrepareStoreOutput(
...@@ -227,8 +238,8 @@ def test_cpu_manager(): ...@@ -227,8 +238,8 @@ def test_cpu_manager():
cpu_manager.complete_store(to_keys([7, 9]), success=False) cpu_manager.complete_store(to_keys([7, 9]), success=False)
# assert [7] is still stored, but [9] is not # assert [7] is still stored, but [9] is not
assert cpu_manager.lookup(to_keys([7])) == 1 assert cpu_manager.lookup(to_keys([7]), _EMPTY_REQ_CTX) == 1
assert cpu_manager.lookup(to_keys([9])) == 0 assert cpu_manager.lookup(to_keys([9]), _EMPTY_REQ_CTX) == 0
verify_events( verify_events(
cpu_manager.take_events(), cpu_manager.take_events(),
...@@ -260,7 +271,9 @@ class TestARCPolicy: ...@@ -260,7 +271,9 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager() cpu_manager, arc_policy = self._make_manager()
# prepare store [1, 2] # prepare store [1, 2]
prepare_store_output = cpu_manager.prepare_store(to_keys([1, 2])) prepare_store_output = cpu_manager.prepare_store(
to_keys([1, 2]), _EMPTY_REQ_CTX
)
verify_store_output( verify_store_output(
prepare_store_output, prepare_store_output,
ExpectedPrepareStoreOutput( ExpectedPrepareStoreOutput(
...@@ -271,7 +284,7 @@ class TestARCPolicy: ...@@ -271,7 +284,7 @@ class TestARCPolicy:
) )
# lookup [1, 2] -> not ready # lookup [1, 2] -> not ready
assert cpu_manager.lookup(to_keys([1, 2])) == 0 assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0
# no events so far # no events so far
assert list(cpu_manager.take_events()) == [] assert list(cpu_manager.take_events()) == []
...@@ -281,9 +294,9 @@ class TestARCPolicy: ...@@ -281,9 +294,9 @@ class TestARCPolicy:
verify_events(cpu_manager.take_events(), expected_stores=({1, 2},)) verify_events(cpu_manager.take_events(), expected_stores=({1, 2},))
# lookup [1, 2] # lookup [1, 2]
assert cpu_manager.lookup(to_keys([1])) == 1 assert cpu_manager.lookup(to_keys([1]), _EMPTY_REQ_CTX) == 1
assert cpu_manager.lookup(to_keys([1, 2])) == 2 assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 2
assert cpu_manager.lookup(to_keys([1, 2, 3])) == 2 assert cpu_manager.lookup(to_keys([1, 2, 3]), _EMPTY_REQ_CTX) == 2
# blocks should be in T1 (recent) # blocks should be in T1 (recent)
assert len(arc_policy.t1) == 2 assert len(arc_policy.t1) == 2
...@@ -297,7 +310,7 @@ class TestARCPolicy: ...@@ -297,7 +310,7 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager(enable_events=False) cpu_manager, arc_policy = self._make_manager(enable_events=False)
# store and complete block 1 # store and complete block 1
cpu_manager.prepare_store(to_keys([1])) cpu_manager.prepare_store(to_keys([1]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([1])) cpu_manager.complete_store(to_keys([1]))
# block 1 starts in T1 (recent) # block 1 starts in T1 (recent)
...@@ -319,7 +332,9 @@ class TestARCPolicy: ...@@ -319,7 +332,9 @@ class TestARCPolicy:
cpu_manager, _ = self._make_manager() cpu_manager, _ = self._make_manager()
# prepare and complete store [1, 2, 3, 4] # prepare and complete store [1, 2, 3, 4]
prepare_store_output = cpu_manager.prepare_store(to_keys([1, 2, 3, 4])) prepare_store_output = cpu_manager.prepare_store(
to_keys([1, 2, 3, 4]), _EMPTY_REQ_CTX
)
verify_store_output( verify_store_output(
prepare_store_output, prepare_store_output,
ExpectedPrepareStoreOutput( ExpectedPrepareStoreOutput(
...@@ -331,19 +346,21 @@ class TestARCPolicy: ...@@ -331,19 +346,21 @@ class TestARCPolicy:
cpu_manager.complete_store(to_keys([1, 2, 3, 4])) cpu_manager.complete_store(to_keys([1, 2, 3, 4]))
# prepare load [2, 3] (increases ref_cnt) # prepare load [2, 3] (increases ref_cnt)
prepare_load_output = cpu_manager.prepare_load(to_keys([2, 3])) prepare_load_output = cpu_manager.prepare_load(to_keys([2, 3]), _EMPTY_REQ_CTX)
verify_load_output(prepare_load_output, [1, 2]) verify_load_output(prepare_load_output, [1, 2])
# prepare store [5, 6, 7] with [2, 3] being loaded # prepare store [5, 6, 7] with [2, 3] being loaded
# should fail because [2, 3] have ref_cnt > 0 # should fail because [2, 3] have ref_cnt > 0
assert cpu_manager.prepare_store(to_keys([5, 6, 7])) is None assert cpu_manager.prepare_store(to_keys([5, 6, 7]), _EMPTY_REQ_CTX) is None
# complete load [2, 3] # complete load [2, 3]
cpu_manager.complete_load(to_keys([2, 3])) cpu_manager.complete_load(to_keys([2, 3]))
# now prepare store [5, 6, 7] should succeed # now prepare store [5, 6, 7] should succeed
# ARC will evict blocks one at a time from T1 as needed # ARC will evict blocks one at a time from T1 as needed
prepare_store_output = cpu_manager.prepare_store(to_keys([5, 6, 7])) prepare_store_output = cpu_manager.prepare_store(
to_keys([5, 6, 7]), _EMPTY_REQ_CTX
)
assert prepare_store_output is not None assert prepare_store_output is not None
# Should successfully evict enough blocks to make room (at least 1) # Should successfully evict enough blocks to make room (at least 1)
assert len(prepare_store_output.evicted_keys) >= 1 assert len(prepare_store_output.evicted_keys) >= 1
...@@ -357,13 +374,13 @@ class TestARCPolicy: ...@@ -357,13 +374,13 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager(num_blocks=2, enable_events=False) cpu_manager, arc_policy = self._make_manager(num_blocks=2, enable_events=False)
# store blocks 1, 2 (fills cache) # store blocks 1, 2 (fills cache)
cpu_manager.prepare_store(to_keys([1, 2])) cpu_manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([1, 2])) cpu_manager.complete_store(to_keys([1, 2]))
initial_target = arc_policy.target_t1_size initial_target = arc_policy.target_t1_size
# store block 3, evicting block 1 (moves to B1 ghost list) # store block 3, evicting block 1 (moves to B1 ghost list)
cpu_manager.prepare_store(to_keys([3])) cpu_manager.prepare_store(to_keys([3]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([3])) cpu_manager.complete_store(to_keys([3]))
# block 1 should be in B1 (ghost list) # block 1 should be in B1 (ghost list)
...@@ -384,7 +401,7 @@ class TestARCPolicy: ...@@ -384,7 +401,7 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager(enable_events=False) cpu_manager, arc_policy = self._make_manager(enable_events=False)
# store blocks 1, 2, 3, 4 # store blocks 1, 2, 3, 4
cpu_manager.prepare_store(to_keys([1, 2, 3, 4])) cpu_manager.prepare_store(to_keys([1, 2, 3, 4]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([1, 2, 3, 4])) cpu_manager.complete_store(to_keys([1, 2, 3, 4]))
# promote blocks 3, 4 to T2 by touching them # promote blocks 3, 4 to T2 by touching them
...@@ -399,7 +416,7 @@ class TestARCPolicy: ...@@ -399,7 +416,7 @@ class TestARCPolicy:
arc_policy.target_t1_size = 1 arc_policy.target_t1_size = 1
# store block 5, should evict from T1 (block 1, LRU in T1) # store block 5, should evict from T1 (block 1, LRU in T1)
output = cpu_manager.prepare_store(to_keys([5])) output = cpu_manager.prepare_store(to_keys([5]), _EMPTY_REQ_CTX)
assert output is not None assert output is not None
assert to_keys([1]) == output.evicted_keys assert to_keys([1]) == output.evicted_keys
...@@ -418,12 +435,12 @@ class TestARCPolicy: ...@@ -418,12 +435,12 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager(num_blocks=2, enable_events=False) cpu_manager, arc_policy = self._make_manager(num_blocks=2, enable_events=False)
# fill cache with blocks 1, 2 # fill cache with blocks 1, 2
cpu_manager.prepare_store(to_keys([1, 2])) cpu_manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([1, 2])) cpu_manager.complete_store(to_keys([1, 2]))
# store many blocks to fill ghost lists # store many blocks to fill ghost lists
for i in range(3, 20): for i in range(3, 20):
cpu_manager.prepare_store(to_keys([i])) cpu_manager.prepare_store(to_keys([i]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([i])) cpu_manager.complete_store(to_keys([i]))
# ghost lists should not exceed cache_capacity # ghost lists should not exceed cache_capacity
...@@ -438,7 +455,7 @@ class TestARCPolicy: ...@@ -438,7 +455,7 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager() cpu_manager, arc_policy = self._make_manager()
# store blocks 1, 2, 3, 4 # store blocks 1, 2, 3, 4
cpu_manager.prepare_store(to_keys([1, 2, 3, 4])) cpu_manager.prepare_store(to_keys([1, 2, 3, 4]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([1, 2, 3, 4])) cpu_manager.complete_store(to_keys([1, 2, 3, 4]))
# promote 3, 4 to T2 # promote 3, 4 to T2
...@@ -453,7 +470,7 @@ class TestARCPolicy: ...@@ -453,7 +470,7 @@ class TestARCPolicy:
assert len(arc_policy.t2) == 3 assert len(arc_policy.t2) == 3
# store block 5, should evict from T1 (block 2, only one in T1) # store block 5, should evict from T1 (block 2, only one in T1)
prepare_store_output = cpu_manager.prepare_store(to_keys([5])) prepare_store_output = cpu_manager.prepare_store(to_keys([5]), _EMPTY_REQ_CTX)
verify_store_output( verify_store_output(
prepare_store_output, prepare_store_output,
ExpectedPrepareStoreOutput( ExpectedPrepareStoreOutput(
...@@ -471,11 +488,11 @@ class TestARCPolicy: ...@@ -471,11 +488,11 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager() cpu_manager, arc_policy = self._make_manager()
# store blocks 1, 2, 3, 4 # store blocks 1, 2, 3, 4
cpu_manager.prepare_store(to_keys([1, 2, 3, 4])) cpu_manager.prepare_store(to_keys([1, 2, 3, 4]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([1, 2, 3, 4])) cpu_manager.complete_store(to_keys([1, 2, 3, 4]))
# prepare store block 5 (will evict block 1) # prepare store block 5 (will evict block 1)
prepare_store_output = cpu_manager.prepare_store(to_keys([5])) prepare_store_output = cpu_manager.prepare_store(to_keys([5]), _EMPTY_REQ_CTX)
assert prepare_store_output is not None assert prepare_store_output is not None
assert len(prepare_store_output.evicted_keys) == 1 assert len(prepare_store_output.evicted_keys) == 1
...@@ -483,7 +500,7 @@ class TestARCPolicy: ...@@ -483,7 +500,7 @@ class TestARCPolicy:
cpu_manager.complete_store(to_keys([5]), success=False) cpu_manager.complete_store(to_keys([5]), success=False)
# block 5 should not be in cache # block 5 should not be in cache
assert cpu_manager.lookup(to_keys([5])) == 0 assert cpu_manager.lookup(to_keys([5]), _EMPTY_REQ_CTX) == 0
# block 5 should not be in T1 or T2 # block 5 should not be in T1 or T2
assert to_keys([5])[0] not in arc_policy.t1 assert to_keys([5])[0] not in arc_policy.t1
assert to_keys([5])[0] not in arc_policy.t2 assert to_keys([5])[0] not in arc_policy.t2
...@@ -500,11 +517,13 @@ class TestARCPolicy: ...@@ -500,11 +517,13 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager() cpu_manager, arc_policy = self._make_manager()
# store [1, 2] # store [1, 2]
cpu_manager.prepare_store(to_keys([1, 2])) cpu_manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([1, 2])) cpu_manager.complete_store(to_keys([1, 2]))
# store [3, 4, 5] -> evicts [1] # store [3, 4, 5] -> evicts [1]
prepare_store_output = cpu_manager.prepare_store(to_keys([3, 4, 5])) prepare_store_output = cpu_manager.prepare_store(
to_keys([3, 4, 5]), _EMPTY_REQ_CTX
)
assert prepare_store_output is not None assert prepare_store_output is not None
assert len(prepare_store_output.evicted_keys) == 1 assert len(prepare_store_output.evicted_keys) == 1
cpu_manager.complete_store(to_keys([3, 4, 5])) cpu_manager.complete_store(to_keys([3, 4, 5]))
...@@ -517,13 +536,13 @@ class TestARCPolicy: ...@@ -517,13 +536,13 @@ class TestARCPolicy:
assert len(arc_policy.t2) == 2 assert len(arc_policy.t2) == 2
# store [6] -> should evict from T1 (4 is oldest in T1) # store [6] -> should evict from T1 (4 is oldest in T1)
prepare_store_output = cpu_manager.prepare_store(to_keys([6])) prepare_store_output = cpu_manager.prepare_store(to_keys([6]), _EMPTY_REQ_CTX)
assert prepare_store_output is not None assert prepare_store_output is not None
cpu_manager.complete_store(to_keys([6])) cpu_manager.complete_store(to_keys([6]))
# verify blocks 2, 3 (in T2) are still present # verify blocks 2, 3 (in T2) are still present
assert cpu_manager.lookup(to_keys([2])) == 1 assert cpu_manager.lookup(to_keys([2]), _EMPTY_REQ_CTX) == 1
assert cpu_manager.lookup(to_keys([3])) == 1 assert cpu_manager.lookup(to_keys([3]), _EMPTY_REQ_CTX) == 1
# verify events # verify events
events = list(cpu_manager.take_events()) events = list(cpu_manager.take_events())
...@@ -543,34 +562,34 @@ def test_filter_reused_manager(): ...@@ -543,34 +562,34 @@ def test_filter_reused_manager():
) )
# Lookup [1, 2] -> 1st time, added to tracker but not eligible for store yet # Lookup [1, 2] -> 1st time, added to tracker but not eligible for store yet
assert manager.lookup(to_keys([1, 2])) == 0 assert manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0
# prepare store [1, 2] -> should be filtered # prepare store [1, 2] -> should be filtered
prepare_store_output = manager.prepare_store(to_keys([1, 2])) prepare_store_output = manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
assert prepare_store_output is not None assert prepare_store_output is not None
assert prepare_store_output.keys_to_store == [] assert prepare_store_output.keys_to_store == []
# Lookup [1] -> 2nd time, eligible now # Lookup [1] -> 2nd time, eligible now
assert manager.lookup(to_keys([1])) == 0 assert manager.lookup(to_keys([1]), _EMPTY_REQ_CTX) == 0
# prepare store [1, 2] -> [1] should be eligible, [2] should be filtered # prepare store [1, 2] -> [1] should be eligible, [2] should be filtered
prepare_store_output = manager.prepare_store(to_keys([1, 2])) prepare_store_output = manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
assert prepare_store_output is not None assert prepare_store_output is not None
assert prepare_store_output.keys_to_store == to_keys([1]) assert prepare_store_output.keys_to_store == to_keys([1])
# Lookup [3, 4] -> 1st time # Lookup [3, 4] -> 1st time
# (evicts [2] from tracker since max_size is 3 and tracker has [1]) # (evicts [2] from tracker since max_size is 3 and tracker has [1])
assert manager.lookup(to_keys([3, 4])) == 0 assert manager.lookup(to_keys([3, 4]), _EMPTY_REQ_CTX) == 0
# Verify [2] was evicted from the tracker (tracker now has: [1], [3], [4]) # Verify [2] was evicted from the tracker (tracker now has: [1], [3], [4])
assert to_keys([2])[0] not in manager.counts assert to_keys([2])[0] not in manager.counts
# Lookup [2] again -> (this adds [2] back to the tracker as 1st time) # Lookup [2] again -> (this adds [2] back to the tracker as 1st time)
assert manager.lookup(to_keys([2])) == 0 assert manager.lookup(to_keys([2]), _EMPTY_REQ_CTX) == 0
# Verify [2] was re-added with count=1 (not eligible yet) # Verify [2] was re-added with count=1 (not eligible yet)
assert manager.counts.get(to_keys([2])[0]) == 1 assert manager.counts.get(to_keys([2])[0]) == 1
# prepare store [2] -> should still be filtered out since count was reset # prepare store [2] -> should still be filtered out since count was reset
prepare_store_output = manager.prepare_store(to_keys([2])) prepare_store_output = manager.prepare_store(to_keys([2]), _EMPTY_REQ_CTX)
assert prepare_store_output is not None assert prepare_store_output is not None
assert prepare_store_output.keys_to_store == [] assert prepare_store_output.keys_to_store == []
......
...@@ -19,6 +19,7 @@ from vllm.v1.core.sched.output import SchedulerOutput ...@@ -19,6 +19,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_offload.abstract import ( from vllm.v1.kv_offload.abstract import (
OffloadingManager, OffloadingManager,
OffloadKey, OffloadKey,
ReqContext,
get_offload_block_hash, get_offload_block_hash,
make_offload_key, make_offload_key,
) )
...@@ -74,6 +75,7 @@ class RequestOffloadState: ...@@ -74,6 +75,7 @@ class RequestOffloadState:
config: SchedulerOffloadConfig config: SchedulerOffloadConfig
req: Request req: Request
group_states: tuple[RequestGroupState, ...] = field(init=False) group_states: tuple[RequestGroupState, ...] = field(init=False)
req_context: ReqContext = field(init=False)
# number of hits in the GPU cache # number of hits in the GPU cache
num_locally_computed_tokens: int = 0 num_locally_computed_tokens: int = 0
...@@ -81,6 +83,7 @@ class RequestOffloadState: ...@@ -81,6 +83,7 @@ class RequestOffloadState:
self.group_states = tuple( self.group_states = tuple(
RequestGroupState() for _ in self.config.kv_group_configs RequestGroupState() for _ in self.config.kv_group_configs
) )
self.req_context = ReqContext(kv_transfer_params=self.req.kv_transfer_params)
def update_offload_keys(self) -> None: def update_offload_keys(self) -> None:
for group_config, group_state in zip( for group_config, group_state in zip(
...@@ -181,7 +184,10 @@ class OffloadingConnectorScheduler: ...@@ -181,7 +184,10 @@ class OffloadingConnectorScheduler:
return 0, False return 0, False
start_block_idx = num_computed_tokens // group_config.offloaded_block_size start_block_idx = num_computed_tokens // group_config.offloaded_block_size
hits = self.manager.lookup(offload_keys[start_block_idx:]) hits = self.manager.lookup(
offload_keys[start_block_idx:],
req_status.req_context,
)
if hits is None: if hits is None:
# indicates a lookup that should be tried later # indicates a lookup that should be tried later
return None, False return None, False
...@@ -249,7 +255,7 @@ class OffloadingConnectorScheduler: ...@@ -249,7 +255,7 @@ class OffloadingConnectorScheduler:
assert len(request.block_hashes) // self.config.block_size_factor >= num_blocks assert len(request.block_hashes) // self.config.block_size_factor >= num_blocks
offload_keys = group_state.offload_keys[start_block_idx:num_blocks] offload_keys = group_state.offload_keys[start_block_idx:num_blocks]
src_spec = self.manager.prepare_load(offload_keys) src_spec = self.manager.prepare_load(offload_keys, req_status.req_context)
dst_spec = GPULoadStoreSpec( dst_spec = GPULoadStoreSpec(
block_ids[num_computed_gpu_blocks:], block_ids[num_computed_gpu_blocks:],
group_sizes=(num_pending_gpu_blocks,), group_sizes=(num_pending_gpu_blocks,),
...@@ -304,7 +310,9 @@ class OffloadingConnectorScheduler: ...@@ -304,7 +310,9 @@ class OffloadingConnectorScheduler:
assert len(req.block_hashes) >= num_gpu_blocks assert len(req.block_hashes) >= num_gpu_blocks
new_offload_keys = group_state.offload_keys[start_block_idx:num_blocks] new_offload_keys = group_state.offload_keys[start_block_idx:num_blocks]
store_output = self.manager.prepare_store(new_offload_keys) store_output = self.manager.prepare_store(
new_offload_keys, req_status.req_context
)
if store_output is None: if store_output is None:
logger.warning( logger.warning(
"Request %s: cannot store %s blocks", req_id, num_new_blocks "Request %s: cannot store %s blocks", req_id, num_new_blocks
......
...@@ -30,7 +30,7 @@ The class provides the following primitives: ...@@ -30,7 +30,7 @@ The class provides the following primitives:
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import NewType from typing import Any, NewType
# `OffloadKey` identifies an offloaded block. It combines a block hash with # `OffloadKey` identifies an offloaded block. It combines a block hash with
# its KV cache group index, encoded as raw bytes to avoid tuple GC overhead. # its KV cache group index, encoded as raw bytes to avoid tuple GC overhead.
...@@ -53,6 +53,11 @@ def get_offload_group_idx(key: OffloadKey) -> int: ...@@ -53,6 +53,11 @@ def get_offload_group_idx(key: OffloadKey) -> int:
return int.from_bytes(key[-4:], "big", signed=False) return int.from_bytes(key[-4:], "big", signed=False)
@dataclass
class ReqContext:
kv_transfer_params: dict[str, Any] | None = None
class LoadStoreSpec(ABC): class LoadStoreSpec(ABC):
""" """
Abstract metadata that encapsulates information allowing a worker Abstract metadata that encapsulates information allowing a worker
...@@ -86,13 +91,18 @@ class OffloadingEvent: ...@@ -86,13 +91,18 @@ class OffloadingEvent:
class OffloadingManager(ABC): class OffloadingManager(ABC):
@abstractmethod @abstractmethod
def lookup(self, keys: Iterable[OffloadKey]) -> int | None: def lookup(
self,
keys: Iterable[OffloadKey],
req_context: ReqContext,
) -> int | None:
""" """
Finds the length of the maximal series of blocks, starting from the Finds the length of the maximal series of blocks, starting from the
first one, that are all offloaded. first one, that are all offloaded.
Args: Args:
keys: the keys identifying the blocks to lookup. keys: the keys identifying the blocks to lookup.
req_context: per-request context (e.g. kv_transfer_params).
Returns: Returns:
An integer representing the maximal number of blocks that An integer representing the maximal number of blocks that
...@@ -103,7 +113,11 @@ class OffloadingManager(ABC): ...@@ -103,7 +113,11 @@ class OffloadingManager(ABC):
pass pass
@abstractmethod @abstractmethod
def prepare_load(self, keys: Iterable[OffloadKey]) -> LoadStoreSpec: def prepare_load(
self,
keys: Iterable[OffloadKey],
req_context: ReqContext,
) -> LoadStoreSpec:
""" """
Prepare the given blocks to be read. Prepare the given blocks to be read.
The given blocks will be protected from eviction until The given blocks will be protected from eviction until
...@@ -112,6 +126,7 @@ class OffloadingManager(ABC): ...@@ -112,6 +126,7 @@ class OffloadingManager(ABC):
Args: Args:
keys: the keys identifying the blocks. keys: the keys identifying the blocks.
req_context: per-request context (e.g. kv_transfer_params).
Returns: Returns:
A LoadStoreSpec that can be used by a worker to locate and load A LoadStoreSpec that can be used by a worker to locate and load
...@@ -139,7 +154,11 @@ class OffloadingManager(ABC): ...@@ -139,7 +154,11 @@ class OffloadingManager(ABC):
return return
@abstractmethod @abstractmethod
def prepare_store(self, keys: Iterable[OffloadKey]) -> PrepareStoreOutput | None: def prepare_store(
self,
keys: Iterable[OffloadKey],
req_context: ReqContext,
) -> PrepareStoreOutput | None:
""" """
Prepare the given blocks to be offloaded. Prepare the given blocks to be offloaded.
The given blocks will be protected from eviction until The given blocks will be protected from eviction until
...@@ -147,6 +166,7 @@ class OffloadingManager(ABC): ...@@ -147,6 +166,7 @@ class OffloadingManager(ABC):
Args: Args:
keys: the keys identifying the blocks. keys: the keys identifying the blocks.
req_context: per-request context (e.g. kv_transfer_params).
Returns: Returns:
A PrepareStoreOutput indicating which blocks need storing, A PrepareStoreOutput indicating which blocks need storing,
......
...@@ -9,6 +9,7 @@ from vllm.v1.kv_offload.abstract import ( ...@@ -9,6 +9,7 @@ from vllm.v1.kv_offload.abstract import (
OffloadingManager, OffloadingManager,
OffloadKey, OffloadKey,
PrepareStoreOutput, PrepareStoreOutput,
ReqContext,
) )
from vllm.v1.kv_offload.cpu.policies.abstract import BlockStatus, CachePolicy from vllm.v1.kv_offload.cpu.policies.abstract import BlockStatus, CachePolicy
from vllm.v1.kv_offload.cpu.policies.arc import ARCCachePolicy from vllm.v1.kv_offload.cpu.policies.arc import ARCCachePolicy
...@@ -83,7 +84,11 @@ class CPUOffloadingManager(OffloadingManager): ...@@ -83,7 +84,11 @@ class CPUOffloadingManager(OffloadingManager):
# --- OffloadingManager interface --- # --- OffloadingManager interface ---
def lookup(self, keys: Iterable[OffloadKey]) -> int | None: def lookup(
self,
keys: Iterable[OffloadKey],
req_context: ReqContext,
) -> int | None:
hit_count = 0 hit_count = 0
for key in keys: for key in keys:
block = self._policy.get(key) block = self._policy.get(key)
...@@ -92,7 +97,11 @@ class CPUOffloadingManager(OffloadingManager): ...@@ -92,7 +97,11 @@ class CPUOffloadingManager(OffloadingManager):
hit_count += 1 hit_count += 1
return hit_count return hit_count
def prepare_load(self, keys: Iterable[OffloadKey]) -> LoadStoreSpec: def prepare_load(
self,
keys: Iterable[OffloadKey],
req_context: ReqContext,
) -> LoadStoreSpec:
blocks = [] blocks = []
for key in keys: for key in keys:
block = self._policy.get(key) block = self._policy.get(key)
...@@ -112,7 +121,11 @@ class CPUOffloadingManager(OffloadingManager): ...@@ -112,7 +121,11 @@ class CPUOffloadingManager(OffloadingManager):
assert block.ref_cnt > 0, f"Block {key!r} ref_cnt is already 0" assert block.ref_cnt > 0, f"Block {key!r} ref_cnt is already 0"
block.ref_cnt -= 1 block.ref_cnt -= 1
def prepare_store(self, keys: Iterable[OffloadKey]) -> PrepareStoreOutput | None: def prepare_store(
self,
keys: Iterable[OffloadKey],
req_context: ReqContext,
) -> PrepareStoreOutput | None:
keys_list = list(keys) keys_list = list(keys)
# filter out blocks that are already stored # filter out blocks that are already stored
......
...@@ -16,6 +16,7 @@ from vllm.v1.kv_offload.abstract import ( ...@@ -16,6 +16,7 @@ from vllm.v1.kv_offload.abstract import (
OffloadingManager, OffloadingManager,
OffloadKey, OffloadKey,
PrepareStoreOutput, PrepareStoreOutput,
ReqContext,
) )
...@@ -65,7 +66,7 @@ class FilterReusedOffloadingManager(OffloadingManager): ...@@ -65,7 +66,7 @@ class FilterReusedOffloadingManager(OffloadingManager):
# Intercepted methods # Intercepted methods
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def lookup(self, keys: Iterable[OffloadKey]) -> int | None: def lookup(self, keys: Iterable[OffloadKey], req_context: ReqContext) -> int | None:
"""Record each key, then delegate lookup to backing manager.""" """Record each key, then delegate lookup to backing manager."""
keys = list(keys) keys = list(keys)
for key in keys: for key in keys:
...@@ -76,9 +77,11 @@ class FilterReusedOffloadingManager(OffloadingManager): ...@@ -76,9 +77,11 @@ class FilterReusedOffloadingManager(OffloadingManager):
if len(self.counts) >= self.max_tracker_size: if len(self.counts) >= self.max_tracker_size:
self.counts.popitem(last=False) # evict LRU self.counts.popitem(last=False) # evict LRU
self.counts[key] = 1 self.counts[key] = 1
return self._backing.lookup(keys) return self._backing.lookup(keys, req_context)
def prepare_store(self, keys: Iterable[OffloadKey]) -> PrepareStoreOutput | None: def prepare_store(
self, keys: Iterable[OffloadKey], req_context: ReqContext
) -> PrepareStoreOutput | None:
"""Filter out blocks below threshold, then delegate to backing. """Filter out blocks below threshold, then delegate to backing.
Filtering is evaluated *before* calling the backing manager's Filtering is evaluated *before* calling the backing manager's
...@@ -93,14 +96,16 @@ class FilterReusedOffloadingManager(OffloadingManager): ...@@ -93,14 +96,16 @@ class FilterReusedOffloadingManager(OffloadingManager):
# Passing an empty list is intentional and safe — CPUOffloadingManager # Passing an empty list is intentional and safe — CPUOffloadingManager
# handles it correctly, returning a PrepareStoreOutput with empty lists. # handles it correctly, returning a PrepareStoreOutput with empty lists.
# Delegate to the backing manager with only the eligible keys. # Delegate to the backing manager with only the eligible keys.
return self._backing.prepare_store(eligible) return self._backing.prepare_store(eligible, req_context)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Delegated methods # Delegated methods
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def prepare_load(self, keys: Iterable[OffloadKey]) -> LoadStoreSpec: def prepare_load(
return self._backing.prepare_load(keys) self, keys: Iterable[OffloadKey], req_context: ReqContext
) -> LoadStoreSpec:
return self._backing.prepare_load(keys, req_context)
def touch(self, keys: Iterable[OffloadKey]) -> None: def touch(self, keys: Iterable[OffloadKey]) -> None:
return self._backing.touch(keys) return self._backing.touch(keys)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment