Unverified Commit fc1d8be3 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Update attention imports (#29540)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent cd007a53
...@@ -139,14 +139,13 @@ def test_standard_attention_backend_selection( ...@@ -139,14 +139,13 @@ def test_standard_attention_backend_selection(
import importlib import importlib
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import _Backend
importlib.reload(envs) importlib.reload(envs)
# Convert string backend to enum if provided # Convert string backend to enum if provided
backend_enum = None backend_enum = None
if selected_backend: if selected_backend:
backend_enum = getattr(_Backend, selected_backend) backend_enum = getattr(AttentionBackendEnum, selected_backend)
# Get the backend class path # Get the backend class path
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
...@@ -253,7 +252,6 @@ def test_mla_backend_selection( ...@@ -253,7 +252,6 @@ def test_mla_backend_selection(
import importlib import importlib
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import _Backend
importlib.reload(envs) importlib.reload(envs)
...@@ -269,7 +267,7 @@ def test_mla_backend_selection( ...@@ -269,7 +267,7 @@ def test_mla_backend_selection(
# Convert string backend to enum if provided # Convert string backend to enum if provided
backend_enum = None backend_enum = None
if selected_backend: if selected_backend:
backend_enum = getattr(_Backend, selected_backend) backend_enum = getattr(AttentionBackendEnum, selected_backend)
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
...@@ -301,7 +299,6 @@ def test_mla_backend_selection( ...@@ -301,7 +299,6 @@ def test_mla_backend_selection(
def test_aiter_fa_requires_gfx9(mock_vllm_config): def test_aiter_fa_requires_gfx9(mock_vllm_config):
"""Test that ROCM_AITER_FA requires gfx9 architecture.""" """Test that ROCM_AITER_FA requires gfx9 architecture."""
from vllm.attention.backends.registry import _Backend
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
# Mock on_gfx9 to return False # Mock on_gfx9 to return False
...@@ -313,7 +310,7 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config): ...@@ -313,7 +310,7 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):
), ),
): ):
RocmPlatform.get_attn_backend_cls( RocmPlatform.get_attn_backend_cls(
selected_backend=_Backend.ROCM_AITER_FA, selected_backend=AttentionBackendEnum.ROCM_AITER_FA,
head_size=128, head_size=128,
dtype=torch.float16, dtype=torch.float16,
kv_cache_dtype="auto", kv_cache_dtype="auto",
......
...@@ -14,6 +14,7 @@ from unittest.mock import patch ...@@ -14,6 +14,7 @@ from unittest.mock import patch
import pytest import pytest
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import ( from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1, KVConnectorBase_V1,
...@@ -24,7 +25,6 @@ from vllm.v1.core.sched.output import SchedulerOutput ...@@ -24,7 +25,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from .utils import create_scheduler, create_vllm_config from .utils import create_scheduler, create_vllm_config
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
...@@ -68,7 +68,7 @@ class OldStyleTestConnector(KVConnectorBase_V1): ...@@ -68,7 +68,7 @@ class OldStyleTestConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer, kv_layer,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> None: ) -> None:
pass pass
...@@ -119,7 +119,7 @@ class NewStyleTestConnector(KVConnectorBase_V1): ...@@ -119,7 +119,7 @@ class NewStyleTestConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer, kv_layer,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> None: ) -> None:
pass pass
......
...@@ -6,11 +6,10 @@ from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args ...@@ -6,11 +6,10 @@ from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
import torch import torch
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backends.utils import KVCacheLayoutType from vllm.v1.attention.backends.utils import KVCacheLayoutType
...@@ -178,8 +177,6 @@ class AttentionBackend(ABC): ...@@ -178,8 +177,6 @@ class AttentionBackend(ABC):
By default, only supports decoder attention. By default, only supports decoder attention.
Backends should override this to support other attention types. Backends should override this to support other attention types.
""" """
from vllm.attention.backends.abstract import AttentionType
return attn_type == AttentionType.DECODER return attn_type == AttentionType.DECODER
@classmethod @classmethod
...@@ -360,7 +357,7 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -360,7 +357,7 @@ class AttentionImpl(ABC, Generic[T]):
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def fused_output_quant_supported(self, quant_key: QuantKey): def fused_output_quant_supported(self, quant_key: "QuantKey"):
""" """
Does this attention implementation support fused output quantization. Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization This is used by the AttnFusionPass to only fuse output quantization
...@@ -412,7 +409,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]): ...@@ -412,7 +409,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
qk_rope_head_dim: int, qk_rope_head_dim: int,
qk_head_dim: int, qk_head_dim: int,
v_head_dim: int, v_head_dim: int,
kv_b_proj: ColumnParallelLinear, kv_b_proj: "ColumnParallelLinear",
indexer: object | None = None, indexer: object | None = None,
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
......
...@@ -5,6 +5,7 @@ import functools ...@@ -5,6 +5,7 @@ import functools
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig from vllm.config.vllm import VllmConfig
...@@ -22,8 +23,6 @@ from vllm.v1.kv_cache_interface import ( ...@@ -22,8 +23,6 @@ from vllm.v1.kv_cache_interface import (
KVCacheSpec, KVCacheSpec,
) )
from ..layer import Attention
@functools.lru_cache @functools.lru_cache
def create_chunked_local_attention_backend( def create_chunked_local_attention_backend(
......
...@@ -14,6 +14,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE ...@@ -14,6 +14,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers.configuration_utils import ALLOWED_LAYER_TYPES from transformers.configuration_utils import ALLOWED_LAYER_TYPES
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig
from vllm.config.pooler import PoolerConfig from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType from vllm.config.scheduler import RunnerType
...@@ -53,7 +54,6 @@ if TYPE_CHECKING: ...@@ -53,7 +54,6 @@ if TYPE_CHECKING:
import vllm.model_executor.layers.quantization as me_quant import vllm.model_executor.layers.quantization as me_quant
import vllm.model_executor.models as me_models import vllm.model_executor.models as me_models
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.load import LoadConfig from vllm.config.load import LoadConfig
from vllm.config.parallel import ParallelConfig from vllm.config.parallel import ParallelConfig
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -61,7 +61,6 @@ if TYPE_CHECKING: ...@@ -61,7 +61,6 @@ if TYPE_CHECKING:
else: else:
PretrainedConfig = Any PretrainedConfig = Any
AttentionBackendEnum = Any
me_quant = LazyLoader( me_quant = LazyLoader(
"model_executor", globals(), "vllm.model_executor.layers.quantization" "model_executor", globals(), "vllm.model_executor.layers.quantization"
) )
......
...@@ -2,19 +2,15 @@ ...@@ -2,19 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Literal, TypeAlias from typing import Any, Literal, TypeAlias
from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config.utils import config from vllm.config.utils import config
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
else:
AttentionBackendEnum = Any
@dataclass @dataclass
class BaseDummyOptions: class BaseDummyOptions:
...@@ -170,9 +166,6 @@ class MultiModalConfig: ...@@ -170,9 +166,6 @@ class MultiModalConfig:
def _validate_mm_encoder_attn_backend( def _validate_mm_encoder_attn_backend(
cls, value: str | AttentionBackendEnum | None cls, value: str | AttentionBackendEnum | None
) -> AttentionBackendEnum | None: ) -> AttentionBackendEnum | None:
# We need to import the real type here (deferred to avoid circular import).
from vllm.attention.backends.registry import AttentionBackendEnum
if isinstance(value, str) and value.upper() == "XFORMERS": if isinstance(value, str) and value.upper() == "XFORMERS":
raise ValueError( raise ValueError(
"Attention backend 'XFORMERS' has been removed (See PR #29262 for " "Attention backend 'XFORMERS' has been removed (See PR #29262 for "
......
...@@ -42,12 +42,12 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional ...@@ -42,12 +42,12 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
...@@ -239,7 +239,7 @@ class KVConnectorBase_V1(ABC): ...@@ -239,7 +239,7 @@ class KVConnectorBase_V1(ABC):
return return
def register_cross_layers_kv_cache( def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"] self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
): ):
""" """
Initialize with a single KV cache tensor used by all layers. Initialize with a single KV cache tensor used by all layers.
......
...@@ -36,6 +36,7 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -36,6 +36,7 @@ from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.v1 import ( from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1, KVConnectorBase_V1,
KVConnectorRole, KVConnectorRole,
...@@ -45,7 +46,6 @@ from vllm.logger import init_logger ...@@ -45,7 +46,6 @@ from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
...@@ -117,7 +117,7 @@ class DecodeBenchConnector(KVConnectorBase_V1): ...@@ -117,7 +117,7 @@ class DecodeBenchConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
# This connector doesn't save KV cache (benchmarking only) # This connector doesn't save KV cache (benchmarking only)
......
...@@ -7,6 +7,7 @@ from lmcache.integration.vllm.vllm_v1_adapter import ( ...@@ -7,6 +7,7 @@ from lmcache.integration.vllm.vllm_v1_adapter import (
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl, LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
) )
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
...@@ -17,7 +18,6 @@ from vllm.logger import init_logger ...@@ -17,7 +18,6 @@ from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
...@@ -91,7 +91,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1): ...@@ -91,7 +91,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
......
...@@ -29,6 +29,7 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import ( ...@@ -29,6 +29,7 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
from lmcache.v1.plugin.plugin_launcher import PluginLauncher from lmcache.v1.plugin.plugin_launcher import PluginLauncher
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
...@@ -50,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput ...@@ -50,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
...@@ -915,7 +915,7 @@ class LMCacheConnectorV1Impl: ...@@ -915,7 +915,7 @@ class LMCacheConnectorV1Impl:
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> None: ) -> None:
"""Start saving the a layer of KV cache from vLLM's paged buffer """Start saving the a layer of KV cache from vLLM's paged buffer
......
...@@ -10,6 +10,7 @@ import zmq ...@@ -10,6 +10,7 @@ import zmq
from lmcache.integration.vllm.utils import mla_enabled from lmcache.integration.vllm.utils import mla_enabled
from lmcache.utils import init_logger as lmcache_init_logger from lmcache.utils import init_logger as lmcache_init_logger
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
...@@ -26,7 +27,6 @@ from vllm.v1.outputs import KVConnectorOutput ...@@ -26,7 +27,6 @@ from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.utils import ConstantList from vllm.v1.utils import ConstantList
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
...@@ -490,7 +490,7 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -490,7 +490,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
......
...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any ...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
...@@ -27,7 +28,6 @@ from vllm.v1.core.sched.output import SchedulerOutput ...@@ -27,7 +28,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
...@@ -216,7 +216,7 @@ class MultiConnector(KVConnectorBase_V1): ...@@ -216,7 +216,7 @@ class MultiConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> None: ) -> None:
for c in self._connectors: for c in self._connectors:
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
import zmq import zmq
from vllm import envs from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -51,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput ...@@ -51,7 +51,6 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request from vllm.v1.request import Request
...@@ -308,7 +307,7 @@ class NixlConnector(KVConnectorBase_V1): ...@@ -308,7 +307,7 @@ class NixlConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> None: ) -> None:
"""NixlConnector does not save explicitly.""" """NixlConnector does not save explicitly."""
......
...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import regex as re import regex as re
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
...@@ -22,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata ...@@ -22,7 +23,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
...@@ -243,7 +243,7 @@ class P2pNcclConnector(KVConnectorBase_V1): ...@@ -243,7 +243,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer """Start saving the KV cache of the layer from vLLM's paged buffer
......
...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import safetensors import safetensors
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorBase_V1,
...@@ -19,7 +20,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata ...@@ -19,7 +20,6 @@ from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
...@@ -211,7 +211,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -211,7 +211,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self, self,
layer_name: str, layer_name: str,
kv_layer: torch.Tensor, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", attn_metadata: AttentionMetadata,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer """Start saving the KV cache of the layer from vLLM's paged buffer
......
...@@ -5,19 +5,17 @@ import time ...@@ -5,19 +5,17 @@ import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NamedTuple from typing import Any, NamedTuple
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.ubatch_utils import UBatchSlices from vllm.v1.worker.ubatch_utils import UBatchSlices
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0
...@@ -195,7 +193,7 @@ class ForwardContext: ...@@ -195,7 +193,7 @@ class ForwardContext:
for each microbatch. for each microbatch.
Set dynamically for each forward pass Set dynamically for each forward pass
""" """
attn_metadata: dict[str, "AttentionMetadata"] | list[dict[str, "AttentionMetadata"]] attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
# TODO: remove after making all virtual_engines share the same kv cache # TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass # set dynamically for each forward pass
......
...@@ -3,14 +3,11 @@ ...@@ -3,14 +3,11 @@
"""Base class for attention-like layers.""" """Base class for attention-like layers."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheSpec
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class AttentionLayerBase(ABC): class AttentionLayerBase(ABC):
""" """
...@@ -22,7 +19,7 @@ class AttentionLayerBase(ABC): ...@@ -22,7 +19,7 @@ class AttentionLayerBase(ABC):
""" """
@abstractmethod @abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]: def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this layer.""" """Get the attention backend class for this layer."""
pass pass
......
...@@ -2,18 +2,15 @@ ...@@ -2,18 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import get_mamba_attn_backend from vllm.attention.selector import get_mamba_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class MambaBase(AttentionLayerBase): class MambaBase(AttentionLayerBase):
""" """
...@@ -66,6 +63,6 @@ class MambaBase(AttentionLayerBase): ...@@ -66,6 +63,6 @@ class MambaBase(AttentionLayerBase):
), ),
) )
def get_attn_backend(self) -> type["AttentionBackend"]: def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this Mamba layer.""" """Get the attention backend class for this Mamba layer."""
return get_mamba_attn_backend(self.mamba_type) return get_mamba_attn_backend(self.mamba_type)
...@@ -18,6 +18,7 @@ from compressed_tensors.quantization import ( ...@@ -18,6 +18,7 @@ from compressed_tensors.quantization import (
from compressed_tensors.transform import TransformConfig from compressed_tensors.transform import TransformConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
...@@ -131,8 +132,6 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -131,8 +132,6 @@ class CompressedTensorsConfig(QuantizationConfig):
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
# collect schemes # collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
......
...@@ -14,6 +14,7 @@ import vllm.envs as envs ...@@ -14,6 +14,7 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.layer import Attention
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
...@@ -277,7 +278,6 @@ class Fp8Config(QuantizationConfig): ...@@ -277,7 +278,6 @@ class Fp8Config(QuantizationConfig):
def get_xpu_quant_method( def get_xpu_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention
from vllm.model_executor.layers.quantization.ipex_quant import ( from vllm.model_executor.layers.quantization.ipex_quant import (
XPUFp8LinearMethod, XPUFp8LinearMethod,
XPUFp8MoEMethod, XPUFp8MoEMethod,
...@@ -307,8 +307,6 @@ class Fp8Config(QuantizationConfig): ...@@ -307,8 +307,6 @@ class Fp8Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if current_platform.is_xpu(): if current_platform.is_xpu():
return self.get_xpu_quant_method(layer, prefix) return self.get_xpu_quant_method(layer, prefix)
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment