Unverified Commit fc1d8be3 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Update attention imports (#29540)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent cd007a53
...@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter ...@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, FusedMoEQuantConfig,
...@@ -149,8 +150,6 @@ class ModelOptQuantConfigBase(QuantizationConfig): ...@@ -149,8 +150,6 @@ class ModelOptQuantConfigBase(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# handle kv-cache first so we can focus only on weight quantization thereafter # handle kv-cache first so we can focus only on weight quantization thereafter
if isinstance(layer, Attention): if isinstance(layer, Attention):
return self.KVCacheMethodCls(self) return self.KVCacheMethodCls(self)
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import envs from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
...@@ -184,8 +185,6 @@ class Mxfp4Config(QuantizationConfig): ...@@ -184,8 +185,6 @@ class Mxfp4Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped( if self.ignored_layers and is_layer_skipped(
prefix=prefix, prefix=prefix,
......
...@@ -8,6 +8,7 @@ import regex as re ...@@ -8,6 +8,7 @@ import regex as re
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
LinearBase, LinearBase,
...@@ -159,8 +160,6 @@ class PetitNvFp4Config(QuantizationConfig): ...@@ -159,8 +160,6 @@ class PetitNvFp4Config(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
exclude = self.require_exclude_modules() exclude = self.require_exclude_modules()
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
...@@ -65,8 +66,6 @@ class PTPCFp8Config(Fp8Config): ...@@ -65,8 +66,6 @@ class PTPCFp8Config(Fp8Config):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers): if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
......
...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast ...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
import torch import torch
from vllm.attention.layer import Attention
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
...@@ -102,8 +103,6 @@ class QuarkConfig(QuantizationConfig): ...@@ -102,8 +103,6 @@ class QuarkConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization. # Check if the layer is skipped for quantization.
exclude_layers = cast(list[str], self.quant_config.get("exclude")) exclude_layers = cast(list[str], self.quant_config.get("exclude"))
if should_ignore_layer( if should_ignore_layer(
......
...@@ -14,6 +14,7 @@ import regex as re ...@@ -14,6 +14,7 @@ import regex as re
import torch import torch
from vllm import envs from vllm import envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from .interface import CpuArchEnum, Platform, PlatformEnum from .interface import CpuArchEnum, Platform, PlatformEnum
...@@ -21,10 +22,8 @@ from .interface import CpuArchEnum, Platform, PlatformEnum ...@@ -21,10 +22,8 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
AttentionBackendEnum = None
VllmConfig = None VllmConfig = None
...@@ -135,8 +134,6 @@ class CpuPlatform(Platform): ...@@ -135,8 +134,6 @@ class CpuPlatform(Platform):
use_sparse: bool, use_sparse: bool,
attn_type: str | None = None, attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN: if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla: if use_mla:
......
...@@ -15,6 +15,8 @@ from typing_extensions import ParamSpec ...@@ -15,6 +15,8 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration # import custom ops, trigger op registration
import vllm._C # noqa import vllm._C # noqa
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml from vllm.utils.import_utils import import_pynvml
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
...@@ -22,11 +24,9 @@ from vllm.utils.torch_utils import cuda_device_count_stateless ...@@ -22,11 +24,9 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
else: else:
AttentionBackendEnum = None
VllmConfig = None VllmConfig = None
CacheDType = None CacheDType = None
...@@ -48,8 +48,6 @@ def _get_backend_priorities( ...@@ -48,8 +48,6 @@ def _get_backend_priorities(
device_capability: DeviceCapability, device_capability: DeviceCapability,
) -> list[AttentionBackendEnum]: ) -> list[AttentionBackendEnum]:
"""Get backend priorities with lazy import to avoid circular dependency.""" """Get backend priorities with lazy import to avoid circular dependency."""
from vllm.attention.backends.registry import AttentionBackendEnum
if use_mla: if use_mla:
if device_capability.major == 10: if device_capability.major == 10:
return [ return [
...@@ -265,8 +263,6 @@ class CudaPlatformBase(Platform): ...@@ -265,8 +263,6 @@ class CudaPlatformBase(Platform):
def get_vit_attn_backend( def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum": ) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum
# Try FlashAttention first # Try FlashAttention first
try: try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
...@@ -335,8 +331,6 @@ class CudaPlatformBase(Platform): ...@@ -335,8 +331,6 @@ class CudaPlatformBase(Platform):
use_sparse: bool, use_sparse: bool,
attn_type: str | None = None, attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention.backends.abstract import AttentionType
if attn_type is None: if attn_type is None:
attn_type = AttentionType.DECODER attn_type = AttentionType.DECODER
......
...@@ -12,12 +12,12 @@ from typing import TYPE_CHECKING, Any, NamedTuple ...@@ -12,12 +12,12 @@ from typing import TYPE_CHECKING, Any, NamedTuple
import numpy as np import numpy as np
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import CacheDType from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
...@@ -226,9 +226,6 @@ class Platform: ...@@ -226,9 +226,6 @@ class Platform:
def get_vit_attn_backend( def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum": ) -> "AttentionBackendEnum":
# Import AttentionBackendEnum here to avoid circular import.
from vllm.attention.backends.registry import AttentionBackendEnum
return AttentionBackendEnum.TORCH_SDPA return AttentionBackendEnum.TORCH_SDPA
@classmethod @classmethod
......
...@@ -8,16 +8,14 @@ from typing import TYPE_CHECKING ...@@ -8,16 +8,14 @@ from typing import TYPE_CHECKING
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.torch_utils import cuda_device_count_stateless from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
else:
AttentionBackendEnum = None
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -196,7 +194,6 @@ class RocmPlatform(Platform): ...@@ -196,7 +194,6 @@ class RocmPlatform(Platform):
from importlib.util import find_spec from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum
if rocm_aiter_ops.is_mha_enabled(): if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models. # Note: AITER FA is only supported for Qwen-VL models.
...@@ -222,7 +219,6 @@ class RocmPlatform(Platform): ...@@ -222,7 +219,6 @@ class RocmPlatform(Platform):
attn_type: str | None = None, attn_type: str | None = None,
) -> str: ) -> str:
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse: if use_sparse:
if kv_cache_dtype.startswith("fp8"): if kv_cache_dtype.startswith("fp8"):
......
...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, cast ...@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, cast
import torch import torch
from tpu_info import device from tpu_info import device
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.inputs import ProcessorInputs, PromptType from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -15,7 +16,6 @@ from .interface import Platform, PlatformEnum ...@@ -15,7 +16,6 @@ from .interface import Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import TypeAlias from typing import TypeAlias
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.cache import BlockSize from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -26,7 +26,6 @@ else: ...@@ -26,7 +26,6 @@ else:
BlockSize = None BlockSize = None
VllmConfig = None VllmConfig = None
PoolingParams = None PoolingParams = None
AttentionBackendEnum = None
ParamsType = None ParamsType = None
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -67,8 +66,6 @@ class TpuPlatform(Platform): ...@@ -67,8 +66,6 @@ class TpuPlatform(Platform):
use_sparse, use_sparse,
attn_type: str | None = None, attn_type: str | None = None,
) -> str: ) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse: if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.") raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != AttentionBackendEnum.PALLAS: if selected_backend != AttentionBackendEnum.PALLAS:
......
...@@ -8,16 +8,15 @@ from typing import TYPE_CHECKING ...@@ -8,16 +8,15 @@ from typing import TYPE_CHECKING
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
VllmConfig = None VllmConfig = None
AttentionBackendEnum = None
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -60,8 +59,6 @@ class XPUPlatform(Platform): ...@@ -60,8 +59,6 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels." "only NHD layout is supported by XPU attention kernels."
) )
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse: if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.") raise NotImplementedError("Sparse Attention is not supported on XPU.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN: if selected_backend == AttentionBackendEnum.TRITON_ATTN:
...@@ -116,8 +113,6 @@ class XPUPlatform(Platform): ...@@ -116,8 +113,6 @@ class XPUPlatform(Platform):
def get_vit_attn_backend( def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum": ) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum
return AttentionBackendEnum.FLASH_ATTN return AttentionBackendEnum.FLASH_ATTN
@classmethod @classmethod
......
...@@ -51,8 +51,6 @@ class CPUAttentionBackend(AttentionBackend): ...@@ -51,8 +51,6 @@ class CPUAttentionBackend(AttentionBackend):
@classmethod @classmethod
def supports_attn_type(cls, attn_type: str) -> bool: def supports_attn_type(cls, attn_type: str) -> bool:
"""CPU attention supports decoder and encoder-only attention.""" """CPU attention supports decoder and encoder-only attention."""
from vllm.attention.backends.abstract import AttentionType
return attn_type in ( return attn_type in (
AttentionType.DECODER, AttentionType.DECODER,
AttentionType.ENCODER, AttentionType.ENCODER,
......
...@@ -84,8 +84,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -84,8 +84,6 @@ class FlashAttentionBackend(AttentionBackend):
@classmethod @classmethod
def supports_attn_type(cls, attn_type: str) -> bool: def supports_attn_type(cls, attn_type: str) -> bool:
"""FlashAttention supports all attention types.""" """FlashAttention supports all attention types."""
from vllm.attention.backends.abstract import AttentionType
return attn_type in ( return attn_type in (
AttentionType.DECODER, AttentionType.DECODER,
AttentionType.ENCODER, AttentionType.ENCODER,
......
...@@ -87,8 +87,6 @@ class FlexAttentionBackend(AttentionBackend): ...@@ -87,8 +87,6 @@ class FlexAttentionBackend(AttentionBackend):
@classmethod @classmethod
def supports_attn_type(cls, attn_type: str) -> bool: def supports_attn_type(cls, attn_type: str) -> bool:
"""FlexAttention supports both decoder and encoder-only attention.""" """FlexAttention supports both decoder and encoder-only attention."""
from vllm.attention.backends.abstract import AttentionType
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY) return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
@staticmethod @staticmethod
......
...@@ -24,12 +24,15 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config ...@@ -24,12 +24,15 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionImpl
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.utils import ( from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout, get_kv_connector_cache_layout,
) )
......
...@@ -6,12 +6,12 @@ from typing import TYPE_CHECKING ...@@ -6,12 +6,12 @@ from typing import TYPE_CHECKING
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
from vllm.v1.kv_offload.worker.worker import OffloadingHandler from vllm.v1.kv_offload.worker.worker import OffloadingHandler
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig from vllm.config import VllmConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -51,7 +51,7 @@ class OffloadingSpec(ABC): ...@@ -51,7 +51,7 @@ class OffloadingSpec(ABC):
def get_handlers( def get_handlers(
self, self,
kv_caches: dict[str, torch.Tensor], kv_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type["AttentionBackend"]], attn_backends: dict[str, type[AttentionBackend]],
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]: ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
""" """
Get offloading handlers along with their respective src and dst types. Get offloading handlers along with their respective src and dst types.
......
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import ( from vllm.config import (
CompilationMode, CompilationMode,
CUDAGraphMode, CUDAGraphMode,
...@@ -157,8 +158,6 @@ class EagleProposer: ...@@ -157,8 +158,6 @@ class EagleProposer:
) )
# Determine allowed attention backends once during initialization. # Determine allowed attention backends once during initialization.
from vllm.attention.backends.registry import AttentionBackendEnum
self.allowed_attn_types: tuple | None = None self.allowed_attn_types: tuple | None = None
if current_platform.is_rocm(): if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
...@@ -17,9 +17,6 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder ...@@ -17,9 +17,6 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
if TYPE_CHECKING:
from vllm.attention.layer import Attention
class MultiModalBudget: class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models.""" """Helper class to calculate budget information for multi-modal models."""
...@@ -278,7 +275,7 @@ def add_kv_sharing_layers_to_kv_cache_groups( ...@@ -278,7 +275,7 @@ def add_kv_sharing_layers_to_kv_cache_groups(
def bind_kv_cache( def bind_kv_cache(
kv_caches: dict[str, torch.Tensor], kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"], forward_context: dict[str, Attention],
runner_kv_caches: list[torch.Tensor], runner_kv_caches: list[torch.Tensor],
num_attn_module: int = 1, num_attn_module: int = 1,
) -> None: ) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment