from typing import Optional, Tuple import torch import flash_mla_sm90 import flash_mla_sm100 def get_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, num_heads_k: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: cache_seqlens: (batch_size), dtype torch.int32. num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. num_heads_k: num_heads_k. Returns: tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ return flash_mla_sm90.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) def flash_mla_with_kvcache_sm90( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: q: (batch_size, seq_len_q, num_heads_q, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). block_table: (batch_size, max_num_blocks_per_seq), torch.int32. cache_seqlens: (batch_size), torch.int32. head_dim_v: Head dimension of v. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. Returns: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, softmax_lse = flash_mla_sm90.fwd_kvcache_mla( q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, ) return out, softmax_lse def _flash_attn_varlen_forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_qo: torch.Tensor, cu_seqlens_kv: torch.Tensor, max_seqlen_qo: int, max_seqlen_kv: int, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, causal: bool = False, softmax_scale: Optional[float] = None, is_varlen: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: qo_total_len, num_qo_heads, head_dim_qk = q.shape kv_total_len, num_kv_heads, head_dim_vo = v.shape mask_mode_code = 1 if causal else 0 if softmax_scale is None: softmax_scale = head_dim_qk ** (-0.5) if out is None: out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) if lse is None: # Make lse contiguous on seqlen dim lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) flash_mla_sm100.fwd( workspace_buffer, q, k, v, cu_seqlens_qo, cu_seqlens_kv, out, lse, mask_mode_code, softmax_scale, max_seqlen_qo, max_seqlen_kv, is_varlen, ) return out, lse def _flash_attn_varlen_backward( do: torch.Tensor, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor, lse: torch.Tensor, cu_seqlens_qo: torch.Tensor, cu_seqlens_kv: torch.Tensor, max_seqlen_qo: int, max_seqlen_kv: int, dq: Optional[torch.Tensor] = None, dk: Optional[torch.Tensor] = None, dv: Optional[torch.Tensor] = None, causal: bool = False, softmax_scale: Optional[float] = None, is_varlen: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: qo_total_len, num_qo_heads, head_dim_qk = q.shape kv_total_len, num_kv_heads, head_dim_vo = v.shape # TODO: fix bwd GQA if num_qo_heads != num_kv_heads: raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") mask_mode_code = 1 if causal else 0 if softmax_scale is None: softmax_scale = head_dim_qk ** (-0.5) if dq is None: dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) if dk is None: dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) if dv is None: dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 bs = cu_seqlens_qo.shape[0] - 1 workspace_bytes = 0 workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse if num_qo_heads != num_kv_heads: workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) flash_mla_sm100.bwd( workspace_buffer, do, q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv, dq, dk, dv, mask_mode_code, softmax_scale, max_seqlen_qo, max_seqlen_kv, is_varlen, ) return dq, dk, dv class FlashAttnVarlenFunc(torch.autograd.Function): def forward( ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_qo: torch.Tensor, cu_seqlens_kv: torch.Tensor, max_seqlen_qo: int, max_seqlen_kv: int, causal: bool = False, softmax_scale: Optional[float] = None, is_varlen: bool = True, ): out, lse = _flash_attn_varlen_forward( q, k, v, cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, causal=causal, softmax_scale=softmax_scale, is_varlen=is_varlen, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) ctx.max_seqlen_qo = max_seqlen_qo ctx.max_seqlen_kv = max_seqlen_kv ctx.causal = causal ctx.softmax_scale = softmax_scale ctx.is_varlen = is_varlen return out, lse def backward( ctx, do: torch.Tensor, dlse: torch.Tensor, ): del dlse # LSE doesn't support backward currently q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors dq, dk, dv = _flash_attn_varlen_backward( do, q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, causal=ctx.causal, softmax_scale=ctx.softmax_scale, is_varlen=ctx.is_varlen, ) return dq, dk, dv, None, None, None, None, None, None, None def flash_attn_varlen_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_qo: torch.Tensor, cu_seqlens_kv: torch.Tensor, max_seqlen_qo: int, max_seqlen_kv: int, dropout_p: float = 0.0, softmax_scale: Optional[float] = None, causal: bool = False, deterministic: bool = False, is_varlen: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: assert dropout_p == 0.0 assert not deterministic return FlashAttnVarlenFunc.apply( q, k, v, cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, causal, softmax_scale, is_varlen, ) def flash_attn_varlen_qkvpacked_func( qkv: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, head_dim_qk: int, dropout_p: float = 0.0, softmax_scale: Optional[float] = None, causal: bool = False, deterministic: bool = False, is_varlen: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: assert dropout_p == 0.0 assert not deterministic return FlashAttnVarlenFunc.apply( qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal, softmax_scale, is_varlen, ) def flash_attn_varlen_kvpacked_func( q: torch.Tensor, kv: torch.Tensor, cu_seqlens_qo: torch.Tensor, cu_seqlens_kv: torch.Tensor, max_seqlen_qo: int, max_seqlen_kv: int, head_dim_qk: int, dropout_p: float = 0.0, softmax_scale: Optional[float] = None, causal: bool = False, deterministic: bool = False, is_varlen: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: assert dropout_p == 0.0 assert not deterministic return FlashAttnVarlenFunc.apply( q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, causal, softmax_scale, is_varlen, ) def flash_mla_with_kvcache_sm100( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, softmax_scale: Optional[float] = None, causal: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: # TODO pass def flash_mla_with_kvcache( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: Optional[torch.Tensor] = None, num_splits: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: capability = torch.cuda.get_device_capability(q.device.index) if capability == (9, 0): return flash_mla_with_kvcache_sm90( q, k_cache, block_table, cache_seqlens, head_dim_v, tile_scheduler_metadata, num_splits, softmax_scale, causal, ) elif capability == (10, 0): raise ValueError(f"Unsupported device capability: {capability}") else: raise ValueError(f"Unsupported device capability: {capability}")