Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
72fc8aa4
Unverified
Commit
72fc8aa4
authored
Sep 12, 2025
by
Wenlong Wang
Committed by
GitHub
Sep 12, 2025
Browse files
[Multi Modal] Add FA3 in VIT (#24347)
Signed-off-by:
wwl2755
<
wangwenlong2755@gmail.com
>
parent
fdb09c77
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
247 additions
and
66 deletions
+247
-66
tests/entrypoints/openai/test_vision.py
tests/entrypoints/openai/test_vision.py
+2
-2
tests/kernels/attention/test_mha_attn.py
tests/kernels/attention/test_mha_attn.py
+35
-14
vllm/attention/layer.py
vllm/attention/layer.py
+67
-8
vllm/model_executor/models/ernie45_vl.py
vllm/model_executor/models/ernie45_vl.py
+20
-3
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+19
-3
vllm/model_executor/models/keye.py
vllm/model_executor/models/keye.py
+15
-2
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+21
-3
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+21
-3
vllm/model_executor/models/siglip2navit.py
vllm/model_executor/models/siglip2navit.py
+14
-2
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+5
-5
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+17
-11
vllm/platforms/interface.py
vllm/platforms/interface.py
+2
-1
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+9
-9
No files found.
tests/entrypoints/openai/test_vision.py
View file @
72fc8aa4
...
@@ -34,11 +34,11 @@ EXPECTED_MM_BEAM_SEARCH_RES = [
...
@@ -34,11 +34,11 @@ EXPECTED_MM_BEAM_SEARCH_RES = [
],
],
[
[
"The image shows a Venn diagram with three over"
,
"The image shows a Venn diagram with three over"
,
"Th
e
image shows a Venn diagram with three
intersect
"
,
"Th
is
image shows a Venn diagram with three
over
"
,
],
],
[
[
"This image displays a gradient of colors ranging from"
,
"This image displays a gradient of colors ranging from"
,
"Th
e
image displays a gradient of colors
ranging fro
m"
,
"Th
is
image displays a gradient of colors
forming a spectru
m"
,
],
],
]
]
...
...
tests/kernels/attention/test_mha_attn.py
View file @
72fc8aa4
...
@@ -36,31 +36,52 @@ def test_mha_attn_platform(device: str):
...
@@ -36,31 +36,52 @@ def test_mha_attn_platform(device: str):
torch
.
set_default_dtype
(
torch
.
float16
)
torch
.
set_default_dtype
(
torch
.
float16
)
if
device
==
"cpu"
:
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.
selecto
r.current_platform"
,
with
patch
(
"vllm.attention.
laye
r.current_platform"
,
CpuPlatform
()),
\
CpuP
latform
()),
\
patch
(
"vllm.model_executor.models.vision.current_p
latform
"
,
patch
(
"vllm.platforms.current_platform"
,
CpuPlatform
()):
CpuPlatform
()):
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
_Backend
.
TORCH_SDPA
_VLLM_V1
assert
attn
.
attn_backend
==
_Backend
.
TORCH_SDPA
elif
device
==
"hip"
:
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.current_platform"
,
with
patch
(
"vllm.attention.layer.current_platform"
,
RocmPlatform
()),
\
RocmPlatform
()),
\
patch
(
"vllm.model_executor.models.vision.current_platform"
,
patch
(
"vllm.platforms.current_platform"
,
RocmPlatform
()),
\
RocmPlatform
()):
patch
(
"vllm.attention.layer.current_platform"
,
RocmPlatform
()):
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
_Backend
.
TORCH_SDPA
assert
attn
.
attn_backend
==
_Backend
.
TORCH_SDPA
else
:
else
:
with
patch
(
"vllm.attention.selector.current_platform"
,
# Test CUDA with head_size=64 (divisible by 32)
CudaPlatform
()),
\
# - should use vLLM's FlashAttention
patch
(
"vllm.platforms.current_platform"
,
CudaPlatform
()):
with
patch
(
"vllm.attention.layer.current_platform"
,
CudaPlatform
()),
\
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CudaPlatform
()):
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
attn
=
MultiHeadAttention
(
16
,
64
,
scale
=
1
)
assert
attn
.
attn_backend
==
_Backend
.
XFORMERS
assert
attn
.
attn_backend
==
_Backend
.
FLASH_ATTN
with
patch
(
"vllm.attention.selector.current_platform"
,
# Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA not available
# - should use xformers
with
patch
(
"vllm.attention.layer.current_platform"
,
CudaPlatform
()),
\
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CudaPlatform
()),
\
CudaPlatform
()),
\
patch
(
"vllm.platforms.current_platform"
,
CudaPlatform
()):
patch
(
"vllm.attention.layer.check_upstream_fa_availability"
,
return_value
=
False
):
attn
=
MultiHeadAttention
(
16
,
72
,
scale
=
1
)
attn
=
MultiHeadAttention
(
16
,
72
,
scale
=
1
)
assert
attn
.
attn_backend
==
_Backend
.
XFORMERS
assert
attn
.
attn_backend
==
_Backend
.
XFORMERS
# Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA available
# - should use upstream FA
with
patch
(
"vllm.attention.layer.current_platform"
,
CudaPlatform
()),
\
patch
(
"vllm.model_executor.models.vision.current_platform"
,
CudaPlatform
()),
\
patch
(
"vllm.attention.layer.check_upstream_fa_availability"
,
return_value
=
True
),
\
patch
.
dict
(
'sys.modules'
,
{
'flash_attn'
:
type
(
'MockFlashAttn'
,
(),
{
'flash_attn_varlen_func'
:
lambda
*
args
,
**
kwargs
:
None
})()}):
attn
=
MultiHeadAttention
(
16
,
72
,
scale
=
1
)
assert
attn
.
attn_backend
==
_Backend
.
FLASH_ATTN
def
ref_attention
(
def
ref_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
...
vllm/attention/layer.py
View file @
72fc8aa4
...
@@ -23,6 +23,7 @@ from vllm.model_executor.layers.linear import UnquantizedLinearMethod
...
@@ -23,6 +23,7 @@ from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.models.vision
import
get_vit_attn_backend
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
...
@@ -55,6 +56,14 @@ def check_xformers_availability():
...
@@ -55,6 +56,14 @@ def check_xformers_availability():
return
USE_XFORMERS_OPS
return
USE_XFORMERS_OPS
def
check_upstream_fa_availability
(
dtype
:
torch
.
dtype
):
if
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
)
and
current_platform
.
is_cuda
(
)
and
current_platform
.
has_device_capability
(
80
):
from
transformers.utils
import
is_flash_attn_2_available
return
is_flash_attn_2_available
()
return
False
class
Attention
(
nn
.
Module
,
AttentionLayerBase
):
class
Attention
(
nn
.
Module
,
AttentionLayerBase
):
"""Attention layer.
"""Attention layer.
...
@@ -349,29 +358,55 @@ class MultiHeadAttention(nn.Module):
...
@@ -349,29 +358,55 @@ class MultiHeadAttention(nn.Module):
f
"divisible by num_kv_heads (
{
self
.
num_kv_heads
}
)"
f
"divisible by num_kv_heads (
{
self
.
num_kv_heads
}
)"
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype
=
torch
.
get_default_dtype
()
dtype
=
torch
.
get_default_dtype
()
attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
# Determine the attention backend
kv_cache_dtype
=
None
,
backend
=
get_vit_attn_backend
(
head_size
=
head_size
,
dtype
=
dtype
)
block_size
=
16
,
is_attention_free
=
False
)
# Some auto-selected backends can be upgraded
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
# to upstream flash attention if available.
# If vllm native fa is selected, we use it directly.
use_upstream_fa
=
False
if
backend
!=
_Backend
.
FLASH_ATTN
and
check_upstream_fa_availability
(
dtype
):
backend
=
_Backend
.
FLASH_ATTN
use_upstream_fa
=
True
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
# currently, only torch_sdpa is supported on rocm
# currently, only torch_sdpa is supported on rocm
self
.
attn_backend
=
_Backend
.
TORCH_SDPA
self
.
attn_backend
=
_Backend
.
TORCH_SDPA
else
:
else
:
self
.
attn_backend
=
backend
if
backend
in
{
self
.
attn_backend
=
backend
if
backend
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
TORCH_SDPA
,
_Backend
.
TORCH_SDPA_VLLM_V1
,
_Backend
.
TORCH_SDPA_VLLM_V1
,
_Backend
.
XFORMERS
,
_Backend
.
XFORMERS
,
_Backend
.
PALLAS_VLLM_V1
,
_Backend
.
PALLAS_VLLM_V1
,
_Backend
.
ROCM_AITER_FA
,
_Backend
.
ROCM_AITER_FA
,
}
else
current_platform
.
get_vit_attn_backend
()
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
,
}
else
_Backend
.
TORCH_SDPA
if
(
self
.
attn_backend
==
_Backend
.
XFORMERS
if
(
self
.
attn_backend
==
_Backend
.
XFORMERS
and
not
check_xformers_availability
()):
and
not
check_xformers_availability
()):
self
.
attn_backend
=
_Backend
.
TORCH_SDPA
self
.
attn_backend
=
_Backend
.
TORCH_SDPA
if
self
.
attn_backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
}:
if
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
self
.
_flash_attn_varlen_func
=
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
self
.
_flash_attn_varlen_func
=
flash_attn_varlen_func
logger
.
info_once
(
f
"MultiHeadAttention attn_backend:
{
self
.
attn_backend
}
, "
f
"use_upstream_fa:
{
use_upstream_fa
}
"
)
def
forward
(
def
forward
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -392,7 +427,31 @@ class MultiHeadAttention(nn.Module):
...
@@ -392,7 +427,31 @@ class MultiHeadAttention(nn.Module):
key
=
torch
.
repeat_interleave
(
key
,
num_repeat
,
dim
=
2
)
key
=
torch
.
repeat_interleave
(
key
,
num_repeat
,
dim
=
2
)
value
=
torch
.
repeat_interleave
(
value
,
num_repeat
,
dim
=
2
)
value
=
torch
.
repeat_interleave
(
value
,
num_repeat
,
dim
=
2
)
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
if
self
.
attn_backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
,
}:
cu_seqlens_q
=
torch
.
arange
(
0
,
(
bsz
+
1
)
*
q_len
,
step
=
q_len
,
dtype
=
torch
.
int32
,
device
=
query
.
device
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
bsz
+
1
)
*
kv_len
,
step
=
kv_len
,
dtype
=
torch
.
int32
,
device
=
key
.
device
)
out
=
self
.
_flash_attn_varlen_func
(
query
.
flatten
(
0
,
1
),
key
.
flatten
(
0
,
1
),
value
.
flatten
(
0
,
1
),
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
q_len
,
max_seqlen_k
=
kv_len
,
softmax_scale
=
self
.
scale
,
)
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
query
,
out
=
xops
.
memory_efficient_attention_forward
(
query
,
...
...
vllm/model_executor/models/ernie45_vl.py
View file @
72fc8aa4
...
@@ -34,6 +34,7 @@ import torch.nn.functional as F
...
@@ -34,6 +34,7 @@ import torch.nn.functional as F
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
transformers
import
BatchFeature
from
transformers
import
BatchFeature
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.distributed
import
utils
as
dist_utils
...
@@ -170,7 +171,16 @@ class Ernie4_5_VisionAttention(nn.Module):
...
@@ -170,7 +171,16 @@ class Ernie4_5_VisionAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.proj"
)
prefix
=
f
"
{
prefix
}
.proj"
)
# Detect attention implementation.
# Detect attention implementation.
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
self
.
hidden_size_per_attention_head
,
dtype
=
torch
.
get_default_dtype
())
self
.
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
use_upstream_fa
=
True
if
self
.
attn_backend
not
in
{
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
ROCM_AITER_FA
_Backend
.
ROCM_AITER_FA
...
@@ -233,7 +243,10 @@ class Ernie4_5_VisionAttention(nn.Module):
...
@@ -233,7 +243,10 @@ class Ernie4_5_VisionAttention(nn.Module):
if
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
if
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
from
aiter
import
flash_attn_varlen_func
from
aiter
import
flash_attn_varlen_func
else
:
else
:
if
self
.
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
...
@@ -457,7 +470,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
...
@@ -457,7 +470,11 @@ class Ernie4_5_VisionTransformer(nn.Module):
),
"vit's config.hidden must be equal to config.embed_dim"
),
"vit's config.hidden must be equal to config.embed_dim"
self
.
ln
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
1e-6
)
self
.
ln
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
1e-6
)
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
())
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
@
property
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
def
dtype
(
self
)
->
torch
.
dtype
:
...
...
vllm/model_executor/models/glm4_1v.py
View file @
72fc8aa4
...
@@ -44,6 +44,7 @@ from transformers.models.glm4v.video_processing_glm4v import (
...
@@ -44,6 +44,7 @@ from transformers.models.glm4v.video_processing_glm4v import (
Glm4vVideoProcessor
)
Glm4vVideoProcessor
)
from
transformers.video_utils
import
VideoMetadata
from
transformers.video_utils
import
VideoMetadata
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
parallel_state
)
parallel_state
)
...
@@ -260,7 +261,15 @@ class Glm4vVisionAttention(nn.Module):
...
@@ -260,7 +261,15 @@ class Glm4vVisionAttention(nn.Module):
)
)
# Detect attention implementation.
# Detect attention implementation.
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
self
.
hidden_size_per_attention_head
,
dtype
=
torch
.
get_default_dtype
())
self
.
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
use_upstream_fa
=
True
if
self
.
attn_backend
not
in
{
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
TORCH_SDPA
,
...
@@ -310,7 +319,10 @@ class Glm4vVisionAttention(nn.Module):
...
@@ -310,7 +319,10 @@ class Glm4vVisionAttention(nn.Module):
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
# from vllm_flash_attn.flash_attn_interface import (
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
# flash_attn_varlen_func)
if
self
.
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
...
@@ -715,7 +727,11 @@ class Glm4vVisionTransformer(nn.Module):
...
@@ -715,7 +727,11 @@ class Glm4vVisionTransformer(nn.Module):
self
.
post_layernorm
=
RMSNorm
(
vision_config
.
hidden_size
,
self
.
post_layernorm
=
RMSNorm
(
vision_config
.
hidden_size
,
eps
=
vision_config
.
rms_norm_eps
)
eps
=
vision_config
.
rms_norm_eps
)
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
())
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
@
property
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
def
dtype
(
self
)
->
torch
.
dtype
:
...
...
vllm/model_executor/models/keye.py
View file @
72fc8aa4
...
@@ -17,6 +17,7 @@ from transformers.modeling_outputs import (BaseModelOutput,
...
@@ -17,6 +17,7 @@ from transformers.modeling_outputs import (BaseModelOutput,
BaseModelOutputWithPooling
)
BaseModelOutputWithPooling
)
from
transformers.utils
import
torch_int
from
transformers.utils
import
torch_int
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -374,7 +375,16 @@ class KeyeSiglipAttention(nn.Module):
...
@@ -374,7 +375,16 @@ class KeyeSiglipAttention(nn.Module):
)
)
# Detect attention implementation.
# Detect attention implementation.
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
self
.
head_dim
,
dtype
=
torch
.
get_default_dtype
())
self
.
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
use_upstream_fa
=
True
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
XFORMERS
}:
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
XFORMERS
}:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Keye-VL does not support
{
self
.
attn_backend
}
backend now."
)
f
"Keye-VL does not support
{
self
.
attn_backend
}
backend now."
)
...
@@ -428,7 +438,10 @@ class KeyeSiglipAttention(nn.Module):
...
@@ -428,7 +438,10 @@ class KeyeSiglipAttention(nn.Module):
)
)
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
if
self
.
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
72fc8aa4
...
@@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
...
@@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from
transformers.models.qwen2_5_vl.configuration_qwen2_5_vl
import
(
from
transformers.models.qwen2_5_vl.configuration_qwen2_5_vl
import
(
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
)
Qwen2_5_VLConfig
,
Qwen2_5_VLVisionConfig
)
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.distributed
import
utils
as
dist_utils
...
@@ -298,7 +299,16 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -298,7 +299,16 @@ class Qwen2_5_VisionAttention(nn.Module):
disable_tp
=
use_data_parallel
)
disable_tp
=
use_data_parallel
)
# Detect attention implementation.
# Detect attention implementation.
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
self
.
hidden_size_per_attention_head
,
dtype
=
torch
.
get_default_dtype
())
self
.
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
use_upstream_fa
=
True
if
self
.
attn_backend
not
in
{
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
ROCM_AITER_FA
_Backend
.
ROCM_AITER_FA
...
@@ -359,7 +369,10 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -359,7 +369,10 @@ class Qwen2_5_VisionAttention(nn.Module):
if
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
if
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
from
aiter
import
flash_attn_varlen_func
from
aiter
import
flash_attn_varlen_func
else
:
else
:
if
self
.
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
...
@@ -628,7 +641,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -628,7 +641,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
prefix
=
f
"
{
prefix
}
.merger"
,
prefix
=
f
"
{
prefix
}
.merger"
,
use_data_parallel
=
use_data_parallel
,
use_data_parallel
=
use_data_parallel
,
)
)
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
())
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
@
property
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
def
dtype
(
self
)
->
torch
.
dtype
:
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
72fc8aa4
...
@@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
...
@@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from
transformers.models.qwen2_vl.video_processing_qwen2_vl
import
(
from
transformers.models.qwen2_vl.video_processing_qwen2_vl
import
(
Qwen2VLVideoProcessor
)
Qwen2VLVideoProcessor
)
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
parallel_state
,
tensor_model_parallel_all_gather
from
vllm.distributed
import
parallel_state
,
tensor_model_parallel_all_gather
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.distributed
import
utils
as
dist_utils
...
@@ -314,7 +315,16 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -314,7 +315,16 @@ class Qwen2VisionAttention(nn.Module):
prefix
=
f
"
{
prefix
}
.proj"
)
prefix
=
f
"
{
prefix
}
.proj"
)
# Detect attention implementation.
# Detect attention implementation.
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
self
.
hidden_size_per_attention_head
,
dtype
=
torch
.
get_default_dtype
())
self
.
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
use_upstream_fa
=
True
if
self
.
attn_backend
not
in
{
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
ROCM_AITER_FA
_Backend
.
ROCM_AITER_FA
...
@@ -374,7 +384,10 @@ class Qwen2VisionAttention(nn.Module):
...
@@ -374,7 +384,10 @@ class Qwen2VisionAttention(nn.Module):
if
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
if
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
from
aiter
import
flash_attn_varlen_func
from
aiter
import
flash_attn_varlen_func
else
:
else
:
if
self
.
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
q
,
k
,
v
=
(
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
])
...
@@ -628,7 +641,12 @@ class Qwen2VisionTransformer(nn.Module):
...
@@ -628,7 +641,12 @@ class Qwen2VisionTransformer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.merger"
,
prefix
=
f
"
{
prefix
}
.merger"
,
)
)
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
())
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
@
property
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
def
dtype
(
self
)
->
torch
.
dtype
:
...
...
vllm/model_executor/models/siglip2navit.py
View file @
72fc8aa4
...
@@ -13,6 +13,7 @@ from torch.nn import functional as F
...
@@ -13,6 +13,7 @@ from torch.nn import functional as F
from
transformers
import
Siglip2VisionConfig
from
transformers
import
Siglip2VisionConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.config
import
QuantizationConfig
from
vllm.config
import
QuantizationConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
...
@@ -236,7 +237,15 @@ class Siglip2Attention(nn.Module):
...
@@ -236,7 +237,15 @@ class Siglip2Attention(nn.Module):
self
.
use_rope
=
config
.
use_rope
self
.
use_rope
=
config
.
use_rope
# Detect attention implementation.
# Detect attention implementation.
self
.
attn_backend
:
_Backend
=
get_vit_attn_backend
(
support_fa
=
True
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
self
.
head_dim
,
dtype
=
torch
.
get_default_dtype
())
self
.
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
use_upstream_fa
=
True
if
self
.
attn_backend
not
in
{
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
ROCM_AITER_FA
_Backend
.
ROCM_AITER_FA
...
@@ -280,7 +289,10 @@ class Siglip2Attention(nn.Module):
...
@@ -280,7 +289,10 @@ class Siglip2Attention(nn.Module):
if
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
if
self
.
attn_backend
==
_Backend
.
ROCM_AITER_FA
:
from
aiter
import
flash_attn_varlen_func
from
aiter
import
flash_attn_varlen_func
else
:
else
:
if
self
.
use_upstream_fa
:
from
flash_attn
import
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
else
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
attn_output
=
flash_attn_varlen_func
(
attn_output
=
flash_attn_varlen_func
(
queries
,
keys
,
values
,
cu_seqlens
,
cu_seqlens
,
max_seqlen
,
queries
,
keys
,
values
,
cu_seqlens
,
cu_seqlens
,
max_seqlen
,
max_seqlen
).
reshape
(
seq_length
,
-
1
)
max_seqlen
).
reshape
(
seq_length
,
-
1
)
...
...
vllm/model_executor/models/vision.py
View file @
72fc8aa4
...
@@ -7,7 +7,6 @@ from typing import Final, Generic, Optional, Protocol, TypeVar, Union
...
@@ -7,7 +7,6 @@ from typing import Final, Generic, Optional, Protocol, TypeVar, Union
import
torch
import
torch
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.attention.selector
import
get_env_variable_attn_backend
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
...
@@ -68,17 +67,18 @@ def get_vision_encoder_info(
...
@@ -68,17 +67,18 @@ def get_vision_encoder_info(
raise
NotImplementedError
(
msg
)
raise
NotImplementedError
(
msg
)
def
get_vit_attn_backend
(
support_fa
:
bool
=
Fals
e
)
->
_Backend
:
def
get_vit_attn_backend
(
head_size
:
int
,
dtype
:
torch
.
dtyp
e
)
->
_Backend
:
"""
"""
Get the available attention backend for Vision Transformer.
Get the available attention backend for Vision Transformer.
"""
"""
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
# Lazy import to avoid circular dependency
from
vllm.attention.selector
import
get_env_variable_attn_backend
selected_backend
:
Optional
[
_Backend
]
=
get_env_variable_attn_backend
()
selected_backend
:
Optional
[
_Backend
]
=
get_env_variable_attn_backend
()
if
selected_backend
is
not
None
:
if
selected_backend
is
not
None
:
return
selected_backend
return
selected_backend
return
current_platform
.
get_vit_attn_backend
(
support_fa
)
return
current_platform
.
get_vit_attn_backend
(
head_size
,
dtype
)
def
resolve_visual_encoder_outputs
(
def
resolve_visual_encoder_outputs
(
...
...
vllm/platforms/cuda.py
View file @
72fc8aa4
...
@@ -209,16 +209,22 @@ class CudaPlatformBase(Platform):
...
@@ -209,16 +209,22 @@ class CudaPlatformBase(Platform):
return
torch
.
cuda
.
max_memory_allocated
(
device
)
return
torch
.
cuda
.
max_memory_allocated
(
device
)
@
classmethod
@
classmethod
def
get_vit_attn_backend
(
cls
,
support_fa
:
bool
=
False
)
->
_Backend
:
def
get_vit_attn_backend
(
cls
,
head_size
:
int
,
if
cls
.
has_device_capability
(
80
)
and
support_fa
:
dtype
:
torch
.
dtype
)
->
_Backend
:
from
transformers.utils
import
is_flash_attn_2_available
if
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
if
is_flash_attn_2_available
():
return
_Backend
.
XFORMERS
if
cls
.
has_device_capability
(
80
):
FLASH_ATTN_V1
=
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
# noqa: E501
from
vllm.attention.selector
import
is_attn_backend_supported
is_default_fa_supported
=
is_attn_backend_supported
(
FLASH_ATTN_V1
,
head_size
,
dtype
,
allow_import_error
=
False
)
if
is_default_fa_supported
:
return
_Backend
.
FLASH_ATTN
return
_Backend
.
FLASH_ATTN
logger
.
warning_once
(
else
:
"Current `vllm-flash-attn` has a bug inside vision "
# Fallback to XFORMERS
"module, so we use xformers backend instead. You can "
return
_Backend
.
XFORMERS
"run `pip install flash-attn` to use flash-attention "
else
:
"backend."
)
# Fallback for Volta/Turing GPUs or FA not supported
# Fallback for Volta/Turing GPUs or FA not supported
return
_Backend
.
XFORMERS
return
_Backend
.
XFORMERS
...
...
vllm/platforms/interface.py
View file @
72fc8aa4
...
@@ -192,7 +192,8 @@ class Platform:
...
@@ -192,7 +192,8 @@ class Platform:
return
device_id
return
device_id
@
classmethod
@
classmethod
def
get_vit_attn_backend
(
cls
,
support_fa
:
bool
=
False
)
->
_Backend
:
def
get_vit_attn_backend
(
cls
,
head_size
:
int
,
dtype
:
torch
.
dtype
)
->
_Backend
:
return
_Backend
.
TORCH_SDPA
return
_Backend
.
TORCH_SDPA
@
classmethod
@
classmethod
...
...
vllm/platforms/rocm.py
View file @
72fc8aa4
...
@@ -175,8 +175,8 @@ class RocmPlatform(Platform):
...
@@ -175,8 +175,8 @@ class RocmPlatform(Platform):
]
]
@
classmethod
@
classmethod
def
get_vit_attn_backend
(
cls
,
support_fa
:
bool
=
False
)
->
_Backend
:
def
get_vit_attn_backend
(
cls
,
head_size
:
int
,
if
support_fa
:
dtype
:
torch
.
dtype
)
->
_Backend
:
if
(
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_MHA
if
(
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_MHA
and
on_gfx9
()):
and
on_gfx9
()):
# Note: AITER FA is only supported for Qwen-VL models.
# Note: AITER FA is only supported for Qwen-VL models.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment