Commit 1e9788a3 authored by zhuwenwen's avatar zhuwenwen
Browse files

update flashmla support

parent eb38edbc
...@@ -154,8 +154,8 @@ def flash_mla_with_kvcache( ...@@ -154,8 +154,8 @@ def flash_mla_with_kvcache(
else: else:
if current_platform.is_rocm(): if current_platform.is_rocm():
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q, k_cache, block_table, cache_seqlens, head_dim_v, tile_scheduler_metadata, q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale,
num_splits, softmax_scale, causal) causal, tile_scheduler_metadata, num_splits)
else: else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
......
...@@ -208,33 +208,48 @@ class RocmPlatform(Platform): ...@@ -208,33 +208,48 @@ class RocmPlatform(Platform):
# from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( # from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
# is_aiter_mla_enabled) # is_aiter_mla_enabled)
if envs.VLLM_USE_FLASH_MLA:
if selected_backend is None: from vllm.attention.ops.flashmla import is_flashmla_supported
# selected_backend = (_Backend.ROCM_AITER_MLA if use_flashmla = selected_backend == _Backend.FLASHMLA or (
# is_aiter_mla_enabled() or block_size == 1 selected_backend is None and is_flashmla_supported()[0])
# else _Backend.TRITON_MLA)
selected_backend = _Backend.TRITON_MLA if use_flashmla:
if block_size != 64:
if selected_backend == _Backend.TRITON_MLA: logger.warning(
if block_size != 1: "FlashMLA backend is not supported for block size %d"
logger.info_once("Using Triton MLA backend on V1 engine.") " (currently only supports block size 64).",
return ("vllm.v1.attention.backends.mla." block_size)
"triton_mla.TritonMLABackend") else:
logger.info_once("Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend")
else:
if selected_backend is None:
# selected_backend = (_Backend.ROCM_AITER_MLA if
# is_aiter_mla_enabled() or block_size == 1
# else _Backend.TRITON_MLA)
selected_backend = _Backend.TRITON_MLA
if selected_backend == _Backend.TRITON_MLA:
if block_size != 1:
logger.info_once("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.")
# if selected_backend == _Backend.ROCM_AITER_MLA:
# if block_size == 1:
# logger.info("Using AITER MLA backend on V1 engine.")
# return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
# raise ValueError(
# f" The selected backend, {selected_backend.name},"
# f"does not support block size {block_size}."
# "(currently only supports block size 1)")
raise ValueError( raise ValueError(
f" The selected backend, {selected_backend.name}," f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.") f"is not MLA type while requested for MLA backend.")
# if selected_backend == _Backend.ROCM_AITER_MLA:
# if block_size == 1:
# logger.info("Using AITER MLA backend on V1 engine.")
# return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
# raise ValueError(
# f" The selected backend, {selected_backend.name},"
# f"does not support block size {block_size}."
# "(currently only supports block size 1)")
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend.")
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \ if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
...@@ -251,21 +266,6 @@ class RocmPlatform(Platform): ...@@ -251,21 +266,6 @@ class RocmPlatform(Platform):
logger.info("Using Rocm/Aiter Attention backend on V1 engine.") logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
return ("vllm.v1.attention.backends." return ("vllm.v1.attention.backends."
"rocm_attn.RocmAttentionBackend") "rocm_attn.RocmAttentionBackend")
if envs.VLLM_USE_FLASH_MLA:
from vllm.attention.ops.flashmla import is_flashmla_supported
use_flashmla = selected_backend == _Backend.FLASHMLA or (
selected_backend is None and is_flashmla_supported()[0])
if use_flashmla:
if block_size != 64:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
block_size)
else:
logger.info_once("Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend")
else: else:
# default case, using triton unified attention # default case, using triton unified attention
logger.info("Using Triton Attention backend on V1 engine.") logger.info("Using Triton Attention backend on V1 engine.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment