Unverified Commit 8e326908 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[KV Connector][BugFix] scheduler: Delay freeing blocks of aborted async loads (#32255)



Fixes a not-yet-reported case where it was possible for blocks to be
freed by an abort before an async transfer completed, resulting
in corrupted KV data.
Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent a2084395
......@@ -3473,3 +3473,52 @@ def test_prepend_skipped_requests_order():
# verify waiting order is preserved
assert list(scheduler.waiting) == expected_waiting_reqs
def test_abort_request_waiting_for_remote_kvs():
scheduler = create_scheduler(use_kv_connector=True)
# add a single request
request = create_requests(num_requests=1)[0]
scheduler.add_request(request)
# set request to waiting for remote KVs, and abort it
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
scheduler.finish_requests((request.request_id,), RequestStatus.FINISHED_ABORTED)
assert request.status == RequestStatus.FINISHED_ABORTED
# verify request is not deleted
assert request.request_id in scheduler.requests
# finish recving request
scheduler_output = scheduler.schedule()
model_runner_output = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
kv_connector_output=KVConnectorOutput(finished_recving={request.request_id}),
)
scheduler.update_from_output(scheduler_output, model_runner_output)
# assert request is deleted
assert request.request_id not in scheduler.requests
assert not scheduler.finished_recving_kv_req_ids
def test_abort_request_finished_recving():
scheduler = create_scheduler(use_kv_connector=True)
# add a single request
request = create_requests(num_requests=1)[0]
scheduler.add_request(request)
# set request to waiting for remote KVs, finished but not yet updated
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
scheduler.finished_recving_kv_req_ids.add(request.request_id)
# abort request
scheduler.finish_requests((request.request_id,), RequestStatus.FINISHED_ABORTED)
assert request.status == RequestStatus.FINISHED_ABORTED
# verify request is deleted
assert request.request_id not in scheduler.requests
assert not scheduler.finished_recving_kv_req_ids
......@@ -42,7 +42,7 @@ from vllm.v1.kv_offload.worker.worker import (
TransferSpec,
)
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import Request
from vllm.v1.request import Request, RequestStatus
from .utils import (
EOS_TOKEN_ID,
......@@ -355,7 +355,7 @@ class RequestRunner:
self.scheduler.update_from_output(scheduler_output, model_runner_output)
if (
prev_token_id is EOS_TOKEN_ID
prev_token_id == EOS_TOKEN_ID
and prev_token_id != token_id
and self.scheduler.requests
):
......@@ -730,6 +730,57 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner):
assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs)
def test_abort_loading_requests(request_runner):
offloaded_block_size = 12
gpu_block_size = 4
num_gpu_blocks = 100
runner = request_runner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks,
)
# store 1 blocks
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(0, 1, 2),
)
# start a request to load the first block, but don't complete
runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.lookup.return_value = 1
runner.run(
decoded_tokens=[],
complete_transfers=False,
)
# request triggered a load
transfer_jobs = list(runner.offloading_spec.handler.transfer_specs)
assert transfer_jobs
# abort request
req_id = str(runner.req_id)
runner.scheduler.finish_requests((req_id,), RequestStatus.FINISHED_ABORTED)
# verify request is not deleted
assert req_id in runner.scheduler.requests
# complete loading request
runner.run(
decoded_tokens=[],
expected_loaded_gpu_block_indexes=(0, 1, 2),
)
# assert request is deleted
assert req_id not in runner.scheduler.requests
class TestOffloadingConnectorStats:
"""Tests for OffloadingConnector stats reconstruction and operations."""
......
......@@ -1670,19 +1670,30 @@ class Scheduler(SchedulerInterface):
# Second pass: set status and free requests
for request in valid_requests:
delay_free_blocks = False
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
delay_free_blocks = (
request.request_id not in self.finished_recving_kv_req_ids
)
self.finished_recving_kv_req_ids.discard(request.request_id)
self.failed_recving_kv_req_ids.discard(request.request_id)
request.status = finished_status
self._free_request(request)
self._free_request(request, delay_free_blocks=delay_free_blocks)
def _free_request(self, request: Request) -> dict[str, Any] | None:
def _free_request(
self, request: Request, delay_free_blocks: bool = False
) -> dict[str, Any] | None:
assert request.is_finished()
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
connector_delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request)
request_id = request.request_id
self.finished_req_ids.add(request_id)
if self.finished_req_ids_dict is not None:
self.finished_req_ids_dict[request.client_index].add(request_id)
delay_free_blocks |= connector_delay_free_blocks
if not delay_free_blocks:
self._free_blocks(request)
......@@ -1954,7 +1965,13 @@ class Scheduler(SchedulerInterface):
# KV Connector:: update recv and send status from last step.
for req_id in kv_connector_output.finished_recving or ():
logger.debug("Finished recving KV transfer for request %s", req_id)
self.finished_recving_kv_req_ids.add(req_id)
assert req_id in self.requests
req = self.requests[req_id]
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
self.finished_recving_kv_req_ids.add(req_id)
else:
assert RequestStatus.is_finished(req.status)
self._free_blocks(self.requests[req_id])
for req_id in kv_connector_output.finished_sending or ():
logger.debug("Finished sending KV transfer for request %s", req_id)
assert req_id in self.requests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment