Unverified Commit 574ad60d authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[KVConnector] Always call connector `clear_metadata()` at end of step (#20756)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
Co-authored-by: default avatarDavid Ben-David <sdavidbd@gmail.com>
parent fdadb6f4
...@@ -57,7 +57,7 @@ class KVConnectorRole(enum.Enum): ...@@ -57,7 +57,7 @@ class KVConnectorRole(enum.Enum):
WORKER = 1 WORKER = 1
class KVConnectorMetadata: class KVConnectorMetadata(ABC): # noqa: B024
""" """
Abstract Metadata used to communicate between the Abstract Metadata used to communicate between the
Scheduler KVConnector and Worker KVConnector. Scheduler KVConnector and Worker KVConnector.
...@@ -71,7 +71,7 @@ class KVConnectorBase_V1(ABC): ...@@ -71,7 +71,7 @@ class KVConnectorBase_V1(ABC):
logger.warning( logger.warning(
"Initializing KVConnectorBase_V1. This API is experimental and " "Initializing KVConnectorBase_V1. This API is experimental and "
"subject to change in the future as we iterate the design.") "subject to change in the future as we iterate the design.")
self._connector_metadata = KVConnectorMetadata() self._connector_metadata: Optional[KVConnectorMetadata] = None
self._vllm_config = vllm_config self._vllm_config = vllm_config
self._role = role self._role = role
...@@ -102,7 +102,7 @@ class KVConnectorBase_V1(ABC): ...@@ -102,7 +102,7 @@ class KVConnectorBase_V1(ABC):
This function should be called by the model runner every time This function should be called by the model runner every time
after the model execution. after the model execution.
""" """
self._connector_metadata = KVConnectorMetadata() self._connector_metadata = None
def _get_connector_metadata(self) -> KVConnectorMetadata: def _get_connector_metadata(self) -> KVConnectorMetadata:
"""Get the connector metadata. """Get the connector metadata.
...@@ -112,6 +112,9 @@ class KVConnectorBase_V1(ABC): ...@@ -112,6 +112,9 @@ class KVConnectorBase_V1(ABC):
Returns: Returns:
ConnectorMetadata: the connector metadata. ConnectorMetadata: the connector metadata.
""" """
# Should only be called while set to valid metadata.
assert self._connector_metadata is not None
return self._connector_metadata return self._connector_metadata
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
......
...@@ -250,28 +250,24 @@ class MultiprocExecutor(Executor): ...@@ -250,28 +250,24 @@ class MultiprocExecutor(Executor):
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput: self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
# aggregate finished_sending, finished_recving from all workers # aggregate finished_sending, finished_recving from all workers
finished_sending = set[str]() def update_finished_set(req_ids: Optional[set[str]],
finished_recving = set[str]() remaining_count_dict: dict[str, int],
for output in outputs: finished_set: set[str]) -> None:
# update finished_sending for req_id in req_ids or ():
for req_id in output.finished_sending or []: new_count = remaining_count_dict[req_id] - 1
new_count = self._send_remaining_count[req_id] - 1
if new_count == 0: if new_count == 0:
# got response from all workers, report back to scheduler finished_set.add(req_id)
finished_sending.add(req_id) del remaining_count_dict[req_id]
del self._send_remaining_count[req_id]
else: else:
self._send_remaining_count[req_id] = new_count remaining_count_dict[req_id] = new_count
# update finished_recving finished_sending = set[str]()
for req_id in output.finished_recving or []: finished_recving = set[str]()
new_count = self._recv_remaining_count[req_id] - 1 for output in outputs:
if new_count == 0: update_finished_set(output.finished_sending,
# got response from all workers, report back to scheduler self._send_remaining_count, finished_sending)
finished_recving.add(req_id) update_finished_set(output.finished_recving,
del self._recv_remaining_count[req_id] self._recv_remaining_count, finished_recving)
else:
self._recv_remaining_count[req_id] = new_count
# select output of the worker specified by output_rank # select output of the worker specified by output_rank
output = outputs[self.output_rank] output = outputs[self.output_rank]
......
...@@ -1539,10 +1539,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1539,10 +1539,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata, attn_metadata,
) )
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
self.eplb_step() self.eplb_step()
return ModelRunnerOutput( return ModelRunnerOutput(
......
...@@ -338,6 +338,10 @@ class Worker(WorkerBase): ...@@ -338,6 +338,10 @@ class Worker(WorkerBase):
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending output.finished_sending = finished_sending
output.finished_recving = finished_recving output.finished_recving = finished_recving
# Clear KVConnector state for this step.
get_kv_transfer_group().clear_connector_metadata()
# with a connector, the scheduler expects output from all workers # with a connector, the scheduler expects output from all workers
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment