from abc import ABC, abstractmethod from dataclasses import dataclass, fields from typing import Any, Dict, List, Optional, Tuple, Type import torch class AttentionBackend(ABC): """Abstract class for attention backends.""" @staticmethod @abstractmethod def get_impl_cls() -> Type["AttentionImpl"]: raise NotImplementedError @staticmethod @abstractmethod def make_metadata(*args, **kwargs) -> "AttentionMetadata": raise NotImplementedError @staticmethod @abstractmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: raise NotImplementedError @staticmethod @abstractmethod def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, src_to_dst: Dict[int, int], ) -> None: raise NotImplementedError @staticmethod @abstractmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: Dict[int, List[int]], ) -> None: raise NotImplementedError @dataclass class AttentionMetadata: def asdict_zerocopy(self) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" # Note that if we add dataclasses as fields, they will need # similar handling. return { field.name: getattr(self, field.name) for field in fields(self) } class AttentionImpl(ABC): @abstractmethod def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, ) -> None: raise NotImplementedError @abstractmethod def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: raise NotImplementedError