Unverified Commit e806178d authored by Zhewen Li's avatar Zhewen Li Committed by GitHub
Browse files

[BugFix][VL] Fix FA selection on Qwen2.5-VL (#27790)


Signed-off-by: default avatarzhewenli <zhewenli@meta.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 5be1bed7
...@@ -318,7 +318,7 @@ steps: ...@@ -318,7 +318,7 @@ steps:
- label: V1 Test entrypoints # 35min - label: V1 Test entrypoints # 35min
timeout_in_minutes: 50 timeout_in_minutes: 50
mirror_hardwares: [amdexperimental] mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1 agent_pool: mi325_1
# grade: Blocking # grade: Blocking
source_file_dependencies: source_file_dependencies:
......
...@@ -43,10 +43,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( ...@@ -43,10 +43,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
) )
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import ( from vllm.attention.layer import maybe_get_vit_flash_attn_backend
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.attention.ops.vit_attn_wrappers import ( from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper, vit_flash_attn_wrapper,
vit_xformers_attn_wrapper, vit_xformers_attn_wrapper,
...@@ -318,6 +315,7 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -318,6 +315,7 @@ class Qwen2_5_VisionAttention(nn.Module):
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA, attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False, use_upstream_fa: bool = False,
attn_backend_override: _Backend | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
...@@ -358,8 +356,14 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -358,8 +356,14 @@ class Qwen2_5_VisionAttention(nn.Module):
maybe_get_vit_flash_attn_backend( maybe_get_vit_flash_attn_backend(
self.attn_backend, self.attn_backend,
self.use_upstream_fa, self.use_upstream_fa,
attn_backend_override=attn_backend_override,
) )
) )
# On ROCm with FLASH_ATTN backend, upstream flash_attn is used
from vllm.platforms import current_platform
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
self.use_upstream_fa = True
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA, _Backend.ROCM_AITER_FA,
...@@ -484,6 +488,7 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -484,6 +488,7 @@ class Qwen2_5_VisionBlock(nn.Module):
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA, attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False, use_upstream_fa: bool = False,
attn_backend_override: _Backend | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
...@@ -499,6 +504,7 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -499,6 +504,7 @@ class Qwen2_5_VisionBlock(nn.Module):
use_data_parallel=use_data_parallel, use_data_parallel=use_data_parallel,
attn_backend=attn_backend, attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa, use_upstream_fa=use_upstream_fa,
attn_backend_override=attn_backend_override,
) )
self.mlp = Qwen2_5_VisionMLP( self.mlp = Qwen2_5_VisionMLP(
dim, dim,
...@@ -698,13 +704,14 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -698,13 +704,14 @@ class Qwen2_5_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(), dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override, attn_backend_override=attn_backend_override,
) )
if (
self.attn_backend != _Backend.FLASH_ATTN self.attn_backend, self.flash_attn_varlen_func = (
and self.attn_backend != _Backend.ROCM_AITER_FA maybe_get_vit_flash_attn_backend(
and check_upstream_fa_availability(torch.get_default_dtype()) self.attn_backend,
): use_upstream_fa,
self.attn_backend = _Backend.FLASH_ATTN attn_backend_override=attn_backend_override,
use_upstream_fa = True )
)
if self.attn_backend not in { if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN,
...@@ -730,6 +737,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -730,6 +737,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
use_data_parallel=use_data_parallel, use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend, attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa, use_upstream_fa=use_upstream_fa,
attn_backend_override=attn_backend_override,
) )
for layer_idx in range(depth) for layer_idx in range(depth)
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment