Unverified Commit 5e75a14a authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Doc] Add DCP support to attention backend doc (#33936)

parent e7e52781
...@@ -152,6 +152,7 @@ Priority is **1 = highest** (tried first). ...@@ -152,6 +152,7 @@ Priority is **1 = highest** (tried first).
| **Sink** | Attention sink support (for StreamingLLM) | | **Sink** | Attention sink support (for StreamingLLM) |
| **Sparse** | Sparse attention support (MLA only) | | **Sparse** | Sparse attention support (MLA only) |
| **MM Prefix** | Multimodal prefix full attention support | | **MM Prefix** | Multimodal prefix full attention support |
| **DCP** | Decode Context Parallelism support (`--decode-context-parallel-size`) |
| **Attention Types** | Supported attention patterns (Decoder, Encoder, Enc-Dec) | | **Attention Types** | Supported attention patterns (Decoder, Encoder, Enc-Dec) |
| **Compute Cap.** | Required CUDA compute capability (N/A for non-CUDA backends) | | **Compute Cap.** | Required CUDA compute capability (N/A for non-CUDA backends) |
...@@ -159,20 +160,20 @@ Priority is **1 = highest** (tried first). ...@@ -159,20 +160,20 @@ Priority is **1 = highest** (tried first).
## Standard Attention (MHA, MQA, GQA) Backends ## Standard Attention (MHA, MQA, GQA) Backends
| Backend | Version | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | MM Prefix | Attention Types | Compute Cap. | | Backend | Version | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | MM Prefix | DCP | Attention Types | Compute Cap. |
|---------|---------|--------|-----------|-------------|------------|------|-----------|-----------------|--------------| |---------|---------|--------|-----------|-------------|------------|------|-----------|-----|-----------------|--------------|
| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | All | N/A | | `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A |
| `FLASHINFER` | Native† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | Decoder | 7.x-9.x | | `FLASHINFER` | Native† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x |
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | Decoder | 10.x | | `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | All | ≥8.0 | | `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | All | 9.x | | `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | Decoder | Any | | `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | Decoder, Encoder Only | Any | | `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | Decoder | N/A | | `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | Decoder | Any | | `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | All | Any | | `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`. > **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
> >
...@@ -199,14 +200,14 @@ configuration. ...@@ -199,14 +200,14 @@ configuration.
### Decode Backends ### Decode Backends
| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Sparse | MM Prefix | Attention Types | Compute Cap. | | Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Sparse | MM Prefix | DCP | Attention Types | Compute Cap. |
|---------|--------|-----------|-------------|------------|------|--------|-----------|-----------------|--------------| |---------|--------|-----------|-------------|------------|------|--------|-----------|-----|-----------------|--------------|
| `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | Decoder | 10.x | | `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x |
| `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | Decoder | 10.x | | `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
| `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | Decoder | 9.x-10.x | | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | Decoder | 9.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
| `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto` | Any | 576 | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto` | Any | 576 | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | Decoder | Any | | `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
...@@ -17,9 +17,14 @@ import argparse ...@@ -17,9 +17,14 @@ import argparse
import ast import ast
import fnmatch import fnmatch
import sys import sys
from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
# ---------------------------------------------------------------------------
# Constants and file paths
# ---------------------------------------------------------------------------
REPO_ROOT = Path(__file__).parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent
RELEVANT_PATTERNS = [ RELEVANT_PATTERNS = [
...@@ -32,6 +37,18 @@ RELEVANT_PATTERNS = [ ...@@ -32,6 +37,18 @@ RELEVANT_PATTERNS = [
"docs/design/attention_backends.md", "docs/design/attention_backends.md",
] ]
BACKENDS_DIR = REPO_ROOT / "vllm" / "v1" / "attention" / "backends"
REGISTRY_FILE = BACKENDS_DIR / "registry.py"
CUDA_PLATFORM_FILE = REPO_ROOT / "vllm" / "platforms" / "cuda.py"
FA_UTILS_FILE = BACKENDS_DIR / "fa_utils.py"
FLASHINFER_UTILS_FILE = REPO_ROOT / "vllm" / "utils" / "flashinfer.py"
MLA_ATTENTION_FILE = (
REPO_ROOT / "vllm" / "model_executor" / "layers" / "attention" / "mla_attention.py"
)
# Backends to skip during doc generation
SKIP_BACKENDS = {"CUSTOM", "TORCH_SDPA"}
def is_relevant_file(filepath: str) -> bool: def is_relevant_file(filepath: str) -> bool:
"""Check if a file matches any of the relevant patterns.""" """Check if a file matches any of the relevant patterns."""
...@@ -46,351 +63,234 @@ def is_relevant_file(filepath: str) -> bool: ...@@ -46,351 +63,234 @@ def is_relevant_file(filepath: str) -> bool:
return any(fnmatch.fnmatch(path_str, pattern) for pattern in RELEVANT_PATTERNS) return any(fnmatch.fnmatch(path_str, pattern) for pattern in RELEVANT_PATTERNS)
BACKENDS_DIR = REPO_ROOT / "vllm" / "v1" / "attention" / "backends" # ---------------------------------------------------------------------------
REGISTRY_FILE = BACKENDS_DIR / "registry.py" # AST utility helpers
CUDA_PLATFORM_FILE = REPO_ROOT / "vllm" / "platforms" / "cuda.py" # ---------------------------------------------------------------------------
FA_UTILS_FILE = BACKENDS_DIR / "fa_utils.py"
FLASHINFER_UTILS_FILE = REPO_ROOT / "vllm" / "utils" / "flashinfer.py"
MLA_ATTENTION_FILE = (
REPO_ROOT / "vllm" / "model_executor" / "layers" / "attention" / "mla_attention.py"
)
def parse_registry() -> dict[str, str]: def find_class_in_ast(tree: ast.AST, class_name: str) -> ast.ClassDef | None:
"""Parse the registry.py file to get backend names and their class paths.""" """Find a class definition in an AST."""
tree = ast.parse(REGISTRY_FILE.read_text())
for node in ast.walk(tree): for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == "AttentionBackendEnum": if isinstance(node, ast.ClassDef) and node.name == class_name:
return _extract_enum_values(node) return node
return {} return None
def _extract_enum_values(node: ast.ClassDef) -> dict[str, str]: def find_method(node: ast.ClassDef, method_name: str) -> ast.FunctionDef | None:
"""Extract enum name -> value mapping from a class definition.""" """Find a method in a class definition."""
result: dict[str, str] = {}
for item in node.body: for item in node.body:
if not isinstance(item, ast.Assign): if isinstance(item, ast.FunctionDef) and item.name == method_name:
continue return item
for target in item.targets: return None
if not isinstance(target, ast.Name):
continue
if isinstance(item.value, ast.Constant) and item.value.value:
result[target.id] = item.value.value
return result
def get_file_from_class_path(class_path: str) -> Path | None:
"""Convert a class path to a file path."""
if not class_path:
return None
module_path = class_path.rsplit(".", 1)[0].replace(".", "/")
py_file = REPO_ROOT / f"{module_path}.py"
return py_file if py_file.exists() else None
def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"""Parse fa_utils.py to detect FA2 vs FA3 feature differences.
Returns a dict with 'fa2' and 'fa3' keys containing their respective
feature overrides for compute capability, KV cache dtypes, and sink support.
"""
if not FA_UTILS_FILE.exists():
return {}
try: def method_returns_true(method: ast.FunctionDef | None) -> bool:
tree = ast.parse(FA_UTILS_FILE.read_text()) """Check if a method simply returns True."""
except Exception: if method is None:
return {} return False
for node in ast.walk(method):
if (
isinstance(node, ast.Return)
and isinstance(node.value, ast.Constant)
and node.value.value is True
):
return True
return False
# Analyze the functions to determine FA3-specific features
fa3_supports_fp8 = False
fa3_supports_sinks = False
fa3_compute_cap: str | None = None
for node in ast.walk(tree): def check_method_overrides(node: ast.ClassDef, method_name: str) -> bool:
if not isinstance(node, ast.FunctionDef): """Check if a method is overridden and returns True."""
continue return method_returns_true(find_method(node, method_name))
# Check flash_attn_supports_fp8 - looks for `get_flash_attn_version() == 3`
if node.name == "flash_attn_supports_fp8":
for n in ast.walk(node):
if (
isinstance(n, ast.Compare)
and isinstance(n.left, ast.Call)
and isinstance(n.left.func, ast.Name)
and n.left.func.id == "get_flash_attn_version"
):
fa3_supports_fp8 = True
break
# Check flash_attn_supports_sinks - looks for `get_flash_attn_version() == 3` def _find_bool_class_var(class_node: ast.ClassDef, var_name: str) -> bool | None:
if node.name == "flash_attn_supports_sinks": """Find a bool class variable in a class definition. Returns None if not found."""
for n in ast.walk(node): for item in class_node.body:
# Check for annotated assignment: attr: bool = True/False
if (
isinstance(item, ast.AnnAssign)
and isinstance(item.target, ast.Name)
and item.target.id == var_name
and isinstance(item.value, ast.Constant)
and isinstance(item.value.value, bool)
):
return item.value.value
# Check for plain assignment: attr = True/False
if isinstance(item, ast.Assign):
for target in item.targets:
if ( if (
isinstance(n, ast.Compare) isinstance(target, ast.Name)
and isinstance(n.left, ast.Call) and target.id == var_name
and isinstance(n.left.func, ast.Name) and isinstance(item.value, ast.Constant)
and n.left.func.id == "get_flash_attn_version" and isinstance(item.value.value, bool)
): ):
fa3_supports_sinks = True return item.value.value
break return None
# Check get_flash_attn_version for FA3 compute capability
# Look for the ternary: 3 if (device_capability.major == 9 ...) else 2
if node.name == "get_flash_attn_version":
for n in ast.walk(node):
# Look for IfExp (ternary) with `device_capability.major == 9`
if isinstance(n, ast.IfExp):
test = n.test
# Check if test is a BoolOp (and) containing the major check
if isinstance(test, ast.BoolOp):
for val in test.values:
if (
isinstance(val, ast.Compare)
and isinstance(val.left, ast.Attribute)
and val.left.attr == "major"
and val.comparators
and isinstance(val.comparators[0], ast.Constant)
):
fa3_compute_cap = f"{val.comparators[0].value}.x"
break
return {
"fa2": {
"supports_fp8": False,
"supports_sink": False,
},
"fa3": {
"compute_capability": fa3_compute_cap,
"supports_fp8": fa3_supports_fp8,
"supports_sink": fa3_supports_sinks,
},
}
def parse_flashinfer_trtllm_features() -> dict[str, dict[str, Any]]:
"""Parse flashinfer.py to detect TRTLLM-specific features.
FLASHINFER uses TRTLLM attention on SM100 (Blackwell), which has different
capabilities (e.g., sink support) than native FlashInfer on earlier GPUs.
"""
if not FLASHINFER_UTILS_FILE.exists():
return {}
try:
tree = ast.parse(FLASHINFER_UTILS_FILE.read_text())
except Exception:
return {}
trtllm_compute_cap: str | None = None
for node in ast.walk(tree): def _parse_list_class_var(node: ast.ClassDef, var_name: str) -> list[str] | None:
if not isinstance(node, ast.FunctionDef): """Parse a list-type class variable, returning None if not found."""
for item in node.body:
if not isinstance(item, ast.AnnAssign):
continue continue
if not isinstance(item.target, ast.Name):
continue
if item.target.id != var_name:
continue
if not (item.value and isinstance(item.value, ast.List)):
continue
result = []
for elt in item.value.elts:
if isinstance(elt, ast.Attribute):
result.append(elt.attr)
elif isinstance(elt, ast.Constant):
result.append(str(elt.value))
return result
return None
# Parse supports_trtllm_attention for compute capability
# Look for: current_platform.is_device_capability_family(100)
if node.name == "supports_trtllm_attention":
for n in ast.walk(node):
if (
isinstance(n, ast.Call)
and isinstance(n.func, ast.Attribute)
and n.func.attr == "is_device_capability_family"
and n.args
and isinstance(n.args[0], ast.Constant)
and isinstance(n.args[0].value, int)
):
cap = n.args[0].value
# Convert 100 -> "10.x"
trtllm_compute_cap = f"{cap // 10}.x"
break
if not trtllm_compute_cap:
return {}
return {
"native": {
# Native FlashInfer: everything except SM100
"supports_sink": False,
},
"trtllm": {
# TRTLLM pathway on Blackwell
"compute_capability": trtllm_compute_cap,
"supports_sink": True,
},
}
def _parse_return_list(
method: ast.FunctionDef | None, handle_multiple_of: bool = False
) -> list[str]:
"""Extract list items from a method's return statement."""
if method is None:
return []
for stmt in ast.walk(method):
if not isinstance(stmt, ast.Return):
continue
if not isinstance(stmt.value, ast.List):
continue
sizes = []
for elt in stmt.value.elts:
if isinstance(elt, ast.Constant):
sizes.append(str(elt.value))
elif (
handle_multiple_of
and isinstance(elt, ast.Call)
and isinstance(elt.func, ast.Name)
and elt.func.id == "MultipleOf"
and elt.args
and isinstance(elt.args[0], ast.Constant)
):
sizes.append(f"%{elt.args[0].value}")
if sizes:
return sizes
return []
def parse_mla_prefill_backends() -> list[dict[str, Any]]:
"""Parse MLA prefill backend options from mla_attention.py.
MLA uses different backends for prefill vs decode. The decode backends are def _get_parent_class_name(class_node: ast.ClassDef) -> str | None:
registered in the registry, but prefill backends are selected at runtime """Get the first parent class name (simple name only).
based on conditions in MLACommonImpl.__init__.
Returns a list of prefill backend info dicts with their requirements. Handles both simple inheritance (class Foo(Bar)) and generic
inheritance (class Foo(Bar[T])).
""" """
if not MLA_ATTENTION_FILE.exists(): if not class_node.bases:
return [] return None
base = class_node.bases[0]
if isinstance(base, ast.Name):
return base.id
if isinstance(base, ast.Subscript) and isinstance(base.value, ast.Name):
return base.value.id
return None
try:
tree = ast.parse(MLA_ATTENTION_FILE.read_text())
except Exception:
return []
# Find compute capability requirements by parsing use_* functions def _resolve_import_to_file(
flashinfer_cc: str | None = None tree: ast.AST, class_name: str, source_file: Path | None = None
cudnn_cc: str | None = None ) -> Path | None:
trtllm_cc: str | None = None """Try to resolve a class name to its source file via imports in the AST.
Handles both absolute imports (from vllm.foo import Bar) and relative
imports (from .foo import Bar) when source_file is provided.
"""
for node in ast.walk(tree): for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef): if not isinstance(node, ast.ImportFrom):
continue continue
for alias in node.names:
actual_name = alias.asname or alias.name
if actual_name != class_name:
continue
if not node.module:
continue
# Parse use_flashinfer_prefill for compute capability (SM100) if node.level and node.level > 0 and source_file:
if node.name == "use_flashinfer_prefill": # Relative import: resolve from the source file's directory
for n in ast.walk(node): base_dir = source_file.parent
if ( for _ in range(node.level - 1):
isinstance(n, ast.Call) base_dir = base_dir.parent
and isinstance(n.func, ast.Attribute) module_path = node.module.replace(".", "/")
and n.func.attr == "is_device_capability_family" py_file = base_dir / f"{module_path}.py"
and n.args else:
and isinstance(n.args[0], ast.Constant) # Absolute import
and isinstance(n.args[0].value, int) module_path = node.module.replace(".", "/")
): py_file = REPO_ROOT / f"{module_path}.py"
flashinfer_cc = f"{n.args[0].value // 10}.x"
# Parse use_cudnn_prefill for compute capability (SM100)
if node.name == "use_cudnn_prefill":
for n in ast.walk(node):
if (
isinstance(n, ast.Call)
and isinstance(n.func, ast.Attribute)
and n.func.attr == "is_device_capability_family"
and n.args
and isinstance(n.args[0], ast.Constant)
and isinstance(n.args[0].value, int)
):
cudnn_cc = f"{n.args[0].value // 10}.x"
# Parse use_trtllm_ragged_deepseek_prefill for compute capability
if node.name == "use_trtllm_ragged_deepseek_prefill":
for n in ast.walk(node):
if (
isinstance(n, ast.Call)
and isinstance(n.func, ast.Attribute)
and n.func.attr == "is_device_capability_family"
and n.args
and isinstance(n.args[0], ast.Constant)
and isinstance(n.args[0].value, int)
):
trtllm_cc = f"{n.args[0].value // 10}.x"
# Build prefill backend list based on what we found
# Order matches the priority in MLACommonImpl.__init__
prefill_backends: list[dict[str, Any]] = []
# TRT-LLM Ragged (highest priority if available)
if trtllm_cc:
prefill_backends.append(
{
"name": "TRT-LLM Ragged‡",
"description": "TensorRT-LLM ragged attention",
"compute_capability": trtllm_cc,
"enable": "Default on SM100",
"disable": "`-ac.use_trtllm_ragged_deepseek_prefill=0`",
"notes": "DeepSeek R1 dims only",
}
)
# FlashInfer prefill
if flashinfer_cc:
prefill_backends.append(
{
"name": "FlashInfer",
"description": "FlashInfer CUTLASS backend",
"compute_capability": flashinfer_cc,
"enable": "`-ac.disable_flashinfer_prefill=0`",
"disable": "`-ac.disable_flashinfer_prefill=1`",
"notes": "DeepSeek R1 dims only",
}
)
# cuDNN prefill
if cudnn_cc:
prefill_backends.append(
{
"name": "cuDNN",
"description": "cuDNN-based attention",
"compute_capability": cudnn_cc,
"enable": "`-ac.use_cudnn_prefill=1`",
"disable": "`-ac.use_cudnn_prefill=0`",
"notes": "",
}
)
# FlashAttention is always available as fallback if py_file.exists():
prefill_backends.append( return py_file
{ return None
"name": "FlashAttention",
"description": "FlashAttention varlen (FA2/FA3)",
"compute_capability": "Any",
"enable": "Default fallback",
"disable": "Use other backends",
"notes": "FA3 on SM90, FA2 otherwise",
}
)
return prefill_backends
def _find_cc_in_function(tree: ast.AST, func_name: str) -> str | None:
"""Find a compute capability from is_device_capability_family() calls in a function.
def find_class_in_ast(tree: ast.AST, class_name: str) -> ast.ClassDef | None: Looks for the pattern: current_platform.is_device_capability_family(N)
"""Find a class definition in an AST.""" and converts N (e.g. 100) to a CC string (e.g. "10.x").
"""
for node in ast.walk(tree): for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == class_name: if not isinstance(node, ast.FunctionDef) or node.name != func_name:
return node continue
for n in ast.walk(node):
if (
isinstance(n, ast.Call)
and isinstance(n.func, ast.Attribute)
and n.func.attr == "is_device_capability_family"
and n.args
and isinstance(n.args[0], ast.Constant)
and isinstance(n.args[0].value, int)
):
return f"{n.args[0].value // 10}.x"
return None return None
def find_method(node: ast.ClassDef, method_name: str) -> ast.FunctionDef | None: # ---------------------------------------------------------------------------
"""Find a method in a class definition.""" # Registry and file resolution
for item in node.body: # ---------------------------------------------------------------------------
if isinstance(item, ast.FunctionDef) and item.name == method_name:
return item
return None
def method_returns_true(method: ast.FunctionDef | None) -> bool: def parse_registry() -> dict[str, str]:
"""Check if a method simply returns True.""" """Parse the registry.py file to get backend names and their class paths."""
if method is None: tree = ast.parse(REGISTRY_FILE.read_text())
return False for node in ast.walk(tree):
for node in ast.walk(method): if isinstance(node, ast.ClassDef) and node.name == "AttentionBackendEnum":
if not isinstance(node, ast.Return): return _extract_enum_values(node)
continue return {}
if isinstance(node.value, ast.Constant) and node.value.value is True:
return True
return False
def _parse_list_class_var(node: ast.ClassDef, var_name: str) -> list[str] | None: def _extract_enum_values(node: ast.ClassDef) -> dict[str, str]:
"""Parse a list-type class variable, returning None if not found.""" """Extract enum name -> value mapping from a class definition."""
result: dict[str, str] = {}
for item in node.body: for item in node.body:
if not isinstance(item, ast.AnnAssign): if not isinstance(item, ast.Assign):
continue
if not isinstance(item.target, ast.Name):
continue
if item.target.id != var_name:
continue
if not (item.value and isinstance(item.value, ast.List)):
continue continue
result = [] for target in item.targets:
for elt in item.value.elts: if not isinstance(target, ast.Name):
if isinstance(elt, ast.Attribute): continue
result.append(elt.attr) if isinstance(item.value, ast.Constant) and item.value.value:
elif isinstance(elt, ast.Constant): result[target.id] = item.value.value
result.append(str(elt.value)) return result
return result
return None
def get_file_from_class_path(class_path: str) -> Path | None:
"""Convert a class path to a file path."""
if not class_path:
return None
module_path = class_path.rsplit(".", 1)[0].replace(".", "/")
py_file = REPO_ROOT / f"{module_path}.py"
return py_file if py_file.exists() else None
# ---------------------------------------------------------------------------
# Backend feature extraction from AST
# ---------------------------------------------------------------------------
def parse_supported_dtypes(node: ast.ClassDef) -> str: def parse_supported_dtypes(node: ast.ClassDef) -> str:
...@@ -432,35 +332,6 @@ def parse_kv_cache_dtypes(node: ast.ClassDef) -> str: ...@@ -432,35 +332,6 @@ def parse_kv_cache_dtypes(node: ast.ClassDef) -> str:
return "auto" return "auto"
def _parse_return_list(
method: ast.FunctionDef | None, handle_multiple_of: bool = False
) -> list[str]:
"""Extract list items from a method's return statement."""
if method is None:
return []
for stmt in ast.walk(method):
if not isinstance(stmt, ast.Return):
continue
if not isinstance(stmt.value, ast.List):
continue
sizes = []
for elt in stmt.value.elts:
if isinstance(elt, ast.Constant):
sizes.append(str(elt.value))
elif (
handle_multiple_of
and isinstance(elt, ast.Call)
and isinstance(elt.func, ast.Name)
and elt.func.id == "MultipleOf"
and elt.args
and isinstance(elt.args[0], ast.Constant)
):
sizes.append(f"%{elt.args[0].value}")
if sizes:
return sizes
return []
def parse_block_sizes(node: ast.ClassDef) -> str: def parse_block_sizes(node: ast.ClassDef) -> str:
"""Parse get_supported_kernel_block_sizes method.""" """Parse get_supported_kernel_block_sizes method."""
method = find_method(node, "get_supported_kernel_block_sizes") method = find_method(node, "get_supported_kernel_block_sizes")
...@@ -536,202 +407,444 @@ def parse_compute_capability(node: ast.ClassDef) -> str: ...@@ -536,202 +407,444 @@ def parse_compute_capability(node: ast.ClassDef) -> str:
return f"{min_cap[0]}.x-{max_cap[0]}.x" return f"{min_cap[0]}.x-{max_cap[0]}.x"
return f"≥{min_cap[0]}.{min_cap[1]}" return f"≥{min_cap[0]}.{min_cap[1]}"
return "Any" return "Any"
def parse_attention_types(node: ast.ClassDef) -> str:
"""Parse supports_attn_type method."""
method = find_method(node, "supports_attn_type")
if method is None:
return "Decoder"
type_map = {
"DECODER": "Decoder",
"ENCODER": "Encoder",
"ENCODER_ONLY": "Encoder Only",
"ENCODER_DECODER": "Enc-Dec",
}
types: set[str] = set()
for n in ast.walk(method):
# Handle `attn_type in (AttentionType.DECODER, ...)`
if not (
isinstance(n, ast.Compare)
and len(n.ops) == 1
and isinstance(n.ops[0], ast.In)
and len(n.comparators) == 1
and isinstance(n.comparators[0], ast.Tuple | ast.Set)
):
continue
for elt in n.comparators[0].elts:
if isinstance(elt, ast.Attribute) and elt.attr in type_map:
types.add(type_map[elt.attr])
if not types:
return "Decoder"
return "All" if len(types) >= 3 else ", ".join(sorted(types))
def parse_impl_bool_attr(
tree: ast.AST,
class_name: str,
attr_name: str,
default: bool = False,
source_file: Path | None = None,
_visited: set[str] | None = None,
) -> bool:
"""Parse a boolean class attribute from an impl class, following inheritance.
Walks up the inheritance chain within the same file and across files
(by resolving imports) to find the attribute value.
"""
if _visited is None:
_visited = set()
if class_name in _visited:
return default
_visited.add(class_name)
class_node = find_class_in_ast(tree, class_name)
if class_node is None:
return default
# Check directly on this class
value = _find_bool_class_var(class_node, attr_name)
if value is not None:
return value
# Check parent class
parent_name = _get_parent_class_name(class_node)
if parent_name:
# Try parent in same file first
parent_node = find_class_in_ast(tree, parent_name)
if parent_node is not None:
return parse_impl_bool_attr(
tree, parent_name, attr_name, default, source_file, _visited
)
# Try resolving cross-file import
parent_file = _resolve_import_to_file(tree, parent_name, source_file)
if parent_file:
try:
parent_tree = ast.parse(parent_file.read_text())
return parse_impl_bool_attr(
parent_tree,
parent_name,
attr_name,
default,
parent_file,
_visited,
)
except Exception:
pass
return default
def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None:
"""Analyze a backend class and extract feature information."""
file_path = get_file_from_class_path(class_path)
if file_path is None:
return None
try:
tree = ast.parse(file_path.read_text())
except Exception as e:
print(f" Warning: Could not parse {file_path}: {e}", file=sys.stderr)
return None
class_name = class_path.rsplit(".", 1)[1]
class_node = find_class_in_ast(tree, class_name)
if class_node is None:
return None
# Check if this is an MLA backend by parent class or naming
parent = _get_parent_class_name(class_node)
mla_parents = {"MLACommonBackend", "FlashMLABackend", "FlashMLASparseBackend"}
is_mla_backend = (
parent in mla_parents
or ".mla." in class_path.lower()
or "_mla" in backend_name.lower()
)
# Determine compute capability - use N/A for non-CUDA backends
is_non_cuda = backend_name.startswith(("CPU_", "ROCM_"))
compute_cap = "N/A" if is_non_cuda else parse_compute_capability(class_node)
# Parse impl class features (DCP support)
impl_method = find_method(class_node, "get_impl_cls")
impl_class_name = None
if impl_method:
for stmt in ast.walk(impl_method):
if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Name):
impl_class_name = stmt.value.id
break
supports_dcp = False
if impl_class_name:
supports_dcp = parse_impl_bool_attr(
tree, impl_class_name, "can_return_lse_for_decode", False, file_path
)
return {
"name": backend_name,
"dtypes": parse_supported_dtypes(class_node),
"kv_cache_dtypes": parse_kv_cache_dtypes(class_node),
"block_sizes": parse_block_sizes(class_node),
"head_sizes": parse_head_sizes(class_node),
"attn_types": parse_attention_types(class_node),
"compute_capability": compute_cap,
"is_mla": is_mla_backend or check_method_overrides(class_node, "is_mla"),
"supports_sink": check_method_overrides(class_node, "supports_sink"),
"is_sparse": check_method_overrides(class_node, "is_sparse"),
"supports_mm_prefix": check_method_overrides(class_node, "supports_mm_prefix"),
"supports_dcp": supports_dcp,
}
# ---------------------------------------------------------------------------
# Special backend variant parsers (FA2/FA3, FlashInfer TRTLLM, MLA prefill)
# ---------------------------------------------------------------------------
def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
"""Parse fa_utils.py to detect FA2 vs FA3 feature differences.
Returns a dict with 'fa2' and 'fa3' keys containing their respective
feature overrides for compute capability, KV cache dtypes, and sink support.
"""
if not FA_UTILS_FILE.exists():
return {}
try:
tree = ast.parse(FA_UTILS_FILE.read_text())
except Exception:
return {}
# Analyze the functions to determine FA3-specific features
fa3_supports_fp8 = False
fa3_supports_sinks = False
fa3_compute_cap: str | None = None
for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef):
continue
# Check flash_attn_supports_fp8 - looks for `get_flash_attn_version() == 3`
if node.name == "flash_attn_supports_fp8":
for n in ast.walk(node):
if (
isinstance(n, ast.Compare)
and isinstance(n.left, ast.Call)
and isinstance(n.left.func, ast.Name)
and n.left.func.id == "get_flash_attn_version"
):
fa3_supports_fp8 = True
break
# Check flash_attn_supports_sinks - looks for `get_flash_attn_version() == 3`
if node.name == "flash_attn_supports_sinks":
for n in ast.walk(node):
if (
isinstance(n, ast.Compare)
and isinstance(n.left, ast.Call)
and isinstance(n.left.func, ast.Name)
and n.left.func.id == "get_flash_attn_version"
):
fa3_supports_sinks = True
break
# Check get_flash_attn_version for FA3 compute capability
# Look for the ternary: 3 if (device_capability.major == 9 ...) else 2
if node.name == "get_flash_attn_version":
for n in ast.walk(node):
# Look for IfExp (ternary) with `device_capability.major == 9`
if isinstance(n, ast.IfExp):
test = n.test
# Check if test is a BoolOp (and) containing the major check
if isinstance(test, ast.BoolOp):
for val in test.values:
if (
isinstance(val, ast.Compare)
and isinstance(val.left, ast.Attribute)
and val.left.attr == "major"
and val.comparators
and isinstance(val.comparators[0], ast.Constant)
):
fa3_compute_cap = f"{val.comparators[0].value}.x"
break
return {
"fa2": {
"supports_fp8": False,
"supports_sink": False,
},
"fa3": {
"compute_capability": fa3_compute_cap,
"supports_fp8": fa3_supports_fp8,
"supports_sink": fa3_supports_sinks,
},
}
def parse_flashinfer_trtllm_features() -> dict[str, dict[str, Any]]:
"""Parse flashinfer.py to detect TRTLLM-specific features.
FLASHINFER uses TRTLLM attention on SM100 (Blackwell), which has different
capabilities (e.g., sink support) than native FlashInfer on earlier GPUs.
"""
if not FLASHINFER_UTILS_FILE.exists():
return {}
try:
tree = ast.parse(FLASHINFER_UTILS_FILE.read_text())
except Exception:
return {}
trtllm_compute_cap = _find_cc_in_function(tree, "supports_trtllm_attention")
def parse_attention_types(node: ast.ClassDef) -> str: if not trtllm_compute_cap:
"""Parse supports_attn_type method.""" return {}
method = find_method(node, "supports_attn_type")
if method is None:
return "Decoder"
type_map = { return {
"DECODER": "Decoder", "native": {
"ENCODER": "Encoder", # Native FlashInfer: everything except SM100
"ENCODER_ONLY": "Encoder Only", "supports_sink": False,
"ENCODER_DECODER": "Enc-Dec", },
"trtllm": {
# TRTLLM pathway on Blackwell
"compute_capability": trtllm_compute_cap,
"supports_sink": True,
},
} }
types: set[str] = set()
for n in ast.walk(method):
# Handle `attn_type in (AttentionType.DECODER, ...)`
if not (
isinstance(n, ast.Compare)
and len(n.ops) == 1
and isinstance(n.ops[0], ast.In)
and len(n.comparators) == 1
and isinstance(n.comparators[0], ast.Tuple | ast.Set)
):
continue
for elt in n.comparators[0].elts: def parse_mla_prefill_backends() -> list[dict[str, Any]]:
if isinstance(elt, ast.Attribute) and elt.attr in type_map: """Parse MLA prefill backend options from mla_attention.py.
types.add(type_map[elt.attr])
if not types: MLA uses different backends for prefill vs decode. The decode backends are
return "Decoder" registered in the registry, but prefill backends are selected at runtime
return "All" if len(types) >= 3 else ", ".join(sorted(types)) based on conditions in MLACommonImpl.__init__.
Returns a list of prefill backend info dicts with their requirements.
"""
if not MLA_ATTENTION_FILE.exists():
return []
def check_method_overrides(node: ast.ClassDef, method_name: str) -> bool: try:
"""Check if a method is overridden and returns True.""" tree = ast.parse(MLA_ATTENTION_FILE.read_text())
method = find_method(node, method_name) except Exception:
return method_returns_true(method) return []
# Find compute capability requirements by parsing use_* functions
trtllm_cc = _find_cc_in_function(tree, "use_trtllm_ragged_deepseek_prefill")
flashinfer_cc = _find_cc_in_function(tree, "use_flashinfer_prefill")
cudnn_cc = _find_cc_in_function(tree, "use_cudnn_prefill")
def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None: # Build prefill backend list based on what we found
"""Analyze a backend class and extract feature information.""" # Order matches the priority in MLACommonImpl.__init__
file_path = get_file_from_class_path(class_path) prefill_backends: list[dict[str, Any]] = []
if file_path is None:
return None
try: # TRT-LLM Ragged (highest priority if available)
tree = ast.parse(file_path.read_text()) if trtllm_cc:
except Exception as e: prefill_backends.append(
print(f" Warning: Could not parse {file_path}: {e}", file=sys.stderr) {
return None "name": "TRT-LLM Ragged‡",
"description": "TensorRT-LLM ragged attention",
"compute_capability": trtllm_cc,
"enable": "Default on SM100",
"disable": "`-ac.use_trtllm_ragged_deepseek_prefill=0`",
"notes": "DeepSeek R1 dims only",
}
)
class_name = class_path.rsplit(".", 1)[1] # FlashInfer prefill
class_node = find_class_in_ast(tree, class_name) if flashinfer_cc:
if class_node is None: prefill_backends.append(
return None {
"name": "FlashInfer",
"description": "FlashInfer CUTLASS backend",
"compute_capability": flashinfer_cc,
"enable": "`-ac.disable_flashinfer_prefill=0`",
"disable": "`-ac.disable_flashinfer_prefill=1`",
"notes": "DeepSeek R1 dims only",
}
)
# Check if this is an MLA backend by parent class or naming # cuDNN prefill
parent = None if cudnn_cc:
if class_node.bases: prefill_backends.append(
base = class_node.bases[0] {
parent = base.id if isinstance(base, ast.Name) else None "name": "cuDNN",
mla_parents = {"MLACommonBackend", "FlashMLABackend", "FlashMLASparseBackend"} "description": "cuDNN-based attention",
is_mla_backend = ( "compute_capability": cudnn_cc,
parent in mla_parents "enable": "`-ac.use_cudnn_prefill=1`",
or ".mla." in class_path.lower() "disable": "`-ac.use_cudnn_prefill=0`",
or "_mla" in backend_name.lower() "notes": "",
}
)
# FlashAttention is always available as fallback
prefill_backends.append(
{
"name": "FlashAttention",
"description": "FlashAttention varlen (FA2/FA3)",
"compute_capability": "Any",
"enable": "Default fallback",
"disable": "Use other backends",
"notes": "FA3 on SM90, FA2 otherwise",
}
) )
# Determine compute capability - use N/A for non-CUDA backends return prefill_backends
is_non_cuda = backend_name.startswith(("CPU_", "ROCM_"))
compute_cap = "N/A" if is_non_cuda else parse_compute_capability(class_node)
return {
"name": backend_name,
"dtypes": parse_supported_dtypes(class_node),
"kv_cache_dtypes": parse_kv_cache_dtypes(class_node),
"block_sizes": parse_block_sizes(class_node),
"head_sizes": parse_head_sizes(class_node),
"attn_types": parse_attention_types(class_node),
"compute_capability": compute_cap,
"is_mla": is_mla_backend or check_method_overrides(class_node, "is_mla"),
"supports_sink": check_method_overrides(class_node, "supports_sink"),
"is_sparse": check_method_overrides(class_node, "is_sparse"),
"supports_mm_prefix": check_method_overrides(class_node, "supports_mm_prefix"),
}
# ---------------------------------------------------------------------------
# Backend variant expansion (FA2/FA3, FlashInfer native/TRTLLM)
# ---------------------------------------------------------------------------
def add_literal_quotes(value: str) -> str:
"""Add literal backticks around all comma-separated items in a string."""
items = [item.strip() for item in value.split(",")]
quoted_items = [f"`{item}`" for item in items]
return ", ".join(quoted_items)
def _expand_flash_attn_variants(
all_backends: list[dict[str, Any]],
fa_features: dict[str, dict[str, Any]],
) -> list[dict[str, Any]]:
"""Expand FLASH_ATTN into FA2 and FA3 variants with different capabilities."""
expanded = []
for backend in all_backends:
if backend["name"] != "FLASH_ATTN":
backend.setdefault("_sort_key", backend["name"])
backend.setdefault("_sort_order", 0)
backend.setdefault("version", "")
expanded.append(backend)
continue
def bool_to_emoji(value: bool) -> str: # Create FA2 entry (keeps base backend's compute_capability)
"""Convert a boolean to a checkmark or X emoji.""" fa2 = backend.copy()
return "✅" if value else "❌" fa2["version"] = "FA2*"
fa2["_sort_key"] = "FLASH_ATTN"
fa2["_sort_order"] = 0
fa2["supports_sink"] = fa_features["fa2"]["supports_sink"]
# Create FA3 entry (uses parsed compute_capability from fa_utils)
fa3 = backend.copy()
fa3["version"] = "FA3*"
fa3["_sort_key"] = "FLASH_ATTN"
fa3["_sort_order"] = 1
if fa_features["fa3"]["compute_capability"]:
fa3["compute_capability"] = fa_features["fa3"]["compute_capability"]
fa3["supports_sink"] = fa_features["fa3"]["supports_sink"]
if fa_features["fa3"]["supports_fp8"]:
base_dtypes = backend["kv_cache_dtypes"].split(", ")
fp8_dtypes = ["fp8", "fp8_e4m3", "fp8_e5m2"]
new_dtypes = [d for d in fp8_dtypes if d not in base_dtypes]
fa3["kv_cache_dtypes"] = ", ".join(base_dtypes + new_dtypes)
expanded.append(fa2)
expanded.append(fa3)
return expanded
def _expand_flashinfer_variants(
all_backends: list[dict[str, Any]],
fi_features: dict[str, dict[str, Any]],
) -> list[dict[str, Any]]:
"""Expand FLASHINFER into native and TRTLLM variants."""
expanded = []
for backend in all_backends:
if backend["name"] != "FLASHINFER":
expanded.append(backend)
continue
# Parse original compute capability to get min CC
orig_cap = backend["compute_capability"]
parts = orig_cap.replace(".x", "").split("-")
min_cc = parts[0] if parts else "7"
trtllm_cc = fi_features["trtllm"]["compute_capability"]
def generate_markdown_table( # Create native entry (pre-Blackwell GPUs)
backends: list[dict[str, Any]], title: str, is_mla_table: bool = False native = backend.copy()
) -> str: native["version"] = "Native†"
"""Generate a markdown table from backend info. native["_sort_key"] = "FLASHINFER"
native["_sort_order"] = 0
native["supports_sink"] = fi_features["native"]["supports_sink"]
native["compute_capability"] = f"{min_cc}.x-9.x"
Args: # Create TRTLLM entry
backends: List of backend info dictionaries. trtllm = backend.copy()
title: Table title. trtllm["version"] = "TRTLLM†"
is_mla_table: If True, include MLA and Sparse columns (for MLA table). trtllm["_sort_key"] = "FLASHINFER"
If False, exclude them (for standard attention table). trtllm["_sort_order"] = 1
""" trtllm["compute_capability"] = trtllm_cc
if not backends: trtllm["supports_sink"] = fi_features["trtllm"]["supports_sink"]
return f"## {title}\n\nNo backends found.\n"
# Check if any backend has a version (for FA2/FA3 split) expanded.append(native)
has_versions = any(b.get("version") for b in backends) expanded.append(trtllm)
return expanded
if is_mla_table:
header = (
"| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes "
"| Sink | Sparse | MM Prefix | Attention Types | Compute Cap. |"
)
separator = (
"|---------|--------|-----------|-------------|------------"
"|------|--------|-----------|-----------------|--------------|"
)
elif has_versions:
header = (
"| Backend | Version | Dtypes | KV Dtypes | Block Sizes "
"| Head Sizes | Sink | MM Prefix | Attention Types | Compute Cap. |"
)
separator = (
"|---------|---------|--------|-----------|-------------"
"|------------|------|-----------|-----------------|--------------|"
)
else:
header = (
"| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes "
"| Sink | MM Prefix | Attention Types | Compute Cap. |"
)
separator = (
"|---------|--------|-----------|-------------|------------"
"|------|-----------|-----------------|--------------|"
)
lines = [f"## {title}", "", header, separator]
def sort_key(x: dict[str, Any]) -> tuple[str, int]:
"""Sort key that keeps parent/child rows together in order."""
return (x.get("_sort_key", x["name"]), x.get("_sort_order", 0))
for info in sorted(backends, key=sort_key):
if is_mla_table:
row = "| `{}` | {} | {} | {} | {} | {} | {} | {} | {} | {} |".format(
info["name"],
info["dtypes"],
add_literal_quotes(info["kv_cache_dtypes"]),
info["block_sizes"],
info["head_sizes"],
bool_to_emoji(info["supports_sink"]),
bool_to_emoji(info["is_sparse"]),
bool_to_emoji(info["supports_mm_prefix"]),
info["attn_types"],
info["compute_capability"],
)
elif has_versions:
row = "| `{}` | {} | {} | {} | {} | {} | {} | {} | {} | {} |".format(
info["name"],
info.get("version", ""),
info["dtypes"],
add_literal_quotes(info["kv_cache_dtypes"]),
info["block_sizes"],
info["head_sizes"],
bool_to_emoji(info["supports_sink"]),
bool_to_emoji(info["supports_mm_prefix"]),
info["attn_types"],
info["compute_capability"],
)
else:
row = "| `{}` | {} | {} | {} | {} | {} | {} | {} | {} |".format(
info["name"],
info["dtypes"],
add_literal_quotes(info["kv_cache_dtypes"]),
info["block_sizes"],
info["head_sizes"],
bool_to_emoji(info["supports_sink"]),
bool_to_emoji(info["supports_mm_prefix"]),
info["attn_types"],
info["compute_capability"],
)
lines.append(row)
lines.append("") # ---------------------------------------------------------------------------
return "\n".join(lines) # CUDA priority list parsing
# ---------------------------------------------------------------------------
def parse_cuda_priority_lists() -> dict[str, list[str]]: def parse_cuda_priority_lists() -> dict[str, list[str]]:
...@@ -827,6 +940,105 @@ def _extract_priorities(body: list, priorities: dict[str, list[str]], prefix: st ...@@ -827,6 +940,105 @@ def _extract_priorities(body: list, priorities: dict[str, list[str]], prefix: st
priorities[f"{prefix}_default"] = backends priorities[f"{prefix}_default"] = backends
# ---------------------------------------------------------------------------
# Data-driven table rendering
#
# Each column is a (header, formatter) pair. The formatter takes a backend
# info dict and returns the cell string. Tables are assembled by selecting
# which columns to include, then calling _render_table().
# ---------------------------------------------------------------------------
# Column type alias for readability
TableColumn = tuple[str, Callable[[dict[str, Any]], str]]
# Shared column definitions -- order here matches the output table order
_COL_BACKEND: TableColumn = ("Backend", lambda b: f"`{b['name']}`")
_COL_VERSION: TableColumn = ("Version", lambda b: b.get("version", ""))
_COL_DTYPES: TableColumn = ("Dtypes", lambda b: b["dtypes"])
_COL_KV_DTYPES: TableColumn = (
"KV Dtypes",
lambda b: add_literal_quotes(b["kv_cache_dtypes"]),
)
_COL_BLOCK_SIZES: TableColumn = ("Block Sizes", lambda b: b["block_sizes"])
_COL_HEAD_SIZES: TableColumn = ("Head Sizes", lambda b: b["head_sizes"])
_COL_SINK: TableColumn = ("Sink", lambda b: bool_to_emoji(b["supports_sink"]))
_COL_SPARSE: TableColumn = ("Sparse", lambda b: bool_to_emoji(b["is_sparse"]))
_COL_MM_PREFIX: TableColumn = (
"MM Prefix",
lambda b: bool_to_emoji(b["supports_mm_prefix"]),
)
_COL_DCP: TableColumn = ("DCP", lambda b: bool_to_emoji(b["supports_dcp"]))
_COL_ATTN_TYPES: TableColumn = ("Attention Types", lambda b: b["attn_types"])
_COL_COMPUTE_CAP: TableColumn = ("Compute Cap.", lambda b: b["compute_capability"])
def add_literal_quotes(value: str) -> str:
"""Add literal backticks around all comma-separated items in a string."""
items = [item.strip() for item in value.split(",")]
return ", ".join(f"`{item}`" for item in items)
def bool_to_emoji(value: bool) -> str:
"""Convert a boolean to a checkmark or X emoji."""
return "✅" if value else "❌"
def _build_columns(is_mla: bool, has_versions: bool) -> list[TableColumn]:
"""Build the column list for a backend feature table.
The column selection depends on whether it's an MLA table (includes
Sparse column) and whether any backend has version variants (includes
Version column).
"""
cols: list[TableColumn] = [_COL_BACKEND]
if has_versions:
cols.append(_COL_VERSION)
cols.extend([_COL_DTYPES, _COL_KV_DTYPES, _COL_BLOCK_SIZES, _COL_HEAD_SIZES])
cols.append(_COL_SINK)
if is_mla:
cols.append(_COL_SPARSE)
cols.extend([_COL_MM_PREFIX, _COL_DCP, _COL_ATTN_TYPES, _COL_COMPUTE_CAP])
return cols
def _sort_key(x: dict[str, Any]) -> tuple[str, int]:
"""Sort key that keeps parent/child rows together in order."""
return (x.get("_sort_key", x["name"]), x.get("_sort_order", 0))
def _render_table(
columns: list[TableColumn],
backends: list[dict[str, Any]],
) -> list[str]:
"""Render a markdown table from column specs and backend data."""
header = "| " + " | ".join(name for name, _ in columns) + " |"
sep = "|" + "|".join("-" * (len(name) + 2) for name, _ in columns) + "|"
lines = [header, sep]
for info in sorted(backends, key=_sort_key):
row = "| " + " | ".join(fmt(info) for _, fmt in columns) + " |"
lines.append(row)
return lines
def generate_markdown_table(
backends: list[dict[str, Any]], title: str, is_mla_table: bool = False
) -> str:
"""Generate a titled markdown table from backend info."""
if not backends:
return f"## {title}\n\nNo backends found.\n"
has_versions = any(b.get("version") for b in backends)
columns = _build_columns(is_mla_table, has_versions)
lines = [f"## {title}", ""]
lines.extend(_render_table(columns, backends))
lines.append("")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Markdown section generators (usage, priority, legend, MLA)
# ---------------------------------------------------------------------------
def generate_usage_section() -> str: def generate_usage_section() -> str:
"""Generate the usage documentation section.""" """Generate the usage documentation section."""
return """## Setting the Attention Backend return """## Setting the Attention Backend
...@@ -959,6 +1171,27 @@ def generate_priority_section(priorities: dict[str, list[str]]) -> str: ...@@ -959,6 +1171,27 @@ def generate_priority_section(priorities: dict[str, list[str]]) -> str:
return "\n".join(lines) return "\n".join(lines)
def generate_legend() -> str:
"""Generate a legend explaining the table columns."""
return """## Legend
| Column | Description |
|--------|-------------|
| **Dtypes** | Supported model data types (fp16, bf16, fp32) |
| **KV Dtypes** | Supported KV cache data types (`auto`, `fp8`, `fp8_e4m3`, etc.) |
| **Block Sizes** | Supported KV cache block sizes (%N means multiples of N) |
| **Head Sizes** | Supported attention head sizes |
| **Sink** | Attention sink support (for StreamingLLM) |
| **Sparse** | Sparse attention support (MLA only) |
| **MM Prefix** | Multimodal prefix full attention support |
| **DCP** | Decode Context Parallelism support (`--decode-context-parallel-size`) |
| **Attention Types** | Supported attention patterns (Decoder, Encoder, Enc-Dec) |
| **Compute Cap.** | Required CUDA compute capability (N/A for non-CUDA backends) |
**Symbols:** ✅ = Supported, ❌ = Not supported
"""
def generate_mla_section( def generate_mla_section(
prefill_backends: list[dict[str, Any]], decode_backends: list[dict[str, Any]] prefill_backends: list[dict[str, Any]], decode_backends: list[dict[str, Any]]
) -> str: ) -> str:
...@@ -999,57 +1232,17 @@ def generate_mla_section( ...@@ -999,57 +1232,17 @@ def generate_mla_section(
] ]
) )
# Generate decode backends table # Reuse data-driven table rendering for decode backends
header = ( columns = _build_columns(is_mla=True, has_versions=False)
"| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes " lines.extend(_render_table(columns, decode_backends))
"| Sink | Sparse | MM Prefix | Attention Types | Compute Cap. |"
)
separator = (
"|---------|--------|-----------|-------------|------------"
"|------|--------|-----------|-----------------|--------------|"
)
lines.extend([header, separator])
def sort_key(x: dict[str, Any]) -> tuple[str, int]:
return (x.get("_sort_key", x["name"]), x.get("_sort_order", 0))
for info in sorted(decode_backends, key=sort_key):
row = "| `{}` | {} | {} | {} | {} | {} | {} | {} | {} | {} |".format(
info["name"],
info["dtypes"],
add_literal_quotes(info["kv_cache_dtypes"]),
info["block_sizes"],
info["head_sizes"],
bool_to_emoji(info["supports_sink"]),
bool_to_emoji(info["is_sparse"]),
bool_to_emoji(info["supports_mm_prefix"]),
info["attn_types"],
info["compute_capability"],
)
lines.append(row)
lines.append("") lines.append("")
return "\n".join(lines) return "\n".join(lines)
def generate_legend() -> str: # ---------------------------------------------------------------------------
"""Generate a legend explaining the table columns.""" # Top-level orchestration
return """## Legend # ---------------------------------------------------------------------------
| Column | Description |
|--------|-------------|
| **Dtypes** | Supported model data types (fp16, bf16, fp32) |
| **KV Dtypes** | Supported KV cache data types (`auto`, `fp8`, `fp8_e4m3`, etc.) |
| **Block Sizes** | Supported KV cache block sizes (%N means multiples of N) |
| **Head Sizes** | Supported attention head sizes |
| **Sink** | Attention sink support (for StreamingLLM) |
| **Sparse** | Sparse attention support (MLA only) |
| **MM Prefix** | Multimodal prefix full attention support |
| **Attention Types** | Supported attention patterns (Decoder, Encoder, Enc-Dec) |
| **Compute Cap.** | Required CUDA compute capability (N/A for non-CUDA backends) |
**Symbols:** ✅ = Supported, ❌ = Not supported
"""
def generate_docs() -> str: def generate_docs() -> str:
...@@ -1071,86 +1264,17 @@ def generate_docs() -> str: ...@@ -1071,86 +1264,17 @@ def generate_docs() -> str:
# Collect backend info # Collect backend info
all_backends = [] all_backends = []
for backend_name, class_path in attention_backends_map.items(): for backend_name, class_path in attention_backends_map.items():
if backend_name in ("CUSTOM", "TORCH_SDPA"): if backend_name in SKIP_BACKENDS:
continue continue
info = analyze_backend(backend_name, class_path) info = analyze_backend(backend_name, class_path)
if info: if info:
all_backends.append(info) all_backends.append(info)
# Expand FLASH_ATTN into FA2 and FA3 variants with different capabilities # Expand backends into version variants
if fa_features: if fa_features:
expanded_backends = [] all_backends = _expand_flash_attn_variants(all_backends, fa_features)
for backend in all_backends:
if backend["name"] == "FLASH_ATTN":
# Create FA2 entry (keeps base backend's compute_capability)
fa2 = backend.copy()
fa2["name"] = "FLASH_ATTN"
fa2["version"] = "FA2*"
fa2["_sort_key"] = "FLASH_ATTN"
fa2["_sort_order"] = 0
fa2["supports_sink"] = fa_features["fa2"]["supports_sink"]
# Create FA3 entry (uses parsed compute_capability from fa_utils)
fa3 = backend.copy()
fa3["name"] = "FLASH_ATTN"
fa3["version"] = "FA3*"
fa3["_sort_key"] = "FLASH_ATTN"
fa3["_sort_order"] = 1
if fa_features["fa3"]["compute_capability"]:
fa3["compute_capability"] = fa_features["fa3"]["compute_capability"]
fa3["supports_sink"] = fa_features["fa3"]["supports_sink"]
if fa_features["fa3"]["supports_fp8"]:
# Add fp8 dtypes to the base backend's kv_cache_dtypes
base_dtypes = backend["kv_cache_dtypes"].split(", ")
fp8_dtypes = ["fp8", "fp8_e4m3", "fp8_e5m2"]
new_dtypes = [d for d in fp8_dtypes if d not in base_dtypes]
fa3["kv_cache_dtypes"] = ", ".join(base_dtypes + new_dtypes)
# Add FA2 first, then FA3
expanded_backends.append(fa2)
expanded_backends.append(fa3)
else:
backend["_sort_key"] = backend["name"]
backend["_sort_order"] = 0
backend["version"] = "" # No version for other backends
expanded_backends.append(backend)
all_backends = expanded_backends
# Expand FLASHINFER into native and TRTLLM variants
if fi_features: if fi_features:
expanded_backends = [] all_backends = _expand_flashinfer_variants(all_backends, fi_features)
for backend in all_backends:
if backend["name"] == "FLASHINFER":
# Parse original compute capability to get min CC
orig_cap = backend["compute_capability"]
parts = orig_cap.replace(".x", "").split("-")
min_cc = parts[0] if parts else "7"
trtllm_cc = fi_features["trtllm"]["compute_capability"]
# Create native entry (pre-Blackwell GPUs)
native = backend.copy()
native["name"] = "FLASHINFER"
native["version"] = "Native†"
native["_sort_key"] = "FLASHINFER"
native["_sort_order"] = 0
native["supports_sink"] = fi_features["native"]["supports_sink"]
# Native FlashInfer is used on GPUs before SM100 (Blackwell)
native["compute_capability"] = f"{min_cc}.x-9.x"
# Create TRTLLM entry
trtllm = backend.copy()
trtllm["name"] = "FLASHINFER"
trtllm["version"] = "TRTLLM†"
trtllm["_sort_key"] = "FLASHINFER"
trtllm["_sort_order"] = 1
trtllm["compute_capability"] = trtllm_cc
trtllm["supports_sink"] = fi_features["trtllm"]["supports_sink"]
expanded_backends.append(native)
expanded_backends.append(trtllm)
else:
expanded_backends.append(backend)
all_backends = expanded_backends
# Split into MLA and non-MLA # Split into MLA and non-MLA
mla_backends = [b for b in all_backends if b["is_mla"]] mla_backends = [b for b in all_backends if b["is_mla"]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment