Unverified Commit 20d7454c authored by Rabi Mishra's avatar Rabi Mishra Committed by GitHub
Browse files

fix(ROCm): Make flash_attn import optional in MLA attention (#33511)


Signed-off-by: default avatarrabi <ramishra@redhat.com>
parent 5819ca89
...@@ -919,10 +919,20 @@ try: ...@@ -919,10 +919,20 @@ try:
is_vllm_fa = True is_vllm_fa = True
except ImportError: except ImportError:
# For rocm use upstream flash attention is_vllm_fa = False
flash_attn_varlen_func = None # type: ignore[assignment]
# On ROCm, vllm_flash_attn is not available, try upstream flash_attn instead.
# On CUDA, vllm_flash_attn should always be available (built with vLLM),
# so we don't attempt the fallback there.
if current_platform.is_rocm(): if current_platform.is_rocm():
try:
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef] from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
is_vllm_fa = False except ImportError:
logger.debug(
"flash_attn not available on ROCm; "
"MLA models using TRITON_MLA will require flash_attn. "
"AITER_MLA backends use aiter kernels instead."
)
def dynamic_per_batched_tensor_quant( def dynamic_per_batched_tensor_quant(
...@@ -1917,6 +1927,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1917,6 +1927,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
self._pad_v = False self._pad_v = False
else: # Use FlashAttention else: # Use FlashAttention
if flash_attn_varlen_func is None:
raise RuntimeError(
"MLA attention requires FlashAttention but it is not "
"available. Please install flash_attn or use "
"--attention-backend ROCM_AITER_MLA."
)
logger.info_once("Using FlashAttention prefill for MLA", scope="local") logger.info_once("Using FlashAttention prefill for MLA", scope="local")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment