Commit 1e9788a3 authored by zhuwenwen's avatar zhuwenwen
Browse files

update flashmla support

parent eb38edbc
......@@ -154,8 +154,8 @@ def flash_mla_with_kvcache(
else:
if current_platform.is_rocm():
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q, k_cache, block_table, cache_seqlens, head_dim_v, tile_scheduler_metadata,
num_splits, softmax_scale, causal)
q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
......
......@@ -208,7 +208,22 @@ class RocmPlatform(Platform):
# from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
# is_aiter_mla_enabled)
if envs.VLLM_USE_FLASH_MLA:
from vllm.attention.ops.flashmla import is_flashmla_supported
use_flashmla = selected_backend == _Backend.FLASHMLA or (
selected_backend is None and is_flashmla_supported()[0])
if use_flashmla:
if block_size != 64:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
block_size)
else:
logger.info_once("Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend")
else:
if selected_backend is None:
# selected_backend = (_Backend.ROCM_AITER_MLA if
# is_aiter_mla_enabled() or block_size == 1
......@@ -251,21 +266,6 @@ class RocmPlatform(Platform):
logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"rocm_attn.RocmAttentionBackend")
if envs.VLLM_USE_FLASH_MLA:
from vllm.attention.ops.flashmla import is_flashmla_supported
use_flashmla = selected_backend == _Backend.FLASHMLA or (
selected_backend is None and is_flashmla_supported()[0])
if use_flashmla:
if block_size != 64:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
block_size)
else:
logger.info_once("Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend")
else:
# default case, using triton unified attention
logger.info("Using Triton Attention backend on V1 engine.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment