import torch import triton import triton.language as tl import math from typing import Optional, Tuple # Simple implementation of flash attention for gfx926 def flash_mla_with_kvcache_triton( q: torch.Tensor, # batch_size x seqlen_q x num_heads_q x head_size_k k_cache: torch.Tensor, # num_blocks x page_block_size x num_heads_k x head_size_k v_cache: torch.Tensor, # num_blocks x page_block_size x num_heads_k x head_size_v block_table: torch.Tensor, # batch_size x max_num_blocks_per_seq cache_seqlens: torch.Tensor, # batch_size head_dim_v: int, softmax_scale: Optional[float] = None, causal: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """ Implementation of flash attention with KV cache for gfx926 architecture """ # Check inputs assert q.dim() == 4, "q must be 4-dimensional" assert k_cache.dim() == 4, "k_cache must be 4-dimensional" assert v_cache.dim() == 4, "v_cache must be 4-dimensional" assert block_table.dim() == 2, "block_table must be 2-dimensional" assert cache_seqlens.dim() == 1, "cache_seqlens must be 1-dimensional" # Get dimensions batch_size, seqlen_q, num_heads_q, head_size_k = q.shape num_blocks, page_block_size, num_heads_k, _ = k_cache.shape max_num_blocks_per_seq = block_table.shape[1] # Check head dimensions assert head_size_k == 576 or head_size_k == 512, "Only head_size_k == 576 or 512 is supported" assert head_dim_v == 512, "Only head_size_v == 512 is supported" assert num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query" assert page_block_size == 64, "Currently page_block_size must be 64" # Set default softmax scale if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_size_k) # Create output tensors out = torch.empty((batch_size, seqlen_q, num_heads_q, head_dim_v), dtype=q.dtype, device=q.device) lse = torch.empty((batch_size, num_heads_q, seqlen_q), dtype=torch.float32, device=q.device) # Use simplified implementation that works on all architectures for b in range(batch_size): seq_len_k = cache_seqlens[b].item() # Get query for this batch q_batch = q[b] # seqlen_q x num_heads_q x head_size_k # Calculate attention scores using the provided k_cache and block_table # For gfx926, we'll use a simplified approach # Get the relevant blocks from the block table num_k_blocks = (seq_len_k + page_block_size - 1) // page_block_size blocks = block_table[b, :num_k_blocks].long() # Ensure blocks are within bounds blocks = blocks % num_blocks # Gather the relevant key and value blocks k = k_cache[blocks].reshape(-1, num_heads_k, head_size_k)[:seq_len_k] v = v_cache[blocks].reshape(-1, num_heads_k, head_dim_v)[:seq_len_k] # Handle NaN values k[k != k] = 0.0 v[v != v] = 0.0 # Expand k and v if needed if num_heads_k < num_heads_q: k = k.repeat_interleave(num_heads_q // num_heads_k, dim=1) v = v.repeat_interleave(num_heads_q // num_heads_k, dim=1) # Calculate attention scores # Reshape k for correct matrix multiplication k_reshaped = k.permute(1, 0, 2) # num_heads_q x seq_len_k x head_size_k scores = torch.einsum('qhd,hkd->qhk', q_batch, k_reshaped) # seqlen_q x num_heads_q x seq_len_k scores *= softmax_scale # Apply causal mask if needed if causal and seqlen_q > 1: mask = torch.ones(seqlen_q, seq_len_k, device=q.device, dtype=torch.bool) mask = mask.tril(diagonal=seq_len_k - seqlen_q) scores = scores.masked_fill(mask.logical_not().unsqueeze(1), -float('inf')) # Apply softmax max_scores = scores.max(dim=-1, keepdim=True)[0] exp_scores = torch.exp(scores - max_scores) sum_exp = exp_scores.sum(dim=-1, keepdim=True) # Calculate lse current_lse = torch.log(sum_exp.squeeze(-1)) + max_scores.squeeze(-1) lse[b] = current_lse.transpose(0, 1) # Calculate attention weights attention = exp_scores / sum_exp attention = attention.to(torch.float32) # Calculate output # Reshape v for correct matrix multiplication v_reshaped = v.permute(1, 0, 2) # num_heads_q x seq_len_k x head_dim_v v_reshaped = v_reshaped.to(torch.float32) out[b] = torch.einsum('qhk,hkd->qhd', attention, v_reshaped) # seqlen_q x num_heads_q x head_dim_v out[b] = out[b].to(q.dtype) # Correct for q tokens which has no attendable k lonely_q_mask = (current_lse == -float('inf')) out[b][lonely_q_mask.unsqueeze(-1).broadcast_to(out[b].shape)] = 0.0 lse[b][lonely_q_mask.transpose(0, 1)] = float('inf') return out, lse def flash_mla_with_kvcache( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata=None, num_splits=None, softmax_scale: Optional[float] = None, causal: bool = False, is_fp8_kvcache: bool = False, indices: Optional[torch.Tensor] = None, attn_sink: Optional[torch.Tensor] = None, extra_k_cache: Optional[torch.Tensor] = None, extra_indices_in_kvcache: Optional[torch.Tensor] = None, topk_length: Optional[torch.Tensor] = None, extra_topk_length: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Wrapper function to match the original flash_mla interface """ # For dense attention (no indices provided) if indices is None: # Use the first head_dim_v dimensions of k_cache as v_cache # This matches the reference implementation head_dim_v_int = head_dim_v.item() if isinstance(head_dim_v, torch.Tensor) else head_dim_v v_cache = k_cache[..., :head_dim_v_int] out, lse = flash_mla_with_kvcache_triton( q, k_cache, v_cache, block_table, cache_seqlens, head_dim_v, softmax_scale, causal ) return out, lse else: # Sparse attention not implemented yet raise NotImplementedError("Sparse attention is not implemented in Triton version") def get_mla_metadata(*args, **kwargs) -> Tuple[dict, None]: """ Returns a dummy metadata object to match the original interface """ return {}, None