Unverified Commit b0906d8b authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[MM Encoder] Default to use TORCH_SDPA backend for ViT on Volta/Turing GPU (#36472)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent aaf5fa9a
......@@ -19,6 +19,7 @@ from vllm.model_executor.layers.attention import MMEncoderAttention
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.interface import DeviceCapability
from vllm.platforms.rocm import RocmPlatform
from vllm.utils.torch_utils import set_default_torch_dtype, set_random_seed
from vllm.v1.attention.backends.registry import AttentionBackendEnum
......@@ -83,6 +84,20 @@ def test_mha_attn_platform(default_vllm_config, device: str):
attn = MMEncoderAttention(16, 72, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TRITON_ATTN
# Test Turing (pre-Ampere, sm_75): FlashAttention requires sm>=80,
# and Triton no longer supports MMA on Turing, so we expect that
# TORCH_SDPA is used for MMEncoderAttention.
with (
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
patch.object(
CudaPlatform,
"get_device_capability",
return_value=DeviceCapability(major=7, minor=5),
),
):
attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
def ref_attention(
query: torch.Tensor,
......
......@@ -413,12 +413,20 @@ class CudaPlatformBase(Platform):
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASHINFER,
]
if cls.has_device_capability(80):
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASHINFER,
]
else:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLASHINFER,
]
@classmethod
def get_vit_attn_backend(
......@@ -438,7 +446,7 @@ class CudaPlatformBase(Platform):
cc = cls.get_device_capability()
for vit_attn_backend in cls.get_supported_vit_attn_backends():
if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
continue
return vit_attn_backend
try:
backend_class = vit_attn_backend.get_class()
is_backend_supported = backend_class.supports_head_size(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment