generate_attention_backend_docs.py 52.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Generates documentation table for attention backends showing feature support.

This script parses all registered attention backends using AST (no imports needed)
and generates a markdown table showing what features each backend supports,
based on the checks in AttentionBackend.validate_configuration().

This approach avoids requiring CUDA/ROCm/GPU libraries to be installed.

When used as a pre-commit hook, this script receives filenames as arguments
and only runs the check if any of the relevant files were modified.
"""

import argparse
import ast
import fnmatch
import sys
20
from collections.abc import Callable
21
22
23
from pathlib import Path
from typing import Any

24
25
26
27
# ---------------------------------------------------------------------------
# Constants and file paths
# ---------------------------------------------------------------------------

28
29
30
31
32
33
34
35
36
37
38
39
REPO_ROOT = Path(__file__).parent.parent.parent

RELEVANT_PATTERNS = [
    "vllm/v1/attention/backends/*.py",
    "vllm/v1/attention/backends/**/*.py",
    "vllm/v1/attention/backends/fa_utils.py",
    "vllm/model_executor/layers/attention/mla_attention.py",
    "vllm/platforms/cuda.py",
    "tools/pre_commit/generate_attention_backend_docs.py",
    "docs/design/attention_backends.md",
]

40
41
42
43
44
45
46
47
48
49
50
51
BACKENDS_DIR = REPO_ROOT / "vllm" / "v1" / "attention" / "backends"
REGISTRY_FILE = BACKENDS_DIR / "registry.py"
CUDA_PLATFORM_FILE = REPO_ROOT / "vllm" / "platforms" / "cuda.py"
FA_UTILS_FILE = BACKENDS_DIR / "fa_utils.py"
FLASHINFER_UTILS_FILE = REPO_ROOT / "vllm" / "utils" / "flashinfer.py"
MLA_ATTENTION_FILE = (
    REPO_ROOT / "vllm" / "model_executor" / "layers" / "attention" / "mla_attention.py"
)

# Backends to skip during doc generation
SKIP_BACKENDS = {"CUSTOM", "TORCH_SDPA"}

52
53
54
55
56
57
58
59
60
61
62
63
64
65

def is_relevant_file(filepath: str) -> bool:
    """Check if a file matches any of the relevant patterns."""
    path = Path(filepath)
    if path.is_absolute():
        try:
            path = path.relative_to(REPO_ROOT)
        except ValueError:
            return False
    path_str = str(path)

    return any(fnmatch.fnmatch(path_str, pattern) for pattern in RELEVANT_PATTERNS)


66
67
68
# ---------------------------------------------------------------------------
# AST utility helpers
# ---------------------------------------------------------------------------
69
70


71
72
def find_class_in_ast(tree: ast.AST, class_name: str) -> ast.ClassDef | None:
    """Find a class definition in an AST."""
73
    for node in ast.walk(tree):
74
75
76
        if isinstance(node, ast.ClassDef) and node.name == class_name:
            return node
    return None
77
78


79
80
def find_method(node: ast.ClassDef, method_name: str) -> ast.FunctionDef | None:
    """Find a method in a class definition."""
81
    for item in node.body:
82
83
84
        if isinstance(item, ast.FunctionDef) and item.name == method_name:
            return item
    return None
85
86


87
88
89
90
91
92
93
94
95
96
97
98
def method_returns_true(method: ast.FunctionDef | None) -> bool:
    """Check if a method simply returns True."""
    if method is None:
        return False
    for node in ast.walk(method):
        if (
            isinstance(node, ast.Return)
            and isinstance(node.value, ast.Constant)
            and node.value.value is True
        ):
            return True
    return False
99
100


101
102
103
def check_method_overrides(node: ast.ClassDef, method_name: str) -> bool:
    """Check if a method is overridden and returns True."""
    return method_returns_true(find_method(node, method_name))
104
105


106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def _find_bool_class_var(class_node: ast.ClassDef, var_name: str) -> bool | None:
    """Find a bool class variable in a class definition. Returns None if not found."""
    for item in class_node.body:
        # Check for annotated assignment: attr: bool = True/False
        if (
            isinstance(item, ast.AnnAssign)
            and isinstance(item.target, ast.Name)
            and item.target.id == var_name
            and isinstance(item.value, ast.Constant)
            and isinstance(item.value.value, bool)
        ):
            return item.value.value
        # Check for plain assignment: attr = True/False
        if isinstance(item, ast.Assign):
            for target in item.targets:
121
                if (
122
123
124
125
                    isinstance(target, ast.Name)
                    and target.id == var_name
                    and isinstance(item.value, ast.Constant)
                    and isinstance(item.value.value, bool)
126
                ):
127
128
                    return item.value.value
    return None
129
130


131
132
133
134
def _parse_list_class_var(node: ast.ClassDef, var_name: str) -> list[str] | None:
    """Parse a list-type class variable, returning None if not found."""
    for item in node.body:
        if not isinstance(item, ast.AnnAssign):
135
            continue
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        if not isinstance(item.target, ast.Name):
            continue
        if item.target.id != var_name:
            continue
        if not (item.value and isinstance(item.value, ast.List)):
            continue
        result = []
        for elt in item.value.elts:
            if isinstance(elt, ast.Attribute):
                result.append(elt.attr)
            elif isinstance(elt, ast.Constant):
                result.append(str(elt.value))
        return result
    return None
150
151


152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def _parse_return_list(
    method: ast.FunctionDef | None, handle_multiple_of: bool = False
) -> list[str]:
    """Extract list items from a method's return statement."""
    if method is None:
        return []
    for stmt in ast.walk(method):
        if not isinstance(stmt, ast.Return):
            continue
        if not isinstance(stmt.value, ast.List):
            continue
        sizes = []
        for elt in stmt.value.elts:
            if isinstance(elt, ast.Constant):
                sizes.append(str(elt.value))
            elif (
                handle_multiple_of
                and isinstance(elt, ast.Call)
                and isinstance(elt.func, ast.Name)
                and elt.func.id == "MultipleOf"
                and elt.args
                and isinstance(elt.args[0], ast.Constant)
            ):
                sizes.append(f"%{elt.args[0].value}")
        if sizes:
            return sizes
    return []
179
180


181
182
def _get_parent_class_name(class_node: ast.ClassDef) -> str | None:
    """Get the first parent class name (simple name only).
183

184
185
    Handles both simple inheritance (class Foo(Bar)) and generic
    inheritance (class Foo(Bar[T])).
186
    """
187
188
189
190
191
192
193
194
    if not class_node.bases:
        return None
    base = class_node.bases[0]
    if isinstance(base, ast.Name):
        return base.id
    if isinstance(base, ast.Subscript) and isinstance(base.value, ast.Name):
        return base.value.id
    return None
195
196


197
198
199
200
def _resolve_import_to_file(
    tree: ast.AST, class_name: str, source_file: Path | None = None
) -> Path | None:
    """Try to resolve a class name to its source file via imports in the AST.
201

202
203
204
    Handles both absolute imports (from vllm.foo import Bar) and relative
    imports (from .foo import Bar) when source_file is provided.
    """
205
    for node in ast.walk(tree):
206
        if not isinstance(node, ast.ImportFrom):
207
            continue
208
209
210
211
212
213
        for alias in node.names:
            actual_name = alias.asname or alias.name
            if actual_name != class_name:
                continue
            if not node.module:
                continue
214

215
216
217
218
219
220
221
222
223
224
225
            if node.level and node.level > 0 and source_file:
                # Relative import: resolve from the source file's directory
                base_dir = source_file.parent
                for _ in range(node.level - 1):
                    base_dir = base_dir.parent
                module_path = node.module.replace(".", "/")
                py_file = base_dir / f"{module_path}.py"
            else:
                # Absolute import
                module_path = node.module.replace(".", "/")
                py_file = REPO_ROOT / f"{module_path}.py"
226

227
228
229
            if py_file.exists():
                return py_file
    return None
230
231


232
233
def _find_cc_in_function(tree: ast.AST, func_name: str) -> str | None:
    """Find a compute capability from is_device_capability_family() calls in a function.
234

235
236
237
    Looks for the pattern: current_platform.is_device_capability_family(N)
    and converts N (e.g. 100) to a CC string (e.g. "10.x").
    """
238
    for node in ast.walk(tree):
239
240
241
242
243
244
245
246
247
248
249
250
        if not isinstance(node, ast.FunctionDef) or node.name != func_name:
            continue
        for n in ast.walk(node):
            if (
                isinstance(n, ast.Call)
                and isinstance(n.func, ast.Attribute)
                and n.func.attr == "is_device_capability_family"
                and n.args
                and isinstance(n.args[0], ast.Constant)
                and isinstance(n.args[0].value, int)
            ):
                return f"{n.args[0].value // 10}.x"
251
252
253
    return None


254
255
256
# ---------------------------------------------------------------------------
# Registry and file resolution
# ---------------------------------------------------------------------------
257
258


259
260
261
262
263
264
265
def parse_registry() -> dict[str, str]:
    """Parse the registry.py file to get backend names and their class paths."""
    tree = ast.parse(REGISTRY_FILE.read_text())
    for node in ast.walk(tree):
        if isinstance(node, ast.ClassDef) and node.name == "AttentionBackendEnum":
            return _extract_enum_values(node)
    return {}
266
267


268
269
270
def _extract_enum_values(node: ast.ClassDef) -> dict[str, str]:
    """Extract enum name -> value mapping from a class definition."""
    result: dict[str, str] = {}
271
    for item in node.body:
272
        if not isinstance(item, ast.Assign):
273
            continue
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        for target in item.targets:
            if not isinstance(target, ast.Name):
                continue
            if isinstance(item.value, ast.Constant) and item.value.value:
                result[target.id] = item.value.value
    return result


def get_file_from_class_path(class_path: str) -> Path | None:
    """Convert a class path to a file path."""
    if not class_path:
        return None
    module_path = class_path.rsplit(".", 1)[0].replace(".", "/")
    py_file = REPO_ROOT / f"{module_path}.py"
    return py_file if py_file.exists() else None


# ---------------------------------------------------------------------------
# Backend feature extraction from AST
# ---------------------------------------------------------------------------
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409


def parse_supported_dtypes(node: ast.ClassDef) -> str:
    """Parse supported_dtypes class variable."""
    dtype_map = {"float16": "fp16", "bfloat16": "bf16", "float32": "fp32"}
    dtypes = _parse_list_class_var(node, "supported_dtypes")
    if dtypes is None:
        return "fp16, bf16"
    return ", ".join(dtype_map.get(d, d) for d in dtypes)


def parse_kv_cache_dtypes(node: ast.ClassDef) -> str:
    """Parse supported_kv_cache_dtypes class var or supports_kv_cache_dtype method."""
    # First try the class variable
    dtypes = _parse_list_class_var(node, "supported_kv_cache_dtypes")
    if dtypes:
        return ", ".join(dtypes)

    # Fall back to parsing the supports_kv_cache_dtype method
    # Look for `kv_cache_dtype in ["auto", "bfloat16"]` pattern
    method = find_method(node, "supports_kv_cache_dtype")
    if method:
        for n in ast.walk(method):
            if (
                isinstance(n, ast.Compare)
                and len(n.ops) == 1
                and isinstance(n.ops[0], ast.In)
                and len(n.comparators) == 1
                and isinstance(n.comparators[0], ast.List)
            ):
                dtypes = [
                    e.value
                    for e in n.comparators[0].elts
                    if isinstance(e, ast.Constant) and isinstance(e.value, str)
                ]
                if dtypes:
                    return ", ".join(dtypes)

    return "auto"


def parse_block_sizes(node: ast.ClassDef) -> str:
    """Parse get_supported_kernel_block_sizes method."""
    method = find_method(node, "get_supported_kernel_block_sizes")
    sizes = _parse_return_list(method, handle_multiple_of=True)
    return ", ".join(sizes) if sizes else "Any"


def parse_head_sizes(node: ast.ClassDef) -> str:
    """Parse get_supported_head_sizes method."""
    method = find_method(node, "get_supported_head_sizes")
    sizes = _parse_return_list(method)
    return ", ".join(sizes) if sizes else "Any"


def parse_compute_capability(node: ast.ClassDef) -> str:
    """Parse supports_compute_capability method."""
    method = find_method(node, "supports_compute_capability")
    if method is None:
        return "Any"

    min_cap: tuple[int, int] | None = None
    max_cap: tuple[int, int] | None = None
    major_list: list[int] = []

    for n in ast.walk(method):
        if not isinstance(n, ast.Compare):
            continue

        # Handle `capability >= DeviceCapability(...)` or `capability <= ...`
        for op, comp in zip(n.ops, n.comparators):
            if not (
                isinstance(comp, ast.Call)
                and isinstance(comp.func, ast.Name)
                and comp.func.id == "DeviceCapability"
                and comp.args
                and isinstance(comp.args[0], ast.Constant)
            ):
                continue
            major = comp.args[0].value
            minor = 0
            if len(comp.args) > 1 and isinstance(comp.args[1], ast.Constant):
                minor = comp.args[1].value
            if isinstance(op, ast.GtE):
                min_cap = (major, minor)
            elif isinstance(op, ast.LtE):
                max_cap = (major, minor)

        # Handle `capability.major == N` or `capability.major in [N, M]`
        if (
            isinstance(n.left, ast.Attribute)
            and n.left.attr == "major"
            and len(n.ops) == 1
            and len(n.comparators) == 1
        ):
            comp = n.comparators[0]
            if isinstance(n.ops[0], ast.Eq) and isinstance(comp, ast.Constant):
                major_list.append(comp.value)
            elif isinstance(n.ops[0], ast.In) and isinstance(comp, ast.List):
                major_list.extend(
                    e.value
                    for e in comp.elts
                    if isinstance(e, ast.Constant) and isinstance(e.value, int)
                )

    if major_list:
        major_list.sort()
        if len(major_list) == 1:
            return f"{major_list[0]}.x"
        return f"{major_list[0]}.x-{major_list[-1]}.x"

    if min_cap:
        if max_cap:
            return f"{min_cap[0]}.x-{max_cap[0]}.x"
        return f"≥{min_cap[0]}.{min_cap[1]}"

410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
    return "Any"


def parse_attention_types(node: ast.ClassDef) -> str:
    """Parse supports_attn_type method."""
    method = find_method(node, "supports_attn_type")
    if method is None:
        return "Decoder"

    type_map = {
        "DECODER": "Decoder",
        "ENCODER": "Encoder",
        "ENCODER_ONLY": "Encoder Only",
        "ENCODER_DECODER": "Enc-Dec",
    }
    types: set[str] = set()

    for n in ast.walk(method):
        # Handle `attn_type in (AttentionType.DECODER, ...)`
        if not (
            isinstance(n, ast.Compare)
            and len(n.ops) == 1
            and isinstance(n.ops[0], ast.In)
            and len(n.comparators) == 1
            and isinstance(n.comparators[0], ast.Tuple | ast.Set)
        ):
            continue

        for elt in n.comparators[0].elts:
            if isinstance(elt, ast.Attribute) and elt.attr in type_map:
                types.add(type_map[elt.attr])

    if not types:
        return "Decoder"
    return "All" if len(types) >= 3 else ", ".join(sorted(types))


def parse_impl_bool_attr(
    tree: ast.AST,
    class_name: str,
    attr_name: str,
    default: bool = False,
    source_file: Path | None = None,
    _visited: set[str] | None = None,
) -> bool:
    """Parse a boolean class attribute from an impl class, following inheritance.

    Walks up the inheritance chain within the same file and across files
    (by resolving imports) to find the attribute value.
    """
    if _visited is None:
        _visited = set()
    if class_name in _visited:
        return default
    _visited.add(class_name)

    class_node = find_class_in_ast(tree, class_name)
    if class_node is None:
        return default

    # Check directly on this class
    value = _find_bool_class_var(class_node, attr_name)
    if value is not None:
        return value

    # Check parent class
    parent_name = _get_parent_class_name(class_node)
    if parent_name:
        # Try parent in same file first
        parent_node = find_class_in_ast(tree, parent_name)
        if parent_node is not None:
            return parse_impl_bool_attr(
                tree, parent_name, attr_name, default, source_file, _visited
            )

        # Try resolving cross-file import
        parent_file = _resolve_import_to_file(tree, parent_name, source_file)
        if parent_file:
            try:
                parent_tree = ast.parse(parent_file.read_text())
                return parse_impl_bool_attr(
                    parent_tree,
                    parent_name,
                    attr_name,
                    default,
                    parent_file,
                    _visited,
                )
            except Exception:
                pass

    return default


def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None:
    """Analyze a backend class and extract feature information."""
    file_path = get_file_from_class_path(class_path)
    if file_path is None:
        return None

    try:
        tree = ast.parse(file_path.read_text())
    except Exception as e:
        print(f"  Warning: Could not parse {file_path}: {e}", file=sys.stderr)
        return None

    class_name = class_path.rsplit(".", 1)[1]
    class_node = find_class_in_ast(tree, class_name)
    if class_node is None:
        return None

    # Check if this is an MLA backend by parent class or naming
    parent = _get_parent_class_name(class_node)
    mla_parents = {"MLACommonBackend", "FlashMLABackend", "FlashMLASparseBackend"}
    is_mla_backend = (
        parent in mla_parents
        or ".mla." in class_path.lower()
        or "_mla" in backend_name.lower()
    )

    # Determine compute capability - use N/A for non-CUDA backends
    is_non_cuda = backend_name.startswith(("CPU_", "ROCM_"))
    compute_cap = "N/A" if is_non_cuda else parse_compute_capability(class_node)

    # Parse impl class features (DCP support)
    impl_method = find_method(class_node, "get_impl_cls")
    impl_class_name = None
    if impl_method:
        for stmt in ast.walk(impl_method):
            if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.Name):
                impl_class_name = stmt.value.id
                break

    supports_dcp = False
    if impl_class_name:
        supports_dcp = parse_impl_bool_attr(
            tree, impl_class_name, "can_return_lse_for_decode", False, file_path
        )

    return {
        "name": backend_name,
        "dtypes": parse_supported_dtypes(class_node),
        "kv_cache_dtypes": parse_kv_cache_dtypes(class_node),
        "block_sizes": parse_block_sizes(class_node),
        "head_sizes": parse_head_sizes(class_node),
        "attn_types": parse_attention_types(class_node),
        "compute_capability": compute_cap,
        "is_mla": is_mla_backend or check_method_overrides(class_node, "is_mla"),
        "supports_sink": check_method_overrides(class_node, "supports_sink"),
        "is_sparse": check_method_overrides(class_node, "is_sparse"),
        "supports_mm_prefix": check_method_overrides(class_node, "supports_mm_prefix"),
        "supports_dcp": supports_dcp,
    }


# ---------------------------------------------------------------------------
566
# Special backend variant parsers (FA2/FA3/FA4, FlashInfer TRTLLM, MLA prefill)
567
568
569
# ---------------------------------------------------------------------------


570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
def _parse_fa4_supported_caps() -> str | None:
    """Parse flash_attn_interface.py for FA4 supported compute capabilities.

    Looks for `cc not in [9, 10, 11]` pattern in _is_fa4_supported().
    """
    fa_interface_file = (
        REPO_ROOT / "vllm" / "vllm_flash_attn" / "flash_attn_interface.py"
    )
    if not fa_interface_file.exists():
        return None

    try:
        tree = ast.parse(fa_interface_file.read_text())
    except Exception:
        return None

    for node in ast.walk(tree):
        if not isinstance(node, ast.FunctionDef) or node.name != "_is_fa4_supported":
            continue
        for n in ast.walk(node):
            if not (
                isinstance(n, ast.Compare)
                and len(n.ops) == 1
                and isinstance(n.ops[0], ast.NotIn)
                and isinstance(n.comparators[0], ast.List)
            ):
                continue
            caps: list[int] = [
                e.value
                for e in n.comparators[0].elts
                if isinstance(e, ast.Constant) and isinstance(e.value, int)
            ]
            if caps:
                caps.sort()
                return f"{caps[0]}.x-{caps[-1]}.x"

    return None


609
def parse_flash_attn_features() -> dict[str, dict[str, Any]]:
610
    """Parse fa_utils.py to detect FA2 vs FA3 vs FA4 feature differences.
611

612
    Returns a dict with 'fa2', 'fa3', and 'fa4' keys containing their respective
613
614
615
616
617
618
619
620
621
622
623
624
625
626
    feature overrides for compute capability, KV cache dtypes, and sink support.
    """
    if not FA_UTILS_FILE.exists():
        return {}

    try:
        tree = ast.parse(FA_UTILS_FILE.read_text())
    except Exception:
        return {}

    # Analyze the functions to determine FA3-specific features
    fa3_supports_fp8 = False
    fa3_supports_sinks = False
    fa3_compute_cap: str | None = None
627
    fa4_compute_cap: str | None = None
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656

    for node in ast.walk(tree):
        if not isinstance(node, ast.FunctionDef):
            continue

        # Check flash_attn_supports_fp8 - looks for `get_flash_attn_version() == 3`
        if node.name == "flash_attn_supports_fp8":
            for n in ast.walk(node):
                if (
                    isinstance(n, ast.Compare)
                    and isinstance(n.left, ast.Call)
                    and isinstance(n.left.func, ast.Name)
                    and n.left.func.id == "get_flash_attn_version"
                ):
                    fa3_supports_fp8 = True
                    break

        # Check flash_attn_supports_sinks - looks for `get_flash_attn_version() == 3`
        if node.name == "flash_attn_supports_sinks":
            for n in ast.walk(node):
                if (
                    isinstance(n, ast.Compare)
                    and isinstance(n.left, ast.Call)
                    and isinstance(n.left.func, ast.Name)
                    and n.left.func.id == "get_flash_attn_version"
                ):
                    fa3_supports_sinks = True
                    break

657
        # Check get_flash_attn_version for FA3/FA4 compute capability
658
659
        if node.name == "get_flash_attn_version":
            for n in ast.walk(node):
660
                # Handle IfExp (ternary) with `device_capability.major == 9`
661
662
663
664
665
666
667
668
669
670
671
672
673
674
                if isinstance(n, ast.IfExp):
                    test = n.test
                    if isinstance(test, ast.BoolOp):
                        for val in test.values:
                            if (
                                isinstance(val, ast.Compare)
                                and isinstance(val.left, ast.Attribute)
                                and val.left.attr == "major"
                                and val.comparators
                                and isinstance(val.comparators[0], ast.Constant)
                            ):
                                fa3_compute_cap = f"{val.comparators[0].value}.x"
                                break

675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
                # Handle If statements for FA3/FA4 detection
                # e.g. `if device_capability.major == 9` -> FA3
                #      `elif device_capability.major >= 10` -> FA4
                if isinstance(n, ast.If):
                    test = n.test
                    comparisons = (
                        [v for v in test.values if isinstance(v, ast.Compare)]
                        if isinstance(test, ast.BoolOp)
                        else [test]
                        if isinstance(test, ast.Compare)
                        else []
                    )
                    for comp in comparisons:
                        if not (
                            isinstance(comp.left, ast.Attribute)
                            and comp.left.attr == "major"
                            and comp.comparators
                            and isinstance(comp.comparators[0], ast.Constant)
                            and isinstance(comp.comparators[0].value, int)
                        ):
                            continue
                        op = comp.ops[0]
                        val = comp.comparators[0].value
                        if isinstance(op, ast.Eq) and fa3_compute_cap is None:
                            fa3_compute_cap = f"{val}.x"
                        elif isinstance(op, ast.GtE) and fa4_compute_cap is None:
                            fa4_compute_cap = f"≥{val}.0"

    # Fallback: try to parse FA4 compute caps from flash_attn_interface.py
    if fa4_compute_cap is None:
        fa4_compute_cap = _parse_fa4_supported_caps()

707
708
709
710
711
712
713
714
715
716
    return {
        "fa2": {
            "supports_fp8": False,
            "supports_sink": False,
        },
        "fa3": {
            "compute_capability": fa3_compute_cap,
            "supports_fp8": fa3_supports_fp8,
            "supports_sink": fa3_supports_sinks,
        },
717
718
719
720
721
        "fa4": {
            "compute_capability": fa4_compute_cap,
            "supports_fp8": False,
            "supports_sink": False,
        },
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
    }


def parse_flashinfer_trtllm_features() -> dict[str, dict[str, Any]]:
    """Parse flashinfer.py to detect TRTLLM-specific features.

    FLASHINFER uses TRTLLM attention on SM100 (Blackwell), which has different
    capabilities (e.g., sink support) than native FlashInfer on earlier GPUs.
    """
    if not FLASHINFER_UTILS_FILE.exists():
        return {}

    try:
        tree = ast.parse(FLASHINFER_UTILS_FILE.read_text())
    except Exception:
        return {}
738

739
    trtllm_compute_cap = _find_cc_in_function(tree, "supports_trtllm_attention")
740

741
742
    if not trtllm_compute_cap:
        return {}
743

744
745
746
747
748
749
750
751
752
753
    return {
        "native": {
            # Native FlashInfer: everything except SM100
            "supports_sink": False,
        },
        "trtllm": {
            # TRTLLM pathway on Blackwell
            "compute_capability": trtllm_compute_cap,
            "supports_sink": True,
        },
754
755
756
    }


757
758
def parse_mla_prefill_backends() -> list[dict[str, Any]]:
    """Parse MLA prefill backend options from mla_attention.py.
759

760
761
762
    MLA uses different backends for prefill vs decode. The decode backends are
    registered in the registry, but prefill backends are selected at runtime
    based on conditions in MLACommonImpl.__init__.
763

764
765
766
767
    Returns a list of prefill backend info dicts with their requirements.
    """
    if not MLA_ATTENTION_FILE.exists():
        return []
768

769
770
771
772
    try:
        tree = ast.parse(MLA_ATTENTION_FILE.read_text())
    except Exception:
        return []
773

774
775
776
777
    # Find compute capability requirements by parsing use_* functions
    trtllm_cc = _find_cc_in_function(tree, "use_trtllm_ragged_deepseek_prefill")
    flashinfer_cc = _find_cc_in_function(tree, "use_flashinfer_prefill")
    cudnn_cc = _find_cc_in_function(tree, "use_cudnn_prefill")
778

779
780
781
    # Build prefill backend list based on what we found
    # Order matches the priority in MLACommonImpl.__init__
    prefill_backends: list[dict[str, Any]] = []
782

783
784
785
786
787
788
789
790
791
792
793
794
    # TRT-LLM Ragged (highest priority if available)
    if trtllm_cc:
        prefill_backends.append(
            {
                "name": "TRT-LLM Ragged‡",
                "description": "TensorRT-LLM ragged attention",
                "compute_capability": trtllm_cc,
                "enable": "Default on SM100",
                "disable": "`-ac.use_trtllm_ragged_deepseek_prefill=0`",
                "notes": "DeepSeek R1 dims only",
            }
        )
795

796
797
798
799
800
801
802
803
804
805
806
807
    # FlashInfer prefill
    if flashinfer_cc:
        prefill_backends.append(
            {
                "name": "FlashInfer",
                "description": "FlashInfer CUTLASS backend",
                "compute_capability": flashinfer_cc,
                "enable": "`-ac.disable_flashinfer_prefill=0`",
                "disable": "`-ac.disable_flashinfer_prefill=1`",
                "notes": "DeepSeek R1 dims only",
            }
        )
808

809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
    # cuDNN prefill
    if cudnn_cc:
        prefill_backends.append(
            {
                "name": "cuDNN",
                "description": "cuDNN-based attention",
                "compute_capability": cudnn_cc,
                "enable": "`-ac.use_cudnn_prefill=1`",
                "disable": "`-ac.use_cudnn_prefill=0`",
                "notes": "",
            }
        )

    # FlashAttention is always available as fallback
    prefill_backends.append(
        {
            "name": "FlashAttention",
            "description": "FlashAttention varlen (FA2/FA3)",
            "compute_capability": "Any",
            "enable": "Default fallback",
            "disable": "Use other backends",
            "notes": "FA3 on SM90, FA2 otherwise",
        }
832
833
    )

834
    return prefill_backends
835
836


837
# ---------------------------------------------------------------------------
838
# Backend variant expansion (FA2/FA3/FA4, FlashInfer native/TRTLLM)
839
# ---------------------------------------------------------------------------
840
841


842
843
844
845
def _expand_flash_attn_variants(
    all_backends: list[dict[str, Any]],
    fa_features: dict[str, dict[str, Any]],
) -> list[dict[str, Any]]:
846
    """Expand FLASH_ATTN into FA2, FA3, and FA4 variants."""
847
848
849
850
851
852
853
854
    expanded = []
    for backend in all_backends:
        if backend["name"] != "FLASH_ATTN":
            backend.setdefault("_sort_key", backend["name"])
            backend.setdefault("_sort_order", 0)
            backend.setdefault("version", "")
            expanded.append(backend)
            continue
855

856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
        # Create FA2 entry (keeps base backend's compute_capability)
        fa2 = backend.copy()
        fa2["version"] = "FA2*"
        fa2["_sort_key"] = "FLASH_ATTN"
        fa2["_sort_order"] = 0
        fa2["supports_sink"] = fa_features["fa2"]["supports_sink"]

        # Create FA3 entry (uses parsed compute_capability from fa_utils)
        fa3 = backend.copy()
        fa3["version"] = "FA3*"
        fa3["_sort_key"] = "FLASH_ATTN"
        fa3["_sort_order"] = 1
        if fa_features["fa3"]["compute_capability"]:
            fa3["compute_capability"] = fa_features["fa3"]["compute_capability"]
        fa3["supports_sink"] = fa_features["fa3"]["supports_sink"]
        if fa_features["fa3"]["supports_fp8"]:
            base_dtypes = backend["kv_cache_dtypes"].split(", ")
            fp8_dtypes = ["fp8", "fp8_e4m3", "fp8_e5m2"]
            new_dtypes = [d for d in fp8_dtypes if d not in base_dtypes]
            fa3["kv_cache_dtypes"] = ", ".join(base_dtypes + new_dtypes)

        expanded.append(fa2)
        expanded.append(fa3)
879
880
881
882
883
884
885
886
887
888
889
890

        # Create FA4 entry if FA4 features are available
        if "fa4" in fa_features:
            fa4 = backend.copy()
            fa4["version"] = "FA4*"
            fa4["_sort_key"] = "FLASH_ATTN"
            fa4["_sort_order"] = 2
            if fa_features["fa4"].get("compute_capability"):
                fa4["compute_capability"] = fa_features["fa4"]["compute_capability"]
            fa4["supports_sink"] = fa_features["fa4"]["supports_sink"]
            expanded.append(fa4)

891
892
893
894
895
896
897
898
899
900
901
902
903
    return expanded


def _expand_flashinfer_variants(
    all_backends: list[dict[str, Any]],
    fi_features: dict[str, dict[str, Any]],
) -> list[dict[str, Any]]:
    """Expand FLASHINFER into native and TRTLLM variants."""
    expanded = []
    for backend in all_backends:
        if backend["name"] != "FLASHINFER":
            expanded.append(backend)
            continue
904

905
906
907
908
909
        # Parse original compute capability to get min CC
        orig_cap = backend["compute_capability"]
        parts = orig_cap.replace(".x", "").split("-")
        min_cc = parts[0] if parts else "7"
        trtllm_cc = fi_features["trtllm"]["compute_capability"]
910

911
912
913
914
915
916
917
        # Create native entry (pre-Blackwell GPUs)
        native = backend.copy()
        native["version"] = "Native†"
        native["_sort_key"] = "FLASHINFER"
        native["_sort_order"] = 0
        native["supports_sink"] = fi_features["native"]["supports_sink"]
        native["compute_capability"] = f"{min_cc}.x-9.x"
918

919
920
921
922
923
924
925
        # Create TRTLLM entry
        trtllm = backend.copy()
        trtllm["version"] = "TRTLLM†"
        trtllm["_sort_key"] = "FLASHINFER"
        trtllm["_sort_order"] = 1
        trtllm["compute_capability"] = trtllm_cc
        trtllm["supports_sink"] = fi_features["trtllm"]["supports_sink"]
926

927
928
929
        expanded.append(native)
        expanded.append(trtllm)
    return expanded
930
931


932
933
934
# ---------------------------------------------------------------------------
# CUDA priority list parsing
# ---------------------------------------------------------------------------
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990


def parse_cuda_priority_lists() -> dict[str, list[str]]:
    """Parse priority lists from cuda.py using AST.

    The structure of _get_backend_priorities is:
        if use_mla:
            if device_capability.major == 10:
                return [MLA list for SM100]
            else:
                return [MLA list for default]
        else:
            if device_capability.major == 10:
                return [Standard list for SM100]
            else:
                return [Standard list for default]
    """
    if not CUDA_PLATFORM_FILE.exists():
        return {}

    try:
        source = CUDA_PLATFORM_FILE.read_text()
        tree = ast.parse(source)
    except Exception:
        return {}

    priorities: dict[str, list[str]] = {}

    # Find the _get_backend_priorities function
    for node in ast.walk(tree):
        if not isinstance(node, ast.FunctionDef):
            continue
        if node.name != "_get_backend_priorities":
            continue

        # Process the function body directly
        for stmt in node.body:
            if not isinstance(stmt, ast.If):
                continue

            # Check if this is the "if use_mla:" branch
            is_mla_branch = (
                isinstance(stmt.test, ast.Name) and stmt.test.id == "use_mla"
            )

            if is_mla_branch:
                _extract_priorities(stmt.body, priorities, "mla")
                if stmt.orelse:
                    _extract_priorities(stmt.orelse, priorities, "standard")
            else:
                _extract_priorities([stmt], priorities, "standard")

    return priorities


def _get_backends_from_return(stmts: list) -> list[str]:
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
    """Extract backend names from return statements in a list of statements.

    Handles starred unpacking (e.g. ``*sparse_backends``) by resolving the
    variable from assignments found in the same statement list.  When the
    variable is conditionally assigned (inside an ``if/else``), the ``else``
    branch value is used as the representative default.
    """
    # Collect variable assignments so we can resolve starred expressions.
    # For conditional assignments, last-written (else branch) wins.
    var_assigns: dict[str, list[str]] = {}
    for stmt in stmts:
        if isinstance(stmt, ast.Assign) and isinstance(stmt.value, ast.List):
            for target in stmt.targets:
                if isinstance(target, ast.Name):
                    var_assigns[target.id] = [
                        e.attr for e in stmt.value.elts if isinstance(e, ast.Attribute)
                    ]
        elif isinstance(stmt, ast.If):
            for branch in (stmt.body, stmt.orelse):
                for branch_stmt in branch:
                    if isinstance(branch_stmt, ast.Assign) and isinstance(
                        branch_stmt.value, ast.List
                    ):
                        for target in branch_stmt.targets:
                            if isinstance(target, ast.Name):
                                var_assigns[target.id] = [
                                    e.attr
                                    for e in branch_stmt.value.elts
                                    if isinstance(e, ast.Attribute)
                                ]

1022
1023
    for stmt in stmts:
        if isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.List):
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
            backends: list[str] = []
            for e in stmt.value.elts:
                if isinstance(e, ast.Attribute):
                    backends.append(e.attr)
                elif (
                    isinstance(e, ast.Starred)
                    and isinstance(e.value, ast.Name)
                    and e.value.id in var_assigns
                ):
                    backends.extend(var_assigns[e.value.id])
            return backends
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
    return []


def _is_sm100_check(test: ast.expr) -> bool:
    """Check if test is `something.major == 10`."""
    return (
        isinstance(test, ast.Compare)
        and isinstance(test.left, ast.Attribute)
        and test.left.attr == "major"
        and len(test.ops) == 1
        and isinstance(test.ops[0], ast.Eq)
        and len(test.comparators) == 1
        and isinstance(test.comparators[0], ast.Constant)
        and test.comparators[0].value == 10
    )


def _extract_priorities(body: list, priorities: dict[str, list[str]], prefix: str):
    """Extract priority lists from if/else statement body."""
    for stmt in body:
        if isinstance(stmt, ast.If):
            is_sm100 = _is_sm100_check(stmt.test)
            if_key = f"{prefix}_sm100" if is_sm100 else f"{prefix}_default"
            else_key = f"{prefix}_default" if is_sm100 else f"{prefix}_sm100"

            if backends := _get_backends_from_return(stmt.body):
                priorities[if_key] = backends
            if backends := _get_backends_from_return(stmt.orelse):
                priorities[else_key] = backends

        elif isinstance(stmt, ast.Return) and isinstance(stmt.value, ast.List):
            backends = [e.attr for e in stmt.value.elts if isinstance(e, ast.Attribute)]
            priorities[f"{prefix}_default"] = backends


1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
# ---------------------------------------------------------------------------
# Data-driven table rendering
#
# Each column is a (header, formatter) pair. The formatter takes a backend
# info dict and returns the cell string. Tables are assembled by selecting
# which columns to include, then calling _render_table().
# ---------------------------------------------------------------------------

# Column type alias for readability
TableColumn = tuple[str, Callable[[dict[str, Any]], str]]

# Shared column definitions -- order here matches the output table order
_COL_BACKEND: TableColumn = ("Backend", lambda b: f"`{b['name']}`")
_COL_VERSION: TableColumn = ("Version", lambda b: b.get("version", ""))
_COL_DTYPES: TableColumn = ("Dtypes", lambda b: b["dtypes"])
_COL_KV_DTYPES: TableColumn = (
    "KV Dtypes",
    lambda b: add_literal_quotes(b["kv_cache_dtypes"]),
)
_COL_BLOCK_SIZES: TableColumn = ("Block Sizes", lambda b: b["block_sizes"])
_COL_HEAD_SIZES: TableColumn = ("Head Sizes", lambda b: b["head_sizes"])
_COL_SINK: TableColumn = ("Sink", lambda b: bool_to_emoji(b["supports_sink"]))
_COL_SPARSE: TableColumn = ("Sparse", lambda b: bool_to_emoji(b["is_sparse"]))
_COL_MM_PREFIX: TableColumn = (
    "MM Prefix",
    lambda b: bool_to_emoji(b["supports_mm_prefix"]),
)
_COL_DCP: TableColumn = ("DCP", lambda b: bool_to_emoji(b["supports_dcp"]))
_COL_ATTN_TYPES: TableColumn = ("Attention Types", lambda b: b["attn_types"])
_COL_COMPUTE_CAP: TableColumn = ("Compute Cap.", lambda b: b["compute_capability"])


def add_literal_quotes(value: str) -> str:
    """Add literal backticks around all comma-separated items in a string."""
    items = [item.strip() for item in value.split(",")]
    return ", ".join(f"`{item}`" for item in items)


def bool_to_emoji(value: bool) -> str:
    """Convert a boolean to a checkmark or X emoji."""
    return "✅" if value else "❌"


def _build_columns(is_mla: bool, has_versions: bool) -> list[TableColumn]:
    """Build the column list for a backend feature table.

    The column selection depends on whether it's an MLA table (includes
    Sparse column) and whether any backend has version variants (includes
    Version column).
    """
    cols: list[TableColumn] = [_COL_BACKEND]
    if has_versions:
        cols.append(_COL_VERSION)
    cols.extend([_COL_DTYPES, _COL_KV_DTYPES, _COL_BLOCK_SIZES, _COL_HEAD_SIZES])
    cols.append(_COL_SINK)
    if is_mla:
        cols.append(_COL_SPARSE)
    cols.extend([_COL_MM_PREFIX, _COL_DCP, _COL_ATTN_TYPES, _COL_COMPUTE_CAP])
    return cols


def _sort_key(x: dict[str, Any]) -> tuple[str, int]:
    """Sort key that keeps parent/child rows together in order."""
    return (x.get("_sort_key", x["name"]), x.get("_sort_order", 0))


def _render_table(
    columns: list[TableColumn],
    backends: list[dict[str, Any]],
) -> list[str]:
    """Render a markdown table from column specs and backend data."""
    header = "| " + " | ".join(name for name, _ in columns) + " |"
    sep = "|" + "|".join("-" * (len(name) + 2) for name, _ in columns) + "|"
    lines = [header, sep]
    for info in sorted(backends, key=_sort_key):
        row = "| " + " | ".join(fmt(info) for _, fmt in columns) + " |"
        lines.append(row)
    return lines


def generate_markdown_table(
    backends: list[dict[str, Any]], title: str, is_mla_table: bool = False
) -> str:
    """Generate a titled markdown table from backend info."""
    if not backends:
        return f"## {title}\n\nNo backends found.\n"
    has_versions = any(b.get("version") for b in backends)
    columns = _build_columns(is_mla_table, has_versions)
    lines = [f"## {title}", ""]
    lines.extend(_render_table(columns, backends))
    lines.append("")
    return "\n".join(lines)


# ---------------------------------------------------------------------------
# Markdown section generators (usage, priority, legend, MLA)
# ---------------------------------------------------------------------------


1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
def generate_usage_section() -> str:
    """Generate the usage documentation section."""
    return """## Setting the Attention Backend

### Command Line

There are two ways to specify the backend from the command line:

**Option 1: Using `--attention-backend` (simple)**

```bash
vllm serve <model> --attention-backend FLASH_ATTN
```

**Option 2: Using `--attention-config.backend` / `-ac.backend` (structured config)**

```bash
# Dot notation
vllm serve <model> --attention-config.backend FLASH_ATTN
vllm serve <model> -ac.backend FLASH_ATTN

# JSON format
vllm serve <model> --attention-config '{"backend": "FLASH_ATTN"}'
vllm serve <model> -ac '{"backend": "FLASH_ATTN"}'
```

> **Note:** `--attention-backend` and `--attention-config.backend` are mutually
> exclusive. Use one or the other, not both.

### Python API

Use `AttentionConfig` with the `LLM` class:

```python
from vllm import LLM
from vllm.config import AttentionConfig
from vllm.v1.attention.backends.registry import AttentionBackendEnum

# Method 1: Using AttentionConfig with enum
llm = LLM(
    model="Qwen/Qwen3-0.6B",
    attention_config=AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN),
)

# Method 2: Using attention_backend parameter with string
llm = LLM(
    model="Qwen/Qwen3-0.6B",
    attention_backend="FLASH_ATTN",
)
```

## Backend Selection Behavior

### Manual Selection

When you explicitly set a backend via `--attention-backend` or `AttentionConfig`:

1. The backend is **validated** against your configuration (model dtype, head
   size, compute capability, etc.)
2. If the backend **doesn't support** your configuration, an error is raised
   with the specific reason
3. If valid, the backend is used

Example error when selecting an incompatible backend:

```text
ValueError: Selected backend FLASHMLA is not valid for this configuration.
Reason: ['compute capability not supported']
```

### Automatic Selection

When no backend is specified (the default):

1. vLLM iterates through backends in **priority order** (see tables below)
2. Each backend is validated against your configuration
3. The **first compatible backend** is selected
4. If no backend is compatible, an error is raised listing all backends and
   their incompatibility reasons
"""


def _priority_table(title: str, backends: list[str]) -> list[str]:
    """Generate a priority table for a list of backends."""
    return [
        f"**{title}:**",
        "",
        "| Priority | Backend |",
        "|----------|---------|",
        *[f"| {i} | `{b}` |" for i, b in enumerate(backends, 1)],
        "",
    ]


def generate_priority_section(priorities: dict[str, list[str]]) -> str:
    """Generate the priority ranking section."""
    lines = [
        "## Backend Priority (CUDA)",
        "",
        "When no backend is explicitly selected, vLLM chooses the first",
        "compatible backend from these priority-ordered lists.",
        "",
        "Priority is **1 = highest** (tried first).",
        "",
        "### Standard Attention (MHA, MQA, GQA)",
        "",
    ]

    sm100 = "Blackwell (SM 10.x)"
    ampere = "Ampere/Hopper (SM 8.x-9.x)"

    if "standard_sm100" in priorities:
        lines.extend(_priority_table(sm100, priorities["standard_sm100"]))
    if "standard_default" in priorities:
        lines.extend(_priority_table(ampere, priorities["standard_default"]))

    lines.extend(["### MLA Attention (DeepSeek-style)", ""])

    if "mla_sm100" in priorities:
        lines.extend(_priority_table(sm100, priorities["mla_sm100"]))
    if "mla_default" in priorities:
        lines.extend(_priority_table(ampere, priorities["mla_default"]))

    lines.append(
        "> **Note:** ROCm and CPU platforms have their own selection logic. "
        "See the platform-specific documentation for details."
    )
    lines.append("")

    return "\n".join(lines)


1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
def generate_legend() -> str:
    """Generate a legend explaining the table columns."""
    return """## Legend

| Column | Description |
|--------|-------------|
| **Dtypes** | Supported model data types (fp16, bf16, fp32) |
| **KV Dtypes** | Supported KV cache data types (`auto`, `fp8`, `fp8_e4m3`, etc.) |
| **Block Sizes** | Supported KV cache block sizes (%N means multiples of N) |
| **Head Sizes** | Supported attention head sizes |
| **Sink** | Attention sink support (for StreamingLLM) |
| **Sparse** | Sparse attention support (MLA only) |
| **MM Prefix** | Multimodal prefix full attention support |
| **DCP** | Decode Context Parallelism support (`--decode-context-parallel-size`) |
| **Attention Types** | Supported attention patterns (Decoder, Encoder, Enc-Dec) |
| **Compute Cap.** | Required CUDA compute capability (N/A for non-CUDA backends) |

**Symbols:** ✅ = Supported, ❌ = Not supported
"""


1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
def generate_mla_section(
    prefill_backends: list[dict[str, Any]], decode_backends: list[dict[str, Any]]
) -> str:
    """Generate the complete MLA section with prefill and decode tables."""
    lines = [
        "## MLA (Multi-head Latent Attention) Backends",
        "",
        "MLA uses separate backends for prefill and decode phases.",
        "",
        "### Prefill Backends",
        "",
        "The prefill backend is selected at runtime based on hardware and",
        "configuration.",
        "",
        "| Backend | Description | Compute Cap. | Enable | Disable | Notes |",
        "|---------|-------------|--------------|--------|---------|-------|",
    ]

    for backend in prefill_backends:
        row = "| {} | {} | {} | {} | {} | {} |".format(
            backend["name"],
            backend["description"],
            backend["compute_capability"],
            backend["enable"],
            backend["disable"],
            backend.get("notes", ""),
        )
        lines.append(row)

    lines.extend(
        [
            "",
            "> **‡** TRT-LLM Ragged is the default on Blackwell (SM100).",
            "> On other GPUs, FlashAttention is used as the default.",
            "",
            "### Decode Backends",
            "",
        ]
    )

1362
1363
1364
    # Reuse data-driven table rendering for decode backends
    columns = _build_columns(is_mla=True, has_versions=False)
    lines.extend(_render_table(columns, decode_backends))
1365
1366
1367
1368
1369

    lines.append("")
    return "\n".join(lines)


1370
1371
1372
# ---------------------------------------------------------------------------
# Top-level orchestration
# ---------------------------------------------------------------------------
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393


def generate_docs() -> str:
    """Generate the complete documentation."""
    attention_backends_map = parse_registry()

    # Parse priority lists from cuda.py
    priorities = parse_cuda_priority_lists()

    # Parse FlashAttention FA2/FA3 feature differences
    fa_features = parse_flash_attn_features()

    # Parse FlashInfer TRTLLM feature differences (native vs TRTLLM on Blackwell)
    fi_features = parse_flashinfer_trtllm_features()

    # Parse MLA prefill backends
    mla_prefill_backends = parse_mla_prefill_backends()

    # Collect backend info
    all_backends = []
    for backend_name, class_path in attention_backends_map.items():
1394
        if backend_name in SKIP_BACKENDS:
1395
1396
1397
1398
1399
            continue
        info = analyze_backend(backend_name, class_path)
        if info:
            all_backends.append(info)

1400
    # Expand backends into version variants
1401
    if fa_features:
1402
        all_backends = _expand_flash_attn_variants(all_backends, fa_features)
1403
    if fi_features:
1404
        all_backends = _expand_flashinfer_variants(all_backends, fi_features)
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449

    # Split into MLA and non-MLA
    mla_backends = [b for b in all_backends if b["is_mla"]]
    non_mla_backends = [b for b in all_backends if not b["is_mla"]]

    # Generate documentation
    script_path = "tools/pre_commit/generate_attention_backend_docs.py"
    doc_lines = [
        "# Attention Backend Feature Support",
        "",
        f"This document is auto-generated by `{script_path}`.",
        "It shows the feature support for each registered attention backend",
        "based on the checks in `AttentionBackend.validate_configuration()`.",
        "",
        "**Do not edit this file manually.** Run the following command to",
        "regenerate it:",
        "",
        "```bash",
        f"python {script_path}",
        "```",
        "",
    ]

    # Add usage documentation
    doc_lines.append(generate_usage_section())

    # Add priority section
    doc_lines.append(generate_priority_section(priorities))

    # Add legend and feature tables
    doc_lines.append(generate_legend())
    standard_title = "Standard Attention (MHA, MQA, GQA) Backends"
    doc_lines.append(
        generate_markdown_table(non_mla_backends, standard_title, is_mla_table=False)
    )
    # Add footnotes for version/variant distinctions (in table order)
    footnotes = []
    if fi_features:
        footnotes.append(
            "> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which "
            "supports sinks. Disable via `--attention-config.use_trtllm_attention=0`."
        )
    if fa_features:
        footnotes.append(
            "> **\\*** Specify the FlashAttention version via "
1450
1451
            "`--attention-config.flash_attn_version=2`, `3`, or `4`. "
            "Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), "
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
            "FA2 otherwise."
        )
    if footnotes:
        doc_lines.append("\n>\n".join(footnotes) + "\n")

    # Add MLA section with prefill and decode backends
    doc_lines.append(generate_mla_section(mla_prefill_backends, mla_backends))

    return "\n".join(doc_lines)


def main():
    parser = argparse.ArgumentParser(
        description="Generate attention backend documentation table"
    )
    parser.add_argument(
        "--output",
        "-o",
        type=str,
        default=str(REPO_ROOT / "docs" / "design" / "attention_backends.md"),
        help="Output file path (default: docs/design/attention_backends.md)",
    )
    parser.add_argument(
        "--check",
        action="store_true",
        help="Check if the documentation is up to date (for pre-commit)",
    )
    parser.add_argument(
        "files",
        nargs="*",
        help="Files to check (passed by pre-commit). If none are relevant, skip.",
    )
    args = parser.parse_args()

    if args.files and not any(is_relevant_file(f) for f in args.files):
        sys.exit(0)

    output_path = Path(args.output)
    new_content = generate_docs()

    if args.check:
        needs_update = (
            not output_path.exists() or output_path.read_text() != new_content
        )
        if needs_update:
            output_path.parent.mkdir(parents=True, exist_ok=True)
            output_path.write_text(new_content)
            print(f"🔄 Regenerated: {output_path}")
            sys.exit(1)
        print(f"✅ Up to date: {output_path}")
        sys.exit(0)

    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_text(new_content)
    print(f"Generated: {output_path}")


if __name__ == "__main__":
    main()