Commit a3416fe1 authored by zhuwenwen's avatar zhuwenwen
Browse files

add flashmla support

parent b79e20fe
......@@ -28,13 +28,17 @@ if current_platform.is_cuda():
else:
_flashmla_extension_C_AVAILABLE = False
if current_platform.is_rocm():
import flash_mla_cuda
_flashmla_C_AVAILABLE = True
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
"""
Return: is_supported_flag, unsupported_reason (optional).
"""
if not current_platform.is_cuda():
return False, "FlashMLA is only supported on CUDA devices."
if not (current_platform.is_cuda() or current_platform.is_rocm()):
return False, "FlashMLA is supported on CUDA and ROCM devices."
if current_platform.get_device_capability()[0] != 9:
return False, "FlashMLA is only supported on Hopper devices."
if not _flashmla_C_AVAILABLE:
......@@ -71,11 +75,18 @@ def get_mla_metadata(
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
"""
if current_platform.is_rocm():
return flash_mla_cuda.get_mla_metadata(cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k)
else:
return torch.ops._flashmla_C.get_mla_decoding_metadata(
cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
is_fp8_kvcache, topk)
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
......@@ -140,6 +151,11 @@ def flash_mla_with_kvcache(
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
else:
if current_platform.is_rocm():
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q, k_cache, block_table, cache_seqlens, head_dim_v, tile_scheduler_metadata,
num_splits, softmax_scale, causal)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
......
......@@ -204,6 +204,7 @@ if TYPE_CHECKING:
VLLM_USE_NCCL_SYMM_MEM: bool = False
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
VLLM_USE_FBGEMM: bool = False
VLLM_USE_FLASH_MLA: bool = False
def get_default_cache_root():
......@@ -1469,6 +1470,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None),
# Flag to enable FBGemm kernels on model execution
"VLLM_USE_FBGEMM": lambda: bool(int(os.getenv("VLLM_USE_FBGEMM", "0"))),
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "0"))),
}
# --8<-- [end:env-vars-definition]
......
......@@ -136,33 +136,34 @@ def use_rocm_custom_paged_attention(
alibi_slopes: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None) -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
# GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
# ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
# ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
if ON_GFX9:
return ((not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
and envs.VLLM_ROCM_USE_AITER) and sinks is None)
else:
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128 and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024 and alibi_slopes is None
and kv_cache_dtype == "auto"
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
# if ON_GFX9:
# return ((not envs.VLLM_USE_V1 or sliding_window == 0
# or sliding_window == (-1, -1))
# and (qtype == torch.half or qtype == torch.bfloat16)
# and (head_size == 64 or head_size == 128)
# and (block_size == 16 or block_size == 32)
# and (gqa_ratio >= 1 and gqa_ratio <= 16)
# and max_seq_len <= 128 * 1024
# and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
# and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
# and envs.VLLM_ROCM_USE_AITER) and sinks is None)
# else:
# return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
# or sliding_window == (-1, -1))
# and (qtype == torch.half or qtype == torch.bfloat16)
# and head_size == 128 and block_size == 16
# and (gqa_ratio >= 3 and gqa_ratio <= 16)
# and max_seq_len <= 128 * 1024 and alibi_slopes is None
# and kv_cache_dtype == "auto"
# and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
return False
class RocmPlatform(Platform):
......@@ -222,14 +223,15 @@ class RocmPlatform(Platform):
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.")
if selected_backend == _Backend.ROCM_AITER_MLA:
if block_size == 1:
logger.info("Using AITER MLA backend on V1 engine.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
"(currently only supports block size 1)")
# if selected_backend == _Backend.ROCM_AITER_MLA:
# if block_size == 1:
# logger.info("Using AITER MLA backend on V1 engine.")
# return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
# raise ValueError(
# f" The selected backend, {selected_backend.name},"
# f"does not support block size {block_size}."
# "(currently only supports block size 1)")
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend.")
......@@ -249,6 +251,21 @@ class RocmPlatform(Platform):
logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"rocm_attn.RocmAttentionBackend")
if envs.VLLM_USE_FLASH_MLA:
from vllm.attention.ops.flashmla import is_flashmla_supported
use_flashmla = selected_backend == _Backend.FLASHMLA or (
selected_backend is None and is_flashmla_supported()[0])
if use_flashmla:
if block_size != 64:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
block_size)
else:
logger.info_once("Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend")
else:
# default case, using triton unified attention
logger.info("Using Triton Attention backend on V1 engine.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment