# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from vllm.logger import init_logger from vllm_omni.diffusion.attention.backends.abstract import ( AttentionBackend, AttentionImpl, AttentionMetadata, ) logger = init_logger(__name__) try: from sageattention import sageattn except ImportError: logger.warning( "SageAttentionBackend is not available. You may install sage-attention" " by pip install git+https://github.com/thu-ml/SageAttention.git" ) raise ImportError # TODO add sage3 attention backend class SageAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @staticmethod def get_supported_head_sizes() -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @staticmethod def get_name() -> str: return "SAGE_ATTN" @staticmethod def get_impl_cls() -> type["SageAttentionImpl"]: return SageAttentionImpl class SageAttentionImpl(AttentionImpl): def __init__( self, num_heads: int, head_size: int, softmax_scale: float, causal: bool = False, num_kv_heads: int | None = None, prefix: str = "", **extra_impl_args, ) -> None: self.causal = causal self.softmax_scale = softmax_scale def forward_cuda( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AttentionMetadata = None, ) -> torch.Tensor: output = sageattn( query, key, value, tensor_layout="NHD", is_causal=self.causal, sm_scale=self.softmax_scale, ) return output