# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from typing import Generic, Protocol, TypeVar import torch from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey class AttentionType: """ Attention type. Use string to be compatible with `torch.compile`. """ DECODER = "decoder" """Decoder attention between previous layer Q/K/V.""" ENCODER = "encoder" """Encoder attention between previous layer Q/K/V for encoder-decoder.""" ENCODER_ONLY = "encoder_only" """Encoder attention between previous layer Q/K/V.""" ENCODER_DECODER = "encoder_decoder" """Attention between dec. Q and enc. K/V for encoder-decoder.""" class MultipleOf: base: int def __init__(self, base: int): self.base = base class AttentionBackend(ABC): """Abstract class for attention backends.""" # For some attention backends, we allocate an output tensor before # calling the custom op. When piecewise cudagraph is enabled, this # makes sure the output tensor is allocated inside the cudagraph. accept_output_buffer: bool = False @staticmethod @abstractmethod def get_name() -> str: raise NotImplementedError @staticmethod @abstractmethod def get_impl_cls() -> type["AttentionImpl"]: raise NotImplementedError @staticmethod @abstractmethod def get_metadata_cls() -> type["AttentionMetadata"]: raise NotImplementedError @classmethod def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]: return cls.get_impl_cls().get_supported_kernel_block_size() @classmethod def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": return cls.get_metadata_cls()(*args, **kwargs) @staticmethod @abstractmethod def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]: raise NotImplementedError @staticmethod @abstractmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, cache_dtype_str: str = "auto", ) -> tuple[int, ...]: raise NotImplementedError @staticmethod def get_kv_cache_stride_order() -> tuple[int, ...]: raise NotImplementedError @classmethod def full_cls_name(cls) -> tuple[str, str]: return (cls.__module__, cls.__qualname__) class AttentionMetadata: pass T = TypeVar("T", bound=AttentionMetadata) class AttentionLayer(Protocol): _q_scale: torch.Tensor _k_scale: torch.Tensor _v_scale: torch.Tensor _q_scale_float: float _k_scale_float: float _v_scale_float: float _prob_scale: torch.Tensor def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: ... class AttentionImpl(ABC, Generic[T]): # Whether the attention impl can return the softmax lse for decode. # Some features like decode context parallelism require the softmax lse. can_return_lse_for_decode: bool = False # some attention backends might not always want to return lse # even if they can return lse (for efficiency reasons) need_to_return_lse_for_decode: bool = False dcp_world_size: int dcp_rank: int def __new__(cls, *args, **kwargs): # use __new__ so that all subclasses will call this self = super().__new__(cls) try: from vllm.distributed.parallel_state import get_dcp_group self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group except AssertionError: # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 self.need_to_return_lse_for_decode = ( self.dcp_world_size > 1 and self.can_return_lse_for_decode ) return self @abstractmethod def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int | None = None, alibi_slopes: list[float] | None = None, sliding_window: int | None = None, kv_cache_dtype: str = "auto", logits_soft_cap: float | None = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, ) -> None: raise NotImplementedError @staticmethod def get_supported_kernel_block_size() -> list[int | MultipleOf]: # TODO: implement this function for all backends. return [MultipleOf(1)] @abstractmethod def forward( self, layer: AttentionLayer, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError def fused_output_quant_supported(self, quant_key: QuantKey): """ Does this attention implementation support fused output quantization. This is used by the AttnFusionPass to only fuse output quantization onto implementations that support it. :param quant_key: QuantKey object that describes the quantization op :return: is fusion supported for this type of quantization """ return False def supports_quant_query_input(self) -> bool: """ Check if this attention implementation supports pre-quantized query input. When True, the attention layer will quantize queries before passing them to this backend, allowing torch.compile to fuse the quantization with previous operations. This is typically supported when using FP8 KV cache with compatible attention kernels (e.g., TRT-LLM). TODO add support to more backends: https://github.com/vllm-project/vllm/issues/25584 Returns: bool: True if the implementation can accept pre-quantized queries. """ return False def process_weights_after_loading(self, act_dtype: torch.dtype): pass class MLAAttentionImpl(AttentionImpl[T], Generic[T]): @abstractmethod def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: list[float] | None, sliding_window: int | None, kv_cache_dtype: str, logits_soft_cap: float | None, attn_type: str, kv_sharing_target_layer_name: str | None, # MLA Specific Arguments q_lora_rank: int | None, kv_lora_rank: int, qk_nope_head_dim: int, qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, indexer: object | None = None, ) -> None: raise NotImplementedError @abstractmethod def forward( self, layer: AttentionLayer, hidden_states_or_cq: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: return kv_cache_dtype != "auto"