Unverified Commit 4c16ba61 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[KVConnector] OffloadingConnector: Fix bug in handling of preemptions (#29870)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent bde57ab2
...@@ -64,8 +64,11 @@ class MockLoadStoreSpec(LoadStoreSpec): ...@@ -64,8 +64,11 @@ class MockLoadStoreSpec(LoadStoreSpec):
class MockOffloadingHandler(OffloadingHandler): class MockOffloadingHandler(OffloadingHandler):
def __init__(self): def __init__(self):
self.transfer_specs: dict[int, TransferSpec] = {}
self.completed_transfers: list[TransferResult] = [] self.completed_transfers: list[TransferResult] = []
self.completed_specs: list[TransferSpec] = [] self.waiting_jobs: set[int] = set()
self.completed_jobs: list[int] = []
self.flushed_jobs: set[int] = set()
def get_finished(self) -> list[TransferResult]: def get_finished(self) -> list[TransferResult]:
finished = self.completed_transfers finished = self.completed_transfers
...@@ -73,10 +76,21 @@ class MockOffloadingHandler(OffloadingHandler): ...@@ -73,10 +76,21 @@ class MockOffloadingHandler(OffloadingHandler):
return finished return finished
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
self.completed_specs.append(spec) self.transfer_specs[job_id] = spec
self.completed_transfers.append((job_id, True)) self.waiting_jobs.add(job_id)
return True return True
def complete_jobs(self, job_ids: set[int]) -> None:
for job_id in job_ids:
if job_id in self.waiting_jobs:
self.waiting_jobs.remove(job_id)
self.completed_jobs.append(job_id)
self.completed_transfers.append((job_id, True))
def wait(self, job_ids: set[int]) -> None:
self.flushed_jobs |= job_ids
self.complete_jobs(job_ids)
class MockOffloadingSpec(OffloadingSpec): class MockOffloadingSpec(OffloadingSpec):
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
...@@ -98,9 +112,22 @@ class MockOffloadingSpec(OffloadingSpec): ...@@ -98,9 +112,22 @@ class MockOffloadingSpec(OffloadingSpec):
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
def complete_transfers(self):
self.handler.complete_jobs(self.handler.waiting_jobs.copy())
def get_completed_transfers(self) -> list[TransferSpec]: def get_completed_transfers(self) -> list[TransferSpec]:
specs = self.handler.completed_specs specs = [
self.handler.completed_specs = [] self.handler.transfer_specs[job_id]
for job_id in self.handler.completed_jobs
]
self.handler.completed_jobs.clear()
return specs
def get_flushed_transfers(self):
specs = [
self.handler.transfer_specs[job_id] for job_id in self.handler.flushed_jobs
]
self.handler.flushed_jobs.clear()
return specs return specs
...@@ -170,12 +197,9 @@ class RequestRunner: ...@@ -170,12 +197,9 @@ class RequestRunner:
# mapping (offloading address) -> gpu_block_index # mapping (offloading address) -> gpu_block_index
self.offloaded: dict[Any, int] = {} self.offloaded: dict[Any, int] = {}
self.pending_loads_count: int = 0
self.pending_stores_count: int = 0
self.unsubmitted_stores_count = 0
self.completed_loads: list[TransferSummary] = [] self.completed_loads: list[TransferSummary] = []
self.completed_stores: list[TransferSummary] = [] self.completed_stores: list[TransferSummary] = []
self.flushed_gpu_block_indexes: set[int] = set()
# maps {block_id: block_offset} # maps {block_id: block_offset}
self.gpu_block_index: dict[int, int] = {} self.gpu_block_index: dict[int, int] = {}
...@@ -202,10 +226,18 @@ class RequestRunner: ...@@ -202,10 +226,18 @@ class RequestRunner:
self.scheduler.add_request(req) self.scheduler.add_request(req)
def _wait_for_transfers(self): def _parse_transfers(self):
for transfer_spec in self.offloading_spec.get_flushed_transfers():
src_spec, dst_spec = transfer_spec
assert isinstance(src_spec, GPULoadStoreSpec)
for block_id in src_spec.block_ids:
self.flushed_gpu_block_indexes.add(
self.gpu_block_index[block_id.item()]
)
block_size_factor = self.offloaded_block_size // self.gpu_block_size block_size_factor = self.offloaded_block_size // self.gpu_block_size
while self.pending_loads_count or self.pending_stores_count:
for transfer_spec in self.offloading_spec.get_completed_transfers(): for transfer_spec in self.offloading_spec.get_completed_transfers():
src_spec, dst_spec = transfer_spec src_spec, dst_spec = transfer_spec
...@@ -237,7 +269,6 @@ class RequestRunner: ...@@ -237,7 +269,6 @@ class RequestRunner:
self.completed_stores.append( self.completed_stores.append(
TransferSummary(gpu_block_indices, offload_addresses) TransferSummary(gpu_block_indices, offload_addresses)
) )
self.pending_stores_count -= 1
else: else:
remainder_sub_block_count = len(offload_addresses) - len( remainder_sub_block_count = len(offload_addresses) - len(
gpu_block_indices gpu_block_indices
...@@ -249,7 +280,6 @@ class RequestRunner: ...@@ -249,7 +280,6 @@ class RequestRunner:
self.completed_loads.append( self.completed_loads.append(
TransferSummary(gpu_block_indices, offload_addresses) TransferSummary(gpu_block_indices, offload_addresses)
) )
self.pending_loads_count -= 1
def _update_gpu_block_idx(self): def _update_gpu_block_idx(self):
for blocks in self.scheduler.kv_cache_manager.coordinator.single_type_managers[ for blocks in self.scheduler.kv_cache_manager.coordinator.single_type_managers[
...@@ -258,18 +288,19 @@ class RequestRunner: ...@@ -258,18 +288,19 @@ class RequestRunner:
for block_idx, block in enumerate(blocks): for block_idx, block in enumerate(blocks):
self.gpu_block_index[block.block_id] = block_idx self.gpu_block_index[block.block_id] = block_idx
def _run(self, decoded_tokens: list[int]): def _run(self, decoded_tokens: list[int], complete_transfers: bool):
""" """
Runs multiple engine (scheduler + worker) steps. Runs multiple engine (scheduler + worker) steps.
Assumes a single request is running. Assumes a single request is running.
Args: Args:
decoded_tokens: the tokens to yield at each step. decoded_tokens: the tokens to yield at each step.
complete_transfers: complete transfers immediately
""" """
tokens_iter = iter(decoded_tokens) tokens_iter = iter(decoded_tokens)
token_id = next(tokens_iter, None) token_id = next(tokens_iter, None)
while token_id is not None: while True:
assert self.scheduler.requests assert self.scheduler.requests
scheduler_output = self.scheduler.schedule() scheduler_output = self.scheduler.schedule()
...@@ -279,10 +310,10 @@ class RequestRunner: ...@@ -279,10 +310,10 @@ class RequestRunner:
assert kv_connector_metadata is not None assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata) assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load) if scheduler_output.preempted_req_ids:
self.worker_connector.handle_preemptions(
self.pending_stores_count += self.unsubmitted_stores_count scheduler_output.preempted_req_ids
self.unsubmitted_stores_count = len(kv_connector_metadata.reqs_to_store) )
self.worker_connector.bind_connector_metadata(kv_connector_metadata) self.worker_connector.bind_connector_metadata(kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx) self.worker_connector.start_load_kv(self._dummy_ctx)
...@@ -290,6 +321,9 @@ class RequestRunner: ...@@ -290,6 +321,9 @@ class RequestRunner:
if scheduler_output.total_num_scheduled_tokens > 0: if scheduler_output.total_num_scheduled_tokens > 0:
self.worker_connector.wait_for_save() self.worker_connector.wait_for_save()
if complete_transfers:
self.offloading_spec.complete_transfers()
finished_sending, finished_recving = self.worker_connector.get_finished( finished_sending, finished_recving = self.worker_connector.get_finished(
scheduler_output.finished_req_ids scheduler_output.finished_req_ids
) )
...@@ -300,7 +334,7 @@ class RequestRunner: ...@@ -300,7 +334,7 @@ class RequestRunner:
reqs=self.scheduler.running, reqs=self.scheduler.running,
finished_sending=finished_sending, finished_sending=finished_sending,
finished_recving=finished_recving, finished_recving=finished_recving,
token_id=token_id, token_id=token_id or 0,
) )
if self.scheduler.running: if self.scheduler.running:
...@@ -308,7 +342,10 @@ class RequestRunner: ...@@ -308,7 +342,10 @@ class RequestRunner:
self.scheduler.update_from_output(scheduler_output, model_runner_output) self.scheduler.update_from_output(scheduler_output, model_runner_output)
self._wait_for_transfers() if token_id is None:
break
self._parse_transfers()
# run one more step to update finished stored # run one more step to update finished stored
if EOS_TOKEN_ID in decoded_tokens: if EOS_TOKEN_ID in decoded_tokens:
...@@ -333,8 +370,10 @@ class RequestRunner: ...@@ -333,8 +370,10 @@ class RequestRunner:
def run( def run(
self, self,
decoded_tokens: list[int], decoded_tokens: list[int],
complete_transfers: bool = True,
expected_stored_gpu_block_indexes: tuple[int, ...] = (), expected_stored_gpu_block_indexes: tuple[int, ...] = (),
expected_loaded_gpu_block_indexes: tuple[int, ...] = (), expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
expected_flushed_gpu_block_indexes: tuple[int, ...] = (),
): ):
""" """
Runs multiple engine (scheduler + worker) steps. Runs multiple engine (scheduler + worker) steps.
...@@ -342,14 +381,17 @@ class RequestRunner: ...@@ -342,14 +381,17 @@ class RequestRunner:
Args: Args:
decoded_tokens: the tokens to yield at each step. decoded_tokens: the tokens to yield at each step.
complete_transfers: complete transfers immediately
expected_stored_gpu_block_indexes: GPU block indexes expected_stored_gpu_block_indexes: GPU block indexes
that are expected to be written during the run. that are expected to be written during the run.
expected_loaded_gpu_block_indexes: GPU block indexes expected_loaded_gpu_block_indexes: GPU block indexes
that are expected to be loaded during the run. that are expected to be loaded during the run.
expected_flushed_gpu_block_indexes: GPU block indexes
that are expected to be flushed during the run.
""" """
self.manager.reset_mock() self.manager.reset_mock()
self._run(decoded_tokens) self._run(decoded_tokens, complete_transfers)
loaded_gpu_block_indexes: set[int] = set() loaded_gpu_block_indexes: set[int] = set()
for transfer in self.completed_loads: for transfer in self.completed_loads:
...@@ -373,6 +415,9 @@ class RequestRunner: ...@@ -373,6 +415,9 @@ class RequestRunner:
assert set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes assert set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes
self.completed_stores.clear() self.completed_stores.clear()
assert set(expected_flushed_gpu_block_indexes) == self.flushed_gpu_block_indexes
self.flushed_gpu_block_indexes.clear()
@pytest.fixture @pytest.fixture
def request_runner(): def request_runner():
...@@ -539,3 +584,69 @@ def test_offloading_connector(request_runner): ...@@ -539,3 +584,69 @@ def test_offloading_connector(request_runner):
assert isinstance(event, BlockRemoved) assert isinstance(event, BlockRemoved)
assert event.block_hashes == to_hashes([4, 5, 6]) assert event.block_hashes == to_hashes([4, 5, 6])
assert event.medium == "B" assert event.medium == "B"
def test_request_preemption(request_runner):
offloaded_block_size = 12
gpu_block_size = 4
num_gpu_blocks = 100
runner = request_runner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks,
)
free_block_queue = runner.scheduler.kv_cache_manager.block_pool.free_block_queue
num_free_blocks_empty = free_block_queue.num_free_blocks
# 2 blocks, store all, without flushing
# blocks = [0, 1, 2], [3, 4, 5]
runner.new_request(token_ids=[0] * offloaded_block_size * 2)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0],
complete_transfers=False,
)
# decode 2 more blocks - 1 gpu block, storing [6, 7, 8] (no flush)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0] * (2 * offloaded_block_size - gpu_block_size),
complete_transfers=False,
)
# simulate KV cache running out of space
free_block_queue.num_free_blocks = 0
# request should be preempted now
runner.run(
decoded_tokens=[],
complete_transfers=False,
expected_flushed_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
expected_stored_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
)
# restore KV cache space and reset GPU prefix cache
free_block_queue.num_free_blocks = num_free_blocks_empty
runner.scheduler.reset_prefix_cache()
# request should now return from preemption
# re-load [0, ..., 8] from the CPU and store [9, 10, 11]
runner.manager.lookup.return_value = 3
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0] * gpu_block_size,
expected_loaded_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(9, 10, 11),
)
...@@ -63,6 +63,12 @@ class OffloadingHandler1To2(OffloadingHandler): ...@@ -63,6 +63,12 @@ class OffloadingHandler1To2(OffloadingHandler):
del self.transfers[job_id] del self.transfers[job_id]
return finished return finished
def wait(self, job_ids: set[int]) -> None:
for job_id in job_ids:
spec = self.transfers.get(job_id)
if spec:
assert spec.finished
class OffloadingHandler2To1(OffloadingHandler): class OffloadingHandler2To1(OffloadingHandler):
def __init__(self): def __init__(self):
...@@ -84,6 +90,12 @@ class OffloadingHandler2To1(OffloadingHandler): ...@@ -84,6 +90,12 @@ class OffloadingHandler2To1(OffloadingHandler):
del self.transfers[job_id] del self.transfers[job_id]
return finished return finished
def wait(self, job_ids: set[int]) -> None:
for job_id in job_ids:
spec = self.transfers.get(job_id)
if spec:
assert spec.finished
def test_offloading_worker(): def test_offloading_worker():
""" """
......
...@@ -25,6 +25,9 @@ The class provides the following primitives: ...@@ -25,6 +25,9 @@ The class provides the following primitives:
Worker-side: runs in each worker, loads/saves KV cache to/from Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata. the Connector based on the metadata.
handle_preemptions() - called if there are preempted requests,
before their blocks are overwritten
start_load_kv() - starts loading all KVs (maybe async) start_load_kv() - starts loading all KVs (maybe async)
wait_for_layer_load() - blocks until layer i load is done wait_for_layer_load() - blocks until layer i load is done
...@@ -262,6 +265,13 @@ class KVConnectorBase_V1(ABC): ...@@ -262,6 +265,13 @@ class KVConnectorBase_V1(ABC):
""" """
return return
def handle_preemptions(self, preempted_req_ids: set[str]):
"""
Handle preempted requests BEFORE their blocks are overwritten.
Needed for connectors which use async saves (e.g., OffloadingConnector)
"""
return
@abstractmethod @abstractmethod
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
""" """
......
...@@ -75,6 +75,10 @@ class OffloadingConnector(KVConnectorBase_V1): ...@@ -75,6 +75,10 @@ class OffloadingConnector(KVConnectorBase_V1):
assert self.connector_worker is not None assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend) self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
def handle_preemptions(self, preempted_req_ids: set[str]):
assert self.connector_worker is not None
self.connector_worker.handle_preemptions(preempted_req_ids)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None assert self.connector_worker is not None
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata) assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
...@@ -348,6 +352,15 @@ class OffloadingConnectorScheduler: ...@@ -348,6 +352,15 @@ class OffloadingConnectorScheduler:
reqs_to_store=self._get_reqs_to_store(scheduler_output), reqs_to_store=self._get_reqs_to_store(scheduler_output),
) )
self._reqs_to_load = {} self._reqs_to_load = {}
# NOTE (orozery): we should move this logic to update_connector_output
# once KVConnectorOutput allows us to report completed transfers
for req_id in scheduler_output.preempted_req_ids or ():
block_hashes = self._reqs_being_stored.get(req_id)
if block_hashes:
self.manager.complete_store(block_hashes)
block_hashes.clear()
return meta return meta
def update_connector_output(self, connector_output: KVConnectorOutput): def update_connector_output(self, connector_output: KVConnectorOutput):
...@@ -466,6 +479,17 @@ class OffloadingConnectorWorker: ...@@ -466,6 +479,17 @@ class OffloadingConnectorWorker:
attn_backends = {cross_layer_name: attn_backend} attn_backends = {cross_layer_name: attn_backend}
self._register_handlers(kv_caches, attn_backends) self._register_handlers(kv_caches, attn_backends)
def handle_preemptions(self, preempted_req_ids: set[str]):
for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec)
assert success
self._unsubmitted_store_jobs.clear()
for req_id in preempted_req_ids:
job_ids = self._store_jobs.get(req_id)
if job_ids:
self.worker.wait(job_ids)
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata): def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
for job_id, transfer_spec in self._unsubmitted_store_jobs: for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec) success = self.worker.transfer_async(job_id, transfer_spec)
......
...@@ -96,6 +96,8 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): ...@@ -96,6 +96,8 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
assert len(src_tensors) > 0 assert len(src_tensors) > 0
self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda
# job_id -> event
self._transfer_events: dict[int, torch.Event] = {}
# queue of transfers (job_id, stream, event) # queue of transfers (job_id, stream, event)
self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque() self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque()
# list of CUDA streams available for re-use # list of CUDA streams available for re-use
...@@ -152,6 +154,7 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): ...@@ -152,6 +154,7 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor) ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
event.record(stream) event.record(stream)
self._transfer_events[job_id] = event
self._transfers.append((job_id, stream, event)) self._transfers.append((job_id, stream, event))
# success # success
...@@ -164,8 +167,15 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): ...@@ -164,8 +167,15 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
results.append((job_id, True)) results.append((job_id, True))
self._stream_pool.append(stream) self._stream_pool.append(stream)
self._event_pool.append(event) self._event_pool.append(event)
del self._transfer_events[job_id]
return results return results
def wait(self, job_ids: set[int]):
for job_id in job_ids:
event = self._transfer_events.get(job_id)
if event is not None:
event.synchronize()
class CpuGpuOffloadingHandlers: class CpuGpuOffloadingHandlers:
def __init__( def __init__(
......
...@@ -53,6 +53,15 @@ class OffloadingHandler(ABC): ...@@ -53,6 +53,15 @@ class OffloadingHandler(ABC):
""" """
pass pass
@abstractmethod
def wait(self, job_ids: set[int]) -> None:
"""
Wait for jobs to finish (blocking).
Args:
job_ids: The set of job IDs to wait for.
"""
class OffloadingWorker: class OffloadingWorker:
""" """
...@@ -142,3 +151,13 @@ class OffloadingWorker: ...@@ -142,3 +151,13 @@ class OffloadingWorker:
for handler in self.handlers: for handler in self.handlers:
finished.extend(handler.get_finished()) finished.extend(handler.get_finished())
return finished return finished
def wait(self, job_ids: set[int]) -> None:
"""
Wait for jobs to finish (blocking).
Args:
job_ids: The set of job IDs to wait for.
"""
for handler in self.handlers:
handler.wait(job_ids)
...@@ -3112,6 +3112,11 @@ class GPUModelRunner( ...@@ -3112,6 +3112,11 @@ class GPUModelRunner(
"after execute_model() returns None." "after execute_model() returns None."
) )
if scheduler_output.preempted_req_ids and has_kv_transfer_group():
get_kv_transfer_group().handle_preemptions(
scheduler_output.preempted_req_ids
)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with ( with (
record_function_or_nullcontext("gpu_model_runner: preprocess"), record_function_or_nullcontext("gpu_model_runner: preprocess"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment