Unverified Commit 8ebfacaa authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention][MLA] Make `FLASHINFER_MLA` the default MLA backend on Blackwell,...


[Attention][MLA] Make `FLASHINFER_MLA` the default MLA backend on Blackwell, and TRTLLM the default prefill (#32339)
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent b89275d0
...@@ -35,7 +35,7 @@ class AttentionConfig: ...@@ -35,7 +35,7 @@ class AttentionConfig:
use_cudnn_prefill: bool = False use_cudnn_prefill: bool = False
"""Whether to use cudnn prefill.""" """Whether to use cudnn prefill."""
use_trtllm_ragged_deepseek_prefill: bool = False use_trtllm_ragged_deepseek_prefill: bool = True
"""Whether to use TRTLLM ragged deepseek prefill.""" """Whether to use TRTLLM ragged deepseek prefill."""
use_trtllm_attention: bool | None = None use_trtllm_attention: bool | None = None
......
...@@ -450,7 +450,6 @@ def use_flashinfer_prefill() -> bool: ...@@ -450,7 +450,6 @@ def use_flashinfer_prefill() -> bool:
not vllm_config.attention_config.disable_flashinfer_prefill not vllm_config.attention_config.disable_flashinfer_prefill
and flashinfer_available and flashinfer_available
and not vllm_config.attention_config.use_cudnn_prefill and not vllm_config.attention_config.use_cudnn_prefill
and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
and current_platform.is_device_capability_family(100) and current_platform.is_device_capability_family(100)
) )
...@@ -1294,25 +1293,27 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1294,25 +1293,27 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if use_flashinfer_prefill(): if use_trtllm_ragged_deepseek_prefill():
logger.debug_once("Using FlashInfer prefill for MLA") logger.info_once(
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi "Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local"
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi )
self._pad_v = False
elif use_trtllm_ragged_deepseek_prefill():
logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA")
self._run_prefill_context_chunk = ( self._run_prefill_context_chunk = (
self._run_prefill_context_chunk_trtllm_ragged self._run_prefill_context_chunk_trtllm_ragged
) )
self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged
self._pad_v = False self._pad_v = False
elif use_flashinfer_prefill():
logger.info_once("Using FlashInfer prefill for MLA")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
self._pad_v = False
elif use_cudnn_prefill(): elif use_cudnn_prefill():
logger.debug_once("Using CUDNN prefill for MLA") logger.info_once("Using CUDNN prefill for MLA", scope="local")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
self._pad_v = False self._pad_v = False
else: # Use FlashAttention else: # Use FlashAttention
logger.debug_once("Using FlashAttention prefill for MLA") logger.info_once("Using FlashAttention prefill for MLA", scope="local")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa
......
...@@ -50,8 +50,8 @@ def _get_backend_priorities( ...@@ -50,8 +50,8 @@ def _get_backend_priorities(
if use_mla: if use_mla:
if device_capability.major == 10: if device_capability.major == 10:
return [ return [
AttentionBackendEnum.CUTLASS_MLA,
AttentionBackendEnum.FLASHINFER_MLA, AttentionBackendEnum.FLASHINFER_MLA,
AttentionBackendEnum.CUTLASS_MLA,
AttentionBackendEnum.FLASH_ATTN_MLA, AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.FLASHMLA, AttentionBackendEnum.FLASHMLA,
AttentionBackendEnum.TRITON_MLA, AttentionBackendEnum.TRITON_MLA,
...@@ -183,12 +183,12 @@ class CudaPlatformBase(Platform): ...@@ -183,12 +183,12 @@ class CudaPlatformBase(Platform):
if vllm_config.attention_config.backend is None: if vllm_config.attention_config.backend is None:
# Default case # Default case
if cls.is_device_capability_family(100) and not use_sparse: if cls.is_device_capability_family(100) and not use_sparse:
# Blackwell => Force CutlassMLA (unless sparse, i.e. DSv3.2). # Blackwell => Force FlashInferMLA (unless sparse, i.e. DSv3.2).
use_cutlass_mla = True use_flashinfer_mla = True
# Set the backend in AttentionConfig so it's used during # Set the backend in AttentionConfig so it's used during
# backend selection # backend selection
vllm_config.attention_config.backend = ( vllm_config.attention_config.backend = (
AttentionBackendEnum.CUTLASS_MLA AttentionBackendEnum.FLASHINFER_MLA
) )
else: else:
# Not Blackwell # Not Blackwell
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment