from typing import List, Optional, Type import torch from sageattention import sageattn from fastvideo.v1.attention.backends.abstract import ( AttentionBackend) # FlashAttentionMetadata, from fastvideo.v1.attention.backends.abstract import (AttentionImpl, AttentionMetadata) from fastvideo.v1.logger import init_logger logger = init_logger(__name__) class SageAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @staticmethod def get_supported_head_sizes() -> List[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @staticmethod def get_name() -> str: return "SAGE_ATTN" @staticmethod def get_impl_cls() -> Type["SageAttentionImpl"]: return SageAttentionImpl # @staticmethod # def get_metadata_cls() -> Type["AttentionMetadata"]: # return FlashAttentionMetadata class SageAttentionImpl(AttentionImpl): def __init__( self, num_heads: int, head_size: int, causal: bool, softmax_scale: float, num_kv_heads: Optional[int] = None, prefix: str = "", **extra_impl_args, ) -> None: self.causal = causal self.softmax_scale = softmax_scale self.dropout = extra_impl_args.get("dropout_p", 0.0) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: output = sageattn( query, key, value, # since input is (batch_size, seq_len, head_num, head_dim) tensor_layout="NHD", is_causal=self.causal) return output