"docs/vscode:/vscode.git/clone" did not exist on "c5d004aaaf3b2106d33974c673bec0568c18f762"
Unverified Commit d4fd2768 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Bugfix][Attention] Fix FlashInfer MLA block size logic (#24692)


Signed-off-by: default avatarMatthew Bonanni <mbonanni001@gmail.com>
parent 7a70a718
...@@ -146,6 +146,7 @@ class CudaPlatformBase(Platform): ...@@ -146,6 +146,7 @@ class CudaPlatformBase(Platform):
# required block_size. # required block_size.
use_flashmla = False use_flashmla = False
use_cutlass_mla = False use_cutlass_mla = False
use_flashinfer_mla = False
if envs.VLLM_ATTENTION_BACKEND is None: if envs.VLLM_ATTENTION_BACKEND is None:
# Default case # Default case
...@@ -164,6 +165,8 @@ class CudaPlatformBase(Platform): ...@@ -164,6 +165,8 @@ class CudaPlatformBase(Platform):
use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA") use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
use_cutlass_mla = ( use_cutlass_mla = (
envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA") envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA")
use_flashinfer_mla = (
envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA")
from vllm.attention.ops.flashmla import is_flashmla_supported from vllm.attention.ops.flashmla import is_flashmla_supported
if use_flashmla and is_flashmla_supported()[0] \ if use_flashmla and is_flashmla_supported()[0] \
...@@ -176,6 +179,11 @@ class CudaPlatformBase(Platform): ...@@ -176,6 +179,11 @@ class CudaPlatformBase(Platform):
cache_config.block_size = 128 cache_config.block_size = 128
logger.info("Forcing kv cache block size to 128 for " logger.info("Forcing kv cache block size to 128 for "
"CUTLASS_MLA backend.") "CUTLASS_MLA backend.")
if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLA "
"backend.")
# lazy import to avoid circular import # lazy import to avoid circular import
from vllm.config import CUDAGraphMode from vllm.config import CUDAGraphMode
...@@ -228,8 +236,9 @@ class CudaPlatformBase(Platform): ...@@ -228,8 +236,9 @@ class CudaPlatformBase(Platform):
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
selected_backend is None and cls.is_device_capability(100) selected_backend is None and cls.is_device_capability(100)
and block_size == 128) and block_size == 128)
use_flashinfermla = (selected_backend == _Backend.FLASHINFER_MLA use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
and cls.has_device_capability(100)) selected_backend is None and cls.is_device_capability(100)
and block_size in [32, 64])
use_flashmla = selected_backend in [ use_flashmla = selected_backend in [
_Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1 _Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1
] or (selected_backend is None and is_flashmla_supported()[0]) ] or (selected_backend is None and is_flashmla_supported()[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment