Unverified Commit f774ba02 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[kv_offload+HMA][4/N]: Support sliding window lookup (#36645)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
Co-authored-by: default avatarNicolò Lucchesi <nlucches@redhat.com>
parent 2aab9acf
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable from collections.abc import Iterable
from unittest.mock import MagicMock
import pytest import pytest
...@@ -10,8 +11,16 @@ from tests.v1.kv_connector.unit.offloading_connector.utils import ( ...@@ -10,8 +11,16 @@ from tests.v1.kv_connector.unit.offloading_connector.utils import (
) )
from tests.v1.kv_connector.unit.utils import EOS_TOKEN_ID from tests.v1.kv_connector.unit.utils import EOS_TOKEN_ID
from vllm.distributed.kv_events import BlockRemoved, BlockStored from vllm.distributed.kv_events import BlockRemoved, BlockStored
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.scheduler import (
OffloadingConnectorScheduler,
)
from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_offload.abstract import OffloadingEvent from vllm.v1.kv_offload.abstract import (
OffloadingEvent,
OffloadingManager,
ReqContext,
get_offload_block_hash,
)
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
...@@ -105,8 +114,7 @@ def test_offloading_connector(request_runner, async_scheduling: bool): ...@@ -105,8 +114,7 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
lambda keys, req_context: generate_store_output([]) lambda keys, req_context: generate_store_output([])
) )
runner.run(decoded_tokens=[EOS_TOKEN_ID]) runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.manager.lookup.assert_called() runner.manager.lookup.assert_called_once()
assert len(list(runner.manager.lookup.call_args.args[0])) == 1
# single block lookup with a hit # single block lookup with a hit
runner.scheduler.reset_prefix_cache() runner.scheduler.reset_prefix_cache()
...@@ -114,7 +122,7 @@ def test_offloading_connector(request_runner, async_scheduling: bool): ...@@ -114,7 +122,7 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
runner.manager.prepare_store.side_effect = ( runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([]) lambda keys, req_context: generate_store_output([])
) )
runner.manager.lookup.return_value = 1 runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1
runner.run( runner.run(
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2) decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2)
) )
...@@ -126,7 +134,7 @@ def test_offloading_connector(request_runner, async_scheduling: bool): ...@@ -126,7 +134,7 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
runner.manager.prepare_store.side_effect = ( runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([]) lambda keys, req_context: generate_store_output([])
) )
runner.manager.lookup.return_value = 1 runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1
runner.run( runner.run(
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5) decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5)
) )
...@@ -210,7 +218,7 @@ def test_request_preemption(request_runner, async_scheduling: bool): ...@@ -210,7 +218,7 @@ def test_request_preemption(request_runner, async_scheduling: bool):
# request should now return from preemption # request should now return from preemption
# re-load [0, ..., 8] from the CPU and store [9, 10, 11] # re-load [0, ..., 8] from the CPU and store [9, 10, 11]
runner.manager.lookup.return_value = 3 runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 3
runner.manager.prepare_store.side_effect = ( runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys) lambda keys, req_context: generate_store_output(keys)
) )
...@@ -251,7 +259,7 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling: ...@@ -251,7 +259,7 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling:
# start a request to load the first block, but don't complete # start a request to load the first block, but don't complete
runner.scheduler.reset_prefix_cache() runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size) runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.lookup.return_value = 1 runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1
runner.run( runner.run(
decoded_tokens=[], decoded_tokens=[],
complete_transfers=False, complete_transfers=False,
...@@ -263,7 +271,7 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling: ...@@ -263,7 +271,7 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling:
# start a new request to load the same first block # start a new request to load the same first block
runner.new_request(token_ids=[0] * offloaded_block_size) runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.lookup.return_value = 1 runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1
runner.run( runner.run(
decoded_tokens=[], decoded_tokens=[],
complete_transfers=False, complete_transfers=False,
...@@ -311,7 +319,7 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool): ...@@ -311,7 +319,7 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool):
# start a request to load the first block, but don't complete # start a request to load the first block, but don't complete
runner.scheduler.reset_prefix_cache() runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size) runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.lookup.return_value = 1 runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1
runner.run( runner.run(
decoded_tokens=[], decoded_tokens=[],
complete_transfers=False, complete_transfers=False,
...@@ -336,3 +344,137 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool): ...@@ -336,3 +344,137 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool):
# assert request is deleted # assert request is deleted
assert req_id not in runner.scheduler.requests assert req_id not in runner.scheduler.requests
# ---------------------------------------------------------------------------
# Unit tests for _maximal_prefix_lookup / _sliding_window_lookup
# ---------------------------------------------------------------------------
def _make_scheduler_with_lookup(
lookup_results: dict[int, bool | None],
) -> OffloadingConnectorScheduler:
"""Create an OffloadingConnectorScheduler with a mocked manager.lookup."""
manager = MagicMock(spec=OffloadingManager)
manager.lookup.side_effect = lambda key, req_context: lookup_results.get(
int(get_offload_block_hash(key).decode()), False
)
scheduler = object.__new__(OffloadingConnectorScheduler)
scheduler.manager = manager
return scheduler
_EMPTY_REQ_CTX = ReqContext()
class TestMaximalPrefixLookup:
def test_all_hit(self):
sched = _make_scheduler_with_lookup({1: True, 2: True})
assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 2
def test_all_miss(self):
sched = _make_scheduler_with_lookup({})
assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0
def test_partial_prefix(self):
sched = _make_scheduler_with_lookup({1: True, 2: True})
assert sched._maximal_prefix_lookup(to_keys([1, 2, 3]), _EMPTY_REQ_CTX) == 2
def test_miss_then_hit(self):
sched = _make_scheduler_with_lookup({2: True})
assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0
def test_single_hit(self):
sched = _make_scheduler_with_lookup({1: True})
assert sched._maximal_prefix_lookup(to_keys([1]), _EMPTY_REQ_CTX) == 1
def test_empty(self):
sched = _make_scheduler_with_lookup({})
assert sched._maximal_prefix_lookup([], _EMPTY_REQ_CTX) == 0
def test_none_defers(self):
sched = _make_scheduler_with_lookup({1: None, 2: True})
assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) is None
def test_none_after_hit_defers(self):
sched = _make_scheduler_with_lookup({1: True, 2: None})
assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) is None
def test_none_stops_at_miss(self):
"""None is treated as hit for iteration, but miss stops the scan."""
sched = _make_scheduler_with_lookup({1: None, 2: False, 3: True})
assert sched._maximal_prefix_lookup(to_keys([1, 2, 3]), _EMPTY_REQ_CTX) is None
# lookup should have been called for blocks 1 and 2 (stops at miss)
assert sched.manager.lookup.call_count == 2
class TestSlidingWindowLookup:
def test_all_hit_exact_window(self):
sched = _make_scheduler_with_lookup({1: True, 2: True})
assert sched._sliding_window_lookup(to_keys([1, 2]), 2, _EMPTY_REQ_CTX) == 2
def test_all_miss(self):
sched = _make_scheduler_with_lookup({})
assert sched._sliding_window_lookup(to_keys([1, 2, 3]), 1, _EMPTY_REQ_CTX) == 0
def test_window_at_end(self):
sched = _make_scheduler_with_lookup({2: True, 3: True})
assert sched._sliding_window_lookup(to_keys([1, 2, 3]), 2, _EMPTY_REQ_CTX) == 3
def test_window_in_middle(self):
sched = _make_scheduler_with_lookup({2: True, 3: True})
assert (
sched._sliding_window_lookup(to_keys([1, 2, 3, 4]), 2, _EMPTY_REQ_CTX) == 3
)
def test_no_full_window_falls_back_to_prefix(self):
sched = _make_scheduler_with_lookup({1: True, 2: True})
assert sched._sliding_window_lookup(to_keys([1, 2, 3]), 3, _EMPTY_REQ_CTX) == 2
def test_single_block_window(self):
sched = _make_scheduler_with_lookup({2: True, 3: True})
assert sched._sliding_window_lookup(to_keys([1, 2, 3]), 1, _EMPTY_REQ_CTX) == 3
def test_gap_resets_consecutive(self):
sched = _make_scheduler_with_lookup({2: True, 3: True, 4: True})
# [1, 2, 3, 0, 4] — gap at 0 resets, window of 2 found at [2,3]
assert (
sched._sliding_window_lookup(to_keys([1, 2, 3, 0, 4]), 2, _EMPTY_REQ_CTX)
== 3
)
def test_window_prefers_rightmost(self):
sched = _make_scheduler_with_lookup({1: True, 2: True, 4: True, 5: True})
# two valid windows: [1,2] at positions 0-1 and [4,5] at positions 3-4
# scans right-to-left, finds [4,5] first
assert (
sched._sliding_window_lookup(to_keys([1, 2, 3, 4, 5]), 2, _EMPTY_REQ_CTX)
== 5
)
def test_prefix_fallback_with_gap(self):
sched = _make_scheduler_with_lookup({2: True, 3: True, 4: True, 5: True})
# window of 4 not found contiguously (gap at 1)
assert (
sched._sliding_window_lookup(to_keys([2, 1, 3, 4, 5]), 4, _EMPTY_REQ_CTX)
== 1
)
def test_empty(self):
sched = _make_scheduler_with_lookup({})
assert sched._sliding_window_lookup([], 1, _EMPTY_REQ_CTX) == 0
def test_none_defers(self):
sched = _make_scheduler_with_lookup({1: True, 2: None})
assert sched._sliding_window_lookup(to_keys([1, 2]), 2, _EMPTY_REQ_CTX) is None
def test_none_with_full_window_still_defers(self):
"""Even if a real window is found after a None, result is deferred."""
# Scan right-to-left: 4(True), 3(None) resets, 2(True), 1(True) = window
# but block 3 was None so defer_lookup is set
sched = _make_scheduler_with_lookup({1: True, 2: True, 3: None, 4: True})
assert (
sched._sliding_window_lookup(to_keys([1, 2, 3, 4]), 2, _EMPTY_REQ_CTX)
is None
)
...@@ -56,8 +56,12 @@ from vllm.v1.request import Request ...@@ -56,8 +56,12 @@ from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
def to_keys(int_ids: list[int]) -> list[OffloadKey]: def to_key(int_hash: int) -> OffloadKey:
return [make_offload_key(str(i).encode(), 0) for i in int_ids] return make_offload_key(str(int_hash).encode(), 0)
def to_keys(int_hashes: list[int]) -> list[OffloadKey]:
return [to_key(i) for i in int_hashes]
class MockLoadStoreSpec(LoadStoreSpec): class MockLoadStoreSpec(LoadStoreSpec):
...@@ -116,6 +120,7 @@ class MockOffloadingSpec(OffloadingSpec): ...@@ -116,6 +120,7 @@ class MockOffloadingSpec(OffloadingSpec):
self.manager = MagicMock(spec=OffloadingManager) self.manager = MagicMock(spec=OffloadingManager)
self.manager.lookup.return_value = 0 self.manager.lookup.return_value = 0
self.manager.prepare_load = lambda keys, req_context: MockLoadStoreSpec(keys) self.manager.prepare_load = lambda keys, req_context: MockLoadStoreSpec(keys)
self.manager.lookup.return_value = False
self.handler = MockOffloadingHandler() self.handler = MockOffloadingHandler()
def get_manager(self) -> OffloadingManager: def get_manager(self) -> OffloadingManager:
...@@ -228,14 +233,14 @@ class RequestRunner: ...@@ -228,14 +233,14 @@ class RequestRunner:
self.scheduler_connector: OffloadingConnector = scheduler_connector self.scheduler_connector: OffloadingConnector = scheduler_connector
# extract mocked OffloadingManager of scheduler connector # extract mocked OffloadingManager of scheduler connector
connector_scheduler = scheduler_connector.connector_scheduler self.connector_scheduler = scheduler_connector.connector_scheduler
assert connector_scheduler is not None assert self.connector_scheduler is not None
manager = connector_scheduler.manager manager = self.connector_scheduler.manager
assert isinstance(manager, MagicMock) assert isinstance(manager, MagicMock)
self.manager: MagicMock = manager self.manager: MagicMock = manager
assert len(connector_scheduler.config.kv_group_configs) == 1 assert len(self.connector_scheduler.config.kv_group_configs) == 1
kv_group_config = connector_scheduler.config.kv_group_configs[0] kv_group_config = self.connector_scheduler.config.kv_group_configs[0]
assert kv_group_config.gpu_block_size == gpu_block_size assert kv_group_config.gpu_block_size == gpu_block_size
assert kv_group_config.offloaded_block_size == offloaded_block_size assert kv_group_config.offloaded_block_size == offloaded_block_size
......
...@@ -35,8 +35,12 @@ class ExpectedPrepareStoreOutput: ...@@ -35,8 +35,12 @@ class ExpectedPrepareStoreOutput:
evicted_keys: list[int] evicted_keys: list[int]
def to_keys(int_ids: list[int]) -> list[OffloadKey]: def to_key(int_hash: int) -> OffloadKey:
return [make_offload_key(str(i).encode(), 0) for i in int_ids] return make_offload_key(str(int_hash).encode(), 0)
def to_keys(int_hashes: list[int]) -> list[OffloadKey]:
return [to_key(i) for i in int_hashes]
def verify_store_output( def verify_store_output(
...@@ -136,7 +140,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy): ...@@ -136,7 +140,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy):
manager.complete_store(to_keys([2, 3, 4, 5])) manager.complete_store(to_keys([2, 3, 4, 5]))
# block 2 must still be present in the cache # block 2 must still be present in the cache
assert manager.lookup(to_keys([2]), _EMPTY_REQ_CTX) == 1 assert manager.lookup(to_key(2), _EMPTY_REQ_CTX) is True
def test_cpu_manager(): def test_cpu_manager():
...@@ -160,7 +164,8 @@ def test_cpu_manager(): ...@@ -160,7 +164,8 @@ def test_cpu_manager():
) )
# lookup [1, 2] -> not ready # lookup [1, 2] -> not ready
assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0 assert cpu_manager.lookup(to_key(1), _EMPTY_REQ_CTX) is False
assert cpu_manager.lookup(to_key(2), _EMPTY_REQ_CTX) is False
# no events so far # no events so far
assert list(cpu_manager.take_events()) == [] assert list(cpu_manager.take_events()) == []
...@@ -170,9 +175,9 @@ def test_cpu_manager(): ...@@ -170,9 +175,9 @@ def test_cpu_manager():
verify_events(cpu_manager.take_events(), expected_stores=({1, 2},)) verify_events(cpu_manager.take_events(), expected_stores=({1, 2},))
# lookup [1, 2] # lookup [1, 2]
assert cpu_manager.lookup(to_keys([1]), _EMPTY_REQ_CTX) == 1 assert cpu_manager.lookup(to_key(1), _EMPTY_REQ_CTX) is True
assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 2 assert cpu_manager.lookup(to_key(2), _EMPTY_REQ_CTX) is True
assert cpu_manager.lookup(to_keys([1, 2, 3]), _EMPTY_REQ_CTX) == 2 assert cpu_manager.lookup(to_key(3), _EMPTY_REQ_CTX) is False
# prepare store [2, 3, 4, 5] -> evicts [1] # prepare store [2, 3, 4, 5] -> evicts [1]
prepare_store_output = cpu_manager.prepare_store( prepare_store_output = cpu_manager.prepare_store(
...@@ -196,6 +201,14 @@ def test_cpu_manager(): ...@@ -196,6 +201,14 @@ def test_cpu_manager():
# complete store [2, 3, 4, 5] # complete store [2, 3, 4, 5]
cpu_manager.complete_store(to_keys([2, 3, 4, 5])) cpu_manager.complete_store(to_keys([2, 3, 4, 5]))
# lookup (now that we have [2, 3, 4, 5])
assert cpu_manager.lookup(to_key(1), _EMPTY_REQ_CTX) is False
assert cpu_manager.lookup(to_key(2), _EMPTY_REQ_CTX) is True
assert cpu_manager.lookup(to_key(3), _EMPTY_REQ_CTX) is True
assert cpu_manager.lookup(to_key(4), _EMPTY_REQ_CTX) is True
assert cpu_manager.lookup(to_key(5), _EMPTY_REQ_CTX) is True
assert cpu_manager.lookup(to_key(0), _EMPTY_REQ_CTX) is False
# prepare load [2, 3] # prepare load [2, 3]
prepare_load_output = cpu_manager.prepare_load(to_keys([2, 3]), _EMPTY_REQ_CTX) prepare_load_output = cpu_manager.prepare_load(to_keys([2, 3]), _EMPTY_REQ_CTX)
verify_load_output(prepare_load_output, [1, 2]) verify_load_output(prepare_load_output, [1, 2])
...@@ -238,8 +251,8 @@ def test_cpu_manager(): ...@@ -238,8 +251,8 @@ def test_cpu_manager():
cpu_manager.complete_store(to_keys([7, 9]), success=False) cpu_manager.complete_store(to_keys([7, 9]), success=False)
# assert [7] is still stored, but [9] is not # assert [7] is still stored, but [9] is not
assert cpu_manager.lookup(to_keys([7]), _EMPTY_REQ_CTX) == 1 assert cpu_manager.lookup(to_key(7), _EMPTY_REQ_CTX) is True
assert cpu_manager.lookup(to_keys([9]), _EMPTY_REQ_CTX) == 0 assert cpu_manager.lookup(to_key(9), _EMPTY_REQ_CTX) is False
verify_events( verify_events(
cpu_manager.take_events(), cpu_manager.take_events(),
...@@ -284,7 +297,8 @@ class TestARCPolicy: ...@@ -284,7 +297,8 @@ class TestARCPolicy:
) )
# lookup [1, 2] -> not ready # lookup [1, 2] -> not ready
assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0 assert cpu_manager.lookup(to_key(1), _EMPTY_REQ_CTX) is False
assert cpu_manager.lookup(to_key(2), _EMPTY_REQ_CTX) is False
# no events so far # no events so far
assert list(cpu_manager.take_events()) == [] assert list(cpu_manager.take_events()) == []
...@@ -294,9 +308,9 @@ class TestARCPolicy: ...@@ -294,9 +308,9 @@ class TestARCPolicy:
verify_events(cpu_manager.take_events(), expected_stores=({1, 2},)) verify_events(cpu_manager.take_events(), expected_stores=({1, 2},))
# lookup [1, 2] # lookup [1, 2]
assert cpu_manager.lookup(to_keys([1]), _EMPTY_REQ_CTX) == 1 assert cpu_manager.lookup(to_key(1), _EMPTY_REQ_CTX) is True
assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 2 assert cpu_manager.lookup(to_key(2), _EMPTY_REQ_CTX) is True
assert cpu_manager.lookup(to_keys([1, 2, 3]), _EMPTY_REQ_CTX) == 2 assert cpu_manager.lookup(to_key(3), _EMPTY_REQ_CTX) is False
# blocks should be in T1 (recent) # blocks should be in T1 (recent)
assert len(arc_policy.t1) == 2 assert len(arc_policy.t1) == 2
...@@ -500,7 +514,7 @@ class TestARCPolicy: ...@@ -500,7 +514,7 @@ class TestARCPolicy:
cpu_manager.complete_store(to_keys([5]), success=False) cpu_manager.complete_store(to_keys([5]), success=False)
# block 5 should not be in cache # block 5 should not be in cache
assert cpu_manager.lookup(to_keys([5]), _EMPTY_REQ_CTX) == 0 assert cpu_manager.lookup(to_key(5), _EMPTY_REQ_CTX) is False
# block 5 should not be in T1 or T2 # block 5 should not be in T1 or T2
assert to_keys([5])[0] not in arc_policy.t1 assert to_keys([5])[0] not in arc_policy.t1
assert to_keys([5])[0] not in arc_policy.t2 assert to_keys([5])[0] not in arc_policy.t2
...@@ -541,8 +555,8 @@ class TestARCPolicy: ...@@ -541,8 +555,8 @@ class TestARCPolicy:
cpu_manager.complete_store(to_keys([6])) cpu_manager.complete_store(to_keys([6]))
# verify blocks 2, 3 (in T2) are still present # verify blocks 2, 3 (in T2) are still present
assert cpu_manager.lookup(to_keys([2]), _EMPTY_REQ_CTX) == 1 assert cpu_manager.lookup(to_key(2), _EMPTY_REQ_CTX) is True
assert cpu_manager.lookup(to_keys([3]), _EMPTY_REQ_CTX) == 1 assert cpu_manager.lookup(to_key(3), _EMPTY_REQ_CTX) is True
# verify events # verify events
events = list(cpu_manager.take_events()) events = list(cpu_manager.take_events())
...@@ -562,7 +576,8 @@ def test_filter_reused_manager(): ...@@ -562,7 +576,8 @@ def test_filter_reused_manager():
) )
# Lookup [1, 2] -> 1st time, added to tracker but not eligible for store yet # Lookup [1, 2] -> 1st time, added to tracker but not eligible for store yet
assert manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0 assert manager.lookup(to_key(1), _EMPTY_REQ_CTX) is False
assert manager.lookup(to_key(2), _EMPTY_REQ_CTX) is False
# prepare store [1, 2] -> should be filtered # prepare store [1, 2] -> should be filtered
prepare_store_output = manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX) prepare_store_output = manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
...@@ -570,7 +585,7 @@ def test_filter_reused_manager(): ...@@ -570,7 +585,7 @@ def test_filter_reused_manager():
assert prepare_store_output.keys_to_store == [] assert prepare_store_output.keys_to_store == []
# Lookup [1] -> 2nd time, eligible now # Lookup [1] -> 2nd time, eligible now
assert manager.lookup(to_keys([1]), _EMPTY_REQ_CTX) == 0 assert manager.lookup(to_key(1), _EMPTY_REQ_CTX) is False
# prepare store [1, 2] -> [1] should be eligible, [2] should be filtered # prepare store [1, 2] -> [1] should be eligible, [2] should be filtered
prepare_store_output = manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX) prepare_store_output = manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX)
...@@ -579,12 +594,13 @@ def test_filter_reused_manager(): ...@@ -579,12 +594,13 @@ def test_filter_reused_manager():
# Lookup [3, 4] -> 1st time # Lookup [3, 4] -> 1st time
# (evicts [2] from tracker since max_size is 3 and tracker has [1]) # (evicts [2] from tracker since max_size is 3 and tracker has [1])
assert manager.lookup(to_keys([3, 4]), _EMPTY_REQ_CTX) == 0 assert manager.lookup(to_key(3), _EMPTY_REQ_CTX) is False
assert manager.lookup(to_key(4), _EMPTY_REQ_CTX) is False
# Verify [2] was evicted from the tracker (tracker now has: [1], [3], [4]) # Verify [2] was evicted from the tracker (tracker now has: [1], [3], [4])
assert to_keys([2])[0] not in manager.counts assert to_keys([2])[0] not in manager.counts
# Lookup [2] again -> (this adds [2] back to the tracker as 1st time) # Lookup [2] again -> (this adds [2] back to the tracker as 1st time)
assert manager.lookup(to_keys([2]), _EMPTY_REQ_CTX) == 0 assert manager.lookup(to_key(2), _EMPTY_REQ_CTX) is False
# Verify [2] was re-added with count=1 (not eligible yet) # Verify [2] was re-added with count=1 (not eligible yet)
assert manager.counts.get(to_keys([2])[0]) == 1 assert manager.counts.get(to_keys([2])[0]) == 1
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import islice from itertools import islice
from typing import Any, NamedTuple from typing import Any, NamedTuple
...@@ -132,6 +132,49 @@ class OffloadingConnectorScheduler: ...@@ -132,6 +132,49 @@ class OffloadingConnectorScheduler:
self._reqs_being_stored = defaultdict[ReqId, set[OffloadKey]](set) self._reqs_being_stored = defaultdict[ReqId, set[OffloadKey]](set)
self._reqs_being_loaded = defaultdict[ReqId, set[OffloadKey]](set) self._reqs_being_loaded = defaultdict[ReqId, set[OffloadKey]](set)
def _maximal_prefix_lookup(
self, keys: Iterable[OffloadKey], req_context: ReqContext
) -> int | None:
"""Find the length of the maximal prefix of offloaded blocks."""
hit_count = 0
defer_lookup = False
for key in keys:
result = self.manager.lookup(key, req_context)
if result is None:
defer_lookup = True
# continue lookup to allow manager to kick-off async lookups
# for all blocks (until a miss is detected)
result = True
if not result:
break
hit_count += 1
return hit_count if not defer_lookup else None
def _sliding_window_lookup(
self,
keys: Sequence[OffloadKey],
sliding_window_size: int,
req_context: ReqContext,
) -> int | None:
"""Find the maximal ending position of consecutive offloaded blocks
within a sliding window."""
defer_lookup = False
consecutive_hits = 0
for idx in range(len(keys) - 1, -1, -1):
result = self.manager.lookup(keys[idx], req_context)
if result is None:
defer_lookup = True
# continue lookup to allow manager to kick-off async lookups
# for all blocks (until a hit is detected)
result = False
if not result:
consecutive_hits = 0
else:
consecutive_hits += 1
if consecutive_hits == sliding_window_size:
return idx + sliding_window_size if not defer_lookup else None
return consecutive_hits if not defer_lookup else None
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, request: Request, num_computed_tokens: int self, request: Request, num_computed_tokens: int
) -> tuple[int | None, bool]: ) -> tuple[int | None, bool]:
...@@ -184,9 +227,10 @@ class OffloadingConnectorScheduler: ...@@ -184,9 +227,10 @@ class OffloadingConnectorScheduler:
return 0, False return 0, False
start_block_idx = num_computed_tokens // group_config.offloaded_block_size start_block_idx = num_computed_tokens // group_config.offloaded_block_size
hits = self.manager.lookup( # Full attention relays on all previous KV cache blocks.
offload_keys[start_block_idx:], # Thus, we search for a maximal prefix of KV cache which are all cached.
req_status.req_context, hits = self._maximal_prefix_lookup(
offload_keys[start_block_idx:], req_status.req_context
) )
if hits is None: if hits is None:
# indicates a lookup that should be tried later # indicates a lookup that should be tried later
......
...@@ -7,8 +7,7 @@ This class runs in the scheduler, tracks which blocks are offloaded ...@@ -7,8 +7,7 @@ This class runs in the scheduler, tracks which blocks are offloaded
and their address. and their address.
The class provides the following primitives: The class provides the following primitives:
lookup() - find the length of the maximal series of blocks, lookup() - check whether a single block is offloaded and ready.
starting from the first one, that are all offloaded.
prepare_load() - prepare given blocks to be read. prepare_load() - prepare given blocks to be read.
The given blocks will be protected from eviction. The given blocks will be protected from eviction.
This function returns a LoadSpec which encapsulates This function returns a LoadSpec which encapsulates
...@@ -91,23 +90,18 @@ class OffloadingEvent: ...@@ -91,23 +90,18 @@ class OffloadingEvent:
class OffloadingManager(ABC): class OffloadingManager(ABC):
@abstractmethod @abstractmethod
def lookup( def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None:
self,
keys: Iterable[OffloadKey],
req_context: ReqContext,
) -> int | None:
""" """
Finds the length of the maximal series of blocks, starting from the Checks whether a single block is offloaded and ready to be read.
first one, that are all offloaded.
Args: Args:
keys: the keys identifying the blocks to lookup. key: the key identifying the block to lookup.
req_context: per-request context (e.g. kv_transfer_params). req_context: per-request context (e.g. kv_transfer_params).
Returns: Returns:
An integer representing the maximal number of blocks that True if the block is offloaded and ready, False if not,
are currently offloaded, or None if the lookup should be retried or None if the lookup should be retried later.
later. Returning None will delay the request handling by the vLLM Returning None will delay the request handling by the vLLM
scheduler. scheduler.
""" """
pass pass
......
...@@ -84,18 +84,9 @@ class CPUOffloadingManager(OffloadingManager): ...@@ -84,18 +84,9 @@ class CPUOffloadingManager(OffloadingManager):
# --- OffloadingManager interface --- # --- OffloadingManager interface ---
def lookup( def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None:
self, block = self._policy.get(key)
keys: Iterable[OffloadKey], return block is not None and block.is_ready
req_context: ReqContext,
) -> int | None:
hit_count = 0
for key in keys:
block = self._policy.get(key)
if block is None or not block.is_ready:
break
hit_count += 1
return hit_count
def prepare_load( def prepare_load(
self, self,
......
...@@ -27,8 +27,9 @@ class FilterReusedOffloadingManager(OffloadingManager): ...@@ -27,8 +27,9 @@ class FilterReusedOffloadingManager(OffloadingManager):
All methods are delegated to the *backing* manager. Two methods are All methods are delegated to the *backing* manager. Two methods are
intercepted: intercepted:
* ``lookup`` — records each visited key in an internal LRU counter.
* ``prepare_store`` — filters out keys that have not yet * ``prepare_store`` — filters out keys that have not yet
* ``lookup`` — records the visited key in an internal LRU
counter, then delegates to the backing manager.
crossed the threshold *before* calling the backing crossed the threshold *before* calling the backing
``prepare_store``. ``prepare_store``.
...@@ -66,18 +67,16 @@ class FilterReusedOffloadingManager(OffloadingManager): ...@@ -66,18 +67,16 @@ class FilterReusedOffloadingManager(OffloadingManager):
# Intercepted methods # Intercepted methods
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def lookup(self, keys: Iterable[OffloadKey], req_context: ReqContext) -> int | None: def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None:
"""Record each key, then delegate lookup to backing manager.""" """Record the key, then delegate lookup to backing manager."""
keys = list(keys) if key in self.counts:
for key in keys: self.counts.move_to_end(key)
if key in self.counts: self.counts[key] += 1
self.counts.move_to_end(key) else:
self.counts[key] += 1 if len(self.counts) >= self.max_tracker_size:
else: self.counts.popitem(last=False) # evict LRU
if len(self.counts) >= self.max_tracker_size: self.counts[key] = 1
self.counts.popitem(last=False) # evict LRU return self._backing.lookup(key, req_context)
self.counts[key] = 1
return self._backing.lookup(keys, req_context)
def prepare_store( def prepare_store(
self, keys: Iterable[OffloadKey], req_context: ReqContext self, keys: Iterable[OffloadKey], req_context: ReqContext
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment