Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9c5ee91b
Unverified
Commit
9c5ee91b
authored
Oct 02, 2025
by
TJian
Committed by
GitHub
Oct 02, 2025
Browse files
[ROCm] [VL] [Bugfix] Fix vit flash attn dispatcher logic for ROCm (#26104)
Signed-off-by:
tjtanaa
<
tunjian.tan@embeddedllm.com
>
parent
27edd2ae
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
154 additions
and
141 deletions
+154
-141
vllm/attention/layer.py
vllm/attention/layer.py
+49
-23
vllm/model_executor/models/dots_ocr.py
vllm/model_executor/models/dots_ocr.py
+19
-22
vllm/model_executor/models/ernie45_vl.py
vllm/model_executor/models/ernie45_vl.py
+23
-26
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+17
-14
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+17
-17
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+18
-22
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+3
-1
vllm/model_executor/models/siglip2navit.py
vllm/model_executor/models/siglip2navit.py
+8
-14
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+0
-2
No files found.
vllm/attention/layer.py
View file @
9c5ee91b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from
typing
import
List
,
Optional
from
typing
import
Callable
,
List
,
Optional
import
torch
import
torch.nn
as
nn
...
...
@@ -68,9 +68,39 @@ def check_upstream_fa_availability(dtype: torch.dtype):
)
and
current_platform
.
has_device_capability
(
80
):
from
transformers.utils
import
is_flash_attn_2_available
return
is_flash_attn_2_available
()
if
current_platform
.
is_rocm
():
from
importlib.util
import
find_spec
return
find_spec
(
"flash_attn"
)
is
not
None
return
False
def
maybe_get_vit_flash_attn_backend
(
attn_backend
:
_Backend
,
use_upstream_fa
:
bool
)
->
tuple
[
_Backend
,
Callable
]:
if
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
attn_backend
!=
_Backend
.
ROCM_AITER_FA
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
attn_backend
=
_Backend
.
FLASH_ATTN
use_upstream_fa
=
True
if
current_platform
.
is_rocm
()
and
\
attn_backend
==
_Backend
.
FLASH_ATTN
:
use_upstream_fa
=
True
if
(
attn_backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
ROCM_AITER_FA
}):
if
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
from
aiter
import
flash_attn_varlen_func
else
:
if
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
else
:
flash_attn_varlen_func
=
None
return
attn_backend
,
flash_attn_varlen_func
class
Attention
(
nn
.
Module
,
AttentionLayerBase
):
"""Attention layer.
...
...
@@ -410,13 +440,9 @@ class MultiHeadAttention(nn.Module):
# to upstream flash attention if available.
# If vllm native fa is selected, we use it directly.
use_upstream_fa
=
False
if
backend
!=
_Backend
.
FLASH_ATTN
and
check_upstream_fa_availability
(
dtype
):
backend
=
_Backend
.
FLASH_ATTN
use_upstream_fa
=
True
if
current_platform
.
is_rocm
()
or
current_platform
.
is_xpu
():
# currently, only torch_sdpa is supported on
rocm/
xpu
if
current_platform
.
is_xpu
():
# currently, only torch_sdpa is supported on xpu
self
.
attn_backend
=
_Backend
.
TORCH_SDPA
else
:
...
...
@@ -428,17 +454,25 @@ class MultiHeadAttention(nn.Module):
_Backend
.
FLASH_ATTN
,
}
else
_Backend
.
TORCH_SDPA
self
.
attn_backend
,
self
.
_flash_attn_varlen_func
\
=
maybe_get_vit_flash_attn_backend
(
self
.
attn_backend
,
use_upstream_fa
,
)
if
(
self
.
attn_backend
==
_Backend
.
XFORMERS
and
not
check_xformers_availability
()):
self
.
attn_backend
=
_Backend
.
TORCH_SDPA
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
if
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
self
.
_flash_attn_varlen_func
=
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
self
.
_flash_attn_varlen_func
=
flash_attn_varlen_func
self
.
is_flash_attn_backend
=
self
.
attn_backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
ROCM_AITER_FA
}
# this condition is just to make sure that the
# use_upstream_fa in the log is correct
if
current_platform
.
is_rocm
()
\
and
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
use_upstream_fa
=
True
logger
.
info_once
(
f
"MultiHeadAttention attn_backend:
{
self
.
attn_backend
}
, "
...
...
@@ -466,7 +500,7 @@ class MultiHeadAttention(nn.Module):
key
=
torch
.
repeat_interleave
(
key
,
num_repeat
,
dim
=
2
)
value
=
torch
.
repeat_interleave
(
value
,
num_repeat
,
dim
=
2
)
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
if
self
.
is_flash_
attn_backend
:
cu_seqlens_q
=
torch
.
arange
(
0
,
(
bsz
+
1
)
*
q_len
,
step
=
q_len
,
dtype
=
torch
.
int32
,
...
...
@@ -507,14 +541,6 @@ class MultiHeadAttention(nn.Module):
from
torch_xla.experimental.custom_kernel
import
flash_attention
out
=
flash_attention
(
query
,
key
,
value
,
sm_scale
=
self
.
scale
)
out
=
out
.
transpose
(
1
,
2
)
elif
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
from
aiter
import
flash_attn_varlen_func
# ROCm Flash Attention expects (batch, seq, heads, head_dim)
out
=
flash_attn_varlen_func
(
query
,
key
,
value
,
softmax_scale
=
self
.
scale
)
else
:
# ViT attention hasn't supported this backend yet
raise
NotImplementedError
(
...
...
vllm/model_executor/models/dots_ocr.py
View file @
9c5ee91b
...
...
@@ -10,7 +10,8 @@ from torch.nn import LayerNorm
from
transformers.models.qwen2_vl
import
Qwen2VLProcessor
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.attention.layer
import
(
check_upstream_fa_availability
,
maybe_get_vit_flash_attn_backend
)
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.distributed.parallel_state
import
(
...
...
@@ -267,10 +268,12 @@ class DotsVisionAttention(nn.Module):
self
.
attn_backend
=
get_vit_attn_backend
(
self
.
hidden_size_per_attention_head
,
torch
.
get_default_dtype
())
self
.
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
use_upstream_fa
=
True
self
.
attn_backend
,
self
.
flash_attn_varlen_func
\
=
maybe_get_vit_flash_attn_backend
(
self
.
attn_backend
,
self
.
use_upstream_fa
,
)
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
ROCM_AITER_FA
...
...
@@ -306,17 +309,10 @@ class DotsVisionAttention(nn.Module):
q
,
k
=
torch
.
chunk
(
qk_rotated
,
2
,
dim
=
0
)
if
self
.
is_flash_attn_backend
:
if
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
from
aiter
import
flash_attn_varlen_func
else
:
if
self
.
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
q_
=
q
.
reshape
(
bs
*
q
.
shape
[
1
],
q
.
shape
[
2
],
q
.
shape
[
3
])
k_
=
k
.
reshape
(
bs
*
k
.
shape
[
1
],
k
.
shape
[
2
],
k
.
shape
[
3
])
v_
=
v
.
reshape
(
bs
*
v
.
shape
[
1
],
v
.
shape
[
2
],
v
.
shape
[
3
])
output
=
flash_attn_varlen_func
(
q_
,
output
=
self
.
flash_attn_varlen_func
(
q_
,
k_
,
v_
,
cu_seqlens_q
=
cu_seqlens
,
...
...
@@ -611,7 +607,8 @@ class DotsVisionTransformer(nn.Module):
self
,
cu_seqlens
:
torch
.
Tensor
)
->
tuple
[
Optional
[
int
],
Optional
[
list
[
int
]]]:
max_seqlen
,
seqlens
=
None
,
None
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
if
(
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
or
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
):
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
...
...
vllm/model_executor/models/ernie45_vl.py
View file @
9c5ee91b
...
...
@@ -35,7 +35,8 @@ from einops import rearrange, repeat
from
transformers
import
BatchFeature
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.attention.layer
import
(
check_upstream_fa_availability
,
maybe_get_vit_flash_attn_backend
)
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
utils
as
dist_utils
...
...
@@ -176,14 +177,18 @@ class Ernie4_5_VisionAttention(nn.Module):
dtype
=
torch
.
get_default_dtype
())
self
.
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
use_upstream_fa
=
True
self
.
attn_backend
,
self
.
flash_attn_varlen_func
\
=
maybe_get_vit_flash_attn_backend
(
self
.
attn_backend
,
self
.
use_upstream_fa
,
)
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
ROCM_AITER_FA
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
ROCM_AITER_FA
,
}:
raise
RuntimeError
(
f
"Ernie45-VL does not support
{
self
.
attn_backend
}
backend now."
...
...
@@ -239,19 +244,10 @@ class Ernie4_5_VisionAttention(nn.Module):
q
,
k
=
torch
.
chunk
(
qk_rotated
,
2
,
dim
=
0
)
if
self
.
is_flash_attn_backend
:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
if
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
from
aiter
import
flash_attn_varlen_func
else
:
if
self
.
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
output
=
flash_attn_varlen_func
(
q
,
output
=
self
.
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
=
cu_seqlens
,
...
...
@@ -516,7 +512,8 @@ class Ernie4_5_VisionTransformer(nn.Module):
self
,
cu_seqlens
:
torch
.
Tensor
)
->
tuple
[
Optional
[
int
],
Optional
[
list
[
int
]]]:
max_seqlen
,
seqlens
=
None
,
None
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
if
(
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
or
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
):
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
...
...
vllm/model_executor/models/glm4_1v.py
View file @
9c5ee91b
...
...
@@ -47,7 +47,8 @@ from transformers.models.glm4v.video_processing_glm4v import (
from
transformers.video_utils
import
VideoMetadata
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.attention.layer
import
(
check_upstream_fa_availability
,
maybe_get_vit_flash_attn_backend
)
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
parallel_state
)
...
...
@@ -263,19 +264,26 @@ class Glm4vVisionAttention(nn.Module):
head_size
=
self
.
hidden_size_per_attention_head
,
dtype
=
torch
.
get_default_dtype
())
self
.
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
use_upstream_fa
=
True
self
.
attn_backend
,
self
.
flash_attn_varlen_func
\
=
maybe_get_vit_flash_attn_backend
(
self
.
attn_backend
,
self
.
use_upstream_fa
,
)
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
ROCM_AITER_FA
,
}:
raise
RuntimeError
(
f
"GLM-4V does not support
{
self
.
attn_backend
}
backend now."
)
self
.
is_flash_attn_backend
=
self
.
attn_backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
ROCM_AITER_FA
}
def
split_qkv
(
self
,
qkv
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]:
# [s, b, 3 * head * head_dim]
seq_len
,
bs
,
_
=
qkv
.
shape
...
...
@@ -316,17 +324,11 @@ class Glm4vVisionAttention(nn.Module):
qk_rotated
=
apply_rotary_pos_emb_vision
(
qk_concat
,
rotary_pos_emb
)
q
,
k
=
torch
.
chunk
(
qk_rotated
,
2
,
dim
=
0
)
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
if
self
.
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
if
self
.
is_flash_attn_backend
:
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
output
=
flash_attn_varlen_func
(
output
=
self
.
flash_attn_varlen_func
(
q
,
k
,
v
,
...
...
@@ -774,7 +776,8 @@ class Glm4vVisionTransformer(nn.Module):
)
->
tuple
[
Optional
[
int
],
Optional
[
list
[
int
]]]:
max_seqlen
,
seqlens
=
None
,
None
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
if
(
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
or
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
):
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
return
max_seqlen
,
seqlens
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
9c5ee91b
...
...
@@ -39,7 +39,8 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
)
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.attention.layer
import
(
check_upstream_fa_availability
,
maybe_get_vit_flash_attn_backend
)
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
utils
as
dist_utils
...
...
@@ -302,6 +303,11 @@ class Qwen2_5_VisionAttention(nn.Module):
disable_tp
=
use_data_parallel
)
self
.
attn_backend
=
attn_backend
self
.
use_upstream_fa
=
use_upstream_fa
self
.
attn_backend
,
self
.
flash_attn_varlen_func
\
=
maybe_get_vit_flash_attn_backend
(
self
.
attn_backend
,
self
.
use_upstream_fa
,
)
self
.
is_flash_attn_backend
=
self
.
attn_backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
ROCM_AITER_FA
}
...
...
@@ -354,17 +360,10 @@ class Qwen2_5_VisionAttention(nn.Module):
q
,
k
=
torch
.
chunk
(
qk_rotated
,
2
,
dim
=
0
)
if
self
.
is_flash_attn_backend
:
if
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
from
aiter
import
flash_attn_varlen_func
else
:
if
self
.
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
output
=
flash_attn_varlen_func
(
q
,
output
=
self
.
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
=
cu_seqlens
,
...
...
@@ -618,6 +617,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
())
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
self
.
attn_backend
!=
_Backend
.
ROCM_AITER_FA
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
9c5ee91b
...
...
@@ -42,7 +42,8 @@ from transformers.models.qwen2_vl.video_processing_qwen2_vl import (
Qwen2VLVideoProcessor
)
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.attention.layer
import
(
check_upstream_fa_availability
,
maybe_get_vit_flash_attn_backend
)
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
parallel_state
,
tensor_model_parallel_all_gather
from
vllm.distributed
import
utils
as
dist_utils
...
...
@@ -319,11 +320,12 @@ class Qwen2VisionAttention(nn.Module):
head_size
=
self
.
hidden_size_per_attention_head
,
dtype
=
torch
.
get_default_dtype
())
self
.
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
use_upstream_fa
=
True
self
.
attn_backend
,
self
.
flash_attn_varlen_func
\
=
maybe_get_vit_flash_attn_backend
(
self
.
attn_backend
,
self
.
use_upstream_fa
,
)
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
...
...
@@ -331,6 +333,7 @@ class Qwen2VisionAttention(nn.Module):
}:
raise
RuntimeError
(
f
"Qwen2-VL does not support
{
self
.
attn_backend
}
backend now."
)
self
.
is_flash_attn_backend
=
self
.
attn_backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
ROCM_AITER_FA
}
...
...
@@ -383,17 +386,10 @@ class Qwen2VisionAttention(nn.Module):
q
,
k
=
torch
.
chunk
(
qk_rotated
,
2
,
dim
=
0
)
if
self
.
is_flash_attn_backend
:
if
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
from
aiter
import
flash_attn_varlen_func
else
:
if
self
.
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
output
=
flash_attn_varlen_func
(
q
,
output
=
self
.
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
=
cu_seqlens
,
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
9c5ee91b
...
...
@@ -323,6 +323,7 @@ class Qwen3_VisionTransformer(nn.Module):
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
())
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
self
.
attn_backend
!=
_Backend
.
ROCM_AITER_FA
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
...
...
@@ -476,7 +477,8 @@ class Qwen3_VisionTransformer(nn.Module):
cu_seqlens
:
torch
.
Tensor
,
)
->
tuple
[
Optional
[
int
],
Optional
[
list
[
int
]]]:
max_seqlen
,
seqlens
=
None
,
None
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
if
(
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
or
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
):
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
...
...
vllm/model_executor/models/siglip2navit.py
View file @
9c5ee91b
...
...
@@ -14,7 +14,7 @@ from transformers import Siglip2VisionConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.attention.layer
import
maybe_get_vit_flash_attn_backend
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -240,11 +240,12 @@ class Siglip2Attention(nn.Module):
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
self
.
head_dim
,
dtype
=
torch
.
get_default_dtype
())
self
.
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
use_upstream_fa
=
True
self
.
attn_backend
,
self
.
flash_attn_varlen_func
\
=
maybe_get_vit_flash_attn_backend
(
self
.
attn_backend
,
self
.
use_upstream_fa
,
)
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
...
...
@@ -286,14 +287,7 @@ class Siglip2Attention(nn.Module):
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
if
self
.
is_flash_attn_backend
:
if
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
from
aiter
import
flash_attn_varlen_func
else
:
if
self
.
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
attn_output
=
flash_attn_varlen_func
(
attn_output
=
self
.
flash_attn_varlen_func
(
queries
,
keys
,
values
,
cu_seqlens
,
cu_seqlens
,
max_seqlen
,
max_seqlen
).
reshape
(
seq_length
,
-
1
)
elif
self
.
attn_backend
==
_Backend
.
TORCH_SDPA
:
...
...
vllm/platforms/rocm.py
View file @
9c5ee91b
...
...
@@ -189,8 +189,6 @@ class RocmPlatform(Platform):
from
vllm.attention.backends.registry
import
_Backend
if
(
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_MHA
and
on_gfx9
()):
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return
_Backend
.
ROCM_AITER_FA
if
on_gfx9
():
return
_Backend
.
FLASH_ATTN
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment