Unverified Commit 4353c9cb authored by omerpaz95's avatar omerpaz95 Committed by GitHub
Browse files

[KV Offload] Pass request context (#39185)


Signed-off-by: default avataromerpaz95 <omerpaz95@gmail.com>
parent 4b7f5ea1
......@@ -32,8 +32,8 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
# 3 blocks, store just the middle block (skip first and last)
# blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8]
runner.new_request(token_ids=[0] * offloaded_block_size * 3)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output(
list(keys)[1:2]
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(list(keys)[1:2])
)
runner.run(decoded_tokens=[0])
......@@ -45,18 +45,22 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
runner.manager.prepare_store.assert_not_called()
# +1 token -> single block, fail prepare_store
runner.manager.prepare_store.side_effect = lambda keys: None
runner.manager.prepare_store.side_effect = lambda keys, req_context: None
runner.run(decoded_tokens=[0])
runner.manager.prepare_store.assert_called()
# 1 more block (+ token for async scheduling)
# now set block_hashes_to_store = []
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output([])
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([])
)
runner.run(decoded_tokens=[0] * (offloaded_block_size + 1))
# 1 more block (+ token for kicking off offloading)
# now check touch was called with all 6 blocks
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output(keys)
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
runner.run(
decoded_tokens=[0] * (offloaded_block_size + 1),
expected_stored_gpu_block_indexes=(15, 16, 17),
......@@ -89,13 +93,17 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
runner.new_request(
token_ids=[0] * gpu_block_size + [1] * (offloaded_block_size - gpu_block_size)
)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output([])
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([])
)
runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.manager.lookup.assert_not_called()
# single block lookup with no hits
runner.new_request(token_ids=[1] * offloaded_block_size)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output([])
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([])
)
runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.manager.lookup.assert_called()
assert len(list(runner.manager.lookup.call_args.args[0])) == 1
......@@ -103,7 +111,9 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
# single block lookup with a hit
runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output([])
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([])
)
runner.manager.lookup.return_value = 1
runner.run(
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2)
......@@ -113,7 +123,9 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
runner.new_request(
token_ids=[0] * offloaded_block_size * 2 + [1] * offloaded_block_size
)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output([])
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([])
)
runner.manager.lookup.return_value = 1
runner.run(
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5)
......@@ -164,14 +176,18 @@ def test_request_preemption(request_runner, async_scheduling: bool):
# 2 blocks, store all, without flushing
# blocks = [0, 1, 2], [3, 4, 5]
runner.new_request(token_ids=[0] * offloaded_block_size * 2)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output(keys)
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
runner.run(
decoded_tokens=[0],
complete_transfers=False,
)
# decode 2 more blocks - 1 gpu block, storing [6, 7, 8] (no flush)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output(keys)
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
runner.run(
decoded_tokens=[0] * (2 * offloaded_block_size - gpu_block_size),
complete_transfers=False,
......@@ -195,7 +211,9 @@ def test_request_preemption(request_runner, async_scheduling: bool):
# request should now return from preemption
# re-load [0, ..., 8] from the CPU and store [9, 10, 11]
runner.manager.lookup.return_value = 3
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output(keys)
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
runner.run(
decoded_tokens=[0] * gpu_block_size,
expected_loaded_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
......@@ -222,7 +240,9 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling:
# store 1 blocks
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output(keys)
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(0, 1, 2),
......@@ -253,7 +273,9 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling:
assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs)
# complete transfers
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output([])
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([])
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_loaded_gpu_block_indexes=(0, 1, 2),
......@@ -278,7 +300,9 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool):
# store 1 blocks
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = lambda keys: generate_store_output(keys)
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(0, 1, 2),
......
......@@ -115,7 +115,7 @@ class MockOffloadingSpec(OffloadingSpec):
self.manager = MagicMock(spec=OffloadingManager)
self.manager.lookup.return_value = 0
self.manager.prepare_load = lambda keys: MockLoadStoreSpec(keys)
self.manager.prepare_load = lambda keys, req_context: MockLoadStoreSpec(keys)
self.handler = MockOffloadingHandler()
def get_manager(self) -> OffloadingManager:
......
......@@ -11,6 +11,7 @@ from vllm.v1.kv_offload.abstract import (
OffloadingEvent,
OffloadKey,
PrepareStoreOutput,
ReqContext,
make_offload_key,
)
from vllm.v1.kv_offload.cpu.manager import CPUOffloadingManager
......@@ -19,6 +20,14 @@ from vllm.v1.kv_offload.mediums import CPULoadStoreSpec
from vllm.v1.kv_offload.reuse_manager import FilterReusedOffloadingManager
def make_req_context(kv_transfer_params: dict | None = None) -> ReqContext:
"""Create a ReqContext as production code would, from a request's params."""
return ReqContext(kv_transfer_params=kv_transfer_params)
_EMPTY_REQ_CTX = make_req_context()
@dataclass
class ExpectedPrepareStoreOutput:
keys_to_store: list[int]
......@@ -103,7 +112,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy):
)
# store [1, 2] and complete
manager.prepare_store(to_keys([1, 2]))
manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
manager.complete_store(to_keys([1, 2]))
# touch [1] to make block 2 the LRU candidate
......@@ -113,7 +122,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy):
# - block 2 is already stored -> filtered out of keys_to_store
# - block 2 must NOT be evicted even though it is the LRU candidate
# - block 1 (ID 0) is evicted instead; new blocks [3,4,5] get IDs 2,3,0
prepare_store_output = manager.prepare_store(to_keys([2, 3, 4, 5]))
prepare_store_output = manager.prepare_store(to_keys([2, 3, 4, 5]), _EMPTY_REQ_CTX)
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
......@@ -127,7 +136,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy):
manager.complete_store(to_keys([2, 3, 4, 5]))
# block 2 must still be present in the cache
assert manager.lookup(to_keys([2])) == 1
assert manager.lookup(to_keys([2]), _EMPTY_REQ_CTX) == 1
def test_cpu_manager():
......@@ -140,7 +149,7 @@ def test_cpu_manager():
)
# prepare store [1, 2]
prepare_store_output = cpu_manager.prepare_store(to_keys([1, 2]))
prepare_store_output = cpu_manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
......@@ -151,7 +160,7 @@ def test_cpu_manager():
)
# lookup [1, 2] -> not ready
assert cpu_manager.lookup(to_keys([1, 2])) == 0
assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0
# no events so far
assert list(cpu_manager.take_events()) == []
......@@ -161,12 +170,14 @@ def test_cpu_manager():
verify_events(cpu_manager.take_events(), expected_stores=({1, 2},))
# lookup [1, 2]
assert cpu_manager.lookup(to_keys([1])) == 1
assert cpu_manager.lookup(to_keys([1, 2])) == 2
assert cpu_manager.lookup(to_keys([1, 2, 3])) == 2
assert cpu_manager.lookup(to_keys([1]), _EMPTY_REQ_CTX) == 1
assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 2
assert cpu_manager.lookup(to_keys([1, 2, 3]), _EMPTY_REQ_CTX) == 2
# prepare store [2, 3, 4, 5] -> evicts [1]
prepare_store_output = cpu_manager.prepare_store(to_keys([2, 3, 4, 5]))
prepare_store_output = cpu_manager.prepare_store(
to_keys([2, 3, 4, 5]), _EMPTY_REQ_CTX
)
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
......@@ -180,23 +191,23 @@ def test_cpu_manager():
verify_events(cpu_manager.take_events(), expected_evictions=({1},))
# prepare store with no space
assert cpu_manager.prepare_store(to_keys([1, 6])) is None
assert cpu_manager.prepare_store(to_keys([1, 6]), _EMPTY_REQ_CTX) is None
# complete store [2, 3, 4, 5]
cpu_manager.complete_store(to_keys([2, 3, 4, 5]))
# prepare load [2, 3]
prepare_load_output = cpu_manager.prepare_load(to_keys([2, 3]))
prepare_load_output = cpu_manager.prepare_load(to_keys([2, 3]), _EMPTY_REQ_CTX)
verify_load_output(prepare_load_output, [1, 2])
# prepare store with no space ([2, 3] is being loaded)
assert cpu_manager.prepare_store(to_keys([6, 7, 8])) is None
assert cpu_manager.prepare_store(to_keys([6, 7, 8]), _EMPTY_REQ_CTX) is None
# complete load [2, 3]
cpu_manager.complete_load(to_keys([2, 3]))
# prepare store [6, 7, 8] -> evicts [2, 3, 4] (oldest)
prepare_store_output = cpu_manager.prepare_store(to_keys([6, 7, 8]))
prepare_store_output = cpu_manager.prepare_store(to_keys([6, 7, 8]), _EMPTY_REQ_CTX)
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
......@@ -213,7 +224,7 @@ def test_cpu_manager():
cpu_manager.touch(to_keys([5, 6, 7]))
# prepare store [7, 9] -> evicts [8] (oldest following previous touch)
prepare_store_output = cpu_manager.prepare_store(to_keys([9]))
prepare_store_output = cpu_manager.prepare_store(to_keys([9]), _EMPTY_REQ_CTX)
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
......@@ -227,8 +238,8 @@ def test_cpu_manager():
cpu_manager.complete_store(to_keys([7, 9]), success=False)
# assert [7] is still stored, but [9] is not
assert cpu_manager.lookup(to_keys([7])) == 1
assert cpu_manager.lookup(to_keys([9])) == 0
assert cpu_manager.lookup(to_keys([7]), _EMPTY_REQ_CTX) == 1
assert cpu_manager.lookup(to_keys([9]), _EMPTY_REQ_CTX) == 0
verify_events(
cpu_manager.take_events(),
......@@ -260,7 +271,9 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager()
# prepare store [1, 2]
prepare_store_output = cpu_manager.prepare_store(to_keys([1, 2]))
prepare_store_output = cpu_manager.prepare_store(
to_keys([1, 2]), _EMPTY_REQ_CTX
)
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
......@@ -271,7 +284,7 @@ class TestARCPolicy:
)
# lookup [1, 2] -> not ready
assert cpu_manager.lookup(to_keys([1, 2])) == 0
assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0
# no events so far
assert list(cpu_manager.take_events()) == []
......@@ -281,9 +294,9 @@ class TestARCPolicy:
verify_events(cpu_manager.take_events(), expected_stores=({1, 2},))
# lookup [1, 2]
assert cpu_manager.lookup(to_keys([1])) == 1
assert cpu_manager.lookup(to_keys([1, 2])) == 2
assert cpu_manager.lookup(to_keys([1, 2, 3])) == 2
assert cpu_manager.lookup(to_keys([1]), _EMPTY_REQ_CTX) == 1
assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 2
assert cpu_manager.lookup(to_keys([1, 2, 3]), _EMPTY_REQ_CTX) == 2
# blocks should be in T1 (recent)
assert len(arc_policy.t1) == 2
......@@ -297,7 +310,7 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager(enable_events=False)
# store and complete block 1
cpu_manager.prepare_store(to_keys([1]))
cpu_manager.prepare_store(to_keys([1]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([1]))
# block 1 starts in T1 (recent)
......@@ -319,7 +332,9 @@ class TestARCPolicy:
cpu_manager, _ = self._make_manager()
# prepare and complete store [1, 2, 3, 4]
prepare_store_output = cpu_manager.prepare_store(to_keys([1, 2, 3, 4]))
prepare_store_output = cpu_manager.prepare_store(
to_keys([1, 2, 3, 4]), _EMPTY_REQ_CTX
)
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
......@@ -331,19 +346,21 @@ class TestARCPolicy:
cpu_manager.complete_store(to_keys([1, 2, 3, 4]))
# prepare load [2, 3] (increases ref_cnt)
prepare_load_output = cpu_manager.prepare_load(to_keys([2, 3]))
prepare_load_output = cpu_manager.prepare_load(to_keys([2, 3]), _EMPTY_REQ_CTX)
verify_load_output(prepare_load_output, [1, 2])
# prepare store [5, 6, 7] with [2, 3] being loaded
# should fail because [2, 3] have ref_cnt > 0
assert cpu_manager.prepare_store(to_keys([5, 6, 7])) is None
assert cpu_manager.prepare_store(to_keys([5, 6, 7]), _EMPTY_REQ_CTX) is None
# complete load [2, 3]
cpu_manager.complete_load(to_keys([2, 3]))
# now prepare store [5, 6, 7] should succeed
# ARC will evict blocks one at a time from T1 as needed
prepare_store_output = cpu_manager.prepare_store(to_keys([5, 6, 7]))
prepare_store_output = cpu_manager.prepare_store(
to_keys([5, 6, 7]), _EMPTY_REQ_CTX
)
assert prepare_store_output is not None
# Should successfully evict enough blocks to make room (at least 1)
assert len(prepare_store_output.evicted_keys) >= 1
......@@ -357,13 +374,13 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager(num_blocks=2, enable_events=False)
# store blocks 1, 2 (fills cache)
cpu_manager.prepare_store(to_keys([1, 2]))
cpu_manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([1, 2]))
initial_target = arc_policy.target_t1_size
# store block 3, evicting block 1 (moves to B1 ghost list)
cpu_manager.prepare_store(to_keys([3]))
cpu_manager.prepare_store(to_keys([3]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([3]))
# block 1 should be in B1 (ghost list)
......@@ -384,7 +401,7 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager(enable_events=False)
# store blocks 1, 2, 3, 4
cpu_manager.prepare_store(to_keys([1, 2, 3, 4]))
cpu_manager.prepare_store(to_keys([1, 2, 3, 4]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([1, 2, 3, 4]))
# promote blocks 3, 4 to T2 by touching them
......@@ -399,7 +416,7 @@ class TestARCPolicy:
arc_policy.target_t1_size = 1
# store block 5, should evict from T1 (block 1, LRU in T1)
output = cpu_manager.prepare_store(to_keys([5]))
output = cpu_manager.prepare_store(to_keys([5]), _EMPTY_REQ_CTX)
assert output is not None
assert to_keys([1]) == output.evicted_keys
......@@ -418,12 +435,12 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager(num_blocks=2, enable_events=False)
# fill cache with blocks 1, 2
cpu_manager.prepare_store(to_keys([1, 2]))
cpu_manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([1, 2]))
# store many blocks to fill ghost lists
for i in range(3, 20):
cpu_manager.prepare_store(to_keys([i]))
cpu_manager.prepare_store(to_keys([i]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([i]))
# ghost lists should not exceed cache_capacity
......@@ -438,7 +455,7 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager()
# store blocks 1, 2, 3, 4
cpu_manager.prepare_store(to_keys([1, 2, 3, 4]))
cpu_manager.prepare_store(to_keys([1, 2, 3, 4]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([1, 2, 3, 4]))
# promote 3, 4 to T2
......@@ -453,7 +470,7 @@ class TestARCPolicy:
assert len(arc_policy.t2) == 3
# store block 5, should evict from T1 (block 2, only one in T1)
prepare_store_output = cpu_manager.prepare_store(to_keys([5]))
prepare_store_output = cpu_manager.prepare_store(to_keys([5]), _EMPTY_REQ_CTX)
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
......@@ -471,11 +488,11 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager()
# store blocks 1, 2, 3, 4
cpu_manager.prepare_store(to_keys([1, 2, 3, 4]))
cpu_manager.prepare_store(to_keys([1, 2, 3, 4]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([1, 2, 3, 4]))
# prepare store block 5 (will evict block 1)
prepare_store_output = cpu_manager.prepare_store(to_keys([5]))
prepare_store_output = cpu_manager.prepare_store(to_keys([5]), _EMPTY_REQ_CTX)
assert prepare_store_output is not None
assert len(prepare_store_output.evicted_keys) == 1
......@@ -483,7 +500,7 @@ class TestARCPolicy:
cpu_manager.complete_store(to_keys([5]), success=False)
# block 5 should not be in cache
assert cpu_manager.lookup(to_keys([5])) == 0
assert cpu_manager.lookup(to_keys([5]), _EMPTY_REQ_CTX) == 0
# block 5 should not be in T1 or T2
assert to_keys([5])[0] not in arc_policy.t1
assert to_keys([5])[0] not in arc_policy.t2
......@@ -500,11 +517,13 @@ class TestARCPolicy:
cpu_manager, arc_policy = self._make_manager()
# store [1, 2]
cpu_manager.prepare_store(to_keys([1, 2]))
cpu_manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
cpu_manager.complete_store(to_keys([1, 2]))
# store [3, 4, 5] -> evicts [1]
prepare_store_output = cpu_manager.prepare_store(to_keys([3, 4, 5]))
prepare_store_output = cpu_manager.prepare_store(
to_keys([3, 4, 5]), _EMPTY_REQ_CTX
)
assert prepare_store_output is not None
assert len(prepare_store_output.evicted_keys) == 1
cpu_manager.complete_store(to_keys([3, 4, 5]))
......@@ -517,13 +536,13 @@ class TestARCPolicy:
assert len(arc_policy.t2) == 2
# store [6] -> should evict from T1 (4 is oldest in T1)
prepare_store_output = cpu_manager.prepare_store(to_keys([6]))
prepare_store_output = cpu_manager.prepare_store(to_keys([6]), _EMPTY_REQ_CTX)
assert prepare_store_output is not None
cpu_manager.complete_store(to_keys([6]))
# verify blocks 2, 3 (in T2) are still present
assert cpu_manager.lookup(to_keys([2])) == 1
assert cpu_manager.lookup(to_keys([3])) == 1
assert cpu_manager.lookup(to_keys([2]), _EMPTY_REQ_CTX) == 1
assert cpu_manager.lookup(to_keys([3]), _EMPTY_REQ_CTX) == 1
# verify events
events = list(cpu_manager.take_events())
......@@ -543,34 +562,34 @@ def test_filter_reused_manager():
)
# Lookup [1, 2] -> 1st time, added to tracker but not eligible for store yet
assert manager.lookup(to_keys([1, 2])) == 0
assert manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0
# prepare store [1, 2] -> should be filtered
prepare_store_output = manager.prepare_store(to_keys([1, 2]))
prepare_store_output = manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
assert prepare_store_output is not None
assert prepare_store_output.keys_to_store == []
# Lookup [1] -> 2nd time, eligible now
assert manager.lookup(to_keys([1])) == 0
assert manager.lookup(to_keys([1]), _EMPTY_REQ_CTX) == 0
# prepare store [1, 2] -> [1] should be eligible, [2] should be filtered
prepare_store_output = manager.prepare_store(to_keys([1, 2]))
prepare_store_output = manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
assert prepare_store_output is not None
assert prepare_store_output.keys_to_store == to_keys([1])
# Lookup [3, 4] -> 1st time
# (evicts [2] from tracker since max_size is 3 and tracker has [1])
assert manager.lookup(to_keys([3, 4])) == 0
assert manager.lookup(to_keys([3, 4]), _EMPTY_REQ_CTX) == 0
# Verify [2] was evicted from the tracker (tracker now has: [1], [3], [4])
assert to_keys([2])[0] not in manager.counts
# Lookup [2] again -> (this adds [2] back to the tracker as 1st time)
assert manager.lookup(to_keys([2])) == 0
assert manager.lookup(to_keys([2]), _EMPTY_REQ_CTX) == 0
# Verify [2] was re-added with count=1 (not eligible yet)
assert manager.counts.get(to_keys([2])[0]) == 1
# prepare store [2] -> should still be filtered out since count was reset
prepare_store_output = manager.prepare_store(to_keys([2]))
prepare_store_output = manager.prepare_store(to_keys([2]), _EMPTY_REQ_CTX)
assert prepare_store_output is not None
assert prepare_store_output.keys_to_store == []
......
......@@ -19,6 +19,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_offload.abstract import (
OffloadingManager,
OffloadKey,
ReqContext,
get_offload_block_hash,
make_offload_key,
)
......@@ -74,6 +75,7 @@ class RequestOffloadState:
config: SchedulerOffloadConfig
req: Request
group_states: tuple[RequestGroupState, ...] = field(init=False)
req_context: ReqContext = field(init=False)
# number of hits in the GPU cache
num_locally_computed_tokens: int = 0
......@@ -81,6 +83,7 @@ class RequestOffloadState:
self.group_states = tuple(
RequestGroupState() for _ in self.config.kv_group_configs
)
self.req_context = ReqContext(kv_transfer_params=self.req.kv_transfer_params)
def update_offload_keys(self) -> None:
for group_config, group_state in zip(
......@@ -181,7 +184,10 @@ class OffloadingConnectorScheduler:
return 0, False
start_block_idx = num_computed_tokens // group_config.offloaded_block_size
hits = self.manager.lookup(offload_keys[start_block_idx:])
hits = self.manager.lookup(
offload_keys[start_block_idx:],
req_status.req_context,
)
if hits is None:
# indicates a lookup that should be tried later
return None, False
......@@ -249,7 +255,7 @@ class OffloadingConnectorScheduler:
assert len(request.block_hashes) // self.config.block_size_factor >= num_blocks
offload_keys = group_state.offload_keys[start_block_idx:num_blocks]
src_spec = self.manager.prepare_load(offload_keys)
src_spec = self.manager.prepare_load(offload_keys, req_status.req_context)
dst_spec = GPULoadStoreSpec(
block_ids[num_computed_gpu_blocks:],
group_sizes=(num_pending_gpu_blocks,),
......@@ -304,7 +310,9 @@ class OffloadingConnectorScheduler:
assert len(req.block_hashes) >= num_gpu_blocks
new_offload_keys = group_state.offload_keys[start_block_idx:num_blocks]
store_output = self.manager.prepare_store(new_offload_keys)
store_output = self.manager.prepare_store(
new_offload_keys, req_status.req_context
)
if store_output is None:
logger.warning(
"Request %s: cannot store %s blocks", req_id, num_new_blocks
......
......@@ -30,7 +30,7 @@ The class provides the following primitives:
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
from typing import NewType
from typing import Any, NewType
# `OffloadKey` identifies an offloaded block. It combines a block hash with
# its KV cache group index, encoded as raw bytes to avoid tuple GC overhead.
......@@ -53,6 +53,11 @@ def get_offload_group_idx(key: OffloadKey) -> int:
return int.from_bytes(key[-4:], "big", signed=False)
@dataclass
class ReqContext:
kv_transfer_params: dict[str, Any] | None = None
class LoadStoreSpec(ABC):
"""
Abstract metadata that encapsulates information allowing a worker
......@@ -86,13 +91,18 @@ class OffloadingEvent:
class OffloadingManager(ABC):
@abstractmethod
def lookup(self, keys: Iterable[OffloadKey]) -> int | None:
def lookup(
self,
keys: Iterable[OffloadKey],
req_context: ReqContext,
) -> int | None:
"""
Finds the length of the maximal series of blocks, starting from the
first one, that are all offloaded.
Args:
keys: the keys identifying the blocks to lookup.
req_context: per-request context (e.g. kv_transfer_params).
Returns:
An integer representing the maximal number of blocks that
......@@ -103,7 +113,11 @@ class OffloadingManager(ABC):
pass
@abstractmethod
def prepare_load(self, keys: Iterable[OffloadKey]) -> LoadStoreSpec:
def prepare_load(
self,
keys: Iterable[OffloadKey],
req_context: ReqContext,
) -> LoadStoreSpec:
"""
Prepare the given blocks to be read.
The given blocks will be protected from eviction until
......@@ -112,6 +126,7 @@ class OffloadingManager(ABC):
Args:
keys: the keys identifying the blocks.
req_context: per-request context (e.g. kv_transfer_params).
Returns:
A LoadStoreSpec that can be used by a worker to locate and load
......@@ -139,7 +154,11 @@ class OffloadingManager(ABC):
return
@abstractmethod
def prepare_store(self, keys: Iterable[OffloadKey]) -> PrepareStoreOutput | None:
def prepare_store(
self,
keys: Iterable[OffloadKey],
req_context: ReqContext,
) -> PrepareStoreOutput | None:
"""
Prepare the given blocks to be offloaded.
The given blocks will be protected from eviction until
......@@ -147,6 +166,7 @@ class OffloadingManager(ABC):
Args:
keys: the keys identifying the blocks.
req_context: per-request context (e.g. kv_transfer_params).
Returns:
A PrepareStoreOutput indicating which blocks need storing,
......
......@@ -9,6 +9,7 @@ from vllm.v1.kv_offload.abstract import (
OffloadingManager,
OffloadKey,
PrepareStoreOutput,
ReqContext,
)
from vllm.v1.kv_offload.cpu.policies.abstract import BlockStatus, CachePolicy
from vllm.v1.kv_offload.cpu.policies.arc import ARCCachePolicy
......@@ -83,7 +84,11 @@ class CPUOffloadingManager(OffloadingManager):
# --- OffloadingManager interface ---
def lookup(self, keys: Iterable[OffloadKey]) -> int | None:
def lookup(
self,
keys: Iterable[OffloadKey],
req_context: ReqContext,
) -> int | None:
hit_count = 0
for key in keys:
block = self._policy.get(key)
......@@ -92,7 +97,11 @@ class CPUOffloadingManager(OffloadingManager):
hit_count += 1
return hit_count
def prepare_load(self, keys: Iterable[OffloadKey]) -> LoadStoreSpec:
def prepare_load(
self,
keys: Iterable[OffloadKey],
req_context: ReqContext,
) -> LoadStoreSpec:
blocks = []
for key in keys:
block = self._policy.get(key)
......@@ -112,7 +121,11 @@ class CPUOffloadingManager(OffloadingManager):
assert block.ref_cnt > 0, f"Block {key!r} ref_cnt is already 0"
block.ref_cnt -= 1
def prepare_store(self, keys: Iterable[OffloadKey]) -> PrepareStoreOutput | None:
def prepare_store(
self,
keys: Iterable[OffloadKey],
req_context: ReqContext,
) -> PrepareStoreOutput | None:
keys_list = list(keys)
# filter out blocks that are already stored
......
......@@ -16,6 +16,7 @@ from vllm.v1.kv_offload.abstract import (
OffloadingManager,
OffloadKey,
PrepareStoreOutput,
ReqContext,
)
......@@ -65,7 +66,7 @@ class FilterReusedOffloadingManager(OffloadingManager):
# Intercepted methods
# ------------------------------------------------------------------
def lookup(self, keys: Iterable[OffloadKey]) -> int | None:
def lookup(self, keys: Iterable[OffloadKey], req_context: ReqContext) -> int | None:
"""Record each key, then delegate lookup to backing manager."""
keys = list(keys)
for key in keys:
......@@ -76,9 +77,11 @@ class FilterReusedOffloadingManager(OffloadingManager):
if len(self.counts) >= self.max_tracker_size:
self.counts.popitem(last=False) # evict LRU
self.counts[key] = 1
return self._backing.lookup(keys)
return self._backing.lookup(keys, req_context)
def prepare_store(self, keys: Iterable[OffloadKey]) -> PrepareStoreOutput | None:
def prepare_store(
self, keys: Iterable[OffloadKey], req_context: ReqContext
) -> PrepareStoreOutput | None:
"""Filter out blocks below threshold, then delegate to backing.
Filtering is evaluated *before* calling the backing manager's
......@@ -93,14 +96,16 @@ class FilterReusedOffloadingManager(OffloadingManager):
# Passing an empty list is intentional and safe — CPUOffloadingManager
# handles it correctly, returning a PrepareStoreOutput with empty lists.
# Delegate to the backing manager with only the eligible keys.
return self._backing.prepare_store(eligible)
return self._backing.prepare_store(eligible, req_context)
# ------------------------------------------------------------------
# Delegated methods
# ------------------------------------------------------------------
def prepare_load(self, keys: Iterable[OffloadKey]) -> LoadStoreSpec:
return self._backing.prepare_load(keys)
def prepare_load(
self, keys: Iterable[OffloadKey], req_context: ReqContext
) -> LoadStoreSpec:
return self._backing.prepare_load(keys, req_context)
def touch(self, keys: Iterable[OffloadKey]) -> None:
return self._backing.touch(keys)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment