Unverified Commit 04741337 authored by Matthias Gehre's avatar Matthias Gehre Committed by GitHub
Browse files

[Attention][AMD] Make flash-attn optional (#30361)


Signed-off-by: default avatarMatthias Gehre <matthias.gehre@amd.com>
parent 74e4bb1c
......@@ -23,11 +23,13 @@ elif current_platform.is_xpu():
elif current_platform.is_rocm():
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
except ImportError as e:
raise ImportError(
"Rocm platform requires upstream flash-attn "
"to be installed. Please install flash-attn first."
) from e
except ImportError:
def flash_attn_varlen_func(*args, **kwargs):
raise ImportError(
"ROCm platform requires upstream flash-attn "
"to be installed. Please install flash-attn first."
)
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment