Unverified Commit b36adfa3 authored by Wei Zhao's avatar Wei Zhao Committed by GitHub
Browse files

[Perf] Set Flashinfer sparse MLA as default backend for FP8 kv cache (#37252)


Signed-off-by: default avatarwzhao18 <wzhao18.sz@gmail.com>
parent e78821b4
...@@ -127,8 +127,8 @@ Priority is **1 = highest** (tried first). ...@@ -127,8 +127,8 @@ Priority is **1 = highest** (tried first).
| 3 | `FLASH_ATTN_MLA` | | 3 | `FLASH_ATTN_MLA` |
| 4 | `FLASHMLA` | | 4 | `FLASHMLA` |
| 5 | `TRITON_MLA` | | 5 | `TRITON_MLA` |
| 6 | `FLASHMLA_SPARSE` | | 6 | `FLASHINFER_MLA_SPARSE`**\*** |
| 7 | `FLASHINFER_MLA_SPARSE` | | 7 | `FLASHMLA_SPARSE` |
**Ampere/Hopper (SM 8.x-9.x):** **Ampere/Hopper (SM 8.x-9.x):**
...@@ -140,6 +140,8 @@ Priority is **1 = highest** (tried first). ...@@ -140,6 +140,8 @@ Priority is **1 = highest** (tried first).
| 4 | `TRITON_MLA` | | 4 | `TRITON_MLA` |
| 5 | `FLASHMLA_SPARSE` | | 5 | `FLASHMLA_SPARSE` |
> **\*** For sparse MLA, FP8 KV cache always prefers `FLASHINFER_MLA_SPARSE`. With BF16 KV cache, `FLASHINFER_MLA_SPARSE` is preferred for low query-head counts (<= 16), while `FLASHMLA_SPARSE` is preferred otherwise.
>
> **Note:** ROCm and CPU platforms have their own selection logic. See the platform-specific documentation for details. > **Note:** ROCm and CPU platforms have their own selection logic. See the platform-specific documentation for details.
## Legend ## Legend
......
...@@ -1262,14 +1262,23 @@ When no backend is specified (the default): ...@@ -1262,14 +1262,23 @@ When no backend is specified (the default):
""" """
def _priority_table(title: str, backends: list[str]) -> list[str]: def _priority_table(
title: str,
backends: list[str],
annotations: dict[str, str] | None = None,
) -> list[str]:
"""Generate a priority table for a list of backends.""" """Generate a priority table for a list of backends."""
def _fmt(b: str) -> str:
suffix = annotations.get(b, "") if annotations else ""
return f"`{b}`{suffix}"
return [ return [
f"**{title}:**", f"**{title}:**",
"", "",
"| Priority | Backend |", "| Priority | Backend |",
"| -------- | ------- |", "| -------- | ------- |",
*[f"| {i} | `{b}` |" for i, b in enumerate(backends, 1)], *[f"| {i} | {_fmt(b)} |" for i, b in enumerate(backends, 1)],
"", "",
] ]
...@@ -1298,11 +1307,25 @@ def generate_priority_section(priorities: dict[str, list[str]]) -> str: ...@@ -1298,11 +1307,25 @@ def generate_priority_section(priorities: dict[str, list[str]]) -> str:
lines.extend(["### MLA Attention (DeepSeek-style)", ""]) lines.extend(["### MLA Attention (DeepSeek-style)", ""])
mla_sm100_annotations = {
"FLASHINFER_MLA_SPARSE": "**\\***",
}
if "mla_sm100" in priorities: if "mla_sm100" in priorities:
lines.extend(_priority_table(sm100, priorities["mla_sm100"])) lines.extend(
_priority_table(sm100, priorities["mla_sm100"], mla_sm100_annotations)
)
if "mla_default" in priorities: if "mla_default" in priorities:
lines.extend(_priority_table(ampere, priorities["mla_default"])) lines.extend(_priority_table(ampere, priorities["mla_default"]))
if "mla_sm100" in priorities:
lines.append(
"> **\\*** For sparse MLA, FP8 KV cache always prefers "
"`FLASHINFER_MLA_SPARSE`. With BF16 KV cache, `FLASHINFER_MLA_SPARSE` "
"is preferred for low query-head counts (<= 16), while "
"`FLASHMLA_SPARSE` is preferred otherwise."
)
lines.append(">")
lines.append( lines.append(
"> **Note:** ROCm and CPU platforms have their own selection logic. " "> **Note:** ROCm and CPU platforms have their own selection logic. "
"See the platform-specific documentation for details." "See the platform-specific documentation for details."
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
pynvml. However, it should not initialize cuda context. pynvml. However, it should not initialize cuda context.
""" """
from __future__ import annotations
import os import os
from collections.abc import Callable from collections.abc import Callable
from datetime import timedelta from datetime import timedelta
...@@ -49,21 +51,34 @@ def _get_backend_priorities( ...@@ -49,21 +51,34 @@ def _get_backend_priorities(
use_mla: bool, use_mla: bool,
device_capability: DeviceCapability, device_capability: DeviceCapability,
num_heads: int | None = None, num_heads: int | None = None,
kv_cache_dtype: CacheDType | None = None,
) -> list[AttentionBackendEnum]: ) -> list[AttentionBackendEnum]:
"""Get backend priorities with lazy import to avoid circular dependency.""" """Get backend priorities with lazy import to avoid circular dependency."""
if use_mla: if use_mla:
if device_capability.major == 10: if device_capability.major == 10:
# Prefer FlashInfer at low head counts (FlashMLA uses padding) # Sparse MLA backend priorities
if num_heads is not None and num_heads <= 16: # See https://github.com/vllm-project/vllm/issues/35807 for
# benchmark results
if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
# Prefer FlashInfer for fp8 kv cache
sparse_backends = [ sparse_backends = [
AttentionBackendEnum.FLASHINFER_MLA_SPARSE, AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
AttentionBackendEnum.FLASHMLA_SPARSE, AttentionBackendEnum.FLASHMLA_SPARSE,
] ]
else: else:
sparse_backends = [ # BF16 KV Cache
AttentionBackendEnum.FLASHMLA_SPARSE, # Prefer FlashInfer at low head counts (FlashMLA uses padding)
AttentionBackendEnum.FLASHINFER_MLA_SPARSE, if num_heads is not None and num_heads <= 16:
] sparse_backends = [
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
AttentionBackendEnum.FLASHMLA_SPARSE,
]
else:
sparse_backends = [
AttentionBackendEnum.FLASHMLA_SPARSE,
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
]
return [ return [
AttentionBackendEnum.FLASHINFER_MLA, AttentionBackendEnum.FLASHINFER_MLA,
AttentionBackendEnum.CUTLASS_MLA, AttentionBackendEnum.CUTLASS_MLA,
...@@ -165,7 +180,7 @@ class CudaPlatformBase(Platform): ...@@ -165,7 +180,7 @@ class CudaPlatformBase(Platform):
pass pass
@classmethod @classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
...@@ -198,11 +213,11 @@ class CudaPlatformBase(Platform): ...@@ -198,11 +213,11 @@ class CudaPlatformBase(Platform):
def get_valid_backends( def get_valid_backends(
cls, cls,
device_capability: DeviceCapability, device_capability: DeviceCapability,
attn_selector_config: "AttentionSelectorConfig", attn_selector_config: AttentionSelectorConfig,
num_heads: int | None = None, num_heads: int | None = None,
) -> tuple[ ) -> tuple[
list[tuple["AttentionBackendEnum", int]], list[tuple[AttentionBackendEnum, int]],
dict["AttentionBackendEnum", tuple[int, list[str]]], dict[AttentionBackendEnum, tuple[int, list[str]]],
]: ]:
valid_backends_priorities = [] valid_backends_priorities = []
invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {} invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}
...@@ -211,6 +226,7 @@ class CudaPlatformBase(Platform): ...@@ -211,6 +226,7 @@ class CudaPlatformBase(Platform):
attn_selector_config.use_mla, attn_selector_config.use_mla,
device_capability, device_capability,
num_heads, num_heads,
attn_selector_config.kv_cache_dtype,
) )
for priority, backend in enumerate(backend_priorities): for priority, backend in enumerate(backend_priorities):
try: try:
...@@ -231,8 +247,8 @@ class CudaPlatformBase(Platform): ...@@ -231,8 +247,8 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "AttentionBackendEnum | None", selected_backend: AttentionBackendEnum | None,
attn_selector_config: "AttentionSelectorConfig", attn_selector_config: AttentionSelectorConfig,
num_heads: int | None = None, num_heads: int | None = None,
) -> str: ) -> str:
device_capability = cls.get_device_capability() device_capability = cls.get_device_capability()
...@@ -324,7 +340,7 @@ class CudaPlatformBase(Platform): ...@@ -324,7 +340,7 @@ class CudaPlatformBase(Platform):
return selected_backend.get_path() return selected_backend.get_path()
@classmethod @classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: def get_supported_vit_attn_backends(cls) -> list[AttentionBackendEnum]:
if cls.has_device_capability(80): if cls.has_device_capability(80):
return [ return [
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
...@@ -345,8 +361,8 @@ class CudaPlatformBase(Platform): ...@@ -345,8 +361,8 @@ class CudaPlatformBase(Platform):
cls, cls,
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
backend: "AttentionBackendEnum | None" = None, backend: AttentionBackendEnum | None = None,
) -> "AttentionBackendEnum": ) -> AttentionBackendEnum:
if backend is not None: if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), ( assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. " f"Backend {backend} is not supported for vit attention. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment