Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c23690f
Unverified
Commit
4c23690f
authored
Nov 18, 2025
by
Matthew Bonanni
Committed by
GitHub
Nov 18, 2025
Browse files
[Attention] FlashAttention ViT support, make default backend (#28763)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
814843e0
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
15 additions
and
46 deletions
+15
-46
cmake/external_projects/vllm_flash_attn.cmake
cmake/external_projects/vllm_flash_attn.cmake
+1
-1
tests/kernels/attention/test_flash_attn.py
tests/kernels/attention/test_flash_attn.py
+2
-2
tests/kernels/attention/test_mha_attn.py
tests/kernels/attention/test_mha_attn.py
+1
-29
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+9
-12
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+2
-2
No files found.
cmake/external_projects/vllm_flash_attn.cmake
View file @
4c23690f
...
...
@@ -38,7 +38,7 @@ else()
FetchContent_Declare
(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG
58e0626a692f09241182582659e3bf8f16472659
GIT_TAG
71bb26f6295449be880344b93b51791cc009237d
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR
${
CMAKE_BINARY_DIR
}
/vllm-flash-attn
...
...
tests/kernels/attention/test_flash_attn.py
View file @
4c23690f
...
...
@@ -13,14 +13,14 @@ from vllm.vllm_flash_attn import (
)
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
)]
HEAD_SIZES
=
[
128
,
256
]
HEAD_SIZES
=
[
40
,
72
,
80
,
128
,
256
]
BLOCK_SIZES
=
[
16
]
DTYPES
=
[
torch
.
bfloat16
]
QDTYPES
=
[
None
,
torch
.
float8_e4m3fn
]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS
=
[
32768
,
2048
]
SOFT_CAPS
=
[
None
,
50.0
]
SOFT_CAPS
=
[
None
]
SLIDING_WINDOWS
=
[
None
,
256
]
...
...
tests/kernels/attention/test_mha_attn.py
View file @
4c23690f
...
...
@@ -62,38 +62,10 @@ def test_mha_attn_platform(device: str):
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
# Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA not available
# - should use xformers
with
(
patch
(
"vllm.attention.layer.current_platform"
,
CudaPlatform
()),
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CudaPlatform
()),
patch
(
"vllm.attention.layer.check_upstream_fa_availability"
,
return_value
=
False
,
),
):
attn
=
MultiHeadAttention
(
16
,
72
,
scale
=
1
)
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
XFORMERS
# Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA available
# - should use upstream FA
# - should use vLLM's FlashAttention
with
(
patch
(
"vllm.attention.layer.current_platform"
,
CudaPlatform
()),
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CudaPlatform
()),
patch
(
"vllm.attention.layer.check_upstream_fa_availability"
,
return_value
=
True
),
patch
.
dict
(
"sys.modules"
,
{
"flash_attn"
:
type
(
"MockFlashAttn"
,
(),
{
"flash_attn_varlen_func"
:
lambda
*
args
,
**
kwargs
:
None
},
)()
},
),
):
attn
=
MultiHeadAttention
(
16
,
72
,
scale
=
1
)
assert
attn
.
attn_backend
==
AttentionBackendEnum
.
FLASH_ATTN
...
...
vllm/platforms/cuda.py
View file @
4c23690f
...
...
@@ -267,25 +267,22 @@ class CudaPlatformBase(Platform):
)
->
"AttentionBackendEnum"
:
from
vllm.attention.backends.registry
import
AttentionBackendEnum
# For Blackwell GPUs, force TORCH_SDPA for now.
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
if
cls
.
has_device_capability
(
100
):
return
AttentionBackendEnum
.
TORCH_SDPA
if
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
return
AttentionBackendEnum
.
XFORMERS
if
cls
.
has_device_capability
(
80
):
# Try FlashAttention first
try
:
backend_class
=
AttentionBackendEnum
.
FLASH_ATTN
.
get_class
()
if
backend_class
.
supports_head_size
(
head_size
)
and
backend_class
.
supports_dtype
(
dtype
):
return
AttentionBackendEnum
.
FLASH_ATTN
except
ImportError
:
pass
if
cls
.
has_device_capability
(
100
):
# xFormers doesn't support Blackwell, fall back to SDPA
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
return
AttentionBackendEnum
.
TORCH_SDPA
else
:
return
AttentionBackendEnum
.
XFORMERS
else
:
# Fallback for Volta/Turing GPUs or FA not supported
return
AttentionBackendEnum
.
XFORMERS
@
classmethod
def
get_valid_backends
(
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
4c23690f
...
...
@@ -119,8 +119,8 @@ class FlashAttentionBackend(AttentionBackend):
raise
ValueError
(
f
"Unrecognized FP8 dtype:
{
kv_cache_dtype
}
"
)
@
classmethod
def
get_
support
ed
_head_size
s
(
cls
)
->
list
[
int
]
:
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
def
support
s
_head_size
(
cls
,
head_size
:
int
)
->
bool
:
return
head_size
%
8
==
0
and
head_size
<=
256
@
classmethod
def
supports_kv_cache_dtype
(
cls
,
kv_cache_dtype
:
CacheDType
|
None
)
->
bool
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment