Unverified Commit 5b681074 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Misc][PD] Fix `get_attn_backend` usage in transfer connectors (#31988)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 8fb2c135
...@@ -1439,7 +1439,7 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): ...@@ -1439,7 +1439,7 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper, patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
patch(f"{nixl_module}.threading.Event"), patch(f"{nixl_module}.threading.Event"),
patch(f"{nixl_module}.threading.Thread") as mock_thread, patch(f"{nixl_module}.threading.Thread") as mock_thread,
patch(f"{nixl_module}.get_attn_backend") as mock_get_attn_backend, patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend,
): ):
# Ensure get_attn_backend returns the correct value due to # Ensure get_attn_backend returns the correct value due to
# _cached_get_attn_backend returning the backend from previous # _cached_get_attn_backend returning the backend from previous
......
...@@ -6,13 +6,14 @@ KV cache helper for store. ...@@ -6,13 +6,14 @@ KV cache helper for store.
from collections.abc import Iterator from collections.abc import Iterator
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Any, Literal, cast
import torch import torch
from vllm.config import get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
...@@ -433,3 +434,26 @@ class TpKVTopology: ...@@ -433,3 +434,26 @@ class TpKVTopology:
) -> list[int]: ) -> list[int]:
remote_tp_size = self.remote_tp_size[remote_engine_id] remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_ranks(remote_tp_size) return self.get_target_remote_ranks(remote_tp_size)
def get_current_attn_backend(vllm_config: VllmConfig):
layer_type = cast(type[Any], AttentionLayerBase)
layers = get_layers_from_vllm_config(vllm_config, layer_type, None)
if layers:
backend = next(iter(layers.values())).get_attn_backend()
else:
# Fallback for tests, when static_forward_context is empty.
logger.debug(
"No layers found in the vLLM config. "
"Falling back to default attention backend."
)
from vllm.v1.attention.selector import get_attn_backend
backend = get_attn_backend(
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
block_size=vllm_config.cache_config.block_size,
use_mla=vllm_config.model_config.use_mla,
)
return backend
...@@ -16,7 +16,10 @@ import zmq.asyncio ...@@ -16,7 +16,10 @@ import zmq.asyncio
from vllm import envs from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology from vllm.distributed.kv_transfer.kv_connector.utils import (
TpKVTopology,
get_current_attn_backend,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
KVConnectorMetadata, KVConnectorMetadata,
...@@ -32,7 +35,6 @@ from vllm.logger import init_logger ...@@ -32,7 +35,6 @@ from vllm.logger import init_logger
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
...@@ -468,13 +470,9 @@ class MooncakeConnectorWorker: ...@@ -468,13 +470,9 @@ class MooncakeConnectorWorker:
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.use_mla = self.model_config.use_mla self.use_mla = self.model_config.use_mla
backend = get_attn_backend( # Get the attention backend from the first layer
self.model_config.get_head_size(), # NOTE (NickLucche) models with multiple backends are not supported yet
self.model_config.dtype, backend = get_current_attn_backend(vllm_config)
self.cache_config.cache_dtype,
self.block_size,
use_mla=self.use_mla,
)
self.backend_name = backend.get_name() self.backend_name = backend.get_name()
self.kv_cache_layout = get_kv_cache_layout() self.kv_cache_layout = get_kv_cache_layout()
logger.debug("Detected attention backend %s", self.backend_name) logger.debug("Detected attention backend %s", self.backend_name)
......
...@@ -24,6 +24,7 @@ from vllm.config import VllmConfig ...@@ -24,6 +24,7 @@ from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId, EngineId,
TpKVTopology, TpKVTopology,
get_current_attn_backend,
kv_postprocess_blksize_and_layout_on_receive, kv_postprocess_blksize_and_layout_on_receive,
kv_postprocess_blksize_on_receive, kv_postprocess_blksize_on_receive,
kv_postprocess_layout_on_receive, kv_postprocess_layout_on_receive,
...@@ -53,7 +54,6 @@ from vllm.platforms import current_platform ...@@ -53,7 +54,6 @@ from vllm.platforms import current_platform
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
...@@ -957,13 +957,10 @@ class NixlConnectorWorker: ...@@ -957,13 +957,10 @@ class NixlConnectorWorker:
self.block_window_per_layer: list[int | None] = [] self.block_window_per_layer: list[int | None] = []
self.use_mla = self.model_config.use_mla self.use_mla = self.model_config.use_mla
backend = get_attn_backend( # Get the attention backend from the first layer
self.model_config.get_head_size(), # NOTE (NickLucche) models with multiple backends are not supported yet
self.model_config.dtype, backend = get_current_attn_backend(vllm_config)
self.cache_config.cache_dtype,
self.block_size,
use_mla=self.use_mla,
)
self.backend_name = backend.get_name() self.backend_name = backend.get_name()
self.kv_cache_layout = get_kv_cache_layout() self.kv_cache_layout = get_kv_cache_layout()
self.host_buffer_kv_cache_layout = self.kv_cache_layout self.host_buffer_kv_cache_layout = self.kv_cache_layout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment