Unverified Commit b89443b8 authored by Kfir Toledo's avatar Kfir Toledo Committed by GitHub
Browse files

[KVConnector]: Enable Cross-layers KV cache layout for MultiConnector (#30761)


Signed-off-by: default avatarKfir Toledo <kfir.toledo@ibm.com>
parent 1d9e9ae8
...@@ -49,6 +49,33 @@ class MockConnector(KVConnectorBase_V1): ...@@ -49,6 +49,33 @@ class MockConnector(KVConnectorBase_V1):
) -> KVConnectorStats | None: ) -> KVConnectorStats | None:
return MockConnectorStats(data=data) if data is not None else None return MockConnectorStats(data=data) if data is not None else None
def start_load_kv(self, forward_context, **kwargs):
pass
def wait_for_layer_load(self, layer_name):
pass
def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
pass
def wait_for_save(self):
pass
def build_connector_meta(self, scheduler_output):
return None
def get_num_new_matched_tokens(self, request, num_computed_tokens):
return (0, False)
def update_state_after_alloc(self, request, blocks, num_tokens) -> None:
pass
class MockCrossLayerConnector(MockConnector):
@property
def prefer_cross_layer_blocks(self) -> bool:
return True
# Register the mock connector # Register the mock connector
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__) KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
...@@ -601,3 +628,21 @@ class TestMultiConnectorStats: ...@@ -601,3 +628,21 @@ class TestMultiConnectorStats:
# One non-empty # One non-empty
stats.data["NixlConnector"].data["transfer_duration"].append(1.0) stats.data["NixlConnector"].data["transfer_duration"].append(1.0)
assert not stats.is_empty() assert not stats.is_empty()
class TestMultiConnectorPreferCrossLayerBlocks:
def test_all_connectors_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
]
assert mc.prefer_cross_layer_blocks is True
def test_mixed_connectors_do_not_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockConnector.__new__(MockConnector), # default False
]
assert mc.prefer_cross_layer_blocks is False
...@@ -38,7 +38,7 @@ The class provides the following primitives: ...@@ -38,7 +38,7 @@ The class provides the following primitives:
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional from typing import TYPE_CHECKING, Any, Literal, Optional
import torch import torch
...@@ -144,15 +144,15 @@ class KVConnectorMetadata(ABC): # noqa: B024 ...@@ -144,15 +144,15 @@ class KVConnectorMetadata(ABC): # noqa: B024
class KVConnectorBase_V1(ABC): class KVConnectorBase_V1(ABC):
""" """
Base class for KV connectors. Base class for KV connectors.
Attributes:
prefer_cross_layer_blocks (bool): Indicates whether this connector
prefers KV blocks that hold KV data for all layers (for speeding
up KV data transfers).
Defaults to False.
""" """
prefer_cross_layer_blocks: ClassVar[bool] = False @property
def prefer_cross_layer_blocks(self) -> bool:
"""
Indicates whether this connector prefers KV blocks that hold KV data for all
layers, which can speed up KV data transfers. Defaults to False.
"""
return False
def __init__( def __init__(
self, self,
......
...@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any ...@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
...@@ -138,6 +138,12 @@ class MultiConnector(KVConnectorBase_V1): ...@@ -138,6 +138,12 @@ class MultiConnector(KVConnectorBase_V1):
# Propagated from scheduler to worker side via the connector metadata. # Propagated from scheduler to worker side via the connector metadata.
self._extra_async_saves: dict[str, int] = {} self._extra_async_saves: dict[str, int] = {}
@property
def prefer_cross_layer_blocks(self) -> bool:
if not self._connectors:
return False
return all(c.prefer_cross_layer_blocks for c in self._connectors)
@classmethod @classmethod
def _get_connector_classes_and_configs( def _get_connector_classes_and_configs(
cls, vllm_config: "VllmConfig" cls, vllm_config: "VllmConfig"
...@@ -164,6 +170,13 @@ class MultiConnector(KVConnectorBase_V1): ...@@ -164,6 +170,13 @@ class MultiConnector(KVConnectorBase_V1):
) )
return ret return ret
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
# Register on all connectors
for c in self._connectors:
c.register_cross_layers_kv_cache(kv_cache, attn_backend)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for c in self._connectors: for c in self._connectors:
c.register_kv_caches(kv_caches) c.register_kv_caches(kv_caches)
......
...@@ -4,7 +4,7 @@ from collections import defaultdict ...@@ -4,7 +4,7 @@ from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from itertools import islice from itertools import islice
from typing import Any, ClassVar from typing import Any
import torch import torch
...@@ -44,7 +44,9 @@ class OffloadingConnectorMetadata(KVConnectorMetadata): ...@@ -44,7 +44,9 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):
class OffloadingConnector(KVConnectorBase_V1): class OffloadingConnector(KVConnectorBase_V1):
prefer_cross_layer_blocks: ClassVar[bool] = True @property
def prefer_cross_layer_blocks(self) -> bool:
return True
def __init__( def __init__(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment