Unverified Commit b0d54194 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] Default to FlashMLA backend for MLA (#14451)


Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent 5f0b53c6
...@@ -112,6 +112,7 @@ class CudaPlatformBase(Platform): ...@@ -112,6 +112,7 @@ class CudaPlatformBase(Platform):
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step: if scheduler_config.is_multi_step:
...@@ -142,14 +143,21 @@ class CudaPlatformBase(Platform): ...@@ -142,14 +143,21 @@ class CudaPlatformBase(Platform):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None: if cache_config and cache_config.block_size is None:
cache_config.block_size = 16 cache_config.block_size = 16
# TODO(lucas): handle this more gracefully # TODO(lucas): handle this more gracefully
if envs.VLLM_ATTENTION_BACKEND is not None \ # Note: model_config may be None during testing
and envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" \ if model_config is not None and model_config.use_mla:
# if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
# we default to FlashMLA backend, so we need to force the blocksize
# here
use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \
or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
from vllm.attention.backends.flashmla import is_flashmla_supported
if use_flashmla and is_flashmla_supported()[0] \
and cache_config.block_size != 64: and cache_config.block_size != 64:
cache_config.block_size = 64 cache_config.block_size = 64
logger.info( logger.info(
"FlashMLA: Forcing kv cache block size to 64 since this" "Forcing kv cache block size to 64 for FlashMLA backend.")
" is currently the only block size supported by the kernel.")
if (parallel_config.data_parallel_size > 1 if (parallel_config.data_parallel_size > 1
and compilation_config.use_cudagraph): and compilation_config.use_cudagraph):
...@@ -173,7 +181,15 @@ class CudaPlatformBase(Platform): ...@@ -173,7 +181,15 @@ class CudaPlatformBase(Platform):
if use_mla: if use_mla:
# TODO(lucas): refactor to be more concise # TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here # we should probably consider factoring out V1 here
if selected_backend == _Backend.FLASHMLA: if selected_backend == _Backend.TRITON_MLA or block_size != 64:
if use_v1:
logger.info_once("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
else:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
else:
from vllm.attention.backends.flashmla import ( from vllm.attention.backends.flashmla import (
is_flashmla_supported) is_flashmla_supported)
if not is_flashmla_supported()[0]: if not is_flashmla_supported()[0]:
...@@ -195,14 +211,6 @@ class CudaPlatformBase(Platform): ...@@ -195,14 +211,6 @@ class CudaPlatformBase(Platform):
logger.info("Using FlashMLA backend.") logger.info("Using FlashMLA backend.")
return ("vllm.attention.backends." return ("vllm.attention.backends."
"flashmla.FlashMLABackend") "flashmla.FlashMLABackend")
if use_v1:
logger.info_once("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
else:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
if use_v1: if use_v1:
logger.info_once("Using Flash Attention backend on V1 engine.") logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends.flash_attn." return ("vllm.v1.attention.backends.flash_attn."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment