Unverified Commit 7013e9ac authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

OffloadingConnector: Prevent redundant loads (#29087)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent c78ee240
...@@ -213,7 +213,6 @@ class RequestRunner: ...@@ -213,7 +213,6 @@ class RequestRunner:
) )
def new_request(self, token_ids: list[int]): def new_request(self, token_ids: list[int]):
assert not self.scheduler.requests
self.req_id += 1 self.req_id += 1
req = Request( req = Request(
...@@ -338,11 +337,20 @@ class RequestRunner: ...@@ -338,11 +337,20 @@ class RequestRunner:
token_id=token_id or 0, token_id=token_id or 0,
) )
prev_token_id = token_id
if self.scheduler.running: if self.scheduler.running:
token_id = next(tokens_iter, None) token_id = next(tokens_iter, None)
self.scheduler.update_from_output(scheduler_output, model_runner_output) self.scheduler.update_from_output(scheduler_output, model_runner_output)
if (
prev_token_id is EOS_TOKEN_ID
and prev_token_id != token_id
and self.scheduler.requests
):
# continue for one more step to allow offloading to kick off
continue
if token_id is None: if token_id is None:
break break
...@@ -651,3 +659,61 @@ def test_request_preemption(request_runner): ...@@ -651,3 +659,61 @@ def test_request_preemption(request_runner):
decoded_tokens=[EOS_TOKEN_ID], decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(9, 10, 11), expected_stored_gpu_block_indexes=(9, 10, 11),
) )
def test_concurrent_lookups_of_the_same_prefix(request_runner):
offloaded_block_size = 12
gpu_block_size = 4
num_gpu_blocks = 100
runner = request_runner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks,
)
# store 1 blocks
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(0, 1, 2),
)
# start a request to load the first block, but don't complete
runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.lookup.return_value = 1
runner.run(
decoded_tokens=[],
complete_transfers=False,
)
# request triggered a load
transfer_jobs = list(runner.offloading_spec.handler.transfer_specs)
assert transfer_jobs
# start a new request to load the same first block
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.lookup.return_value = 1
runner.run(
decoded_tokens=[],
complete_transfers=False,
)
# request did not trigger a load
assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs)
# complete transfers
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output([])
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_loaded_gpu_block_indexes=(0, 1, 2),
)
# second request will use the GPU prefix cache
assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs)
...@@ -107,7 +107,7 @@ class OffloadingConnector(KVConnectorBase_V1): ...@@ -107,7 +107,7 @@ class OffloadingConnector(KVConnectorBase_V1):
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]: ) -> tuple[int | None, bool]:
assert self.connector_scheduler is not None assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens( return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens request, num_computed_tokens
...@@ -161,6 +161,11 @@ class OffloadingConnectorScheduler: ...@@ -161,6 +161,11 @@ class OffloadingConnectorScheduler:
# request blocks are stored in order # request blocks are stored in order
# index of next block (of size offloaded_block_size) to offload # index of next block (of size offloaded_block_size) to offload
self._next_stored_block_idx: dict[ReqId, int] = {} self._next_stored_block_idx: dict[ReqId, int] = {}
# if GPU prefix caching is enabled,
# track loaded blocks to avoid redundant loads
self._blocks_being_loaded: set[BlockHash] | None = (
set() if spec.vllm_config.cache_config.enable_prefix_caching else None
)
# request ID -> set(block hashes being stored/load) # request ID -> set(block hashes being stored/load)
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set) self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
...@@ -181,7 +186,7 @@ class OffloadingConnectorScheduler: ...@@ -181,7 +186,7 @@ class OffloadingConnectorScheduler:
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, request: Request, num_computed_tokens: int self, request: Request, num_computed_tokens: int
) -> tuple[int, bool]: ) -> tuple[int | None, bool]:
""" """
Get number of new tokens that can be loaded beyond the Get number of new tokens that can be loaded beyond the
num_computed_tokens. num_computed_tokens.
...@@ -195,6 +200,9 @@ class OffloadingConnectorScheduler: ...@@ -195,6 +200,9 @@ class OffloadingConnectorScheduler:
A tuple with the following elements: A tuple with the following elements:
- The number of tokens that can be loaded beyond what is - The number of tokens that can be loaded beyond what is
already computed. already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if tokens will be loaded asynchronously - `True` if tokens will be loaded asynchronously
(between scheduler steps). (between scheduler steps).
""" """
...@@ -214,6 +222,9 @@ class OffloadingConnectorScheduler: ...@@ -214,6 +222,9 @@ class OffloadingConnectorScheduler:
hits = self.manager.lookup( hits = self.manager.lookup(
self._get_block_hashes(request, start_idx=start_block_idx) self._get_block_hashes(request, start_idx=start_block_idx)
) )
if hits is None:
# indicates a lookup that should be tried later
return None, False
if hits == 0: if hits == 0:
return 0, False return 0, False
...@@ -229,6 +240,22 @@ class OffloadingConnectorScheduler: ...@@ -229,6 +240,22 @@ class OffloadingConnectorScheduler:
if num_hit_tokens < self.offloaded_block_size: if num_hit_tokens < self.offloaded_block_size:
return 0, False return 0, False
if self._blocks_being_loaded:
block_hashes = self._get_block_hashes(
request, start_idx=start_block_idx, end_idx=start_block_idx + hits
)
if any(
block_hash in self._blocks_being_loaded for block_hash in block_hashes
):
# hit blocks are being loaded, delay request
logger.debug(
"Delaying request %s since some of its blocks are already"
" being loaded",
request.request_id,
)
return None, False
return num_hit_tokens, True return num_hit_tokens, True
def update_state_after_alloc( def update_state_after_alloc(
...@@ -270,9 +297,13 @@ class OffloadingConnectorScheduler: ...@@ -270,9 +297,13 @@ class OffloadingConnectorScheduler:
) )
self._reqs_to_load[request.request_id] = (src_spec, dst_spec) self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
self._reqs_being_loaded[request.request_id].update(block_hashes) req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
req_blocks_being_loaded.update(block_hashes)
self._next_stored_block_idx[request.request_id] = num_blocks self._next_stored_block_idx[request.request_id] = num_blocks
if self._blocks_being_loaded is not None:
self._blocks_being_loaded.update(req_blocks_being_loaded)
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
reqs_to_store: dict[ReqId, TransferSpec] = {} reqs_to_store: dict[ReqId, TransferSpec] = {}
# iterate over both new and cached requests # iterate over both new and cached requests
...@@ -379,6 +410,8 @@ class OffloadingConnectorScheduler: ...@@ -379,6 +410,8 @@ class OffloadingConnectorScheduler:
for req_id in connector_output.finished_recving or []: for req_id in connector_output.finished_recving or []:
block_hashes = self._reqs_being_loaded.pop(req_id, None) block_hashes = self._reqs_being_loaded.pop(req_id, None)
if block_hashes: if block_hashes:
if self._blocks_being_loaded:
self._blocks_being_loaded.difference_update(block_hashes)
self.manager.complete_load(block_hashes) self.manager.complete_load(block_hashes)
def request_finished( def request_finished(
......
...@@ -68,7 +68,7 @@ class OffloadingEvent: ...@@ -68,7 +68,7 @@ class OffloadingEvent:
class OffloadingManager(ABC): class OffloadingManager(ABC):
@abstractmethod @abstractmethod
def lookup(self, block_hashes: Iterable[BlockHash]) -> int: def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None:
""" """
Finds the length of the maximal series of blocks, starting from the Finds the length of the maximal series of blocks, starting from the
first one, that are all offloaded. first one, that are all offloaded.
...@@ -78,7 +78,9 @@ class OffloadingManager(ABC): ...@@ -78,7 +78,9 @@ class OffloadingManager(ABC):
Returns: Returns:
An integer representing the maximal number of blocks that An integer representing the maximal number of blocks that
are currently offloaded. are currently offloaded, or None if the lookup should be retried
later. Returning None will delay the request handling by the vLLM
scheduler.
""" """
pass pass
......
...@@ -63,7 +63,7 @@ class ARCOffloadingManager(OffloadingManager): ...@@ -63,7 +63,7 @@ class ARCOffloadingManager(OffloadingManager):
self.events: list[OffloadingEvent] | None = [] if enable_events else None self.events: list[OffloadingEvent] | None = [] if enable_events else None
self.cache_capacity: int = self.backend.get_num_free_blocks() self.cache_capacity: int = self.backend.get_num_free_blocks()
def lookup(self, block_hashes: Iterable[BlockHash]) -> int: def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None:
hit_count = 0 hit_count = 0
for block_hash in block_hashes: for block_hash in block_hashes:
block = self.t1.get(block_hash) or self.t2.get(block_hash) block = self.t1.get(block_hash) or self.t2.get(block_hash)
......
...@@ -24,7 +24,7 @@ class LRUOffloadingManager(OffloadingManager): ...@@ -24,7 +24,7 @@ class LRUOffloadingManager(OffloadingManager):
self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict() self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict()
self.events: list[OffloadingEvent] | None = [] if enable_events else None self.events: list[OffloadingEvent] | None = [] if enable_events else None
def lookup(self, block_hashes: Iterable[BlockHash]) -> int: def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None:
hit_count = 0 hit_count = 0
for block_hash in block_hashes: for block_hash in block_hashes:
block = self.blocks.get(block_hash) block = self.blocks.get(block_hash)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment