Unverified Commit cc079763 authored by David Ben-David's avatar David Ben-David Committed by GitHub
Browse files

[BugFix] Avoid calling KV connector layer APIs when metadata is unset (#28253)


Signed-off-by: default avatarDavid Ben-David <davidb@pliops.com>
Co-authored-by: default avatarDavid Ben-David <davidb@pliops.com>
Co-authored-by: default avatarMark McLoughlin <markmc@redhat.com>
parent a7adbc6c
...@@ -837,6 +837,8 @@ def wait_for_kv_layer_from_connector(layer_name: str): ...@@ -837,6 +837,8 @@ def wait_for_kv_layer_from_connector(layer_name: str):
return return
connector = get_kv_transfer_group() connector = get_kv_transfer_group()
if not connector.has_connector_metadata():
return
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
...@@ -854,6 +856,8 @@ def maybe_save_kv_layer_to_connector( ...@@ -854,6 +856,8 @@ def maybe_save_kv_layer_to_connector(
return return
connector = get_kv_transfer_group() connector = get_kv_transfer_group()
if not connector.has_connector_metadata():
return
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
......
...@@ -204,11 +204,18 @@ class KVConnectorBase_V1(ABC): ...@@ -204,11 +204,18 @@ class KVConnectorBase_V1(ABC):
Returns: Returns:
ConnectorMetadata: the connector metadata. ConnectorMetadata: the connector metadata.
""" """
# Should only be called while set to valid metadata. # Should only be called while set to valid metadata.
assert self._connector_metadata is not None assert self._connector_metadata is not None
return self._connector_metadata return self._connector_metadata
def has_connector_metadata(self) -> bool:
"""Check whether the connector metadata is currently set.
Returns:
bool: True if connector metadata exists, False otherwise.
"""
return self._connector_metadata is not None
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
""" """
Initialize with the KV caches. Useful for pre-registering the Initialize with the KV caches. Useful for pre-registering the
......
...@@ -171,16 +171,22 @@ class MultiConnector(KVConnectorBase_V1): ...@@ -171,16 +171,22 @@ class MultiConnector(KVConnectorBase_V1):
# We must override the base class method here because we need to bind # We must override the base class method here because we need to bind
# the metadata to each connector in the order of the connectors in the # the metadata to each connector in the order of the connectors in the
# MultiKVConnectorMetadata. # MultiKVConnectorMetadata.
#
# Note: Call the base class method to ensure metadata is also set on the
# MultiConnector instance itself; otherwise, `has_connector_metadata()` will
# always return False.
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None: def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
assert isinstance(connector_metadata, MultiKVConnectorMetadata) assert isinstance(connector_metadata, MultiKVConnectorMetadata)
if connector_metadata.extra_async_saves: if connector_metadata.extra_async_saves:
self._extra_async_saves.update(connector_metadata.extra_async_saves) self._extra_async_saves.update(connector_metadata.extra_async_saves)
for c, cm in zip(self._connectors, connector_metadata.metadata): for c, cm in zip(self._connectors, connector_metadata.metadata):
c.bind_connector_metadata(cm) c.bind_connector_metadata(cm)
super().bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None: def clear_connector_metadata(self) -> None:
for c in self._connectors: for c in self._connectors:
c.clear_connector_metadata() c.clear_connector_metadata()
super().clear_connector_metadata()
def shutdown(self): def shutdown(self):
exception: Exception | None = None exception: Exception | None = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment