Commit a3416fe1 authored by zhuwenwen's avatar zhuwenwen
Browse files

add flashmla support

parent b79e20fe
...@@ -27,14 +27,18 @@ if current_platform.is_cuda(): ...@@ -27,14 +27,18 @@ if current_platform.is_cuda():
_flashmla_extension_C_AVAILABLE = False _flashmla_extension_C_AVAILABLE = False
else: else:
_flashmla_extension_C_AVAILABLE = False _flashmla_extension_C_AVAILABLE = False
if current_platform.is_rocm():
import flash_mla_cuda
_flashmla_C_AVAILABLE = True
def is_flashmla_supported() -> Tuple[bool, Optional[str]]: def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
""" """
Return: is_supported_flag, unsupported_reason (optional). Return: is_supported_flag, unsupported_reason (optional).
""" """
if not current_platform.is_cuda(): if not (current_platform.is_cuda() or current_platform.is_rocm()):
return False, "FlashMLA is only supported on CUDA devices." return False, "FlashMLA is supported on CUDA and ROCM devices."
if current_platform.get_device_capability()[0] != 9: if current_platform.get_device_capability()[0] != 9:
return False, "FlashMLA is only supported on Hopper devices." return False, "FlashMLA is only supported on Hopper devices."
if not _flashmla_C_AVAILABLE: if not _flashmla_C_AVAILABLE:
...@@ -71,11 +75,18 @@ def get_mla_metadata( ...@@ -71,11 +75,18 @@ def get_mla_metadata(
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32. - num_splits: (batch_size + 1), dtype torch.int32.
""" """
return torch.ops._flashmla_C.get_mla_decoding_metadata(
if current_platform.is_rocm():
return flash_mla_cuda.get_mla_metadata(cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k)
else:
return torch.ops._flashmla_C.get_mla_decoding_metadata(
cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
is_fp8_kvcache, topk) is_fp8_kvcache, topk)
def flash_mla_with_kvcache( def flash_mla_with_kvcache(
q: torch.Tensor, q: torch.Tensor,
k_cache: torch.Tensor, k_cache: torch.Tensor,
...@@ -141,10 +152,15 @@ def flash_mla_with_kvcache( ...@@ -141,10 +152,15 @@ def flash_mla_with_kvcache(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k) causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
else: else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( if current_platform.is_rocm():
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache, q, k_cache, block_table, cache_seqlens, head_dim_v, tile_scheduler_metadata,
indices) num_splits, softmax_scale, causal)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
indices)
return out, softmax_lse return out, softmax_lse
......
...@@ -204,6 +204,7 @@ if TYPE_CHECKING: ...@@ -204,6 +204,7 @@ if TYPE_CHECKING:
VLLM_USE_NCCL_SYMM_MEM: bool = False VLLM_USE_NCCL_SYMM_MEM: bool = False
VLLM_NCCL_INCLUDE_PATH: Optional[str] = None VLLM_NCCL_INCLUDE_PATH: Optional[str] = None
VLLM_USE_FBGEMM: bool = False VLLM_USE_FBGEMM: bool = False
VLLM_USE_FLASH_MLA: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1469,6 +1470,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1469,6 +1470,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None), lambda: os.environ.get("VLLM_NCCL_INCLUDE_PATH", None),
# Flag to enable FBGemm kernels on model execution # Flag to enable FBGemm kernels on model execution
"VLLM_USE_FBGEMM": lambda: bool(int(os.getenv("VLLM_USE_FBGEMM", "0"))), "VLLM_USE_FBGEMM": lambda: bool(int(os.getenv("VLLM_USE_FBGEMM", "0"))),
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "0"))),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -136,33 +136,34 @@ def use_rocm_custom_paged_attention( ...@@ -136,33 +136,34 @@ def use_rocm_custom_paged_attention(
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None) -> bool: sinks: Optional[torch.Tensor] = None) -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName # GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"]) # ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"]) # ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
# custom paged attn always supported on V0. On V1, requires sliding window # custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy. # disabled due to observed numerical discrepancy.
if ON_GFX9: # if ON_GFX9:
return ((not envs.VLLM_USE_V1 or sliding_window == 0 # return ((not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1)) # or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16) # and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128) # and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32) # and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) # and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024 # and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) # and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN # and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
and envs.VLLM_ROCM_USE_AITER) and sinks is None) # and envs.VLLM_ROCM_USE_AITER) and sinks is None)
else: # else:
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0 # return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
or sliding_window == (-1, -1)) # or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16) # and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128 and block_size == 16 # and head_size == 128 and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16) # and (gqa_ratio >= 3 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024 and alibi_slopes is None # and max_seq_len <= 128 * 1024 and alibi_slopes is None
and kv_cache_dtype == "auto" # and kv_cache_dtype == "auto"
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None) # and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None)
return False
class RocmPlatform(Platform): class RocmPlatform(Platform):
...@@ -222,14 +223,15 @@ class RocmPlatform(Platform): ...@@ -222,14 +223,15 @@ class RocmPlatform(Platform):
raise ValueError( raise ValueError(
f" The selected backend, {selected_backend.name}," f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.") f"does not support block size {block_size}.")
if selected_backend == _Backend.ROCM_AITER_MLA: # if selected_backend == _Backend.ROCM_AITER_MLA:
if block_size == 1: # if block_size == 1:
logger.info("Using AITER MLA backend on V1 engine.") # logger.info("Using AITER MLA backend on V1 engine.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 # return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
raise ValueError( # raise ValueError(
f" The selected backend, {selected_backend.name}," # f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}." # f"does not support block size {block_size}."
"(currently only supports block size 1)") # "(currently only supports block size 1)")
raise ValueError( raise ValueError(
f" The selected backend, {selected_backend.name}," f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend.") f"is not MLA type while requested for MLA backend.")
...@@ -249,6 +251,21 @@ class RocmPlatform(Platform): ...@@ -249,6 +251,21 @@ class RocmPlatform(Platform):
logger.info("Using Rocm/Aiter Attention backend on V1 engine.") logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
return ("vllm.v1.attention.backends." return ("vllm.v1.attention.backends."
"rocm_attn.RocmAttentionBackend") "rocm_attn.RocmAttentionBackend")
if envs.VLLM_USE_FLASH_MLA:
from vllm.attention.ops.flashmla import is_flashmla_supported
use_flashmla = selected_backend == _Backend.FLASHMLA or (
selected_backend is None and is_flashmla_supported()[0])
if use_flashmla:
if block_size != 64:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
block_size)
else:
logger.info_once("Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend")
else: else:
# default case, using triton unified attention # default case, using triton unified attention
logger.info("Using Triton Attention backend on V1 engine.") logger.info("Using Triton Attention backend on V1 engine.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment