# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py from typing import Optional, Tuple import torch from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.platforms.rocm import get_gcn_arch_name logger = init_logger(__name__) if current_platform.is_cuda(): try: import vllm._flashmla_C # noqa: F401 _flashmla_C_AVAILABLE = True except ImportError: _flashmla_C_AVAILABLE = False else: _flashmla_C_AVAILABLE = False if current_platform.is_cuda(): try: import vllm._flashmla_extension_C # noqa: F401 _flashmla_extension_C_AVAILABLE = True except ImportError: _flashmla_extension_C_AVAILABLE = False else: _flashmla_extension_C_AVAILABLE = False if current_platform.is_rocm(): # import flash_mla.cuda as flash_mla_cuda from flash_mla.flash_mla_interface import flash_mla_cuda _flashmla_C_AVAILABLE = True _flashmla_extension_C_AVAILABLE = True def _is_flashmla_available() -> tuple[bool, str | None]: if not _flashmla_C_AVAILABLE: return ( False, "vllm._flashmla_C is not available, likely was not " "compiled due to insufficient nvcc version or a supported arch " "was not in the list of target arches to compile for.", ) if not _flashmla_extension_C_AVAILABLE or not current_platform.is_rocm(): return ( False, "vllm._flashmla_extension_C is not available, likely " "was not compiled due to a build error.", ) return True, None def is_flashmla_dense_supported() -> tuple[bool, str | None]: """ Return: is_supported_flag, unsupported_reason (optional). """ is_availble, maybe_reason = _is_flashmla_available() if not is_availble: return False, maybe_reason if not current_platform.is_device_capability_family(90): return False, "FlashMLA Dense is only supported on Hopper devices." return True, None def is_flashmla_sparse_supported() -> tuple[bool, str | None]: """ Return: is_supported_flag, unsupported_reason (optional). """ is_availble, maybe_reason = _is_flashmla_available() if not is_availble: return False, maybe_reason if not ( current_platform.is_device_capability_family(90) or current_platform.is_device_capability_family(100) ): return ( False, "FlashMLA Sparse is only supported on Hopper and Blackwell devices.", ) return True, None def _raise_flashmla_unavailable(*_args, **_kwargs): _, reason = _is_flashmla_available() raise RuntimeError(reason or "FlashMLA is not available") if _is_flashmla_available()[0]: if current_platform.is_rocm(): from flash_mla.flash_mla_interface import ( # noqa: F401 FlashMLASchedMeta, # flash_attn_varlen_func, # flash_attn_varlen_kvpacked_func, # flash_attn_varlen_qkvpacked_func, flash_mla_sparse_fwd, flash_mla_with_kvcache, get_mla_metadata, ) else: from vllm.third_party.flashmla.flash_mla_interface import ( # noqa: F401 FlashMLASchedMeta, flash_attn_varlen_func, flash_attn_varlen_kvpacked_func, flash_attn_varlen_qkvpacked_func, flash_mla_sparse_fwd, flash_mla_with_kvcache, get_mla_metadata, ) else: class FlashMLASchedMeta: # type: ignore[no-redef] pass flash_attn_varlen_func = _raise_flashmla_unavailable # type: ignore[assignment] flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment] flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment] flash_mla_sparse_fwd = _raise_flashmla_unavailable # type: ignore[assignment] flash_mla_with_kvcache = _raise_flashmla_unavailable # type: ignore[assignment] get_mla_metadata = _raise_flashmla_unavailable # type: ignore[assignment] def get_mla_metadata_dense_fp8( cache_seqlens: torch.Tensor, num_q_tokens_per_head_k: int, num_heads_k: int, ) -> tuple[torch.Tensor, torch.Tensor]: if not _is_flashmla_available()[0]: _raise_flashmla_unavailable() if current_platform.is_rocm(): return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8( cache_seqlens, num_q_tokens_per_head_k, num_heads_k, ) else: return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8( cache_seqlens, num_q_tokens_per_head_k, num_heads_k, ) def flash_mla_with_kvcache_fp8( q: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: float | None = None, causal: bool = False, descale_q: torch.Tensor | None = None, descale_k: torch.Tensor | None = None, kv_cache_dtype: str | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if not _is_flashmla_available()[0]: _raise_flashmla_unavailable() if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if current_platform.is_rocm(): if get_gcn_arch_name() == "gfx938": out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8( q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, descale_q, descale_k, ) else: out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla( q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, descale_k, kv_cache_dtype, ) else: out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8( q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, descale_q, descale_k, ) return out, softmax_lse def flash_mla_with_kvcache_fp8_with_cat( q_nope: torch.Tensor, q_pe: torch.Tensor, k_cache: torch.Tensor, block_table: torch.Tensor, cache_seqlens: torch.Tensor, head_dim_v: int, tile_scheduler_metadata: torch.Tensor, num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, descale_q: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: q_nope: (batch_size, seq_len_q, num_heads_q, 512). q_pe: (batch_size, seq_len_q, num_heads_q, 64). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). block_table: (batch_size, max_num_blocks_per_seq), torch.int32. cache_seqlens: (batch_size), torch.int32. head_dim_v: Head dimension of v. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. Returns: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5) out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat( q_nope, q_pe, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale, causal, tile_scheduler_metadata, num_splits, descale_q, descale_k, ) return out, softmax_lse # # TODO: Add fake functions # # @register_fake("_flashmla_C::get_mla_metadata") # def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: # return .... # # @register_fake("_flashmla_C::fwd_kvcache_mla") # def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: # return .... #