Unverified Commit b1b5e045 authored by Yan Ma's avatar Yan Ma Committed by GitHub
Browse files

[XPU] allow TORCH_SDPA/TRITON_ATTN as XPU vit Backend (#35010)


Signed-off-by: default avatarYan Ma <yan.ma@intel.com>
parent 5f68464f
...@@ -249,7 +249,14 @@ class MMEncoderAttention(CustomOp): ...@@ -249,7 +249,14 @@ class MMEncoderAttention(CustomOp):
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor: ) -> torch.Tensor:
assert self.is_flash_attn_backend, ( if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
"XPU only supports FLASH_ATTN for vision attention." return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
) elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:
raise ValueError(
f"Unsupported multi-modal encoder attention backend for XPU: "
f"{self.attn_backend}."
)
...@@ -89,6 +89,7 @@ class XPUPlatform(Platform): ...@@ -89,6 +89,7 @@ class XPUPlatform(Platform):
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [ return [
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.TORCH_SDPA,
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment