Unverified Commit fcf0687b authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[kv_offload+HMA][0/N]: Support block-level preemption handling (#34805)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
Co-authored-by: default avatarNicolò Lucchesi <nlucches@redhat.com>
parent 86b7e3c9
...@@ -231,10 +231,11 @@ def test_multi_example_connector_consistency(): ...@@ -231,10 +231,11 @@ def test_multi_example_connector_consistency():
] ]
# First three events are from initialization (register_kv_caches, # First three events are from initialization (register_kv_caches,
# set_host_xfer_buffer_ops, get_handshake_metadata), then generate() events. # set_host_xfer_buffer_ops, get_handshake_metadata), then generate() events.
assert events["storage1-WORKER"][:7] == [ assert events["storage1-WORKER"][:8] == [
"register_kv_caches", "register_kv_caches",
"set_host_xfer_buffer_ops", "set_host_xfer_buffer_ops",
"get_handshake_metadata", "get_handshake_metadata",
"handle_preemptions",
"bind_connector_metadata", "bind_connector_metadata",
"start_load_kv", "start_load_kv",
"wait_for_layer_load", "wait_for_layer_load",
...@@ -246,10 +247,11 @@ def test_multi_example_connector_consistency(): ...@@ -246,10 +247,11 @@ def test_multi_example_connector_consistency():
"update_state_after_alloc num_blocks=[0] 0", "update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta", "build_connector_meta",
] ]
assert events["storage2-WORKER"][:7] == [ assert events["storage2-WORKER"][:8] == [
"register_kv_caches", "register_kv_caches",
"set_host_xfer_buffer_ops", "set_host_xfer_buffer_ops",
"get_handshake_metadata", "get_handshake_metadata",
"handle_preemptions",
"bind_connector_metadata", "bind_connector_metadata",
"start_load_kv", "start_load_kv",
"wait_for_layer_load", "wait_for_layer_load",
...@@ -399,8 +401,8 @@ def test_multi_connector_handle_preemptions_integration(): ...@@ -399,8 +401,8 @@ def test_multi_connector_handle_preemptions_integration():
# testing the delegation behavior of MultiConnector here. # testing the delegation behavior of MultiConnector here.
# The connector attribute contains the KV connector. # The connector attribute contains the KV connector.
assert scheduler.connector is not None, "Scheduler should have a connector" assert scheduler.connector is not None, "Scheduler should have a connector"
preempted_req_ids = {"req-1", "req-2", "req-3"} connector_md = scheduler.connector.build_connector_meta(scheduler.schedule())
scheduler.connector.handle_preemptions(preempted_req_ids) scheduler.connector.handle_preemptions(connector_md)
# Verify both connectors received the handle_preemptions call # Verify both connectors received the handle_preemptions call
events = get_connector_events() events = get_connector_events()
......
...@@ -363,10 +363,7 @@ class RequestRunner: ...@@ -363,10 +363,7 @@ class RequestRunner:
assert kv_connector_metadata is not None assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata) assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
if scheduler_output.preempted_req_ids: self.worker_connector.handle_preemptions(kv_connector_metadata)
self.worker_connector.handle_preemptions(
scheduler_output.preempted_req_ids
)
self.worker_connector.bind_connector_metadata(kv_connector_metadata) self.worker_connector.bind_connector_metadata(kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx) self.worker_connector.start_load_kv(self._dummy_ctx)
......
...@@ -25,8 +25,8 @@ The class provides the following primitives: ...@@ -25,8 +25,8 @@ The class provides the following primitives:
Worker-side: runs in each worker, loads/saves KV cache to/from Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata. the Connector based on the metadata.
handle_preemptions() - called if there are preempted requests, handle_preemptions() - called for handling preempted requests
before their blocks are overwritten or request evicted blocks before they are overwritten
start_load_kv() - starts loading all KVs (maybe async) start_load_kv() - starts loading all KVs (maybe async)
wait_for_layer_load() - blocks until layer i load is done wait_for_layer_load() - blocks until layer i load is done
...@@ -288,9 +288,9 @@ class KVConnectorBase_V1(ABC): ...@@ -288,9 +288,9 @@ class KVConnectorBase_V1(ABC):
""" """
return return
def handle_preemptions(self, preempted_req_ids: set[str]): def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata):
""" """
Handle preempted requests BEFORE their blocks are overwritten. Handle preempted requests or evicted blocks BEFORE they are overwritten.
Needed for connectors which use async saves (e.g., OffloadingConnector) Needed for connectors which use async saves (e.g., OffloadingConnector)
""" """
return return
......
...@@ -315,10 +315,11 @@ class MultiConnector(KVConnectorBase_V1): ...@@ -315,10 +315,11 @@ class MultiConnector(KVConnectorBase_V1):
for c in self._connectors: for c in self._connectors:
c.set_host_xfer_buffer_ops(copy_operation) c.set_host_xfer_buffer_ops(copy_operation)
def handle_preemptions(self, preempted_req_ids: set[str]): def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata):
"""Handle preempted requests for all sub-connectors.""" """Handle preempted requests for all sub-connectors."""
for c in self._connectors: assert isinstance(kv_connector_metadata, MultiKVConnectorMetadata)
c.handle_preemptions(preempted_req_ids) for c, cm in zip(self._connectors, kv_connector_metadata.metadata):
c.handle_preemptions(cm)
def get_finished_count(self) -> int | None: def get_finished_count(self) -> int | None:
# TODO(https://github.com/vllm-project/vllm/issues/33400) # TODO(https://github.com/vllm-project/vllm/issues/33400)
......
...@@ -111,6 +111,7 @@ class OffloadingConnectorStats(KVConnectorStats): ...@@ -111,6 +111,7 @@ class OffloadingConnectorStats(KVConnectorStats):
class OffloadingConnectorMetadata(KVConnectorMetadata): class OffloadingConnectorMetadata(KVConnectorMetadata):
reqs_to_load: dict[ReqId, TransferSpec] reqs_to_load: dict[ReqId, TransferSpec]
reqs_to_store: dict[ReqId, TransferSpec] reqs_to_store: dict[ReqId, TransferSpec]
reqs_to_flush: set[str] | None = None
class OffloadingConnector(KVConnectorBase_V1): class OffloadingConnector(KVConnectorBase_V1):
...@@ -146,9 +147,10 @@ class OffloadingConnector(KVConnectorBase_V1): ...@@ -146,9 +147,10 @@ class OffloadingConnector(KVConnectorBase_V1):
assert self.connector_worker is not None assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend) self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
def handle_preemptions(self, preempted_req_ids: set[str]): def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata):
assert self.connector_worker is not None assert self.connector_worker is not None
self.connector_worker.handle_preemptions(preempted_req_ids) assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
self.connector_worker.handle_preemptions(kv_connector_metadata)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None assert self.connector_worker is not None
...@@ -482,6 +484,7 @@ class OffloadingConnectorScheduler: ...@@ -482,6 +484,7 @@ class OffloadingConnectorScheduler:
meta = OffloadingConnectorMetadata( meta = OffloadingConnectorMetadata(
reqs_to_load=self._reqs_to_load, reqs_to_load=self._reqs_to_load,
reqs_to_store=self._get_reqs_to_store(scheduler_output), reqs_to_store=self._get_reqs_to_store(scheduler_output),
reqs_to_flush=scheduler_output.preempted_req_ids,
) )
self._reqs_to_load = {} self._reqs_to_load = {}
...@@ -619,13 +622,13 @@ class OffloadingConnectorWorker: ...@@ -619,13 +622,13 @@ class OffloadingConnectorWorker:
attn_backends = {cross_layer_name: attn_backend} attn_backends = {cross_layer_name: attn_backend}
self._register_handlers(kv_caches, attn_backends) self._register_handlers(kv_caches, attn_backends)
def handle_preemptions(self, preempted_req_ids: set[str]): def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata):
for job_id, transfer_spec in self._unsubmitted_store_jobs: for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec) success = self.worker.transfer_async(job_id, transfer_spec)
assert success assert success
self._unsubmitted_store_jobs.clear() self._unsubmitted_store_jobs.clear()
for req_id in preempted_req_ids: for req_id in kv_connector_metadata.reqs_to_flush or ():
job_ids = self._store_jobs.get(req_id) job_ids = self._store_jobs.get(req_id)
if job_ids: if job_ids:
self.worker.wait(job_ids) self.worker.wait(job_ids)
......
...@@ -63,11 +63,10 @@ class ActiveKVConnector(KVConnector): ...@@ -63,11 +63,10 @@ class ActiveKVConnector(KVConnector):
if self._disabled: if self._disabled:
return return
if scheduler_output.preempted_req_ids:
self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids)
kv_connector_metadata = scheduler_output.kv_connector_metadata kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None assert kv_connector_metadata is not None
self.kv_connector.bind_connector_metadata(kv_connector_metadata) self.kv_connector.bind_connector_metadata(kv_connector_metadata)
self.kv_connector.handle_preemptions(kv_connector_metadata)
# TODO: sort out KV Connectors' use of forward_context # TODO: sort out KV Connectors' use of forward_context
if is_forward_context_available(): if is_forward_context_available():
......
...@@ -3594,10 +3594,10 @@ class GPUModelRunner( ...@@ -3594,10 +3594,10 @@ class GPUModelRunner(
scheduled_spec_decode_tokens=spec_decode_tokens_copy, scheduled_spec_decode_tokens=spec_decode_tokens_copy,
) )
if scheduler_output.preempted_req_ids and has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().handle_preemptions( kv_connector_metadata = scheduler_output.kv_connector_metadata
scheduler_output.preempted_req_ids assert kv_connector_metadata is not None
) get_kv_transfer_group().handle_preemptions(kv_connector_metadata)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with ( with (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment