# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, vit_torch_sdpa_wrapper, ) from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import MultiModalConfig from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.models.vision import get_vit_attn_backend logger = init_logger(__name__) @CustomOp.register("mm_encoder_attn") class MMEncoderAttention(CustomOp): """Multi-headed attention without any cache, used for multimodal encoder.""" def __init__( self, num_heads: int, head_size: int, scale: float | None = None, num_kv_heads: int | None = None, prefix: str = "", multimodal_config: MultiModalConfig | None = None, ) -> None: """ Args: num_heads: number of attention heads per partition. head_size: hidden_size per attention head. scale: scale factor. num_kv_heads: number of kv heads. prefix: This has no effect, it is only here to make it easier to swap between Attention and MultiHeadAttention multimodal_config: configs for multi-modal. """ super().__init__() self.num_heads = num_heads self.head_size = head_size self.scale = scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.layer_name = prefix assert self.num_heads % self.num_kv_heads == 0, ( f"num_heads ({self.num_heads}) is not " f"divisible by num_kv_heads ({self.num_kv_heads})" ) self.num_queries_per_kv = self.num_heads // self.num_kv_heads # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() # Try to get vision attention backend from multimodal_config. attn_backend_override = None if multimodal_config is not None: attn_backend_override = multimodal_config.mm_encoder_attn_backend # Get device-specific vision attention backend. self.attn_backend = get_vit_attn_backend( head_size=head_size, dtype=dtype, attn_backend_override=attn_backend_override, ) self.is_flash_attn_backend = self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, } self._fa_version = ( get_flash_attn_version() if self.is_flash_attn_backend else None ) logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") @classmethod def enabled(cls) -> bool: return True def maybe_reshape_qkv_to_4d( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bsz: int, q_len: int, kv_len: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Reshape query, key, value to 4D tensors: (batch_size, seq_len, num_heads, head_size) """ query = query.view(bsz, q_len, self.num_heads, self.head_size) key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) if (num_repeat := self.num_queries_per_kv) > 1: # Handle MQA and GQA key = torch.repeat_interleave(key, num_repeat, dim=2) value = torch.repeat_interleave(value, num_repeat, dim=2) return query, key, value def _forward_sdpa( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, ) -> torch.Tensor: """Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size) """ bsz, q_len = query.size()[:2] kv_len = key.size(1) is_reshaped = query.dim() != 4 query, key, value = self.maybe_reshape_qkv_to_4d( query, key, value, bsz, q_len, kv_len ) output = vit_torch_sdpa_wrapper( q=query, k=key, v=value, cu_seqlens=cu_seqlens, ) if is_reshaped: output = output.reshape(bsz, q_len, -1) return output def _forward_fa( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: """Input shape: (batch_size x seq_len x hidden_size) or (batch_size x seq_len x num_heads x head_size) """ assert (cu_seqlens is not None and max_seqlen is not None) or ( cu_seqlens is None and max_seqlen is None ), "cu_seqlens and max_seqlen should be both set or both None." bsz, q_len = query.size()[:2] kv_len = key.size(1) is_reshaped = query.dim() != 4 query, key, value = self.maybe_reshape_qkv_to_4d( query, key, value, bsz, q_len, kv_len ) output = vit_flash_attn_wrapper( q=query, k=key, v=value, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=bsz, is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA), fa_version=self._fa_version, ) if is_reshaped: output = output.reshape(bsz, q_len, -1) return output def forward_native( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: return self._forward_sdpa(query, key, value, cu_seqlens) def forward_cuda( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: if self.is_flash_attn_backend: return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: return self._forward_sdpa(query, key, value, cu_seqlens) else: raise ValueError( f"Unsupported multi-modal encoder attention backend for CUDA: " f"{self.attn_backend}." ) def forward_cpu( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: return self._forward_sdpa(query, key, value, cu_seqlens) def forward_xpu( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: assert self.is_flash_attn_backend, ( "XPU only supports FLASH_ATTN for vision attention." ) return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) def forward_tpu( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention ) -> torch.Tensor: assert self.attn_backend == AttentionBackendEnum.PALLAS, ( f"MMEncoderAttention on TPU only supports PALLAS backend, " f"but got {self.attn_backend}." ) if cu_seqlens is None: query, key, value = (x.transpose(1, 2) for x in (query, key, value)) from torch_xla.experimental.custom_kernel import flash_attention out = flash_attention(query, key, value, sm_scale=self.scale) out = out.transpose(1, 2) return out logger.warning_once( "PALLAS backend with cu_seqlens is not supported for ViT yet. ", "Falling back to SDPA implementation.", ) return self._forward_sdpa(query, key, value, cu_seqlens)