Unverified Commit 97ef11dd authored by monajafi-amd's avatar monajafi-amd Committed by GitHub
Browse files

[ROCm][ViT] Enable Flash Attention Triton backend on RDNA3/RDNA4 (#32944)


Signed-off-by: default avatarmohammad najafi <mohammad.najafi@amd.com>
parent ecc3dd66
......@@ -163,6 +163,28 @@ def use_rocm_custom_paged_attention(
)
@cache
def flash_attn_triton_available() -> bool:
if not on_gfx1x():
return False
try:
from importlib.util import find_spec
if find_spec("flash_attn") is None:
return False
if find_spec("flash_attn.flash_attn_triton_amd") is None:
return False
if os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") != "TRUE":
logger.info_once(
"Set FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE to enable "
"Flash Attention Triton backend on RDNA."
)
return False
return True
except ImportError:
return False
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM
device_name: str = "rocm"
......@@ -348,7 +370,7 @@ class RocmPlatform(Platform):
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_enabled():
if rocm_aiter_ops.is_enabled() and on_gfx9():
logger.info_once("Using AITER Flash Attention backend for ViT model.")
return AttentionBackendEnum.ROCM_AITER_FA
......@@ -360,6 +382,17 @@ class RocmPlatform(Platform):
logger.info_once("Using Flash Attention backend for ViT model.")
return AttentionBackendEnum.FLASH_ATTN
# RDNA3/RDNA4 (gfx11xx/gfx12xx): Use Flash Attention Triton backend
if (
on_gfx1x()
and flash_attn_triton_available()
and (dtype == torch.float16 or dtype == torch.bfloat16)
):
logger.info_once(
"Using Flash Attention (Triton backend) for ViT model on RDNA."
)
return AttentionBackendEnum.FLASH_ATTN
logger.info_once("Using Torch SDPA backend for ViT model.")
return AttentionBackendEnum.TORCH_SDPA
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment