"vscode:/vscode.git/clone" did not exist on "dddbff46242a9292085e2ae3309dc559f242cad6"
Commit 1e9788a3 authored by zhuwenwen's avatar zhuwenwen
Browse files

update flashmla support

parent eb38edbc
......@@ -154,8 +154,8 @@ def flash_mla_with_kvcache(
else:
if current_platform.is_rocm():
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q, k_cache, block_table, cache_seqlens, head_dim_v, tile_scheduler_metadata,
num_splits, softmax_scale, causal)
q, k_cache, None, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
......
......@@ -208,33 +208,48 @@ class RocmPlatform(Platform):
# from vllm.v1.attention.backends.mla.rocm_aiter_mla import (
# is_aiter_mla_enabled)
if selected_backend is None:
# selected_backend = (_Backend.ROCM_AITER_MLA if
# is_aiter_mla_enabled() or block_size == 1
# else _Backend.TRITON_MLA)
selected_backend = _Backend.TRITON_MLA
if selected_backend == _Backend.TRITON_MLA:
if block_size != 1:
logger.info_once("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
if envs.VLLM_USE_FLASH_MLA:
from vllm.attention.ops.flashmla import is_flashmla_supported
use_flashmla = selected_backend == _Backend.FLASHMLA or (
selected_backend is None and is_flashmla_supported()[0])
if use_flashmla:
if block_size != 64:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
block_size)
else:
logger.info_once("Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend")
else:
if selected_backend is None:
# selected_backend = (_Backend.ROCM_AITER_MLA if
# is_aiter_mla_enabled() or block_size == 1
# else _Backend.TRITON_MLA)
selected_backend = _Backend.TRITON_MLA
if selected_backend == _Backend.TRITON_MLA:
if block_size != 1:
logger.info_once("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.")
# if selected_backend == _Backend.ROCM_AITER_MLA:
# if block_size == 1:
# logger.info("Using AITER MLA backend on V1 engine.")
# return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
# raise ValueError(
# f" The selected backend, {selected_backend.name},"
# f"does not support block size {block_size}."
# "(currently only supports block size 1)")
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}.")
# if selected_backend == _Backend.ROCM_AITER_MLA:
# if block_size == 1:
# logger.info("Using AITER MLA backend on V1 engine.")
# return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
# raise ValueError(
# f" The selected backend, {selected_backend.name},"
# f"does not support block size {block_size}."
# "(currently only supports block size 1)")
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend.")
f"is not MLA type while requested for MLA backend.")
if envs.VLLM_USE_V1:
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
......@@ -251,21 +266,6 @@ class RocmPlatform(Platform):
logger.info("Using Rocm/Aiter Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"rocm_attn.RocmAttentionBackend")
if envs.VLLM_USE_FLASH_MLA:
from vllm.attention.ops.flashmla import is_flashmla_supported
use_flashmla = selected_backend == _Backend.FLASHMLA or (
selected_backend is None and is_flashmla_supported()[0])
if use_flashmla:
if block_size != 64:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
block_size)
else:
logger.info_once("Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend")
else:
# default case, using triton unified attention
logger.info("Using Triton Attention backend on V1 engine.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment