Unverified Commit 1db4f47f authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix multi async save in MultiConnector (#18246)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent d3d91b6f
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import copy import copy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
import torch import torch
...@@ -21,9 +22,10 @@ if TYPE_CHECKING: ...@@ -21,9 +22,10 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...], @dataclass
KVConnectorMetadata): class MultiKVConnectorMetadata(KVConnectorMetadata):
pass metadata: tuple[KVConnectorMetadata, ...]
extra_async_saves: Optional[dict[str, int]] = None
class MultiConnector(KVConnectorBase_V1): class MultiConnector(KVConnectorBase_V1):
...@@ -54,6 +56,7 @@ class MultiConnector(KVConnectorBase_V1): ...@@ -54,6 +56,7 @@ class MultiConnector(KVConnectorBase_V1):
# Keeps track of *additional* remaining async saves (beyond 1) to be # Keeps track of *additional* remaining async saves (beyond 1) to be
# finished per request. Not needed for async loads since we only allow # finished per request. Not needed for async loads since we only allow
# a single connector to load. # a single connector to load.
# Propagated from scheduler to worker side via the connector metadata.
self._extra_async_saves: dict[str, int] = {} self._extra_async_saves: dict[str, int] = {}
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
...@@ -66,7 +69,10 @@ class MultiConnector(KVConnectorBase_V1): ...@@ -66,7 +69,10 @@ class MultiConnector(KVConnectorBase_V1):
def bind_connector_metadata( def bind_connector_metadata(
self, connector_metadata: KVConnectorMetadata) -> None: self, connector_metadata: KVConnectorMetadata) -> None:
assert isinstance(connector_metadata, MultiKVConnectorMetadata) assert isinstance(connector_metadata, MultiKVConnectorMetadata)
for c, cm in zip(self._connectors, connector_metadata): if connector_metadata.extra_async_saves:
self._extra_async_saves.update(
connector_metadata.extra_async_saves)
for c, cm in zip(self._connectors, connector_metadata.metadata):
c.bind_connector_metadata(cm) c.bind_connector_metadata(cm)
def clear_connector_metadata(self) -> None: def clear_connector_metadata(self) -> None:
...@@ -152,8 +158,13 @@ class MultiConnector(KVConnectorBase_V1): ...@@ -152,8 +158,13 @@ class MultiConnector(KVConnectorBase_V1):
def build_connector_meta( def build_connector_meta(
self, self,
scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata:
return MultiKVConnectorMetadata( metadata = MultiKVConnectorMetadata(metadata=tuple(
c.build_connector_meta(scheduler_output) for c in self._connectors) c.build_connector_meta(scheduler_output)
for c in self._connectors))
if self._extra_async_saves:
metadata.extra_async_saves = self._extra_async_saves
self._extra_async_saves = {}
return metadata
def request_finished( def request_finished(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment