Unverified Commit 570c3e1c authored by Bradley D's avatar Bradley D Committed by GitHub
Browse files

[Bugfix] Honor --mm_encoder_attn_backend when used (#27124)


Co-authored-by: default avatarBradley D <4551889+bradleyhd@users.noreply.github.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 3a4255c7
...@@ -93,12 +93,15 @@ def check_upstream_fa_availability(dtype: torch.dtype): ...@@ -93,12 +93,15 @@ def check_upstream_fa_availability(dtype: torch.dtype):
def maybe_get_vit_flash_attn_backend( def maybe_get_vit_flash_attn_backend(
attn_backend: _Backend, use_upstream_fa: bool attn_backend: _Backend,
use_upstream_fa: bool,
attn_backend_override: _Backend | None = None,
) -> tuple[_Backend, Callable]: ) -> tuple[_Backend, Callable]:
if ( if (
attn_backend != _Backend.FLASH_ATTN attn_backend != _Backend.FLASH_ATTN
and attn_backend != _Backend.ROCM_AITER_FA and attn_backend != _Backend.ROCM_AITER_FA
and check_upstream_fa_availability(torch.get_default_dtype()) and check_upstream_fa_availability(torch.get_default_dtype())
and attn_backend_override is None
): ):
attn_backend = _Backend.FLASH_ATTN attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True use_upstream_fa = True
...@@ -499,6 +502,7 @@ class MultiHeadAttention(nn.Module): ...@@ -499,6 +502,7 @@ class MultiHeadAttention(nn.Module):
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
use_upstream_fa, use_upstream_fa,
attn_backend_override=attn_backend_override,
) )
) )
......
...@@ -299,6 +299,7 @@ class DotsVisionAttention(nn.Module): ...@@ -299,6 +299,7 @@ class DotsVisionAttention(nn.Module):
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
self.use_upstream_fa, self.use_upstream_fa,
attn_backend_override=attn_backend_override,
) )
) )
if self.attn_backend not in { if self.attn_backend not in {
......
...@@ -206,6 +206,7 @@ class Ernie4_5_VisionAttention(nn.Module): ...@@ -206,6 +206,7 @@ class Ernie4_5_VisionAttention(nn.Module):
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
self.use_upstream_fa, self.use_upstream_fa,
attn_backend_override=attn_backend_override,
) )
) )
......
...@@ -296,6 +296,7 @@ class Glm4vVisionAttention(nn.Module): ...@@ -296,6 +296,7 @@ class Glm4vVisionAttention(nn.Module):
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
self.use_upstream_fa, self.use_upstream_fa,
attn_backend_override=attn_backend_override,
) )
) )
......
...@@ -364,6 +364,7 @@ class Qwen2VisionAttention(nn.Module): ...@@ -364,6 +364,7 @@ class Qwen2VisionAttention(nn.Module):
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
self.use_upstream_fa, self.use_upstream_fa,
attn_backend_override=attn_backend_override,
) )
) )
......
...@@ -259,6 +259,7 @@ class Siglip2Attention(nn.Module): ...@@ -259,6 +259,7 @@ class Siglip2Attention(nn.Module):
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
self.use_upstream_fa, self.use_upstream_fa,
attn_backend_override=attn_backend_override,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment