Unverified Commit 71cd8926 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[MM Encoder] Add Triton ViT attention backend (#32183)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 19fab441
...@@ -17,7 +17,7 @@ from vllm.platforms import current_platform ...@@ -17,7 +17,7 @@ from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_default_torch_dtype, set_random_seed
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend from vllm.v1.attention.selector import _cached_get_attn_backend
...@@ -71,6 +71,15 @@ def test_mha_attn_platform(default_vllm_config, device: str): ...@@ -71,6 +71,15 @@ def test_mha_attn_platform(default_vllm_config, device: str):
attn = MMEncoderAttention(16, 72, scale=1) attn = MMEncoderAttention(16, 72, scale=1)
assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN assert attn.attn_backend == AttentionBackendEnum.FLASH_ATTN
# Test CUDA with head_size=72 (not divisible by 32)
# - should use vLLM's FlashAttention
with (
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
set_default_torch_dtype(torch.float32),
):
attn = MMEncoderAttention(16, 72, scale=1)
assert attn.attn_backend == AttentionBackendEnum.TRITON_ATTN
def ref_attention( def ref_attention(
query: torch.Tensor, query: torch.Tensor,
...@@ -153,7 +162,12 @@ def test_mha_attn_forward( ...@@ -153,7 +162,12 @@ def test_mha_attn_forward(
v, v,
scale=scale, scale=scale,
).reshape(batch_size, seq_len, num_heads * head_size) ).reshape(batch_size, seq_len, num_heads * head_size)
torch.testing.assert_close(output, ref_output) tol_kwargs = (
dict(rtol=1e-3, atol=1e-3)
if attn.attn_backend == AttentionBackendEnum.TRITON_ATTN
else {}
)
torch.testing.assert_close(output, ref_output, **tol_kwargs)
@pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS) @pytest.mark.parametrize("var_seq_len", VAR_SEQ_LENS)
......
...@@ -12,6 +12,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum ...@@ -12,6 +12,7 @@ from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.vit_attn_wrappers import ( from vllm.v1.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper, vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper, vit_torch_sdpa_wrapper,
vit_triton_attn_wrapper,
) )
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -165,6 +166,41 @@ class MMEncoderAttention(CustomOp): ...@@ -165,6 +166,41 @@ class MMEncoderAttention(CustomOp):
output = output.reshape(bsz, q_len, -1) output = output.reshape(bsz, q_len, -1)
return output return output
def _forward_triton(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
assert (cu_seqlens is not None and max_seqlen is not None) or (
cu_seqlens is None and max_seqlen is None
), "cu_seqlens and max_seqlen should be both set or both None."
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_reshaped = query.dim() != 4
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
output = vit_triton_attn_wrapper(
q=query,
k=key,
v=value,
batch_size=bsz,
scale=self.scale,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
if is_reshaped:
output = output.reshape(bsz, q_len, -1)
return output
def forward_native( def forward_native(
self, self,
query: torch.Tensor, query: torch.Tensor,
...@@ -185,6 +221,8 @@ class MMEncoderAttention(CustomOp): ...@@ -185,6 +221,8 @@ class MMEncoderAttention(CustomOp):
) -> torch.Tensor: ) -> torch.Tensor:
if self.is_flash_attn_backend: if self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens) return self._forward_sdpa(query, key, value, cu_seqlens)
else: else:
......
...@@ -573,10 +573,11 @@ class DotsVisionTransformer(nn.Module): ...@@ -573,10 +573,11 @@ class DotsVisionTransformer(nn.Module):
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None: def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None:
max_seqlen = None max_seqlen = None
if ( if self.attn_backend in {
self.attn_backend == AttentionBackendEnum.FLASH_ATTN AttentionBackendEnum.FLASH_ATTN,
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA AttentionBackendEnum.ROCM_AITER_FA,
): AttentionBackendEnum.TRITON_ATTN,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen return max_seqlen
......
...@@ -446,10 +446,11 @@ class Ernie4_5_VisionTransformer(nn.Module): ...@@ -446,10 +446,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None: def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
max_seqlen = None max_seqlen = None
if ( if self.attn_backend in {
self.attn_backend == AttentionBackendEnum.FLASH_ATTN AttentionBackendEnum.FLASH_ATTN,
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA AttentionBackendEnum.ROCM_AITER_FA,
): AttentionBackendEnum.TRITON_ATTN,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen return max_seqlen
......
...@@ -723,10 +723,11 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -723,10 +723,11 @@ class Glm4vVisionTransformer(nn.Module):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> torch.Tensor | None: ) -> torch.Tensor | None:
max_seqlen = None max_seqlen = None
if ( if self.attn_backend in {
self.attn_backend == AttentionBackendEnum.FLASH_ATTN AttentionBackendEnum.FLASH_ATTN,
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA AttentionBackendEnum.ROCM_AITER_FA,
): AttentionBackendEnum.TRITON_ATTN,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen return max_seqlen
......
...@@ -730,14 +730,7 @@ class SiglipEncoder(nn.Module): ...@@ -730,14 +730,7 @@ class SiglipEncoder(nn.Module):
head_size=head_dim, head_size=head_dim,
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
) )
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"PaddleOCR-VL does not support {self.attn_backend} backend now."
)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
SiglipEncoderLayer( SiglipEncoderLayer(
...@@ -805,6 +798,7 @@ class SiglipEncoder(nn.Module): ...@@ -805,6 +798,7 @@ class SiglipEncoder(nn.Module):
if self.attn_backend in { if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
}: }:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
......
...@@ -607,15 +607,6 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -607,15 +607,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
) )
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
)
with set_model_tag("Qwen2_5_VisionBlock", is_encoder=True): with set_model_tag("Qwen2_5_VisionBlock", is_encoder=True):
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
...@@ -761,6 +752,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -761,6 +752,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
if self.attn_backend in { if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
}: }:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen return max_seqlen
......
...@@ -642,6 +642,7 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -642,6 +642,7 @@ class Qwen2VisionTransformer(nn.Module):
if self.attn_backend in { if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
}: }:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen return max_seqlen
......
...@@ -391,6 +391,7 @@ class Qwen3OmniMoeAudioEncoder(nn.Module): ...@@ -391,6 +391,7 @@ class Qwen3OmniMoeAudioEncoder(nn.Module):
if self.attn_backend in { if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
}: }:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen return max_seqlen
...@@ -919,6 +920,7 @@ class Qwen3Omni_VisionTransformer(nn.Module): ...@@ -919,6 +920,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
if self.attn_backend in { if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
}: }:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen return max_seqlen
......
...@@ -385,14 +385,6 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -385,14 +385,6 @@ class Qwen3_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
) )
if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now."
)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
Qwen3_VisionBlock( Qwen3_VisionBlock(
...@@ -526,9 +518,10 @@ class Qwen3_VisionTransformer(nn.Module): ...@@ -526,9 +518,10 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
max_seqlen = torch.zeros([], device=cu_seqlens.device) max_seqlen = torch.zeros([], device=cu_seqlens.device)
if ( if self.attn_backend in (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN AttentionBackendEnum.FLASH_ATTN,
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
): ):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
return max_seqlen return max_seqlen
......
...@@ -108,7 +108,7 @@ def get_vit_attn_backend( ...@@ -108,7 +108,7 @@ def get_vit_attn_backend(
multimodal_config: MultiModalConfig | None = ( multimodal_config: MultiModalConfig | None = (
model_config.multimodal_config if model_config is not None else None model_config.multimodal_config if model_config is not None else None
) )
except AssertionError: except (AssertionError, AttributeError):
multimodal_config = None multimodal_config = None
attn_backend_override = ( attn_backend_override = (
...@@ -134,7 +134,7 @@ def is_vit_use_data_parallel(): ...@@ -134,7 +134,7 @@ def is_vit_use_data_parallel():
multimodal_config: MultiModalConfig | None = ( multimodal_config: MultiModalConfig | None = (
model_config.multimodal_config if model_config is not None else None model_config.multimodal_config if model_config is not None else None
) )
except AssertionError: except (AssertionError, AttributeError):
multimodal_config = None multimodal_config = None
mm_encoder_tp_mode = ( mm_encoder_tp_mode = (
......
...@@ -411,8 +411,9 @@ class CudaPlatformBase(Platform): ...@@ -411,8 +411,9 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [ return [
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.TORCH_SDPA,
] ]
@classmethod @classmethod
...@@ -430,14 +431,25 @@ class CudaPlatformBase(Platform): ...@@ -430,14 +431,25 @@ class CudaPlatformBase(Platform):
logger.info_once(f"Using backend {backend} for vit attention") logger.info_once(f"Using backend {backend} for vit attention")
return backend return backend
# Try FlashAttention first cc = cls.get_device_capability()
if (cc := cls.get_device_capability()) and cc.major >= 8: for vit_attn_backend in cls.get_supported_vit_attn_backends():
if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
continue
try: try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class() backend_class = vit_attn_backend.get_class()
if backend_class.supports_head_size( is_backend_supported = backend_class.supports_head_size(
head_size head_size
) and backend_class.supports_dtype(dtype): ) and backend_class.supports_dtype(dtype)
return AttentionBackendEnum.FLASH_ATTN if cc is not None:
is_backend_supported = (
is_backend_supported
and backend_class.supports_compute_capability(cc)
)
if is_backend_supported:
logger.info_once(
f"Using backend {vit_attn_backend} for vit attention"
)
return vit_attn_backend
except ImportError: except ImportError:
pass pass
......
...@@ -384,6 +384,7 @@ class RocmPlatform(Platform): ...@@ -384,6 +384,7 @@ class RocmPlatform(Platform):
return [ return [
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
] ]
......
...@@ -110,6 +110,83 @@ def vit_flash_attn_wrapper( ...@@ -110,6 +110,83 @@ def vit_flash_attn_wrapper(
) )
def triton_attn_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: int,
scale: float | None = None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
from vllm.v1.attention.ops.triton_prefill_attention import context_attention_fwd
q_len = q.size(1)
if cu_seqlens is None:
cu_seqlens = torch.arange(
0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device
)
max_seqlen = q_len if max_seqlen is None else max_seqlen.item()
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = torch.empty_like(q)
context_attention_fwd(
q,
k,
v,
output,
b_start_loc=cu_seqlens[:-1],
b_seq_len=cu_seqlens[1:] - cu_seqlens[:-1],
max_input_len=max_seqlen,
is_causal=False,
sliding_window_q=None,
sliding_window_k=None,
softmax_scale=scale,
)
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
return context_layer
def triton_attn_wrapper_fake(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: int,
scale: float | None = None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(q)
direct_register_custom_op(
op_name="triton_attn_wrapper",
op_func=triton_attn_wrapper,
fake_impl=triton_attn_wrapper_fake,
)
def vit_triton_attn_wrapper(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
batch_size: int,
scale: float | None = None,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.triton_attn_wrapper(
q,
k,
v,
batch_size,
scale,
cu_seqlens,
max_seqlen,
)
def apply_sdpa( def apply_sdpa(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment