Commit bb1d0df8 authored by zhuwenwen's avatar zhuwenwen
Browse files

support flashmla backend

parent 23607ca0
...@@ -18,13 +18,16 @@ if current_platform.is_cuda(): ...@@ -18,13 +18,16 @@ if current_platform.is_cuda():
else: else:
_flashmla_C_AVAILABLE = False _flashmla_C_AVAILABLE = False
if current_platform.is_rocm():
import flash_mla_cuda
_flashmla_C_AVAILABLE = True
def is_flashmla_supported() -> Tuple[bool, Optional[str]]: def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
""" """
Return: is_supported_flag, unsupported_reason (optional). Return: is_supported_flag, unsupported_reason (optional).
""" """
if not current_platform.is_cuda(): if not (current_platform.is_cuda() or current_platform.is_rocm()):
return False, "FlashMLA is only supported on CUDA devices." return False, "FlashMLA is supported on CUDA and ROCM devices."
if current_platform.get_device_capability()[0] != 9: if current_platform.get_device_capability()[0] != 9:
return False, "FlashMLA is only supported on Hopper devices." return False, "FlashMLA is only supported on Hopper devices."
if not _flashmla_C_AVAILABLE: if not _flashmla_C_AVAILABLE:
...@@ -51,9 +54,14 @@ def get_mla_metadata( ...@@ -51,9 +54,14 @@ def get_mla_metadata(
dtype torch.int32. dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32.
""" """
return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, if current_platform.is_rocm():
num_heads_per_head_k, return flash_mla_cuda.get_mla_metadata(cache_seqlens,
num_heads_k) num_heads_per_head_k,
num_heads_k)
else:
return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
num_heads_per_head_k,
num_heads_k)
def flash_mla_with_kvcache( def flash_mla_with_kvcache(
...@@ -87,18 +95,32 @@ def flash_mla_with_kvcache( ...@@ -87,18 +95,32 @@ def flash_mla_with_kvcache(
""" """
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5) softmax_scale = q.shape[-1]**(-0.5)
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( if current_platform.is_rocm():
q, out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
k_cache, q,
None, k_cache,
head_dim_v, None,
cache_seqlens, head_dim_v,
block_table, cache_seqlens,
softmax_scale, block_table,
causal, softmax_scale,
tile_scheduler_metadata, causal,
num_splits, tile_scheduler_metadata,
) num_splits,
)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
)
return out, softmax_lse return out, softmax_lse
...@@ -112,4 +134,4 @@ def flash_mla_with_kvcache( ...@@ -112,4 +134,4 @@ def flash_mla_with_kvcache(
# @register_fake("_flashmla_C::fwd_kvcache_mla") # @register_fake("_flashmla_C::fwd_kvcache_mla")
# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: # def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]:
# return .... # return ....
# #
\ No newline at end of file
...@@ -138,8 +138,40 @@ class RocmPlatform(Platform): ...@@ -138,8 +138,40 @@ class RocmPlatform(Platform):
kv_cache_dtype, block_size, use_v1, kv_cache_dtype, block_size, use_v1,
use_mla) -> str: use_mla) -> str:
if use_mla: if use_mla:
logger.info("Using Triton MLA backend.") # logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend" # return "vllm.attention.backends.triton_mla.TritonMLABackend"
if selected_backend == _Backend.TRITON_MLA or block_size != 64:
if use_v1:
logger.info_once("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
else:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
else:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)
if not is_flashmla_supported()[0]:
logger.warning(
"FlashMLA backend is not supported due to %s",
is_flashmla_supported()[1])
elif block_size != 64:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
block_size)
else:
if use_v1:
logger.info_once(
"Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend")
else:
logger.info("Using FlashMLA backend.")
return ("vllm.attention.backends."
"flashmla.FlashMLABackend")
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
...@@ -311,4 +343,4 @@ class RocmPlatform(Platform): ...@@ -311,4 +343,4 @@ class RocmPlatform(Platform):
# We only enable custom allreduce for MI300 series # We only enable custom allreduce for MI300 series
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
supported_archs = ['gfx94'] supported_archs = ['gfx94']
return any(gfx in gcn_arch for gfx in supported_archs) return any(gfx in gcn_arch for gfx in supported_archs)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment