Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
20d7454c
Unverified
Commit
20d7454c
authored
Feb 06, 2026
by
Rabi Mishra
Committed by
GitHub
Feb 06, 2026
Browse files
fix(ROCm): Make flash_attn import optional in MLA attention (#33511)
Signed-off-by:
rabi
<
ramishra@redhat.com
>
parent
5819ca89
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
3 deletions
+19
-3
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+19
-3
No files found.
vllm/model_executor/layers/attention/mla_attention.py
View file @
20d7454c
...
...
@@ -919,10 +919,20 @@ try:
is_vllm_fa
=
True
except
ImportError
:
# For rocm use upstream flash attention
is_vllm_fa
=
False
flash_attn_varlen_func
=
None
# type: ignore[assignment]
# On ROCm, vllm_flash_attn is not available, try upstream flash_attn instead.
# On CUDA, vllm_flash_attn should always be available (built with vLLM),
# so we don't attempt the fallback there.
if
current_platform
.
is_rocm
():
try
:
from
flash_attn
import
flash_attn_varlen_func
# type: ignore[no-redef]
is_vllm_fa
=
False
except
ImportError
:
logger
.
debug
(
"flash_attn not available on ROCm; "
"MLA models using TRITON_MLA will require flash_attn. "
"AITER_MLA backends use aiter kernels instead."
)
def
dynamic_per_batched_tensor_quant
(
...
...
@@ -1917,6 +1927,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self
.
_run_prefill_new_tokens
=
self
.
_run_prefill_new_tokens_cudnn
self
.
_pad_v
=
False
else
:
# Use FlashAttention
if
flash_attn_varlen_func
is
None
:
raise
RuntimeError
(
"MLA attention requires FlashAttention but it is not "
"available. Please install flash_attn or use "
"--attention-backend ROCM_AITER_MLA."
)
logger
.
info_once
(
"Using FlashAttention prefill for MLA"
,
scope
=
"local"
)
self
.
_run_prefill_context_chunk
=
self
.
_run_prefill_context_chunk_fa
self
.
_run_prefill_new_tokens
=
self
.
_run_prefill_new_tokens_fa
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment