Unverified Commit b0906d8b authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[MM Encoder] Default to use TORCH_SDPA backend for ViT on Volta/Turing GPU (#36472)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent aaf5fa9a
...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.attention import MMEncoderAttention ...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.attention import MMEncoderAttention
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.interface import DeviceCapability
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
from vllm.utils.torch_utils import set_default_torch_dtype, set_random_seed from vllm.utils.torch_utils import set_default_torch_dtype, set_random_seed
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
...@@ -83,6 +84,20 @@ def test_mha_attn_platform(default_vllm_config, device: str): ...@@ -83,6 +84,20 @@ def test_mha_attn_platform(default_vllm_config, device: str):
attn = MMEncoderAttention(16, 72, scale=1) attn = MMEncoderAttention(16, 72, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TRITON_ATTN assert attn.attn_backend == AttentionBackendEnum.TRITON_ATTN
# Test Turing (pre-Ampere, sm_75): FlashAttention requires sm>=80,
# and Triton no longer supports MMA on Turing, so we expect that
# TORCH_SDPA is used for MMEncoderAttention.
with (
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
patch.object(
CudaPlatform,
"get_device_capability",
return_value=DeviceCapability(major=7, minor=5),
),
):
attn = MMEncoderAttention(16, 64, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TORCH_SDPA
def ref_attention( def ref_attention(
query: torch.Tensor, query: torch.Tensor,
......
...@@ -413,12 +413,20 @@ class CudaPlatformBase(Platform): ...@@ -413,12 +413,20 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
if cls.has_device_capability(80):
return [ return [
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASHINFER, AttentionBackendEnum.FLASHINFER,
] ]
else:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLASHINFER,
]
@classmethod @classmethod
def get_vit_attn_backend( def get_vit_attn_backend(
...@@ -438,7 +446,7 @@ class CudaPlatformBase(Platform): ...@@ -438,7 +446,7 @@ class CudaPlatformBase(Platform):
cc = cls.get_device_capability() cc = cls.get_device_capability()
for vit_attn_backend in cls.get_supported_vit_attn_backends(): for vit_attn_backend in cls.get_supported_vit_attn_backends():
if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA: if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
continue return vit_attn_backend
try: try:
backend_class = vit_attn_backend.get_class() backend_class = vit_attn_backend.get_class()
is_backend_supported = backend_class.supports_head_size( is_backend_supported = backend_class.supports_head_size(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment