Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5e75a14a
Unverified
Commit
5e75a14a
authored
Feb 09, 2026
by
Michael Goin
Committed by
GitHub
Feb 09, 2026
Browse files
[Doc] Add DCP support to attention backend doc (#33936)
parent
e7e52781
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
769 additions
and
644 deletions
+769
-644
docs/design/attention_backends.md
docs/design/attention_backends.md
+26
-25
tools/pre_commit/generate_attention_backend_docs.py
tools/pre_commit/generate_attention_backend_docs.py
+743
-619
No files found.
docs/design/attention_backends.md
View file @
5e75a14a
...
@@ -152,6 +152,7 @@ Priority is **1 = highest** (tried first).
...
@@ -152,6 +152,7 @@ Priority is **1 = highest** (tried first).
|
**Sink**
| Attention sink support (for StreamingLLM) |
|
**Sink**
| Attention sink support (for StreamingLLM) |
|
**Sparse**
| Sparse attention support (MLA only) |
|
**Sparse**
| Sparse attention support (MLA only) |
|
**MM Prefix**
| Multimodal prefix full attention support |
|
**MM Prefix**
| Multimodal prefix full attention support |
|
**DCP**
| Decode Context Parallelism support (
`--decode-context-parallel-size`
) |
|
**Attention Types**
| Supported attention patterns (Decoder, Encoder, Enc-Dec) |
|
**Attention Types**
| Supported attention patterns (Decoder, Encoder, Enc-Dec) |
|
**Compute Cap.**
| Required CUDA compute capability (N/A for non-CUDA backends) |
|
**Compute Cap.**
| Required CUDA compute capability (N/A for non-CUDA backends) |
...
@@ -159,20 +160,20 @@ Priority is **1 = highest** (tried first).
...
@@ -159,20 +160,20 @@ Priority is **1 = highest** (tried first).
## Standard Attention (MHA, MQA, GQA) Backends
## Standard Attention (MHA, MQA, GQA) Backends
| Backend | Version | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | MM Prefix | Attention Types | Compute Cap. |
| Backend | Version | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | MM Prefix |
DCP |
Attention Types | Compute Cap. |
|---------|---------|--------|-----------|-------------|------------|------|-----------|-----------------|--------------|
|---------|---------|--------|-----------|-------------|------------|------|-----------|-----
|-----
------------|--------------|
|
`CPU_ATTN`
| | fp16, bf16, fp32 |
`auto`
| Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | All | N/A |
|
`CPU_ATTN`
| | fp16, bf16, fp32 |
`auto`
| Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ |
❌ |
All | N/A |
|
`FLASHINFER`
| Native† | fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | Decoder | 7.x-9.x |
|
`FLASHINFER`
| Native† | fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 16, 32, 64 | 64, 128, 256 | ❌ | ❌ |
✅ |
Decoder | 7.x-9.x |
|
`FLASHINFER`
| TRTLLM† | fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | Decoder | 10.x |
|
`FLASHINFER`
| TRTLLM† | fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 16, 32, 64 | 64, 128, 256 | ✅ | ❌ |
✅ |
Decoder | 10.x |
|
`FLASH_ATTN`
| FA2
*
| fp16, bf16 |
`auto`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | All | ≥8.0 |
|
`FLASH_ATTN`
| FA2
*
| fp16, bf16 |
`auto`
,
`bfloat16`
| %16 | Any | ❌ | ❌ |
✅ |
All | ≥8.0 |
|
`FLASH_ATTN`
| FA3
*
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| %16 | Any | ✅ | ❌ | All | 9.x |
|
`FLASH_ATTN`
| FA3
*
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| %16 | Any | ✅ | ❌ |
✅ |
All | 9.x |
|
`FLASH_ATTN_DIFFKV`
| | fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | Decoder | Any |
|
`FLASH_ATTN_DIFFKV`
| | fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ |
✅ |
Decoder | Any |
|
`FLEX_ATTENTION`
| | fp16, bf16, fp32 |
`auto`
,
`bfloat16`
| Any | Any | ❌ | ✅ | Decoder, Encoder Only | Any |
|
`FLEX_ATTENTION`
| | fp16, bf16, fp32 |
`auto`
,
`bfloat16`
| Any | Any | ❌ | ✅ |
❌ |
Decoder, Encoder Only | Any |
|
`ROCM_AITER_FA`
| | fp16, bf16 |
`auto`
| 16, 32 | 64, 128, 256 | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_FA`
| | fp16, bf16 |
`auto`
| 16, 32 | 64, 128, 256 | ❌ | ❌ |
❌ |
Decoder | N/A |
|
`ROCM_AITER_UNIFIED_ATTN`
| | fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_UNIFIED_ATTN`
| | fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ |
❌ |
Decoder | N/A |
|
`ROCM_ATTN`
| | fp16, bf16, fp32 |
`auto`
| 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | Decoder | N/A |
|
`ROCM_ATTN`
| | fp16, bf16, fp32 |
`auto`
| 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ |
❌ |
Decoder | N/A |
|
`TREE_ATTN`
| | fp16, bf16 |
`auto`
| %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | Decoder | Any |
|
`TREE_ATTN`
| | fp16, bf16 |
`auto`
| %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ |
❌ |
Decoder | Any |
|
`TRITON_ATTN`
| | fp16, bf16, fp32 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| %16 | Any | ✅ | ✅ | All | Any |
|
`TRITON_ATTN`
| | fp16, bf16, fp32 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| %16 | Any | ✅ | ✅ |
❌ |
All | Any |
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
> **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`.
>
>
...
@@ -199,14 +200,14 @@ configuration.
...
@@ -199,14 +200,14 @@ configuration.
### Decode Backends
### Decode Backends
| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Sparse | MM Prefix | Attention Types | Compute Cap. |
| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Sparse | MM Prefix |
DCP |
Attention Types | Compute Cap. |
|---------|--------|-----------|-------------|------------|------|--------|-----------|-----------------|--------------|
|---------|--------|-----------|-------------|------------|------|--------|-----------|-----
|-----
------------|--------------|
|
`CUTLASS_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
| 128 | Any | ❌ | ❌ | ❌ | Decoder | 10.x |
|
`CUTLASS_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
| 128 | Any | ❌ | ❌ | ❌ |
✅ |
Decoder | 10.x |
|
`FLASHINFER_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
| 32, 64 | Any | ❌ | ❌ | ❌ | Decoder | 10.x |
|
`FLASHINFER_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
| 32, 64 | Any | ❌ | ❌ | ❌ |
❌ |
Decoder | 10.x |
|
`FLASHMLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
| 64 | Any | ❌ | ❌ | ❌ | Decoder | 9.x-10.x |
|
`FLASHMLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
| 64 | Any | ❌ | ❌ | ❌ |
✅ |
Decoder | 9.x-10.x |
|
`FLASHMLA_SPARSE`
| bf16 |
`auto`
,
`bfloat16`
,
`fp8_ds_mla`
| 64 | 576 | ❌ | ✅ | ❌ | Decoder | 9.x-10.x |
|
`FLASHMLA_SPARSE`
| bf16 |
`auto`
,
`bfloat16`
,
`fp8_ds_mla`
| 64 | 576 | ❌ | ✅ | ❌ |
❌ |
Decoder | 9.x-10.x |
|
`FLASH_ATTN_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ❌ | Decoder | 9.x |
|
`FLASH_ATTN_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ❌ |
✅ |
Decoder | 9.x |
|
`ROCM_AITER_MLA`
| fp16, bf16 |
`auto`
| 1 | Any | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_MLA`
| fp16, bf16 |
`auto`
| 1 | Any | ❌ | ❌ | ❌ |
❌ |
Decoder | N/A |
|
`ROCM_AITER_MLA_SPARSE`
| fp16, bf16 |
`auto`
| Any | 576 | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_MLA_SPARSE`
| fp16, bf16 |
`auto`
| Any | 576 | ❌ | ❌ | ❌ |
❌ |
Decoder | N/A |
|
`ROCM_AITER_TRITON_MLA`
| fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_TRITON_MLA`
| fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ❌ |
❌ |
Decoder | N/A |
|
`TRITON_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| Any | Any | ❌ | ❌ | ❌ | Decoder | Any |
|
`TRITON_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| Any | Any | ❌ | ❌ | ❌ |
✅ |
Decoder | Any |
tools/pre_commit/generate_attention_backend_docs.py
View file @
5e75a14a
...
@@ -17,9 +17,14 @@ import argparse
...
@@ -17,9 +17,14 @@ import argparse
import
ast
import
ast
import
fnmatch
import
fnmatch
import
sys
import
sys
from
collections.abc
import
Callable
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
from
typing
import
Any
# ---------------------------------------------------------------------------
# Constants and file paths
# ---------------------------------------------------------------------------
REPO_ROOT
=
Path
(
__file__
).
parent
.
parent
.
parent
REPO_ROOT
=
Path
(
__file__
).
parent
.
parent
.
parent
RELEVANT_PATTERNS
=
[
RELEVANT_PATTERNS
=
[
...
@@ -32,6 +37,18 @@ RELEVANT_PATTERNS = [
...
@@ -32,6 +37,18 @@ RELEVANT_PATTERNS = [
"docs/design/attention_backends.md"
,
"docs/design/attention_backends.md"
,
]
]
BACKENDS_DIR
=
REPO_ROOT
/
"vllm"
/
"v1"
/
"attention"
/
"backends"
REGISTRY_FILE
=
BACKENDS_DIR
/
"registry.py"
CUDA_PLATFORM_FILE
=
REPO_ROOT
/
"vllm"
/
"platforms"
/
"cuda.py"
FA_UTILS_FILE
=
BACKENDS_DIR
/
"fa_utils.py"
FLASHINFER_UTILS_FILE
=
REPO_ROOT
/
"vllm"
/
"utils"
/
"flashinfer.py"
MLA_ATTENTION_FILE
=
(
REPO_ROOT
/
"vllm"
/
"model_executor"
/
"layers"
/
"attention"
/
"mla_attention.py"
)
# Backends to skip during doc generation
SKIP_BACKENDS
=
{
"CUSTOM"
,
"TORCH_SDPA"
}
def
is_relevant_file
(
filepath
:
str
)
->
bool
:
def
is_relevant_file
(
filepath
:
str
)
->
bool
:
"""Check if a file matches any of the relevant patterns."""
"""Check if a file matches any of the relevant patterns."""
...
@@ -46,351 +63,234 @@ def is_relevant_file(filepath: str) -> bool:
...
@@ -46,351 +63,234 @@ def is_relevant_file(filepath: str) -> bool:
return
any
(
fnmatch
.
fnmatch
(
path_str
,
pattern
)
for
pattern
in
RELEVANT_PATTERNS
)
return
any
(
fnmatch
.
fnmatch
(
path_str
,
pattern
)
for
pattern
in
RELEVANT_PATTERNS
)
BACKENDS_DIR
=
REPO_ROOT
/
"vllm"
/
"v1"
/
"attention"
/
"backends"
# ---------------------------------------------------------------------------
REGISTRY_FILE
=
BACKENDS_DIR
/
"registry.py"
# AST utility helpers
CUDA_PLATFORM_FILE
=
REPO_ROOT
/
"vllm"
/
"platforms"
/
"cuda.py"
# ---------------------------------------------------------------------------
FA_UTILS_FILE
=
BACKENDS_DIR
/
"fa_utils.py"
FLASHINFER_UTILS_FILE
=
REPO_ROOT
/
"vllm"
/
"utils"
/
"flashinfer.py"
MLA_ATTENTION_FILE
=
(
REPO_ROOT
/
"vllm"
/
"model_executor"
/
"layers"
/
"attention"
/
"mla_attention.py"
)
def
parse_registry
()
->
dict
[
str
,
str
]:
def
find_class_in_ast
(
tree
:
ast
.
AST
,
class_name
:
str
)
->
ast
.
ClassDef
|
None
:
"""Parse the registry.py file to get backend names and their class paths."""
"""Find a class definition in an AST."""
tree
=
ast
.
parse
(
REGISTRY_FILE
.
read_text
())
for
node
in
ast
.
walk
(
tree
):
for
node
in
ast
.
walk
(
tree
):
if
isinstance
(
node
,
ast
.
ClassDef
)
and
node
.
name
==
"AttentionBackendEnum"
:
if
isinstance
(
node
,
ast
.
ClassDef
)
and
node
.
name
==
class_name
:
return
_extract_enum_values
(
node
)
return
node
return
{}
return
None
def
_extract_enum_values
(
node
:
ast
.
ClassDef
)
->
dict
[
str
,
str
]:
def
find_method
(
node
:
ast
.
ClassDef
,
method_name
:
str
)
->
ast
.
FunctionDef
|
None
:
"""Extract enum name -> value mapping from a class definition."""
"""Find a method in a class definition."""
result
:
dict
[
str
,
str
]
=
{}
for
item
in
node
.
body
:
for
item
in
node
.
body
:
if
not
isinstance
(
item
,
ast
.
Assign
):
if
isinstance
(
item
,
ast
.
FunctionDef
)
and
item
.
name
==
method_name
:
continue
return
item
for
target
in
item
.
targets
:
return
None
if
not
isinstance
(
target
,
ast
.
Name
):
continue
if
isinstance
(
item
.
value
,
ast
.
Constant
)
and
item
.
value
.
value
:
result
[
target
.
id
]
=
item
.
value
.
value
return
result
def
get_file_from_class_path
(
class_path
:
str
)
->
Path
|
None
:
"""Convert a class path to a file path."""
if
not
class_path
:
return
None
module_path
=
class_path
.
rsplit
(
"."
,
1
)[
0
].
replace
(
"."
,
"/"
)
py_file
=
REPO_ROOT
/
f
"
{
module_path
}
.py"
return
py_file
if
py_file
.
exists
()
else
None
def
parse_flash_attn_features
()
->
dict
[
str
,
dict
[
str
,
Any
]]:
"""Parse fa_utils.py to detect FA2 vs FA3 feature differences.
Returns a dict with 'fa2' and 'fa3' keys containing their respective
feature overrides for compute capability, KV cache dtypes, and sink support.
"""
if
not
FA_UTILS_FILE
.
exists
():
return
{}
try
:
def
method_returns_true
(
method
:
ast
.
FunctionDef
|
None
)
->
bool
:
tree
=
ast
.
parse
(
FA_UTILS_FILE
.
read_text
())
"""Check if a method simply returns True."""
except
Exception
:
if
method
is
None
:
return
{}
return
False
for
node
in
ast
.
walk
(
method
):
if
(
isinstance
(
node
,
ast
.
Return
)
and
isinstance
(
node
.
value
,
ast
.
Constant
)
and
node
.
value
.
value
is
True
):
return
True
return
False
# Analyze the functions to determine FA3-specific features
fa3_supports_fp8
=
False
fa3_supports_sinks
=
False
fa3_compute_cap
:
str
|
None
=
None
for
node
in
ast
.
walk
(
tree
)
:
def
check_method_overrides
(
node
:
ast
.
ClassDef
,
method_name
:
str
)
->
bool
:
if
not
isinstance
(
node
,
ast
.
FunctionDef
):
"""Check if a method is overridden and returns True."""
continue
return
method_returns_true
(
find_method
(
node
,
method_name
))
# Check flash_attn_supports_fp8 - looks for `get_flash_attn_version() == 3`
if
node
.
name
==
"flash_attn_supports_fp8"
:
for
n
in
ast
.
walk
(
node
):
if
(
isinstance
(
n
,
ast
.
Compare
)
and
isinstance
(
n
.
left
,
ast
.
Call
)
and
isinstance
(
n
.
left
.
func
,
ast
.
Name
)
and
n
.
left
.
func
.
id
==
"get_flash_attn_version"
):
fa3_supports_fp8
=
True
break
# Check flash_attn_supports_sinks - looks for `get_flash_attn_version() == 3`
def
_find_bool_class_var
(
class_node
:
ast
.
ClassDef
,
var_name
:
str
)
->
bool
|
None
:
if
node
.
name
==
"flash_attn_supports_sinks"
:
"""Find a bool class variable in a class definition. Returns None if not found."""
for
n
in
ast
.
walk
(
node
):
for
item
in
class_node
.
body
:
# Check for annotated assignment: attr: bool = True/False
if
(
isinstance
(
item
,
ast
.
AnnAssign
)
and
isinstance
(
item
.
target
,
ast
.
Name
)
and
item
.
target
.
id
==
var_name
and
isinstance
(
item
.
value
,
ast
.
Constant
)
and
isinstance
(
item
.
value
.
value
,
bool
)
):
return
item
.
value
.
value
# Check for plain assignment: attr = True/False
if
isinstance
(
item
,
ast
.
Assign
):
for
target
in
item
.
targets
:
if
(
if
(
isinstance
(
n
,
ast
.
Compar
e
)
isinstance
(
target
,
ast
.
Nam
e
)
and
isinstance
(
n
.
left
,
ast
.
Call
)
and
target
.
id
==
var_name
and
isinstance
(
n
.
left
.
func
,
ast
.
Name
)
and
isinstance
(
item
.
value
,
ast
.
Constant
)
and
n
.
left
.
func
.
id
==
"get_flash_attn_version"
and
isinstance
(
item
.
value
.
value
,
bool
)
):
):
fa3_supports_sinks
=
True
return
item
.
value
.
value
break
return
None
# Check get_flash_attn_version for FA3 compute capability
# Look for the ternary: 3 if (device_capability.major == 9 ...) else 2
if
node
.
name
==
"get_flash_attn_version"
:
for
n
in
ast
.
walk
(
node
):
# Look for IfExp (ternary) with `device_capability.major == 9`
if
isinstance
(
n
,
ast
.
IfExp
):
test
=
n
.
test
# Check if test is a BoolOp (and) containing the major check
if
isinstance
(
test
,
ast
.
BoolOp
):
for
val
in
test
.
values
:
if
(
isinstance
(
val
,
ast
.
Compare
)
and
isinstance
(
val
.
left
,
ast
.
Attribute
)
and
val
.
left
.
attr
==
"major"
and
val
.
comparators
and
isinstance
(
val
.
comparators
[
0
],
ast
.
Constant
)
):
fa3_compute_cap
=
f
"
{
val
.
comparators
[
0
].
value
}
.x"
break
return
{
"fa2"
:
{
"supports_fp8"
:
False
,
"supports_sink"
:
False
,
},
"fa3"
:
{
"compute_capability"
:
fa3_compute_cap
,
"supports_fp8"
:
fa3_supports_fp8
,
"supports_sink"
:
fa3_supports_sinks
,
},
}
def
parse_flashinfer_trtllm_features
()
->
dict
[
str
,
dict
[
str
,
Any
]]:
"""Parse flashinfer.py to detect TRTLLM-specific features.
FLASHINFER uses TRTLLM attention on SM100 (Blackwell), which has different
capabilities (e.g., sink support) than native FlashInfer on earlier GPUs.
"""
if
not
FLASHINFER_UTILS_FILE
.
exists
():
return
{}
try
:
tree
=
ast
.
parse
(
FLASHINFER_UTILS_FILE
.
read_text
())
except
Exception
:
return
{}
trtllm_compute_cap
:
str
|
None
=
None
for
node
in
ast
.
walk
(
tree
):
def
_parse_list_class_var
(
node
:
ast
.
ClassDef
,
var_name
:
str
)
->
list
[
str
]
|
None
:
if
not
isinstance
(
node
,
ast
.
FunctionDef
):
"""Parse a list-type class variable, returning None if not found."""
for
item
in
node
.
body
:
if
not
isinstance
(
item
,
ast
.
AnnAssign
):
continue
continue
if
not
isinstance
(
item
.
target
,
ast
.
Name
):
continue
if
item
.
target
.
id
!=
var_name
:
continue
if
not
(
item
.
value
and
isinstance
(
item
.
value
,
ast
.
List
)):
continue
result
=
[]
for
elt
in
item
.
value
.
elts
:
if
isinstance
(
elt
,
ast
.
Attribute
):
result
.
append
(
elt
.
attr
)
elif
isinstance
(
elt
,
ast
.
Constant
):
result
.
append
(
str
(
elt
.
value
))
return
result
return
None
# Parse supports_trtllm_attention for compute capability
# Look for: current_platform.is_device_capability_family(100)
if
node
.
name
==
"supports_trtllm_attention"
:
for
n
in
ast
.
walk
(
node
):
if
(
isinstance
(
n
,
ast
.
Call
)
and
isinstance
(
n
.
func
,
ast
.
Attribute
)
and
n
.
func
.
attr
==
"is_device_capability_family"
and
n
.
args
and
isinstance
(
n
.
args
[
0
],
ast
.
Constant
)
and
isinstance
(
n
.
args
[
0
].
value
,
int
)
):
cap
=
n
.
args
[
0
].
value
# Convert 100 -> "10.x"
trtllm_compute_cap
=
f
"
{
cap
//
10
}
.x"
break
if
not
trtllm_compute_cap
:
return
{}
return
{
"native"
:
{
# Native FlashInfer: everything except SM100
"supports_sink"
:
False
,
},
"trtllm"
:
{
# TRTLLM pathway on Blackwell
"compute_capability"
:
trtllm_compute_cap
,
"supports_sink"
:
True
,
},
}
def
_parse_return_list
(
method
:
ast
.
FunctionDef
|
None
,
handle_multiple_of
:
bool
=
False
)
->
list
[
str
]:
"""Extract list items from a method's return statement."""
if
method
is
None
:
return
[]
for
stmt
in
ast
.
walk
(
method
):
if
not
isinstance
(
stmt
,
ast
.
Return
):
continue
if
not
isinstance
(
stmt
.
value
,
ast
.
List
):
continue
sizes
=
[]
for
elt
in
stmt
.
value
.
elts
:
if
isinstance
(
elt
,
ast
.
Constant
):
sizes
.
append
(
str
(
elt
.
value
))
elif
(
handle_multiple_of
and
isinstance
(
elt
,
ast
.
Call
)
and
isinstance
(
elt
.
func
,
ast
.
Name
)
and
elt
.
func
.
id
==
"MultipleOf"
and
elt
.
args
and
isinstance
(
elt
.
args
[
0
],
ast
.
Constant
)
):
sizes
.
append
(
f
"%
{
elt
.
args
[
0
].
value
}
"
)
if
sizes
:
return
sizes
return
[]
def
parse_mla_prefill_backends
()
->
list
[
dict
[
str
,
Any
]]:
"""Parse MLA prefill backend options from mla_attention.py.
MLA uses different backends for prefill vs decode. The decode backends are
def
_get_parent_class_name
(
class_node
:
ast
.
ClassDef
)
->
str
|
None
:
registered in the registry, but prefill backends are selected at runtime
"""Get the first parent class name (simple name only).
based on conditions in MLACommonImpl.__init__.
Returns a list of prefill backend info dicts with their requirements.
Handles both simple inheritance (class Foo(Bar)) and generic
inheritance (class Foo(Bar[T])).
"""
"""
if
not
MLA_ATTENTION_FILE
.
exists
():
if
not
class_node
.
bases
:
return
[]
return
None
base
=
class_node
.
bases
[
0
]
if
isinstance
(
base
,
ast
.
Name
):
return
base
.
id
if
isinstance
(
base
,
ast
.
Subscript
)
and
isinstance
(
base
.
value
,
ast
.
Name
):
return
base
.
value
.
id
return
None
try
:
tree
=
ast
.
parse
(
MLA_ATTENTION_FILE
.
read_text
())
except
Exception
:
return
[]
# Find compute capability requirements by parsing use_* functions
def
_resolve_import_to_file
(
flashinfer_cc
:
str
|
None
=
None
tree
:
ast
.
AST
,
class_name
:
str
,
source_file
:
Path
|
None
=
None
cudnn_cc
:
str
|
None
=
None
)
->
Path
|
None
:
trtllm_cc
:
str
|
None
=
None
"""Try to resolve a class name to its source file via imports in the AST.
Handles both absolute imports (from vllm.foo import Bar) and relative
imports (from .foo import Bar) when source_file is provided.
"""
for
node
in
ast
.
walk
(
tree
):
for
node
in
ast
.
walk
(
tree
):
if
not
isinstance
(
node
,
ast
.
FunctionDef
):
if
not
isinstance
(
node
,
ast
.
ImportFrom
):
continue
continue
for
alias
in
node
.
names
:
actual_name
=
alias
.
asname
or
alias
.
name
if
actual_name
!=
class_name
:
continue
if
not
node
.
module
:
continue
# Parse use_flashinfer_prefill for compute capability (SM100)
if
node
.
level
and
node
.
level
>
0
and
source_file
:
if
node
.
name
==
"use_flashinfer_prefill"
:
# Relative import: resolve from the source file's directory
for
n
in
ast
.
walk
(
node
):
base_dir
=
source_file
.
parent
if
(
for
_
in
range
(
node
.
level
-
1
):
isinstance
(
n
,
ast
.
Call
)
base_dir
=
base_dir
.
parent
and
isinstance
(
n
.
func
,
ast
.
Attribute
)
module_path
=
node
.
module
.
replace
(
"."
,
"/"
)
and
n
.
func
.
attr
==
"is_device_capability_family"
py_file
=
base_dir
/
f
"
{
module_path
}
.py"
and
n
.
args
else
:
and
isinstance
(
n
.
args
[
0
],
ast
.
Constant
)
# Absolute import
and
isinstance
(
n
.
args
[
0
].
value
,
int
)
module_path
=
node
.
module
.
replace
(
"."
,
"/"
)
):
py_file
=
REPO_ROOT
/
f
"
{
module_path
}
.py"
flashinfer_cc
=
f
"
{
n
.
args
[
0
].
value
//
10
}
.x"
# Parse use_cudnn_prefill for compute capability (SM100)
if
node
.
name
==
"use_cudnn_prefill"
:
for
n
in
ast
.
walk
(
node
):
if
(
isinstance
(
n
,
ast
.
Call
)
and
isinstance
(
n
.
func
,
ast
.
Attribute
)
and
n
.
func
.
attr
==
"is_device_capability_family"
and
n
.
args
and
isinstance
(
n
.
args
[
0
],
ast
.
Constant
)
and
isinstance
(
n
.
args
[
0
].
value
,
int
)
):
cudnn_cc
=
f
"
{
n
.
args
[
0
].
value
//
10
}
.x"
# Parse use_trtllm_ragged_deepseek_prefill for compute capability
if
node
.
name
==
"use_trtllm_ragged_deepseek_prefill"
:
for
n
in
ast
.
walk
(
node
):
if
(
isinstance
(
n
,
ast
.
Call
)
and
isinstance
(
n
.
func
,
ast
.
Attribute
)
and
n
.
func
.
attr
==
"is_device_capability_family"
and
n
.
args
and
isinstance
(
n
.
args
[
0
],
ast
.
Constant
)
and
isinstance
(
n
.
args
[
0
].
value
,
int
)
):
trtllm_cc
=
f
"
{
n
.
args
[
0
].
value
//
10
}
.x"
# Build prefill backend list based on what we found
# Order matches the priority in MLACommonImpl.__init__
prefill_backends
:
list
[
dict
[
str
,
Any
]]
=
[]
# TRT-LLM Ragged (highest priority if available)
if
trtllm_cc
:
prefill_backends
.
append
(
{
"name"
:
"TRT-LLM Ragged‡"
,
"description"
:
"TensorRT-LLM ragged attention"
,
"compute_capability"
:
trtllm_cc
,
"enable"
:
"Default on SM100"
,
"disable"
:
"`-ac.use_trtllm_ragged_deepseek_prefill=0`"
,
"notes"
:
"DeepSeek R1 dims only"
,
}
)
# FlashInfer prefill
if
flashinfer_cc
:
prefill_backends
.
append
(
{
"name"
:
"FlashInfer"
,
"description"
:
"FlashInfer CUTLASS backend"
,
"compute_capability"
:
flashinfer_cc
,
"enable"
:
"`-ac.disable_flashinfer_prefill=0`"
,
"disable"
:
"`-ac.disable_flashinfer_prefill=1`"
,
"notes"
:
"DeepSeek R1 dims only"
,
}
)
# cuDNN prefill
if
cudnn_cc
:
prefill_backends
.
append
(
{
"name"
:
"cuDNN"
,
"description"
:
"cuDNN-based attention"
,
"compute_capability"
:
cudnn_cc
,
"enable"
:
"`-ac.use_cudnn_prefill=1`"
,
"disable"
:
"`-ac.use_cudnn_prefill=0`"
,
"notes"
:
""
,
}
)
# FlashAttention is always available as fallback
if
py_file
.
exists
():
prefill_backends
.
append
(
return
py_file
{
return
None
"name"
:
"FlashAttention"
,
"description"
:
"FlashAttention varlen (FA2/FA3)"
,
"compute_capability"
:
"Any"
,
"enable"
:
"Default fallback"
,
"disable"
:
"Use other backends"
,
"notes"
:
"FA3 on SM90, FA2 otherwise"
,
}
)
return
prefill_backends
def
_find_cc_in_function
(
tree
:
ast
.
AST
,
func_name
:
str
)
->
str
|
None
:
"""Find a compute capability from is_device_capability_family() calls in a function.
def
find_class_in_ast
(
tree
:
ast
.
AST
,
class_name
:
str
)
->
ast
.
ClassDef
|
None
:
Looks for the pattern: current_platform.is_device_capability_family(N)
"""Find a class definition in an AST."""
and converts N (e.g. 100) to a CC string (e.g. "10.x").
"""
for
node
in
ast
.
walk
(
tree
):
for
node
in
ast
.
walk
(
tree
):
if
isinstance
(
node
,
ast
.
ClassDef
)
and
node
.
name
==
class_name
:
if
not
isinstance
(
node
,
ast
.
FunctionDef
)
or
node
.
name
!=
func_name
:
return
node
continue
for
n
in
ast
.
walk
(
node
):
if
(
isinstance
(
n
,
ast
.
Call
)
and
isinstance
(
n
.
func
,
ast
.
Attribute
)
and
n
.
func
.
attr
==
"is_device_capability_family"
and
n
.
args
and
isinstance
(
n
.
args
[
0
],
ast
.
Constant
)
and
isinstance
(
n
.
args
[
0
].
value
,
int
)
):
return
f
"
{
n
.
args
[
0
].
value
//
10
}
.x"
return
None
return
None
def
find_method
(
node
:
ast
.
ClassDef
,
method_name
:
str
)
->
ast
.
FunctionDef
|
None
:
# ---------------------------------------------------------------------------
"""Find a method in a class definition."""
# Registry and file resolution
for
item
in
node
.
body
:
# ---------------------------------------------------------------------------
if
isinstance
(
item
,
ast
.
FunctionDef
)
and
item
.
name
==
method_name
:
return
item
return
None
def
method_returns_true
(
method
:
ast
.
FunctionDef
|
None
)
->
bool
:
def
parse_registry
()
->
dict
[
str
,
str
]:
"""Check if a method simply returns True."""
"""Parse the registry.py file to get backend names and their class paths."""
if
method
is
None
:
tree
=
ast
.
parse
(
REGISTRY_FILE
.
read_text
())
return
False
for
node
in
ast
.
walk
(
tree
):
for
node
in
ast
.
walk
(
method
):
if
isinstance
(
node
,
ast
.
ClassDef
)
and
node
.
name
==
"AttentionBackendEnum"
:
if
not
isinstance
(
node
,
ast
.
Return
):
return
_extract_enum_values
(
node
)
continue
return
{}
if
isinstance
(
node
.
value
,
ast
.
Constant
)
and
node
.
value
.
value
is
True
:
return
True
return
False
def
_parse_list_class_var
(
node
:
ast
.
ClassDef
,
var_name
:
str
)
->
list
[
str
]
|
None
:
def
_extract_enum_values
(
node
:
ast
.
ClassDef
)
->
dict
[
str
,
str
]:
"""Parse a list-type class variable, returning None if not found."""
"""Extract enum name -> value mapping from a class definition."""
result
:
dict
[
str
,
str
]
=
{}
for
item
in
node
.
body
:
for
item
in
node
.
body
:
if
not
isinstance
(
item
,
ast
.
AnnAssign
):
if
not
isinstance
(
item
,
ast
.
Assign
):
continue
if
not
isinstance
(
item
.
target
,
ast
.
Name
):
continue
if
item
.
target
.
id
!=
var_name
:
continue
if
not
(
item
.
value
and
isinstance
(
item
.
value
,
ast
.
List
)):
continue
continue
result
=
[]
for
target
in
item
.
targets
:
for
elt
in
item
.
value
.
elts
:
if
not
isinstance
(
target
,
ast
.
Name
):
if
isinstance
(
elt
,
ast
.
Attribute
):
continue
result
.
append
(
elt
.
attr
)
if
isinstance
(
item
.
value
,
ast
.
Constant
)
and
item
.
value
.
value
:
elif
isinstance
(
elt
,
ast
.
Constant
):
result
[
target
.
id
]
=
item
.
value
.
value
result
.
append
(
str
(
elt
.
value
))
return
result
return
result
return
None
def
get_file_from_class_path
(
class_path
:
str
)
->
Path
|
None
:
"""Convert a class path to a file path."""
if
not
class_path
:
return
None
module_path
=
class_path
.
rsplit
(
"."
,
1
)[
0
].
replace
(
"."
,
"/"
)
py_file
=
REPO_ROOT
/
f
"
{
module_path
}
.py"
return
py_file
if
py_file
.
exists
()
else
None
# ---------------------------------------------------------------------------
# Backend feature extraction from AST
# ---------------------------------------------------------------------------
def
parse_supported_dtypes
(
node
:
ast
.
ClassDef
)
->
str
:
def
parse_supported_dtypes
(
node
:
ast
.
ClassDef
)
->
str
:
...
@@ -432,35 +332,6 @@ def parse_kv_cache_dtypes(node: ast.ClassDef) -> str:
...
@@ -432,35 +332,6 @@ def parse_kv_cache_dtypes(node: ast.ClassDef) -> str:
return
"auto"
return
"auto"
def
_parse_return_list
(
method
:
ast
.
FunctionDef
|
None
,
handle_multiple_of
:
bool
=
False
)
->
list
[
str
]:
"""Extract list items from a method's return statement."""
if
method
is
None
:
return
[]
for
stmt
in
ast
.
walk
(
method
):
if
not
isinstance
(
stmt
,
ast
.
Return
):
continue
if
not
isinstance
(
stmt
.
value
,
ast
.
List
):
continue
sizes
=
[]
for
elt
in
stmt
.
value
.
elts
:
if
isinstance
(
elt
,
ast
.
Constant
):
sizes
.
append
(
str
(
elt
.
value
))
elif
(
handle_multiple_of
and
isinstance
(
elt
,
ast
.
Call
)
and
isinstance
(
elt
.
func
,
ast
.
Name
)
and
elt
.
func
.
id
==
"MultipleOf"
and
elt
.
args
and
isinstance
(
elt
.
args
[
0
],
ast
.
Constant
)
):
sizes
.
append
(
f
"%
{
elt
.
args
[
0
].
value
}
"
)
if
sizes
:
return
sizes
return
[]
def
parse_block_sizes
(
node
:
ast
.
ClassDef
)
->
str
:
def
parse_block_sizes
(
node
:
ast
.
ClassDef
)
->
str
:
"""Parse get_supported_kernel_block_sizes method."""
"""Parse get_supported_kernel_block_sizes method."""
method
=
find_method
(
node
,
"get_supported_kernel_block_sizes"
)
method
=
find_method
(
node
,
"get_supported_kernel_block_sizes"
)
...
@@ -536,202 +407,444 @@ def parse_compute_capability(node: ast.ClassDef) -> str:
...
@@ -536,202 +407,444 @@ def parse_compute_capability(node: ast.ClassDef) -> str:
return
f
"
{
min_cap
[
0
]
}
.x-
{
max_cap
[
0
]
}
.x"
return
f
"
{
min_cap
[
0
]
}
.x-
{
max_cap
[
0
]
}
.x"
return
f
"≥
{
min_cap
[
0
]
}
.
{
min_cap
[
1
]
}
"
return
f
"≥
{
min_cap
[
0
]
}
.
{
min_cap
[
1
]
}
"
return
"Any"
return
"Any"
def
parse_attention_types
(
node
:
ast
.
ClassDef
)
->
str
:
"""Parse supports_attn_type method."""
method
=
find_method
(
node
,
"supports_attn_type"
)
if
method
is
None
:
return
"Decoder"
type_map
=
{
"DECODER"
:
"Decoder"
,
"ENCODER"
:
"Encoder"
,
"ENCODER_ONLY"
:
"Encoder Only"
,
"ENCODER_DECODER"
:
"Enc-Dec"
,
}
types
:
set
[
str
]
=
set
()
for
n
in
ast
.
walk
(
method
):
# Handle `attn_type in (AttentionType.DECODER, ...)`
if
not
(
isinstance
(
n
,
ast
.
Compare
)
and
len
(
n
.
ops
)
==
1
and
isinstance
(
n
.
ops
[
0
],
ast
.
In
)
and
len
(
n
.
comparators
)
==
1
and
isinstance
(
n
.
comparators
[
0
],
ast
.
Tuple
|
ast
.
Set
)
):
continue
for
elt
in
n
.
comparators
[
0
].
elts
:
if
isinstance
(
elt
,
ast
.
Attribute
)
and
elt
.
attr
in
type_map
:
types
.
add
(
type_map
[
elt
.
attr
])
if
not
types
:
return
"Decoder"
return
"All"
if
len
(
types
)
>=
3
else
", "
.
join
(
sorted
(
types
))
def
parse_impl_bool_attr
(
tree
:
ast
.
AST
,
class_name
:
str
,
attr_name
:
str
,
default
:
bool
=
False
,
source_file
:
Path
|
None
=
None
,
_visited
:
set
[
str
]
|
None
=
None
,
)
->
bool
:
"""Parse a boolean class attribute from an impl class, following inheritance.
Walks up the inheritance chain within the same file and across files
(by resolving imports) to find the attribute value.
"""
if
_visited
is
None
:
_visited
=
set
()
if
class_name
in
_visited
:
return
default
_visited
.
add
(
class_name
)
class_node
=
find_class_in_ast
(
tree
,
class_name
)
if
class_node
is
None
:
return
default
# Check directly on this class
value
=
_find_bool_class_var
(
class_node
,
attr_name
)
if
value
is
not
None
:
return
value
# Check parent class
parent_name
=
_get_parent_class_name
(
class_node
)
if
parent_name
:
# Try parent in same file first
parent_node
=
find_class_in_ast
(
tree
,
parent_name
)
if
parent_node
is
not
None
:
return
parse_impl_bool_attr
(
tree
,
parent_name
,
attr_name
,
default
,
source_file
,
_visited
)
# Try resolving cross-file import
parent_file
=
_resolve_import_to_file
(
tree
,
parent_name
,
source_file
)
if
parent_file
:
try
:
parent_tree
=
ast
.
parse
(
parent_file
.
read_text
())
return
parse_impl_bool_attr
(
parent_tree
,
parent_name
,
attr_name
,
default
,
parent_file
,
_visited
,
)
except
Exception
:
pass
return
default
def
analyze_backend
(
backend_name
:
str
,
class_path
:
str
)
->
dict
[
str
,
Any
]
|
None
:
"""Analyze a backend class and extract feature information."""
file_path
=
get_file_from_class_path
(
class_path
)
if
file_path
is
None
:
return
None
try
:
tree
=
ast
.
parse
(
file_path
.
read_text
())
except
Exception
as
e
:
print
(
f
" Warning: Could not parse
{
file_path
}
:
{
e
}
"
,
file
=
sys
.
stderr
)
return
None
class_name
=
class_path
.
rsplit
(
"."
,
1
)[
1
]
class_node
=
find_class_in_ast
(
tree
,
class_name
)
if
class_node
is
None
:
return
None
# Check if this is an MLA backend by parent class or naming
parent
=
_get_parent_class_name
(
class_node
)
mla_parents
=
{
"MLACommonBackend"
,
"FlashMLABackend"
,
"FlashMLASparseBackend"
}
is_mla_backend
=
(
parent
in
mla_parents
or
".mla."
in
class_path
.
lower
()
or
"_mla"
in
backend_name
.
lower
()
)
# Determine compute capability - use N/A for non-CUDA backends
is_non_cuda
=
backend_name
.
startswith
((
"CPU_"
,
"ROCM_"
))
compute_cap
=
"N/A"
if
is_non_cuda
else
parse_compute_capability
(
class_node
)
# Parse impl class features (DCP support)
impl_method
=
find_method
(
class_node
,
"get_impl_cls"
)
impl_class_name
=
None
if
impl_method
:
for
stmt
in
ast
.
walk
(
impl_method
):
if
isinstance
(
stmt
,
ast
.
Return
)
and
isinstance
(
stmt
.
value
,
ast
.
Name
):
impl_class_name
=
stmt
.
value
.
id
break
supports_dcp
=
False
if
impl_class_name
:
supports_dcp
=
parse_impl_bool_attr
(
tree
,
impl_class_name
,
"can_return_lse_for_decode"
,
False
,
file_path
)
return
{
"name"
:
backend_name
,
"dtypes"
:
parse_supported_dtypes
(
class_node
),
"kv_cache_dtypes"
:
parse_kv_cache_dtypes
(
class_node
),
"block_sizes"
:
parse_block_sizes
(
class_node
),
"head_sizes"
:
parse_head_sizes
(
class_node
),
"attn_types"
:
parse_attention_types
(
class_node
),
"compute_capability"
:
compute_cap
,
"is_mla"
:
is_mla_backend
or
check_method_overrides
(
class_node
,
"is_mla"
),
"supports_sink"
:
check_method_overrides
(
class_node
,
"supports_sink"
),
"is_sparse"
:
check_method_overrides
(
class_node
,
"is_sparse"
),
"supports_mm_prefix"
:
check_method_overrides
(
class_node
,
"supports_mm_prefix"
),
"supports_dcp"
:
supports_dcp
,
}
# ---------------------------------------------------------------------------
# Special backend variant parsers (FA2/FA3, FlashInfer TRTLLM, MLA prefill)
# ---------------------------------------------------------------------------
def
parse_flash_attn_features
()
->
dict
[
str
,
dict
[
str
,
Any
]]:
"""Parse fa_utils.py to detect FA2 vs FA3 feature differences.
Returns a dict with 'fa2' and 'fa3' keys containing their respective
feature overrides for compute capability, KV cache dtypes, and sink support.
"""
if
not
FA_UTILS_FILE
.
exists
():
return
{}
try
:
tree
=
ast
.
parse
(
FA_UTILS_FILE
.
read_text
())
except
Exception
:
return
{}
# Analyze the functions to determine FA3-specific features
fa3_supports_fp8
=
False
fa3_supports_sinks
=
False
fa3_compute_cap
:
str
|
None
=
None
for
node
in
ast
.
walk
(
tree
):
if
not
isinstance
(
node
,
ast
.
FunctionDef
):
continue
# Check flash_attn_supports_fp8 - looks for `get_flash_attn_version() == 3`
if
node
.
name
==
"flash_attn_supports_fp8"
:
for
n
in
ast
.
walk
(
node
):
if
(
isinstance
(
n
,
ast
.
Compare
)
and
isinstance
(
n
.
left
,
ast
.
Call
)
and
isinstance
(
n
.
left
.
func
,
ast
.
Name
)
and
n
.
left
.
func
.
id
==
"get_flash_attn_version"
):
fa3_supports_fp8
=
True
break
# Check flash_attn_supports_sinks - looks for `get_flash_attn_version() == 3`
if
node
.
name
==
"flash_attn_supports_sinks"
:
for
n
in
ast
.
walk
(
node
):
if
(
isinstance
(
n
,
ast
.
Compare
)
and
isinstance
(
n
.
left
,
ast
.
Call
)
and
isinstance
(
n
.
left
.
func
,
ast
.
Name
)
and
n
.
left
.
func
.
id
==
"get_flash_attn_version"
):
fa3_supports_sinks
=
True
break
# Check get_flash_attn_version for FA3 compute capability
# Look for the ternary: 3 if (device_capability.major == 9 ...) else 2
if
node
.
name
==
"get_flash_attn_version"
:
for
n
in
ast
.
walk
(
node
):
# Look for IfExp (ternary) with `device_capability.major == 9`
if
isinstance
(
n
,
ast
.
IfExp
):
test
=
n
.
test
# Check if test is a BoolOp (and) containing the major check
if
isinstance
(
test
,
ast
.
BoolOp
):
for
val
in
test
.
values
:
if
(
isinstance
(
val
,
ast
.
Compare
)
and
isinstance
(
val
.
left
,
ast
.
Attribute
)
and
val
.
left
.
attr
==
"major"
and
val
.
comparators
and
isinstance
(
val
.
comparators
[
0
],
ast
.
Constant
)
):
fa3_compute_cap
=
f
"
{
val
.
comparators
[
0
].
value
}
.x"
break
return
{
"fa2"
:
{
"supports_fp8"
:
False
,
"supports_sink"
:
False
,
},
"fa3"
:
{
"compute_capability"
:
fa3_compute_cap
,
"supports_fp8"
:
fa3_supports_fp8
,
"supports_sink"
:
fa3_supports_sinks
,
},
}
def
parse_flashinfer_trtllm_features
()
->
dict
[
str
,
dict
[
str
,
Any
]]:
"""Parse flashinfer.py to detect TRTLLM-specific features.
FLASHINFER uses TRTLLM attention on SM100 (Blackwell), which has different
capabilities (e.g., sink support) than native FlashInfer on earlier GPUs.
"""
if
not
FLASHINFER_UTILS_FILE
.
exists
():
return
{}
try
:
tree
=
ast
.
parse
(
FLASHINFER_UTILS_FILE
.
read_text
())
except
Exception
:
return
{}
trtllm_compute_cap
=
_find_cc_in_function
(
tree
,
"supports_trtllm_attention"
)
def
parse_attention_types
(
node
:
ast
.
ClassDef
)
->
str
:
if
not
trtllm_compute_cap
:
"""Parse supports_attn_type method."""
return
{}
method
=
find_method
(
node
,
"supports_attn_type"
)
if
method
is
None
:
return
"Decoder"
type_map
=
{
return
{
"DECODER"
:
"Decoder"
,
"native"
:
{
"ENCODER"
:
"Encoder"
,
# Native FlashInfer: everything except SM100
"ENCODER_ONLY"
:
"Encoder Only"
,
"supports_sink"
:
False
,
"ENCODER_DECODER"
:
"Enc-Dec"
,
},
"trtllm"
:
{
# TRTLLM pathway on Blackwell
"compute_capability"
:
trtllm_compute_cap
,
"supports_sink"
:
True
,
},
}
}
types
:
set
[
str
]
=
set
()
for
n
in
ast
.
walk
(
method
):
# Handle `attn_type in (AttentionType.DECODER, ...)`
if
not
(
isinstance
(
n
,
ast
.
Compare
)
and
len
(
n
.
ops
)
==
1
and
isinstance
(
n
.
ops
[
0
],
ast
.
In
)
and
len
(
n
.
comparators
)
==
1
and
isinstance
(
n
.
comparators
[
0
],
ast
.
Tuple
|
ast
.
Set
)
):
continue
for
elt
in
n
.
comparators
[
0
].
elts
:
def
parse_mla_prefill_backends
()
->
list
[
dict
[
str
,
Any
]]:
if
isinstance
(
elt
,
ast
.
Attribute
)
and
elt
.
attr
in
type_map
:
"""Parse MLA prefill backend options from mla_attention.py.
types
.
add
(
type_map
[
elt
.
attr
])
if
not
types
:
MLA uses different backends for prefill vs decode. The decode backends are
return
"Decoder"
registered in the registry, but prefill backends are selected at runtime
return
"All"
if
len
(
types
)
>=
3
else
", "
.
join
(
sorted
(
types
))
based on conditions in MLACommonImpl.__init__.
Returns a list of prefill backend info dicts with their requirements.
"""
if
not
MLA_ATTENTION_FILE
.
exists
():
return
[]
def
check_method_overrides
(
node
:
ast
.
ClassDef
,
method_name
:
str
)
->
bool
:
try
:
"""Check if a method is overridden and returns True."""
tree
=
ast
.
parse
(
MLA_ATTENTION_FILE
.
read_text
())
method
=
find_method
(
node
,
method_name
)
except
Exception
:
return
method_returns_true
(
method
)
return
[]
# Find compute capability requirements by parsing use_* functions
trtllm_cc
=
_find_cc_in_function
(
tree
,
"use_trtllm_ragged_deepseek_prefill"
)
flashinfer_cc
=
_find_cc_in_function
(
tree
,
"use_flashinfer_prefill"
)
cudnn_cc
=
_find_cc_in_function
(
tree
,
"use_cudnn_prefill"
)
def
analyze_backend
(
backend_name
:
str
,
class_path
:
str
)
->
dict
[
str
,
Any
]
|
None
:
# Build prefill backend list based on what we found
"""Analyze a backend class and extract feature information."""
# Order matches the priority in MLACommonImpl.__init__
file_path
=
get_file_from_class_path
(
class_path
)
prefill_backends
:
list
[
dict
[
str
,
Any
]]
=
[]
if
file_path
is
None
:
return
None
try
:
# TRT-LLM Ragged (highest priority if available)
tree
=
ast
.
parse
(
file_path
.
read_text
())
if
trtllm_cc
:
except
Exception
as
e
:
prefill_backends
.
append
(
print
(
f
" Warning: Could not parse
{
file_path
}
:
{
e
}
"
,
file
=
sys
.
stderr
)
{
return
None
"name"
:
"TRT-LLM Ragged‡"
,
"description"
:
"TensorRT-LLM ragged attention"
,
"compute_capability"
:
trtllm_cc
,
"enable"
:
"Default on SM100"
,
"disable"
:
"`-ac.use_trtllm_ragged_deepseek_prefill=0`"
,
"notes"
:
"DeepSeek R1 dims only"
,
}
)
class_name
=
class_path
.
rsplit
(
"."
,
1
)[
1
]
# FlashInfer prefill
class_node
=
find_class_in_ast
(
tree
,
class_name
)
if
flashinfer_cc
:
if
class_node
is
None
:
prefill_backends
.
append
(
return
None
{
"name"
:
"FlashInfer"
,
"description"
:
"FlashInfer CUTLASS backend"
,
"compute_capability"
:
flashinfer_cc
,
"enable"
:
"`-ac.disable_flashinfer_prefill=0`"
,
"disable"
:
"`-ac.disable_flashinfer_prefill=1`"
,
"notes"
:
"DeepSeek R1 dims only"
,
}
)
# Check if this is an MLA backend by parent class or naming
# cuDNN prefill
parent
=
None
if
cudnn_cc
:
if
class_node
.
bases
:
prefill_backends
.
append
(
base
=
class_node
.
bases
[
0
]
{
parent
=
base
.
id
if
isinstance
(
base
,
ast
.
Name
)
else
None
"name"
:
"cuDNN"
,
mla_parents
=
{
"MLACommonBackend"
,
"FlashMLABackend"
,
"FlashMLASparseBackend"
}
"description"
:
"cuDNN-based attention"
,
is_mla_backend
=
(
"compute_capability"
:
cudnn_cc
,
parent
in
mla_parents
"enable"
:
"`-ac.use_cudnn_prefill=1`"
,
or
".mla."
in
class_path
.
lower
()
"disable"
:
"`-ac.use_cudnn_prefill=0`"
,
or
"_mla"
in
backend_name
.
lower
()
"notes"
:
""
,
}
)
# FlashAttention is always available as fallback
prefill_backends
.
append
(
{
"name"
:
"FlashAttention"
,
"description"
:
"FlashAttention varlen (FA2/FA3)"
,
"compute_capability"
:
"Any"
,
"enable"
:
"Default fallback"
,
"disable"
:
"Use other backends"
,
"notes"
:
"FA3 on SM90, FA2 otherwise"
,
}
)
)
# Determine compute capability - use N/A for non-CUDA backends
return
prefill_backends
is_non_cuda
=
backend_name
.
startswith
((
"CPU_"
,
"ROCM_"
))
compute_cap
=
"N/A"
if
is_non_cuda
else
parse_compute_capability
(
class_node
)
return
{
"name"
:
backend_name
,
"dtypes"
:
parse_supported_dtypes
(
class_node
),
"kv_cache_dtypes"
:
parse_kv_cache_dtypes
(
class_node
),
"block_sizes"
:
parse_block_sizes
(
class_node
),
"head_sizes"
:
parse_head_sizes
(
class_node
),
"attn_types"
:
parse_attention_types
(
class_node
),
"compute_capability"
:
compute_cap
,
"is_mla"
:
is_mla_backend
or
check_method_overrides
(
class_node
,
"is_mla"
),
"supports_sink"
:
check_method_overrides
(
class_node
,
"supports_sink"
),
"is_sparse"
:
check_method_overrides
(
class_node
,
"is_sparse"
),
"supports_mm_prefix"
:
check_method_overrides
(
class_node
,
"supports_mm_prefix"
),
}
# ---------------------------------------------------------------------------
# Backend variant expansion (FA2/FA3, FlashInfer native/TRTLLM)
# ---------------------------------------------------------------------------
def
add_literal_quotes
(
value
:
str
)
->
str
:
"""Add literal backticks around all comma-separated items in a string."""
items
=
[
item
.
strip
()
for
item
in
value
.
split
(
","
)]
quoted_items
=
[
f
"`
{
item
}
`"
for
item
in
items
]
return
", "
.
join
(
quoted_items
)
def
_expand_flash_attn_variants
(
all_backends
:
list
[
dict
[
str
,
Any
]],
fa_features
:
dict
[
str
,
dict
[
str
,
Any
]],
)
->
list
[
dict
[
str
,
Any
]]:
"""Expand FLASH_ATTN into FA2 and FA3 variants with different capabilities."""
expanded
=
[]
for
backend
in
all_backends
:
if
backend
[
"name"
]
!=
"FLASH_ATTN"
:
backend
.
setdefault
(
"_sort_key"
,
backend
[
"name"
])
backend
.
setdefault
(
"_sort_order"
,
0
)
backend
.
setdefault
(
"version"
,
""
)
expanded
.
append
(
backend
)
continue
def
bool_to_emoji
(
value
:
bool
)
->
str
:
# Create FA2 entry (keeps base backend's compute_capability)
"""Convert a boolean to a checkmark or X emoji."""
fa2
=
backend
.
copy
()
return
"✅"
if
value
else
"❌"
fa2
[
"version"
]
=
"FA2*"
fa2
[
"_sort_key"
]
=
"FLASH_ATTN"
fa2
[
"_sort_order"
]
=
0
fa2
[
"supports_sink"
]
=
fa_features
[
"fa2"
][
"supports_sink"
]
# Create FA3 entry (uses parsed compute_capability from fa_utils)
fa3
=
backend
.
copy
()
fa3
[
"version"
]
=
"FA3*"
fa3
[
"_sort_key"
]
=
"FLASH_ATTN"
fa3
[
"_sort_order"
]
=
1
if
fa_features
[
"fa3"
][
"compute_capability"
]:
fa3
[
"compute_capability"
]
=
fa_features
[
"fa3"
][
"compute_capability"
]
fa3
[
"supports_sink"
]
=
fa_features
[
"fa3"
][
"supports_sink"
]
if
fa_features
[
"fa3"
][
"supports_fp8"
]:
base_dtypes
=
backend
[
"kv_cache_dtypes"
].
split
(
", "
)
fp8_dtypes
=
[
"fp8"
,
"fp8_e4m3"
,
"fp8_e5m2"
]
new_dtypes
=
[
d
for
d
in
fp8_dtypes
if
d
not
in
base_dtypes
]
fa3
[
"kv_cache_dtypes"
]
=
", "
.
join
(
base_dtypes
+
new_dtypes
)
expanded
.
append
(
fa2
)
expanded
.
append
(
fa3
)
return
expanded
def
_expand_flashinfer_variants
(
all_backends
:
list
[
dict
[
str
,
Any
]],
fi_features
:
dict
[
str
,
dict
[
str
,
Any
]],
)
->
list
[
dict
[
str
,
Any
]]:
"""Expand FLASHINFER into native and TRTLLM variants."""
expanded
=
[]
for
backend
in
all_backends
:
if
backend
[
"name"
]
!=
"FLASHINFER"
:
expanded
.
append
(
backend
)
continue
# Parse original compute capability to get min CC
orig_cap
=
backend
[
"compute_capability"
]
parts
=
orig_cap
.
replace
(
".x"
,
""
).
split
(
"-"
)
min_cc
=
parts
[
0
]
if
parts
else
"7"
trtllm_cc
=
fi_features
[
"trtllm"
][
"compute_capability"
]
def
generate_markdown_table
(
# Create native entry (pre-Blackwell GPUs)
backends
:
list
[
dict
[
str
,
Any
]],
title
:
str
,
is_mla_table
:
bool
=
False
native
=
backend
.
copy
()
)
->
str
:
native
[
"version"
]
=
"Native†"
"""Generate a markdown table from backend info.
native
[
"_sort_key"
]
=
"FLASHINFER"
native
[
"_sort_order"
]
=
0
native
[
"supports_sink"
]
=
fi_features
[
"native"
][
"supports_sink"
]
native
[
"compute_capability"
]
=
f
"
{
min_cc
}
.x-9.x"
Args:
# Create TRTLLM entry
backends: List of backend info dictionaries.
trtllm
=
backend
.
copy
()
title: Table title.
trtllm
[
"version"
]
=
"TRTLLM†"
is_mla_table: If True, include MLA and Sparse columns (for MLA table).
trtllm
[
"_sort_key"
]
=
"FLASHINFER"
If False, exclude them (for standard attention table).
trtllm
[
"_sort_order"
]
=
1
"""
trtllm
[
"compute_capability"
]
=
trtllm_cc
if
not
backends
:
trtllm
[
"supports_sink"
]
=
fi_features
[
"trtllm"
][
"supports_sink"
]
return
f
"##
{
title
}
\n\n
No backends found.
\n
"
# Check if any backend has a version (for FA2/FA3 split)
expanded
.
append
(
native
)
has_versions
=
any
(
b
.
get
(
"version"
)
for
b
in
backends
)
expanded
.
append
(
trtllm
)
return
expanded
if
is_mla_table
:
header
=
(
"| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes "
"| Sink | Sparse | MM Prefix | Attention Types | Compute Cap. |"
)
separator
=
(
"|---------|--------|-----------|-------------|------------"
"|------|--------|-----------|-----------------|--------------|"
)
elif
has_versions
:
header
=
(
"| Backend | Version | Dtypes | KV Dtypes | Block Sizes "
"| Head Sizes | Sink | MM Prefix | Attention Types | Compute Cap. |"
)
separator
=
(
"|---------|---------|--------|-----------|-------------"
"|------------|------|-----------|-----------------|--------------|"
)
else
:
header
=
(
"| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes "
"| Sink | MM Prefix | Attention Types | Compute Cap. |"
)
separator
=
(
"|---------|--------|-----------|-------------|------------"
"|------|-----------|-----------------|--------------|"
)
lines
=
[
f
"##
{
title
}
"
,
""
,
header
,
separator
]
def
sort_key
(
x
:
dict
[
str
,
Any
])
->
tuple
[
str
,
int
]:
"""Sort key that keeps parent/child rows together in order."""
return
(
x
.
get
(
"_sort_key"
,
x
[
"name"
]),
x
.
get
(
"_sort_order"
,
0
))
for
info
in
sorted
(
backends
,
key
=
sort_key
):
if
is_mla_table
:
row
=
"| `{}` | {} | {} | {} | {} | {} | {} | {} | {} | {} |"
.
format
(
info
[
"name"
],
info
[
"dtypes"
],
add_literal_quotes
(
info
[
"kv_cache_dtypes"
]),
info
[
"block_sizes"
],
info
[
"head_sizes"
],
bool_to_emoji
(
info
[
"supports_sink"
]),
bool_to_emoji
(
info
[
"is_sparse"
]),
bool_to_emoji
(
info
[
"supports_mm_prefix"
]),
info
[
"attn_types"
],
info
[
"compute_capability"
],
)
elif
has_versions
:
row
=
"| `{}` | {} | {} | {} | {} | {} | {} | {} | {} | {} |"
.
format
(
info
[
"name"
],
info
.
get
(
"version"
,
""
),
info
[
"dtypes"
],
add_literal_quotes
(
info
[
"kv_cache_dtypes"
]),
info
[
"block_sizes"
],
info
[
"head_sizes"
],
bool_to_emoji
(
info
[
"supports_sink"
]),
bool_to_emoji
(
info
[
"supports_mm_prefix"
]),
info
[
"attn_types"
],
info
[
"compute_capability"
],
)
else
:
row
=
"| `{}` | {} | {} | {} | {} | {} | {} | {} | {} |"
.
format
(
info
[
"name"
],
info
[
"dtypes"
],
add_literal_quotes
(
info
[
"kv_cache_dtypes"
]),
info
[
"block_sizes"
],
info
[
"head_sizes"
],
bool_to_emoji
(
info
[
"supports_sink"
]),
bool_to_emoji
(
info
[
"supports_mm_prefix"
]),
info
[
"attn_types"
],
info
[
"compute_capability"
],
)
lines
.
append
(
row
)
lines
.
append
(
""
)
# ---------------------------------------------------------------------------
return
"
\n
"
.
join
(
lines
)
# CUDA priority list parsing
# ---------------------------------------------------------------------------
def
parse_cuda_priority_lists
()
->
dict
[
str
,
list
[
str
]]:
def
parse_cuda_priority_lists
()
->
dict
[
str
,
list
[
str
]]:
...
@@ -827,6 +940,105 @@ def _extract_priorities(body: list, priorities: dict[str, list[str]], prefix: st
...
@@ -827,6 +940,105 @@ def _extract_priorities(body: list, priorities: dict[str, list[str]], prefix: st
priorities
[
f
"
{
prefix
}
_default"
]
=
backends
priorities
[
f
"
{
prefix
}
_default"
]
=
backends
# ---------------------------------------------------------------------------
# Data-driven table rendering
#
# Each column is a (header, formatter) pair. The formatter takes a backend
# info dict and returns the cell string. Tables are assembled by selecting
# which columns to include, then calling _render_table().
# ---------------------------------------------------------------------------
# Column type alias for readability
TableColumn
=
tuple
[
str
,
Callable
[[
dict
[
str
,
Any
]],
str
]]
# Shared column definitions -- order here matches the output table order
_COL_BACKEND
:
TableColumn
=
(
"Backend"
,
lambda
b
:
f
"`
{
b
[
'name'
]
}
`"
)
_COL_VERSION
:
TableColumn
=
(
"Version"
,
lambda
b
:
b
.
get
(
"version"
,
""
))
_COL_DTYPES
:
TableColumn
=
(
"Dtypes"
,
lambda
b
:
b
[
"dtypes"
])
_COL_KV_DTYPES
:
TableColumn
=
(
"KV Dtypes"
,
lambda
b
:
add_literal_quotes
(
b
[
"kv_cache_dtypes"
]),
)
_COL_BLOCK_SIZES
:
TableColumn
=
(
"Block Sizes"
,
lambda
b
:
b
[
"block_sizes"
])
_COL_HEAD_SIZES
:
TableColumn
=
(
"Head Sizes"
,
lambda
b
:
b
[
"head_sizes"
])
_COL_SINK
:
TableColumn
=
(
"Sink"
,
lambda
b
:
bool_to_emoji
(
b
[
"supports_sink"
]))
_COL_SPARSE
:
TableColumn
=
(
"Sparse"
,
lambda
b
:
bool_to_emoji
(
b
[
"is_sparse"
]))
_COL_MM_PREFIX
:
TableColumn
=
(
"MM Prefix"
,
lambda
b
:
bool_to_emoji
(
b
[
"supports_mm_prefix"
]),
)
_COL_DCP
:
TableColumn
=
(
"DCP"
,
lambda
b
:
bool_to_emoji
(
b
[
"supports_dcp"
]))
_COL_ATTN_TYPES
:
TableColumn
=
(
"Attention Types"
,
lambda
b
:
b
[
"attn_types"
])
_COL_COMPUTE_CAP
:
TableColumn
=
(
"Compute Cap."
,
lambda
b
:
b
[
"compute_capability"
])
def
add_literal_quotes
(
value
:
str
)
->
str
:
"""Add literal backticks around all comma-separated items in a string."""
items
=
[
item
.
strip
()
for
item
in
value
.
split
(
","
)]
return
", "
.
join
(
f
"`
{
item
}
`"
for
item
in
items
)
def
bool_to_emoji
(
value
:
bool
)
->
str
:
"""Convert a boolean to a checkmark or X emoji."""
return
"✅"
if
value
else
"❌"
def
_build_columns
(
is_mla
:
bool
,
has_versions
:
bool
)
->
list
[
TableColumn
]:
"""Build the column list for a backend feature table.
The column selection depends on whether it's an MLA table (includes
Sparse column) and whether any backend has version variants (includes
Version column).
"""
cols
:
list
[
TableColumn
]
=
[
_COL_BACKEND
]
if
has_versions
:
cols
.
append
(
_COL_VERSION
)
cols
.
extend
([
_COL_DTYPES
,
_COL_KV_DTYPES
,
_COL_BLOCK_SIZES
,
_COL_HEAD_SIZES
])
cols
.
append
(
_COL_SINK
)
if
is_mla
:
cols
.
append
(
_COL_SPARSE
)
cols
.
extend
([
_COL_MM_PREFIX
,
_COL_DCP
,
_COL_ATTN_TYPES
,
_COL_COMPUTE_CAP
])
return
cols
def
_sort_key
(
x
:
dict
[
str
,
Any
])
->
tuple
[
str
,
int
]:
"""Sort key that keeps parent/child rows together in order."""
return
(
x
.
get
(
"_sort_key"
,
x
[
"name"
]),
x
.
get
(
"_sort_order"
,
0
))
def
_render_table
(
columns
:
list
[
TableColumn
],
backends
:
list
[
dict
[
str
,
Any
]],
)
->
list
[
str
]:
"""Render a markdown table from column specs and backend data."""
header
=
"| "
+
" | "
.
join
(
name
for
name
,
_
in
columns
)
+
" |"
sep
=
"|"
+
"|"
.
join
(
"-"
*
(
len
(
name
)
+
2
)
for
name
,
_
in
columns
)
+
"|"
lines
=
[
header
,
sep
]
for
info
in
sorted
(
backends
,
key
=
_sort_key
):
row
=
"| "
+
" | "
.
join
(
fmt
(
info
)
for
_
,
fmt
in
columns
)
+
" |"
lines
.
append
(
row
)
return
lines
def
generate_markdown_table
(
backends
:
list
[
dict
[
str
,
Any
]],
title
:
str
,
is_mla_table
:
bool
=
False
)
->
str
:
"""Generate a titled markdown table from backend info."""
if
not
backends
:
return
f
"##
{
title
}
\n\n
No backends found.
\n
"
has_versions
=
any
(
b
.
get
(
"version"
)
for
b
in
backends
)
columns
=
_build_columns
(
is_mla_table
,
has_versions
)
lines
=
[
f
"##
{
title
}
"
,
""
]
lines
.
extend
(
_render_table
(
columns
,
backends
))
lines
.
append
(
""
)
return
"
\n
"
.
join
(
lines
)
# ---------------------------------------------------------------------------
# Markdown section generators (usage, priority, legend, MLA)
# ---------------------------------------------------------------------------
def
generate_usage_section
()
->
str
:
def
generate_usage_section
()
->
str
:
"""Generate the usage documentation section."""
"""Generate the usage documentation section."""
return
"""## Setting the Attention Backend
return
"""## Setting the Attention Backend
...
@@ -959,6 +1171,27 @@ def generate_priority_section(priorities: dict[str, list[str]]) -> str:
...
@@ -959,6 +1171,27 @@ def generate_priority_section(priorities: dict[str, list[str]]) -> str:
return
"
\n
"
.
join
(
lines
)
return
"
\n
"
.
join
(
lines
)
def
generate_legend
()
->
str
:
"""Generate a legend explaining the table columns."""
return
"""## Legend
| Column | Description |
|--------|-------------|
| **Dtypes** | Supported model data types (fp16, bf16, fp32) |
| **KV Dtypes** | Supported KV cache data types (`auto`, `fp8`, `fp8_e4m3`, etc.) |
| **Block Sizes** | Supported KV cache block sizes (%N means multiples of N) |
| **Head Sizes** | Supported attention head sizes |
| **Sink** | Attention sink support (for StreamingLLM) |
| **Sparse** | Sparse attention support (MLA only) |
| **MM Prefix** | Multimodal prefix full attention support |
| **DCP** | Decode Context Parallelism support (`--decode-context-parallel-size`) |
| **Attention Types** | Supported attention patterns (Decoder, Encoder, Enc-Dec) |
| **Compute Cap.** | Required CUDA compute capability (N/A for non-CUDA backends) |
**Symbols:** ✅ = Supported, ❌ = Not supported
"""
def
generate_mla_section
(
def
generate_mla_section
(
prefill_backends
:
list
[
dict
[
str
,
Any
]],
decode_backends
:
list
[
dict
[
str
,
Any
]]
prefill_backends
:
list
[
dict
[
str
,
Any
]],
decode_backends
:
list
[
dict
[
str
,
Any
]]
)
->
str
:
)
->
str
:
...
@@ -999,57 +1232,17 @@ def generate_mla_section(
...
@@ -999,57 +1232,17 @@ def generate_mla_section(
]
]
)
)
# Generate decode backends table
# Reuse data-driven table rendering for decode backends
header
=
(
columns
=
_build_columns
(
is_mla
=
True
,
has_versions
=
False
)
"| Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes "
lines
.
extend
(
_render_table
(
columns
,
decode_backends
))
"| Sink | Sparse | MM Prefix | Attention Types | Compute Cap. |"
)
separator
=
(
"|---------|--------|-----------|-------------|------------"
"|------|--------|-----------|-----------------|--------------|"
)
lines
.
extend
([
header
,
separator
])
def
sort_key
(
x
:
dict
[
str
,
Any
])
->
tuple
[
str
,
int
]:
return
(
x
.
get
(
"_sort_key"
,
x
[
"name"
]),
x
.
get
(
"_sort_order"
,
0
))
for
info
in
sorted
(
decode_backends
,
key
=
sort_key
):
row
=
"| `{}` | {} | {} | {} | {} | {} | {} | {} | {} | {} |"
.
format
(
info
[
"name"
],
info
[
"dtypes"
],
add_literal_quotes
(
info
[
"kv_cache_dtypes"
]),
info
[
"block_sizes"
],
info
[
"head_sizes"
],
bool_to_emoji
(
info
[
"supports_sink"
]),
bool_to_emoji
(
info
[
"is_sparse"
]),
bool_to_emoji
(
info
[
"supports_mm_prefix"
]),
info
[
"attn_types"
],
info
[
"compute_capability"
],
)
lines
.
append
(
row
)
lines
.
append
(
""
)
lines
.
append
(
""
)
return
"
\n
"
.
join
(
lines
)
return
"
\n
"
.
join
(
lines
)
def
generate_legend
()
->
str
:
# ---------------------------------------------------------------------------
"""Generate a legend explaining the table columns."""
# Top-level orchestration
return
"""## Legend
# ---------------------------------------------------------------------------
| Column | Description |
|--------|-------------|
| **Dtypes** | Supported model data types (fp16, bf16, fp32) |
| **KV Dtypes** | Supported KV cache data types (`auto`, `fp8`, `fp8_e4m3`, etc.) |
| **Block Sizes** | Supported KV cache block sizes (%N means multiples of N) |
| **Head Sizes** | Supported attention head sizes |
| **Sink** | Attention sink support (for StreamingLLM) |
| **Sparse** | Sparse attention support (MLA only) |
| **MM Prefix** | Multimodal prefix full attention support |
| **Attention Types** | Supported attention patterns (Decoder, Encoder, Enc-Dec) |
| **Compute Cap.** | Required CUDA compute capability (N/A for non-CUDA backends) |
**Symbols:** ✅ = Supported, ❌ = Not supported
"""
def
generate_docs
()
->
str
:
def
generate_docs
()
->
str
:
...
@@ -1071,86 +1264,17 @@ def generate_docs() -> str:
...
@@ -1071,86 +1264,17 @@ def generate_docs() -> str:
# Collect backend info
# Collect backend info
all_backends
=
[]
all_backends
=
[]
for
backend_name
,
class_path
in
attention_backends_map
.
items
():
for
backend_name
,
class_path
in
attention_backends_map
.
items
():
if
backend_name
in
(
"CUSTOM"
,
"TORCH_SDPA"
)
:
if
backend_name
in
SKIP_BACKENDS
:
continue
continue
info
=
analyze_backend
(
backend_name
,
class_path
)
info
=
analyze_backend
(
backend_name
,
class_path
)
if
info
:
if
info
:
all_backends
.
append
(
info
)
all_backends
.
append
(
info
)
# Expand
FLASH_ATTN into FA2 and FA3 variants with different capabilitie
s
# Expand
backends into version variant
s
if
fa_features
:
if
fa_features
:
expanded_backends
=
[]
all_backends
=
_expand_flash_attn_variants
(
all_backends
,
fa_features
)
for
backend
in
all_backends
:
if
backend
[
"name"
]
==
"FLASH_ATTN"
:
# Create FA2 entry (keeps base backend's compute_capability)
fa2
=
backend
.
copy
()
fa2
[
"name"
]
=
"FLASH_ATTN"
fa2
[
"version"
]
=
"FA2*"
fa2
[
"_sort_key"
]
=
"FLASH_ATTN"
fa2
[
"_sort_order"
]
=
0
fa2
[
"supports_sink"
]
=
fa_features
[
"fa2"
][
"supports_sink"
]
# Create FA3 entry (uses parsed compute_capability from fa_utils)
fa3
=
backend
.
copy
()
fa3
[
"name"
]
=
"FLASH_ATTN"
fa3
[
"version"
]
=
"FA3*"
fa3
[
"_sort_key"
]
=
"FLASH_ATTN"
fa3
[
"_sort_order"
]
=
1
if
fa_features
[
"fa3"
][
"compute_capability"
]:
fa3
[
"compute_capability"
]
=
fa_features
[
"fa3"
][
"compute_capability"
]
fa3
[
"supports_sink"
]
=
fa_features
[
"fa3"
][
"supports_sink"
]
if
fa_features
[
"fa3"
][
"supports_fp8"
]:
# Add fp8 dtypes to the base backend's kv_cache_dtypes
base_dtypes
=
backend
[
"kv_cache_dtypes"
].
split
(
", "
)
fp8_dtypes
=
[
"fp8"
,
"fp8_e4m3"
,
"fp8_e5m2"
]
new_dtypes
=
[
d
for
d
in
fp8_dtypes
if
d
not
in
base_dtypes
]
fa3
[
"kv_cache_dtypes"
]
=
", "
.
join
(
base_dtypes
+
new_dtypes
)
# Add FA2 first, then FA3
expanded_backends
.
append
(
fa2
)
expanded_backends
.
append
(
fa3
)
else
:
backend
[
"_sort_key"
]
=
backend
[
"name"
]
backend
[
"_sort_order"
]
=
0
backend
[
"version"
]
=
""
# No version for other backends
expanded_backends
.
append
(
backend
)
all_backends
=
expanded_backends
# Expand FLASHINFER into native and TRTLLM variants
if
fi_features
:
if
fi_features
:
expanded_backends
=
[]
all_backends
=
_expand_flashinfer_variants
(
all_backends
,
fi_features
)
for
backend
in
all_backends
:
if
backend
[
"name"
]
==
"FLASHINFER"
:
# Parse original compute capability to get min CC
orig_cap
=
backend
[
"compute_capability"
]
parts
=
orig_cap
.
replace
(
".x"
,
""
).
split
(
"-"
)
min_cc
=
parts
[
0
]
if
parts
else
"7"
trtllm_cc
=
fi_features
[
"trtllm"
][
"compute_capability"
]
# Create native entry (pre-Blackwell GPUs)
native
=
backend
.
copy
()
native
[
"name"
]
=
"FLASHINFER"
native
[
"version"
]
=
"Native†"
native
[
"_sort_key"
]
=
"FLASHINFER"
native
[
"_sort_order"
]
=
0
native
[
"supports_sink"
]
=
fi_features
[
"native"
][
"supports_sink"
]
# Native FlashInfer is used on GPUs before SM100 (Blackwell)
native
[
"compute_capability"
]
=
f
"
{
min_cc
}
.x-9.x"
# Create TRTLLM entry
trtllm
=
backend
.
copy
()
trtllm
[
"name"
]
=
"FLASHINFER"
trtllm
[
"version"
]
=
"TRTLLM†"
trtllm
[
"_sort_key"
]
=
"FLASHINFER"
trtllm
[
"_sort_order"
]
=
1
trtllm
[
"compute_capability"
]
=
trtllm_cc
trtllm
[
"supports_sink"
]
=
fi_features
[
"trtllm"
][
"supports_sink"
]
expanded_backends
.
append
(
native
)
expanded_backends
.
append
(
trtllm
)
else
:
expanded_backends
.
append
(
backend
)
all_backends
=
expanded_backends
# Split into MLA and non-MLA
# Split into MLA and non-MLA
mla_backends
=
[
b
for
b
in
all_backends
if
b
[
"is_mla"
]]
mla_backends
=
[
b
for
b
in
all_backends
if
b
[
"is_mla"
]]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment