Commit a0e22db9 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_FLASH_MLA to use flashmla backend

parent c91e1a7c
...@@ -18,6 +18,7 @@ if TYPE_CHECKING: ...@@ -18,6 +18,7 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_TC_PAGED_ATTN: bool = False VLLM_USE_TC_PAGED_ATTN: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
...@@ -654,6 +655,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -654,6 +655,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRITON_OPT_MLA": "VLLM_USE_TRITON_OPT_MLA":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_OPT_MLA", "0"))), lambda: bool(int(os.getenv("VLLM_USE_TRITON_OPT_MLA", "0"))),
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))),
# If set, vLLM will use the Triton implementation of moe_align_block_size, # If set, vLLM will use the Triton implementation of moe_align_block_size,
# i.e. moe_align_block_size_triton in fused_moe.py. # i.e. moe_align_block_size_triton in fused_moe.py.
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON": "VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON":
......
...@@ -138,9 +138,6 @@ class RocmPlatform(Platform): ...@@ -138,9 +138,6 @@ class RocmPlatform(Platform):
kv_cache_dtype, block_size, use_v1, kv_cache_dtype, block_size, use_v1,
use_mla) -> str: use_mla) -> str:
if use_mla: if use_mla:
# logger.info("Using Triton MLA backend.")
# return "vllm.attention.backends.triton_mla.TritonMLABackend"
if selected_backend == _Backend.TRITON_MLA or block_size != 64: if selected_backend == _Backend.TRITON_MLA or block_size != 64:
if use_v1: if use_v1:
logger.info_once("Using Triton MLA backend on V1 engine.") logger.info_once("Using Triton MLA backend on V1 engine.")
...@@ -150,6 +147,7 @@ class RocmPlatform(Platform): ...@@ -150,6 +147,7 @@ class RocmPlatform(Platform):
logger.info("Using Triton MLA backend.") logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend" return "vllm.attention.backends.triton_mla.TritonMLABackend"
else: else:
if envs.VLLM_USE_FLASH_MLA:
from vllm.attention.backends.flashmla import ( from vllm.attention.backends.flashmla import (
is_flashmla_supported) is_flashmla_supported)
if not is_flashmla_supported()[0]: if not is_flashmla_supported()[0]:
...@@ -171,6 +169,9 @@ class RocmPlatform(Platform): ...@@ -171,6 +169,9 @@ class RocmPlatform(Platform):
logger.info("Using FlashMLA backend.") logger.info("Using FlashMLA backend.")
return ("vllm.attention.backends." return ("vllm.attention.backends."
"flashmla.FlashMLABackend") "flashmla.FlashMLABackend")
else:
logger.info("Using Triton MLA backend (block size 64).")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment