Unverified Commit e087fbc3 authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[MM] Pass FA version in ViT Attn (#30756)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent e80455ca
...@@ -10,6 +10,7 @@ from vllm.attention.ops.vit_attn_wrappers import ( ...@@ -10,6 +10,7 @@ from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper, vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper, vit_torch_sdpa_wrapper,
) )
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import MultiModalConfig from vllm.config import MultiModalConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
...@@ -101,6 +102,10 @@ class MMEncoderAttention(CustomOp): ...@@ -101,6 +102,10 @@ class MMEncoderAttention(CustomOp):
self.attn_backend, self.attn_backend,
) )
if self.is_flash_attn_backend:
assert self.flash_attn_varlen_func is not None
self._fa_version = get_flash_attn_version()
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
@classmethod @classmethod
...@@ -204,6 +209,7 @@ class MMEncoderAttention(CustomOp): ...@@ -204,6 +209,7 @@ class MMEncoderAttention(CustomOp):
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
batch_size=bsz, batch_size=bsz,
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA), is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
fa_version=self._fa_version,
) )
return output return output
......
...@@ -28,11 +28,15 @@ def flash_attn_maxseqlen_wrapper( ...@@ -28,11 +28,15 @@ def flash_attn_maxseqlen_wrapper(
max_seqlen: torch.Tensor, max_seqlen: torch.Tensor,
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
fa_version: int,
) -> torch.Tensor: ) -> torch.Tensor:
kwargs = {}
if is_rocm_aiter: if is_rocm_aiter:
from aiter import flash_attn_varlen_func from aiter import flash_attn_varlen_func
else: else:
from vllm.attention.utils.fa_utils import flash_attn_varlen_func from vllm.attention.utils.fa_utils import flash_attn_varlen_func
kwargs["fa_version"] = fa_version
q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (einops.rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q, q,
...@@ -44,6 +48,7 @@ def flash_attn_maxseqlen_wrapper( ...@@ -44,6 +48,7 @@ def flash_attn_maxseqlen_wrapper(
max_seqlen_k=max_seqlen.item(), max_seqlen_k=max_seqlen.item(),
dropout_p=0.0, dropout_p=0.0,
causal=False, causal=False,
**kwargs,
) )
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size) context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
return context_layer return context_layer
...@@ -57,6 +62,7 @@ def flash_attn_maxseqlen_wrapper_fake( ...@@ -57,6 +62,7 @@ def flash_attn_maxseqlen_wrapper_fake(
max_seqlen: torch.Tensor, max_seqlen: torch.Tensor,
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
fa_version: int,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(q) return torch.empty_like(q)
...@@ -76,9 +82,10 @@ def vit_flash_attn_wrapper( ...@@ -76,9 +82,10 @@ def vit_flash_attn_wrapper(
max_seqlen: torch.Tensor, max_seqlen: torch.Tensor,
batch_size: int, batch_size: int,
is_rocm_aiter: bool, is_rocm_aiter: bool,
fa_version: int,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops.vllm.flash_attn_maxseqlen_wrapper( return torch.ops.vllm.flash_attn_maxseqlen_wrapper(
q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, fa_version
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment