# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py from typing import Optional, Tuple import torch from vllm.logger import init_logger from vllm.platforms import current_platform from vllm import envs logger = init_logger(__name__) if current_platform.is_cuda(): try: import vllm._flashmla_C # noqa: F401 _flashmla_C_AVAILABLE = True except ImportError: _flashmla_C_AVAILABLE = False else: _flashmla_C_AVAILABLE = False if current_platform.is_rocm(): import flash_mla_cuda _flashmla_C_AVAILABLE = True def is_flashmla_supported() -> Tuple[bool, Optional[str]]: """ Return: is_supported_flag, unsupported_reason (optional). """ if not (current_platform.is_cuda() or current_platform.is_rocm()): return False, "FlashMLA is supported on CUDA and ROCM devices." if current_platform.get_device_capability()[0] != 9: return False, "FlashMLA is only supported on Hopper devices." if not _flashmla_C_AVAILABLE: return False, "vllm._flashmla_C is not available, likely was not "\ "compiled due to insufficient nvcc version or a supported arch "\ "(only sm90a currently) was not in the list of target arches to "\ "compile for." return True, None def get_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, num_heads_k: int, num_heads_q: Optional[int] = None, is_fp8_kvcache: bool = False, topk: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: cache_seqlens: (batch_size), dtype torch.int32. num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. num_heads_k: num_heads_k. Return: tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ if current_platform.is_rocm(): return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k, num_heads_q, is_fp8_kvcache, topk) else: return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) def get_mla_decoding_metadata_dense_fp8( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, num_heads_k: int, num_heads_q : int = 16, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: cache_seqlens: (batch_size), dtype torch.int32. num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. num_heads_k: num_heads_k. Return: tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, num_heads_per_head_k, num_heads_k, num_heads_q) def flash_mla_with_kvcache( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, is_fp8_kvcache: bool = False, indices: Optional[torch.Tensor] = None, k_scale = None, kv_cache_dtype = "auto", ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: q: (batch_size, seq_len_q, num_heads_q, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). block_table: (batch_size, max_num_blocks_per_seq), torch.int32. cache_seqlens: (batch_size), torch.int32. head_dim_v: Head_dim of v. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: softmax_scale = q.shape[-1]**(-0.5) if current_platform.is_rocm(): if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2": kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla( q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, k_scale, kv_dtype, ) return out, softmax_lse out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache, indices, ) else: out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, ) return out, softmax_lse def flash_mla_with_kvcache_q_nope_pe( q_nope: torch.Tensor, q_pe: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, k_scale = None, kv_cache_dtype = "auto", ) -> Tuple[torch.Tensor, torch.Tensor]: if softmax_scale is None: softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5) if current_platform.is_rocm(): if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2": kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla( q_nope, q_pe, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, k_scale, kv_dtype, ) return out, softmax_lse out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_nope_pe( q_nope, q_pe, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, ) return out, softmax_lse def flash_mla_with_kvcache_fp8( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, descale_q: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: q: (batch_size, seq_len_q, num_heads_q, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). block_table: (batch_size, max_num_blocks_per_seq), torch.int32. cache_seqlens: (batch_size), torch.int32. head_dim_v: Head_dim of v. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_decoding_metadata_dense_fp8. num_splits: (batch_size + 1), torch.int32, return by get_mla_decoding_metadata_dense_fp8. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: softmax_scale = q.shape[-1]**(-0.5) out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8( q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, descale_q, descale_k, ) return out, softmax_lse def flash_mla_with_kvcache_fp8_with_cat( q_nope: torch.Tensor, q_pe: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, descale_q: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: q_nope: (batch_size, seq_len_q, num_heads_q, 512). q_pe: (batch_size, seq_len_q, num_heads_q, 64). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). block_table: (batch_size, max_num_blocks_per_seq), torch.int32. cache_seqlens: (batch_size), torch.int32. head_dim_v: Head dimension of v. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. Returns: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5) out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat( q_nope, q_pe, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, descale_q, descale_k, ) return out, softmax_lse # # TODO: Add fake functions # # @register_fake("_flashmla_C::get_mla_metadata") # def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: # return .... # # @register_fake("_flashmla_C::fwd_kvcache_mla") # def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: # return .... #