""" Attention layer with torch scaled_dot_product_attention and PagedAttention.""" from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type import torch from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataPerStage) from vllm.attention.ops.paged_attn import (PagedAttention, PagedAttentionMetadata) class TorchSDPABackend(AttentionBackend): @staticmethod def get_impl_cls() -> Type["TorchSDPABackendImpl"]: return TorchSDPABackendImpl @staticmethod def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata": return TorchSDPAMetadata(*args, **kwargs) @staticmethod def get_kv_cache_shape( num_blocks: int, block_size: int, num_kv_heads: int, head_size: int, ) -> Tuple[int, ...]: return PagedAttention.get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size) @staticmethod def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, src_to_dst: Dict[int, int], ) -> None: PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], src_to_dists: Dict[int, List[int]], ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) @dataclass class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): """Metadata for TorchSDPABackend. """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool prompt_lens: Optional[List[int]] prompt_lens_tensor: Optional[torch.Tensor] max_subquery_len: Optional[int] = None max_prompt_len: Optional[int] = None subquery_start_loc: Optional[torch.Tensor] = None seq_start_loc: Optional[torch.Tensor] = None use_cuda_graph: bool = False def __post_init__(self): # Set during the execution of the first attention op. # It is a list because it is needed to set per prompt # when alibi slopes is used. It is because of the limitation # from xformer API. # will not appear in the __repr__ and __init__ self.attn_bias: Optional[List[torch.Tensor]] = None class TorchSDPABackendImpl(AttentionImpl): def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window if alibi_slopes is not None: assert len(alibi_slopes) == num_heads alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes self.need_mask = (self.alibi_slopes is not None or self.sliding_window is not None) assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads suppored_head_sizes = PagedAttention.get_supported_head_sizes() if head_size not in suppored_head_sizes: raise ValueError( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata[TorchSDPAMetadata], kv_scale: float, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] """ num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) if kv_cache is not None: key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) PagedAttention.write_to_paged_cache(key, value, key_cache, value_cache, attn_metadata.slot_mapping, attn_metadata.kv_cache_dtype, kv_scale) num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. query = query[:num_prefill_tokens] key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens if prefill_meta := attn_metadata.prefill_metadata: if (kv_cache is None or prefill_meta.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) if prefill_meta.attn_bias is None: if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, prefill_meta.prompt_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( prefill_meta.prompt_lens, self.sliding_window, query.dtype) # type: ignore else: att_masks = [None] * len(prefill_meta.prompt_lens) prefill_meta.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) start = 0 out = torch.empty((num_tokens, self.num_heads, self.head_size), dtype=query.dtype) for prompt_len, mask in zip(prefill_meta.prompt_lens, prefill_meta.attn_bias): end = start + prompt_len sub_out = scaled_dot_product_attention( query[:, start:end, :], key[:, start:end, :], value[:, start:end, :], attn_mask=mask, dropout_p=0.0, is_causal=not self.need_mask, scale=self.scale).movedim(query.dim() - 2, 0) out[start:end, :, :] = sub_out start = end assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: # prefix-enabled attention raise RuntimeError( "Torch SDPA backend doesn't support prefix decoding.") if decode_meta := attn_metadata.decode_metadata: # Decoding run. out = PagedAttention.forward_decode( decode_query, key_cache, value_cache, decode_meta.block_tables, decode_meta.context_lens, decode_meta.max_context_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, self.alibi_slopes, kv_scale, ) assert out.shape == output[num_prefill_tokens:].shape output[num_prefill_tokens:] # Reshape the output tensor. return output.view(-1, self.num_heads * self.head_size) def _make_alibi_bias( alibi_slopes: torch.Tensor, dtype: torch.dtype, prompt_lens: List[int], ) -> List[torch.Tensor]: attn_biases = [] for prompt_len in prompt_lens: bias = torch.arange(prompt_len, dtype=dtype) # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(prompt_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. bias = bias[None, :] - bias[:, None] num_heads = alibi_slopes.shape[0] bias = bias[None, :].expand(num_heads, prompt_len, prompt_len) bias.mul_(alibi_slopes[:, None, None]) inf_mask = torch.empty( (1, prompt_len, prompt_len), dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) attn_biases.append((bias + inf_mask).to(dtype)) return attn_biases def _make_sliding_window_bias( prompt_lens: List[int], window_size: Optional[int], dtype: torch.dtype, ) -> List[torch.Tensor]: attn_biases = [] for prompt_len in prompt_lens: tensor = torch.full( (1, prompt_len, prompt_len), dtype=dtype, fill_value=1, ) shift = 0 mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore if window_size is not None: mask = torch.triu(mask, diagonal=shift - window_size + 1) mask = torch.log(mask) attn_biases.append(mask.to(dtype)) return attn_biases