Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e8cc53af
Unverified
Commit
e8cc53af
authored
Jul 14, 2025
by
Cyrus Leung
Committed by
GitHub
Jul 14, 2025
Browse files
[Misc] Log the reason for falling back to FlexAttention (#20699)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
a4851cfe
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
104 additions
and
32 deletions
+104
-32
vllm/attention/selector.py
vllm/attention/selector.py
+40
-9
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+35
-22
vllm/reasoning/hunyuan_a13b_reasoning_parser.py
vllm/reasoning/hunyuan_a13b_reasoning_parser.py
+1
-1
vllm/v1/attention/backends/cpu_attn.py
vllm/v1/attention/backends/cpu_attn.py
+4
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+4
-0
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+4
-0
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+4
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+4
-0
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+4
-0
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+4
-0
No files found.
vllm/attention/selector.py
View file @
e8cc53af
...
...
@@ -3,6 +3,7 @@
import
os
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
functools
import
cache
from
typing
import
Generator
,
Optional
,
Union
...
...
@@ -79,32 +80,62 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return
forced_attn_backend
def
supports_head_size
(
@
dataclass
(
frozen
=
True
)
class
_IsSupported
:
can_import
:
bool
head_size
:
bool
dtype
:
bool
def
__bool__
(
self
)
->
bool
:
return
self
.
can_import
and
self
.
head_size
and
self
.
dtype
def
is_attn_backend_supported
(
attn_backend
:
Union
[
str
,
type
[
AttentionBackend
]],
head_size
:
int
,
)
->
bool
:
dtype
:
torch
.
dtype
,
*
,
allow_import_error
:
bool
=
True
,
)
->
_IsSupported
:
if
isinstance
(
attn_backend
,
str
):
try
:
attn_backend
=
resolve_obj_by_qualname
(
attn_backend
)
except
ImportError
:
return
False
if
not
allow_import_error
:
raise
return
_IsSupported
(
can_import
=
False
,
head_size
=
False
,
dtype
=
False
)
assert
isinstance
(
attn_backend
,
type
)
# TODO: Update the interface once V0 is removed
if
get_supported_head_sizes
:
=
getattr
(
attn_backend
,
"get_supported_head_sizes"
,
None
):
return
head_size
in
get_supported_head_sizes
()
if
validate_head_size
:
=
getattr
(
attn_backend
,
"validate_head_size"
,
None
):
is_head_size_supported
=
head_size
in
get_supported_head_sizes
()
elif
validate_head_size
:
=
getattr
(
attn_backend
,
"validate_head_size"
,
None
):
try
:
validate_head_size
(
head_size
)
return
True
is_head_size_supported
=
True
except
Exception
:
return
False
is_head_size_supported
=
False
else
:
raise
NotImplementedError
(
f
"
{
attn_backend
.
__name__
}
does not support "
"head size validation"
)
if
get_supported_dtypes
:
=
getattr
(
attn_backend
,
"get_supported_dtypes"
,
None
):
is_dtype_supported
=
dtype
in
get_supported_dtypes
()
else
:
raise
NotImplementedError
(
f
"
{
attn_backend
.
__name__
}
does not support "
"dtype validation"
)
return
_IsSupported
(
can_import
=
True
,
head_size
=
is_head_size_supported
,
dtype
=
is_dtype_supported
,
)
def
get_attn_backend
(
head_size
:
int
,
...
...
vllm/platforms/cuda.py
View file @
e8cc53af
...
...
@@ -259,45 +259,58 @@ class CudaPlatformBase(Platform):
logger
.
info_once
(
"Using Flash Attention backend on V1 engine."
)
return
FLASH_ATTN_V1
from
vllm.attention.selector
import
supports_head_size
from
vllm.attention.selector
import
is_attn_backend_supported
# Default backends for V1 engine
# FP32 is only supported by FlexAttention
if
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
logger
.
info_once
(
"Using FlexAttention backend for %s on V1 engine."
,
dtype
,
)
return
FLEX_ATTENTION_V1
# Prefer FlashInfer for Blackwell GPUs if installed
if
cls
.
is_device_capability
(
100
)
and
\
supports_head_size
(
FLASHINFER_V1
,
head_size
):
try
:
import
flashinfer
# noqa: F401
if
cls
.
is_device_capability
(
100
):
if
is_default_backend_supported
:
=
is_attn_backend_supported
(
FLASHINFER_V1
,
head_size
,
dtype
):
from
vllm.v1.attention.backends.utils
import
(
set_kv_cache_layout
)
logger
.
info_once
(
"Using FlashInfer backend with HND KV cache layout on "
"V1 engine by default for Blackwell (SM 10.0) GPUs."
)
set_kv_cache_layout
(
"HND"
)
return
FLASHINFER_V1
except
ImportError
:
logger
.
info_once
(
if
not
is_default_backend_supported
.
can_import
:
logger
.
warning_once
(
"FlashInfer failed to import for V1 engine on "
"Blackwell (SM 10.0) GPUs; it is recommended to "
"install FlashInfer for better performance."
)
pass
# FlashAttention is the default for SM 8.0+ GPUs
if
cls
.
has_device_capability
(
80
)
and
\
supports_head_size
(
FLASH_ATTN_V1
,
head_size
):
logger
.
info_once
(
"Using Flash Attention backend on V1 engine."
)
if
cls
.
has_device_capability
(
80
):
if
is_default_backend_supported
:
=
is_attn_backend_supported
(
FLASH_ATTN_V1
,
head_size
,
dtype
,
allow_import_error
=
False
):
logger
.
info_once
(
"Using Flash Attention backend on "
"V1 engine."
)
return
FLASH_ATTN_V1
# FlexAttention is the default for older GPUs
else
:
logger
.
info_once
(
"Using FlexAttention backend on V1 engine."
)
return
FLEX_ATTENTION_V1
assert
not
is_default_backend_supported
use_flex_attention_reason
=
{}
if
not
is_default_backend_supported
.
head_size
:
use_flex_attention_reason
[
"head_size"
]
=
head_size
if
not
is_default_backend_supported
.
dtype
:
use_flex_attention_reason
[
"dtype"
]
=
dtype
logger
.
info_once
(
"Using FlexAttention backend for %s on V1 engine."
,
", "
.
join
(
f
"
{
k
}
=
{
v
}
"
for
k
,
v
in
use_flex_attention_reason
.
items
()),
)
return
FLEX_ATTENTION_V1
# Backends for V0 engine
if
selected_backend
==
_Backend
.
FLASHINFER
:
logger
.
info
(
"Using FlashInfer backend."
)
...
...
vllm/reasoning/hunyuan_a13b_reasoning_parser.py
View file @
e8cc53af
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
re
from
collections.abc
import
Sequence
from
typing
import
Optional
,
Union
import
regex
as
re
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
...
...
vllm/v1/attention/backends/cpu_attn.py
View file @
e8cc53af
...
...
@@ -37,6 +37,10 @@ logger = init_logger(__name__)
class
TorchSDPABackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
False
@
classmethod
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
@
classmethod
def
validate_head_size
(
cls
,
head_size
:
int
)
->
None
:
attn_impl
=
_get_paged_attn_impl
()
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
e8cc53af
...
...
@@ -44,6 +44,10 @@ class FlashAttentionBackend(AttentionBackend):
accept_output_buffer
:
bool
=
True
@
classmethod
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
e8cc53af
...
...
@@ -42,6 +42,10 @@ class FlashInferBackend(AttentionBackend):
accept_output_buffer
:
bool
=
True
cached_sm100a_supported
:
Optional
[
bool
]
=
None
@
classmethod
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
...
...
vllm/v1/attention/backends/flex_attention.py
View file @
e8cc53af
...
...
@@ -42,6 +42,10 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
class
FlexAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
@
classmethod
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
@
classmethod
def
validate_head_size
(
cls
,
head_size
:
int
)
->
None
:
return
# FlexAttention supports any head size
...
...
vllm/v1/attention/backends/mla/common.py
View file @
e8cc53af
...
...
@@ -262,6 +262,10 @@ class MLACommonBackend(AttentionBackend):
)
->
tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
head_size
)
@
classmethod
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
576
]
...
...
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
e8cc53af
...
...
@@ -314,6 +314,10 @@ class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer
:
bool
=
True
@
classmethod
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
e8cc53af
...
...
@@ -190,6 +190,10 @@ class TritonAttentionBackend(AttentionBackend):
accept_output_buffer
:
bool
=
True
@
classmethod
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment