Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c242c980
Unverified
Commit
c242c980
authored
Sep 26, 2025
by
Wentao Ye
Committed by
GitHub
Sep 26, 2025
Browse files
[Bugfix] Allow Only SDPA Backend for ViT on B200 for Qwen3-VL (#25788)
parent
f1d53d15
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
75 additions
and
51 deletions
+75
-51
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+37
-36
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+38
-15
No files found.
vllm/model_executor/models/qwen2_5_vl.py
View file @
c242c980
...
@@ -274,6 +274,8 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -274,6 +274,8 @@ class Qwen2_5_VisionAttention(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
use_data_parallel
:
bool
=
False
,
attn_backend
:
_Backend
=
_Backend
.
TORCH_SDPA
,
use_upstream_fa
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
# Per attention head and per partition values.
# Per attention head and per partition values.
...
@@ -300,25 +302,8 @@ class Qwen2_5_VisionAttention(nn.Module):
...
@@ -300,25 +302,8 @@ class Qwen2_5_VisionAttention(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.proj"
,
prefix
=
f
"
{
prefix
}
.proj"
,
disable_tp
=
use_data_parallel
)
disable_tp
=
use_data_parallel
)
self
.
attn_backend
=
attn_backend
# Detect attention implementation.
self
.
use_upstream_fa
=
use_upstream_fa
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
self
.
hidden_size_per_attention_head
,
dtype
=
torch
.
get_default_dtype
())
self
.
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
use_upstream_fa
=
True
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
ROCM_AITER_FA
}:
raise
RuntimeError
(
f
"Qwen2.5-VL does not support
{
self
.
attn_backend
}
backend now."
)
self
.
is_flash_attn_backend
=
self
.
attn_backend
in
{
self
.
is_flash_attn_backend
=
self
.
attn_backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
ROCM_AITER_FA
_Backend
.
FLASH_ATTN
,
_Backend
.
ROCM_AITER_FA
}
}
...
@@ -443,6 +428,8 @@ class Qwen2_5_VisionBlock(nn.Module):
...
@@ -443,6 +428,8 @@ class Qwen2_5_VisionBlock(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
use_data_parallel
:
bool
=
False
,
attn_backend
:
_Backend
=
_Backend
.
TORCH_SDPA
,
use_upstream_fa
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
norm_layer
is
None
:
if
norm_layer
is
None
:
...
@@ -455,7 +442,9 @@ class Qwen2_5_VisionBlock(nn.Module):
...
@@ -455,7 +442,9 @@ class Qwen2_5_VisionBlock(nn.Module):
projection_size
=
dim
,
projection_size
=
dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
use_data_parallel
=
use_data_parallel
)
use_data_parallel
=
use_data_parallel
,
attn_backend
=
attn_backend
,
use_upstream_fa
=
use_upstream_fa
)
self
.
mlp
=
Qwen2_5_VisionMLP
(
dim
,
self
.
mlp
=
Qwen2_5_VisionMLP
(
dim
,
mlp_hidden_dim
,
mlp_hidden_dim
,
act_fn
=
act_fn
,
act_fn
=
act_fn
,
...
@@ -627,17 +616,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -627,17 +616,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
head_dim
=
self
.
hidden_size
//
self
.
num_heads
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_pos_emb
=
Qwen2_5_VisionRotaryEmbedding
(
head_dim
//
2
)
self
.
rotary_pos_emb
=
Qwen2_5_VisionRotaryEmbedding
(
head_dim
//
2
)
use_upstream_fa
=
False
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
())
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
use_upstream_fa
=
True
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
ROCM_AITER_FA
}:
raise
RuntimeError
(
f
"Qwen2.5-VL does not support
{
self
.
attn_backend
}
backend now."
)
self
.
blocks
=
nn
.
ModuleList
([
self
.
blocks
=
nn
.
ModuleList
([
Qwen2_5_VisionBlock
(
dim
=
self
.
hidden_size
,
Qwen2_5_VisionBlock
(
num_heads
=
self
.
num_heads
,
dim
=
self
.
hidden_size
,
mlp_hidden_dim
=
vision_config
.
intermediate_size
,
num_heads
=
self
.
num_heads
,
act_fn
=
get_act_and_mul_fn
(
mlp_hidden_dim
=
vision_config
.
intermediate_size
,
vision_config
.
hidden_act
),
act_fn
=
get_act_and_mul_fn
(
vision_config
.
hidden_act
),
norm_layer
=
norm_layer
,
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
,
use_data_parallel
=
use_data_parallel
)
use_data_parallel
=
use_data_parallel
,
for
layer_idx
in
range
(
depth
)
attn_backend
=
self
.
attn_backend
,
use_upstream_fa
=
use_upstream_fa
)
for
layer_idx
in
range
(
depth
)
])
])
self
.
merger
=
Qwen2_5_VisionPatchMerger
(
self
.
merger
=
Qwen2_5_VisionPatchMerger
(
d_model
=
vision_config
.
out_hidden_size
,
d_model
=
vision_config
.
out_hidden_size
,
...
@@ -648,12 +655,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
...
@@ -648,12 +655,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
prefix
=
f
"
{
prefix
}
.merger"
,
prefix
=
f
"
{
prefix
}
.merger"
,
use_data_parallel
=
use_data_parallel
,
use_data_parallel
=
use_data_parallel
,
)
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
())
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
@
property
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
def
dtype
(
self
)
->
torch
.
dtype
:
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
c242c980
...
@@ -63,7 +63,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -63,7 +63,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement
,
PromptUpdate
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.platforms
import
_Backend
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
...
@@ -158,6 +158,8 @@ class Qwen3_VisionBlock(nn.Module):
...
@@ -158,6 +158,8 @@ class Qwen3_VisionBlock(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
use_data_parallel
:
bool
=
False
,
attn_backend
:
_Backend
=
_Backend
.
TORCH_SDPA
,
use_upstream_fa
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
norm_layer
is
None
:
if
norm_layer
is
None
:
...
@@ -170,7 +172,9 @@ class Qwen3_VisionBlock(nn.Module):
...
@@ -170,7 +172,9 @@ class Qwen3_VisionBlock(nn.Module):
projection_size
=
dim
,
projection_size
=
dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
use_data_parallel
=
use_data_parallel
)
use_data_parallel
=
use_data_parallel
,
attn_backend
=
attn_backend
,
use_upstream_fa
=
use_upstream_fa
)
self
.
mlp
=
Qwen3_VisionMLP
(
dim
,
self
.
mlp
=
Qwen3_VisionMLP
(
dim
,
mlp_hidden_dim
,
mlp_hidden_dim
,
act_fn
=
act_fn
,
act_fn
=
act_fn
,
...
@@ -287,19 +291,6 @@ class Qwen3_VisionTransformer(nn.Module):
...
@@ -287,19 +291,6 @@ class Qwen3_VisionTransformer(nn.Module):
head_dim
=
self
.
hidden_size
//
self
.
num_heads
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_pos_emb
=
Qwen2_5_VisionRotaryEmbedding
(
head_dim
//
2
)
self
.
rotary_pos_emb
=
Qwen2_5_VisionRotaryEmbedding
(
head_dim
//
2
)
self
.
blocks
=
nn
.
ModuleList
([
Qwen3_VisionBlock
(
dim
=
self
.
hidden_size
,
num_heads
=
self
.
num_heads
,
mlp_hidden_dim
=
vision_config
.
intermediate_size
,
act_fn
=
_ACTIVATION_REGISTRY
[
vision_config
.
hidden_act
],
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
,
use_data_parallel
=
use_data_parallel
)
for
layer_idx
in
range
(
vision_config
.
depth
)
])
self
.
merger
=
Qwen3_VisionPatchMerger
(
self
.
merger
=
Qwen3_VisionPatchMerger
(
d_model
=
vision_config
.
out_hidden_size
,
d_model
=
vision_config
.
out_hidden_size
,
context_dim
=
self
.
hidden_size
,
context_dim
=
self
.
hidden_size
,
...
@@ -325,10 +316,42 @@ class Qwen3_VisionTransformer(nn.Module):
...
@@ -325,10 +316,42 @@ class Qwen3_VisionTransformer(nn.Module):
self
.
attn_backend
=
get_vit_attn_backend
(
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
())
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
())
use_upstream_fa
=
False
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
\
check_upstream_fa_availability
(
check_upstream_fa_availability
(
torch
.
get_default_dtype
()):
torch
.
get_default_dtype
()):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
use_upstream_fa
=
True
if
self
.
attn_backend
not
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
ROCM_AITER_FA
}:
raise
RuntimeError
(
f
"Qwen3-VL does not support
{
self
.
attn_backend
}
backend now."
)
if
current_platform
.
is_device_capability
(
100
)
and
self
.
attn_backend
!=
_Backend
.
TORCH_SDPA
:
# TODO(Roger/Wentao): remove this after FA
# or XFORMERS's issue fixed on Blackwell
logger
.
info_once
(
"Qwen3-VL vision attention does not support "
f
"
{
self
.
attn_backend
}
backend on Blackwell now. "
"Vision attention backend is set to TORCH_SDPA."
)
self
.
attn_backend
=
_Backend
.
TORCH_SDPA
self
.
blocks
=
nn
.
ModuleList
([
Qwen3_VisionBlock
(
dim
=
self
.
hidden_size
,
num_heads
=
self
.
num_heads
,
mlp_hidden_dim
=
vision_config
.
intermediate_size
,
act_fn
=
_ACTIVATION_REGISTRY
[
vision_config
.
hidden_act
],
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
,
use_data_parallel
=
use_data_parallel
,
attn_backend
=
self
.
attn_backend
,
use_upstream_fa
=
use_upstream_fa
)
for
layer_idx
in
range
(
vision_config
.
depth
)
])
@
property
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
def
dtype
(
self
)
->
torch
.
dtype
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment