Unverified Commit 38658ec6 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix][MM encoder] Fix ViT attention backend resolving for Turing GPU (#29614)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent a24ea541
...@@ -264,14 +264,15 @@ class CudaPlatformBase(Platform): ...@@ -264,14 +264,15 @@ class CudaPlatformBase(Platform):
cls, head_size: int, dtype: torch.dtype cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum": ) -> "AttentionBackendEnum":
# Try FlashAttention first # Try FlashAttention first
try: if (cc := cls.get_device_capability()) and cc.major >= 8:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() try:
if backend_class.supports_head_size( backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
head_size if backend_class.supports_head_size(
) and backend_class.supports_dtype(dtype): head_size
return AttentionBackendEnum.FLASH_ATTN ) and backend_class.supports_dtype(dtype):
except ImportError: return AttentionBackendEnum.FLASH_ATTN
pass except ImportError:
pass
return AttentionBackendEnum.TORCH_SDPA return AttentionBackendEnum.TORCH_SDPA
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment