"docker/Dockerfile.cpu" did not exist on "ca77dd7a44f2bc103c668560818918ac0335835a"
Unverified Commit fc1d8be3 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Update attention imports (#29540)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent cd007a53
......@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
......@@ -149,8 +150,6 @@ class ModelOptQuantConfigBase(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# handle kv-cache first so we can focus only on weight quantization thereafter
if isinstance(layer, Attention):
return self.KVCacheMethodCls(self)
......
......@@ -8,6 +8,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
......@@ -184,8 +185,6 @@ class Mxfp4Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped(
prefix=prefix,
......
......@@ -8,6 +8,7 @@ import regex as re
import torch
from torch.nn.parameter import Parameter
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (
LinearBase,
......@@ -159,8 +160,6 @@ class PetitNvFp4Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
exclude = self.require_exclude_modules()
if isinstance(layer, LinearBase):
......
......@@ -7,6 +7,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
......@@ -65,8 +66,6 @@ class PTPCFp8Config(Fp8Config):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()
......
......@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
import torch
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
......@@ -102,8 +103,6 @@ class QuarkConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
if should_ignore_layer(
......
......@@ -14,6 +14,7 @@ import regex as re
import torch
from vllm import envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from .interface import CpuArchEnum, Platform, PlatformEnum
......@@ -21,10 +22,8 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
AttentionBackendEnum = None
VllmConfig = None
......@@ -135,8 +134,6 @@ class CpuPlatform(Platform):
use_sparse: bool,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
......
......@@ -15,6 +15,8 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml
from vllm.utils.torch_utils import cuda_device_count_stateless
......@@ -22,11 +24,9 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
else:
AttentionBackendEnum = None
VllmConfig = None
CacheDType = None
......@@ -48,8 +48,6 @@ def _get_backend_priorities(
device_capability: DeviceCapability,
) -> list[AttentionBackendEnum]:
"""Get backend priorities with lazy import to avoid circular dependency."""
from vllm.attention.backends.registry import AttentionBackendEnum
if use_mla:
if device_capability.major == 10:
return [
......@@ -265,8 +263,6 @@ class CudaPlatformBase(Platform):
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum
# Try FlashAttention first
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
......@@ -335,8 +331,6 @@ class CudaPlatformBase(Platform):
use_sparse: bool,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.abstract import AttentionType
if attn_type is None:
attn_type = AttentionType.DECODER
......
......@@ -12,12 +12,12 @@ from typing import TYPE_CHECKING, Any, NamedTuple
import numpy as np
import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType
......@@ -226,9 +226,6 @@ class Platform:
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
# Import AttentionBackendEnum here to avoid circular import.
from vllm.attention.backends.registry import AttentionBackendEnum
return AttentionBackendEnum.TORCH_SDPA
@classmethod
......
......@@ -8,16 +8,14 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
AttentionBackendEnum = None
logger = init_logger(__name__)
......@@ -196,7 +194,6 @@ class RocmPlatform(Platform):
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
......@@ -222,7 +219,6 @@ class RocmPlatform(Platform):
attn_type: str | None = None,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
if kv_cache_dtype.startswith("fp8"):
......
......@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, cast
import torch
from tpu_info import device
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
......@@ -15,7 +16,6 @@ from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from typing import TypeAlias
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams
......@@ -26,7 +26,6 @@ else:
BlockSize = None
VllmConfig = None
PoolingParams = None
AttentionBackendEnum = None
ParamsType = None
logger = init_logger(__name__)
......@@ -67,8 +66,6 @@ class TpuPlatform(Platform):
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != AttentionBackendEnum.PALLAS:
......
......@@ -8,16 +8,15 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
VllmConfig = None
AttentionBackendEnum = None
logger = init_logger(__name__)
......@@ -60,8 +59,6 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
......@@ -116,8 +113,6 @@ class XPUPlatform(Platform):
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum
return AttentionBackendEnum.FLASH_ATTN
@classmethod
......
......@@ -51,8 +51,6 @@ class CPUAttentionBackend(AttentionBackend):
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder and encoder-only attention."""
from vllm.attention.backends.abstract import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
......
......@@ -84,8 +84,6 @@ class FlashAttentionBackend(AttentionBackend):
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types."""
from vllm.attention.backends.abstract import AttentionType
return attn_type in (
AttentionType.DECODER,
AttentionType.ENCODER,
......
......@@ -87,8 +87,6 @@ class FlexAttentionBackend(AttentionBackend):
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlexAttention supports both decoder and encoder-only attention."""
from vllm.attention.backends.abstract import AttentionType
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
@staticmethod
......
......@@ -24,12 +24,15 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionImpl
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout,
)
......
......@@ -6,12 +6,12 @@ from typing import TYPE_CHECKING
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
logger = init_logger(__name__)
......@@ -51,7 +51,7 @@ class OffloadingSpec(ABC):
def get_handlers(
self,
kv_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type["AttentionBackend"]],
attn_backends: dict[str, type[AttentionBackend]],
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
"""
Get offloading handlers along with their respective src and dst types.
......
......@@ -8,6 +8,7 @@ import numpy as np
import torch
import torch.nn as nn
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
CompilationMode,
CUDAGraphMode,
......@@ -157,8 +158,6 @@ class EagleProposer:
)
# Determine allowed attention backends once during initialization.
from vllm.attention.backends.registry import AttentionBackendEnum
self.allowed_attn_types: tuple | None = None
if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
......
......@@ -2,11 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
......@@ -17,9 +17,6 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
if TYPE_CHECKING:
from vllm.attention.layer import Attention
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
......@@ -278,7 +275,7 @@ def add_kv_sharing_layers_to_kv_cache_groups(
def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
forward_context: dict[str, Attention],
runner_kv_caches: list[torch.Tensor],
num_attn_module: int = 1,
) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment