Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af09b3f0
Unverified
Commit
af09b3f0
authored
Jun 12, 2025
by
Michael Goin
Committed by
GitHub
Jun 12, 2025
Browse files
[Bugfix][V1] Allow manual FlashAttention for Blackwell (#19492)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
4f6c42fa
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
4 deletions
+13
-4
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+13
-4
No files found.
vllm/platforms/cuda.py
View file @
af09b3f0
...
@@ -226,15 +226,21 @@ class CudaPlatformBase(Platform):
...
@@ -226,15 +226,21 @@ class CudaPlatformBase(Platform):
if
selected_backend
==
_Backend
.
FLASHINFER
:
if
selected_backend
==
_Backend
.
FLASHINFER
:
logger
.
info_once
(
"Using FlashInfer backend on V1 engine."
)
logger
.
info_once
(
"Using FlashInfer backend on V1 engine."
)
return
"vllm.v1.attention.backends.flashinfer.FlashInferBackend"
return
"vllm.v1.attention.backends.flashinfer.FlashInferBackend"
if
selected_backend
==
_Backend
.
FLEX_ATTENTION
:
el
if
selected_backend
==
_Backend
.
FLEX_ATTENTION
:
logger
.
info
(
"Using FlexAttenion backend on V1 engine."
)
logger
.
info
(
"Using FlexAttenion backend on V1 engine."
)
return
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
# noqa: E501
return
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
# noqa: E501
if
selected_backend
==
_Backend
.
TRITON_ATTN_VLLM_V1
:
el
if
selected_backend
==
_Backend
.
TRITON_ATTN_VLLM_V1
:
logger
.
info_once
(
"Using Triton backend on V1 engine."
)
logger
.
info_once
(
"Using Triton backend on V1 engine."
)
return
(
"vllm.v1.attention.backends."
return
(
"vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend"
)
"triton_attn.TritonAttentionBackend"
)
elif
selected_backend
==
_Backend
.
FLASH_ATTN
:
logger
.
info_once
(
"Using Flash Attention backend on V1 engine."
)
return
(
"vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend"
)
# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
if
cls
.
is_device_capability
(
100
):
if
cls
.
is_device_capability
(
100
):
# Prefer FlashInfer for V1 on Blackwell GPUs if installed
try
:
try
:
import
flashinfer
# noqa: F401
import
flashinfer
# noqa: F401
logger
.
info_once
(
logger
.
info_once
(
...
@@ -248,10 +254,13 @@ class CudaPlatformBase(Platform):
...
@@ -248,10 +254,13 @@ class CudaPlatformBase(Platform):
"Blackwell (SM 10.0) GPUs; it is recommended to "
"Blackwell (SM 10.0) GPUs; it is recommended to "
"install FlashInfer for better performance."
)
"install FlashInfer for better performance."
)
pass
pass
if
cls
.
has_device_capability
(
80
):
# FlashAttention is the default for SM 8.0+ GPUs
elif
cls
.
has_device_capability
(
80
):
logger
.
info_once
(
"Using Flash Attention backend on V1 engine."
)
logger
.
info_once
(
"Using Flash Attention backend on V1 engine."
)
return
(
"vllm.v1.attention.backends."
return
(
"vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend"
)
"flash_attn.FlashAttentionBackend"
)
# Backends for V0 engine
if
selected_backend
==
_Backend
.
FLASHINFER
:
if
selected_backend
==
_Backend
.
FLASHINFER
:
logger
.
info
(
"Using FlashInfer backend."
)
logger
.
info
(
"Using FlashInfer backend."
)
return
"vllm.attention.backends.flashinfer.FlashInferBackend"
return
"vllm.attention.backends.flashinfer.FlashInferBackend"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment