Unverified Commit fc1d8be3 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Update attention imports (#29540)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent cd007a53
......@@ -139,14 +139,13 @@ def test_standard_attention_backend_selection(
import importlib
import vllm.envs as envs
from vllm.attention.backends.registry import _Backend
importlib.reload(envs)
# Convert string backend to enum if provided
backend_enum = None
if selected_backend:
backend_enum = getattr(_Backend, selected_backend)
backend_enum = getattr(AttentionBackendEnum, selected_backend)
# Get the backend class path
from vllm.platforms.rocm import RocmPlatform
......@@ -253,7 +252,6 @@ def test_mla_backend_selection(
import importlib
import vllm.envs as envs
from vllm.attention.backends.registry import _Backend
importlib.reload(envs)
......@@ -269,7 +267,7 @@ def test_mla_backend_selection(
# Convert string backend to enum if provided
backend_enum = None
if selected_backend:
backend_enum = getattr(_Backend, selected_backend)
backend_enum = getattr(AttentionBackendEnum, selected_backend)
from vllm.platforms.rocm import RocmPlatform
......@@ -301,7 +299,6 @@ def test_mla_backend_selection(
def test_aiter_fa_requires_gfx9(mock_vllm_config):
"""Test that ROCM_AITER_FA requires gfx9 architecture."""
from vllm.attention.backends.registry import _Backend
from vllm.platforms.rocm import RocmPlatform
# Mock on_gfx9 to return False
......@@ -313,7 +310,7 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):
),
):
RocmPlatform.get_attn_backend_cls(
selected_backend=_Backend.ROCM_AITER_FA,
selected_backend=AttentionBackendEnum.ROCM_AITER_FA,
head_size=128,
dtype=torch.float16,
kv_cache_dtype="auto",
......
......@@ -14,6 +14,7 @@ from unittest.mock import patch
import pytest
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
......@@ -24,7 +25,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from .utils import create_scheduler, create_vllm_config
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
......@@ -68,7 +68,7 @@ class OldStyleTestConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
pass
......@@ -119,7 +119,7 @@ class NewStyleTestConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
pass
......
......@@ -6,11 +6,10 @@ from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
import torch
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
if TYPE_CHECKING:
from vllm.config.cache import CacheDType
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType
......@@ -178,8 +177,6 @@ class AttentionBackend(ABC):
By default, only supports decoder attention.
Backends should override this to support other attention types.
"""
from vllm.attention.backends.abstract import AttentionType
return attn_type == AttentionType.DECODER
@classmethod
......@@ -360,7 +357,7 @@ class AttentionImpl(ABC, Generic[T]):
) -> torch.Tensor:
raise NotImplementedError
def fused_output_quant_supported(self, quant_key: QuantKey):
def fused_output_quant_supported(self, quant_key: "QuantKey"):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
......@@ -412,7 +409,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
kv_b_proj: "ColumnParallelLinear",
indexer: object | None = None,
) -> None:
raise NotImplementedError
......
......@@ -5,6 +5,7 @@ import functools
import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
......@@ -22,8 +23,6 @@ from vllm.v1.kv_cache_interface import (
KVCacheSpec,
)
from ..layer import Attention
@functools.lru_cache
def create_chunked_local_attention_backend(
......
......@@ -14,6 +14,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType
......@@ -53,7 +54,6 @@ if TYPE_CHECKING:
import vllm.model_executor.layers.quantization as me_quant
import vllm.model_executor.models as me_models
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig
from vllm.model_executor.layers.quantization import QuantizationMethods
......@@ -61,7 +61,6 @@ if TYPE_CHECKING:
else:
PretrainedConfig = Any
AttentionBackendEnum = Any
me_quant = LazyLoader(
"model_executor", globals(), "vllm.model_executor.layers.quantization"
)
......
......@@ -2,19 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Literal, TypeAlias
from typing import Any, Literal, TypeAlias
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.utils import config
from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
else:
AttentionBackendEnum = Any
@dataclass
class BaseDummyOptions:
......@@ -170,9 +166,6 @@ class MultiModalConfig:
def _validate_mm_encoder_attn_backend(
cls, value: str | AttentionBackendEnum | None
) -> AttentionBackendEnum | None:
# We need to import the real type here (deferred to avoid circular import).
from vllm.attention.backends.registry import AttentionBackendEnum
if isinstance(value, str) and value.upper() == "XFORMERS":
raise ValueError(
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
......
......@@ -42,12 +42,12 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
......@@ -239,7 +239,7 @@ class KVConnectorBase_V1(ABC):
return
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
"""
Initialize with a single KV cache tensor used by all layers.
......
......@@ -36,6 +36,7 @@ from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
......@@ -45,7 +46,6 @@ from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
......@@ -117,7 +117,7 @@ class DecodeBenchConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
# This connector doesn't save KV cache (benchmarking only)
......
......@@ -7,6 +7,7 @@ from lmcache.integration.vllm.vllm_v1_adapter import (
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
)
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
......@@ -17,7 +18,6 @@ from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
......@@ -91,7 +91,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""
......
......@@ -29,6 +29,7 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
from lmcache.v1.plugin.plugin_launcher import PluginLauncher
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
......@@ -50,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.version import __version__ as VLLM_VERSION
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.kv_cache_manager import KVCacheManager
......@@ -915,7 +915,7 @@ class LMCacheConnectorV1Impl:
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""Start saving the a layer of KV cache from vLLM's paged buffer
......
......@@ -10,6 +10,7 @@ import zmq
from lmcache.integration.vllm.utils import mla_enabled
from lmcache.utils import init_logger as lmcache_init_logger
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
......@@ -26,7 +27,6 @@ from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
......@@ -490,7 +490,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""
......
......@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
......@@ -27,7 +28,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
......@@ -216,7 +216,7 @@ class MultiConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
for c in self._connectors:
......
......@@ -20,7 +20,7 @@ import torch
import zmq
from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
......@@ -51,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
......@@ -308,7 +307,7 @@ class NixlConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""NixlConnector does not save explicitly."""
......
......@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import regex as re
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
......@@ -22,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
......@@ -243,7 +243,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
......
......@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import safetensors
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
......@@ -19,7 +20,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
......@@ -211,7 +211,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
......
......@@ -5,19 +5,17 @@ import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NamedTuple
from typing import Any, NamedTuple
import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ubatch_utils import UBatchSlices
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
logger = init_logger(__name__)
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
......@@ -195,7 +193,7 @@ class ForwardContext:
for each microbatch.
Set dynamically for each forward pass
"""
attn_metadata: dict[str, "AttentionMetadata"] | list[dict[str, "AttentionMetadata"]]
attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
......
......@@ -3,14 +3,11 @@
"""Base class for attention-like layers."""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheSpec
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class AttentionLayerBase(ABC):
"""
......@@ -22,7 +19,7 @@ class AttentionLayerBase(ABC):
"""
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this layer."""
pass
......
......@@ -2,18 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import get_mamba_attn_backend
from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class MambaBase(AttentionLayerBase):
"""
......@@ -66,6 +63,6 @@ class MambaBase(AttentionLayerBase):
),
)
def get_attn_backend(self) -> type["AttentionBackend"]:
def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this Mamba layer."""
return get_mamba_attn_backend(self.mamba_type)
......@@ -18,6 +18,7 @@ from compressed_tensors.quantization import (
from compressed_tensors.transform import TransformConfig
import vllm.envs as envs
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
......@@ -131,8 +132,6 @@ class CompressedTensorsConfig(QuantizationConfig):
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
......
......@@ -14,6 +14,7 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.layer import Attention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
......@@ -277,7 +278,6 @@ class Fp8Config(QuantizationConfig):
def get_xpu_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention
from vllm.model_executor.layers.quantization.ipex_quant import (
XPUFp8LinearMethod,
XPUFp8MoEMethod,
......@@ -307,8 +307,6 @@ class Fp8Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if current_platform.is_xpu():
return self.get_xpu_quant_method(layer, prefix)
if isinstance(layer, LinearBase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment