Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c3a2c6ac
Unverified
Commit
c3a2c6ac
authored
Oct 21, 2025
by
Roger Wang
Committed by
GitHub
Oct 21, 2025
Browse files
[MM][Core] Decouple ViT backend from LM backend (#27061)
Signed-off-by:
Roger Wang
<
hey@rogerw.io
>
parent
72f431e7
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
230 additions
and
17 deletions
+230
-17
tests/config/test_multimodal_config.py
tests/config/test_multimodal_config.py
+25
-0
vllm/attention/layer.py
vllm/attention/layer.py
+10
-1
vllm/config/model.py
vllm/config/model.py
+5
-0
vllm/config/multimodal.py
vllm/config/multimodal.py
+38
-4
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+9
-0
vllm/model_executor/models/dots_ocr.py
vllm/model_executor/models/dots_ocr.py
+17
-2
vllm/model_executor/models/ernie45_vl.py
vllm/model_executor/models/ernie45_vl.py
+15
-1
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+15
-1
vllm/model_executor/models/keye.py
vllm/model_executor/models/keye.py
+18
-1
vllm/model_executor/models/ovis2_5.py
vllm/model_executor/models/ovis2_5.py
+12
-0
vllm/model_executor/models/qwen2_5_vl.py
vllm/model_executor/models/qwen2_5_vl.py
+10
-1
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+15
-1
vllm/model_executor/models/qwen3_omni_moe_thinker.py
vllm/model_executor/models/qwen3_omni_moe_thinker.py
+10
-1
vllm/model_executor/models/qwen3_vl.py
vllm/model_executor/models/qwen3_vl.py
+10
-2
vllm/model_executor/models/siglip2navit.py
vllm/model_executor/models/siglip2navit.py
+12
-1
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+9
-1
No files found.
tests/config/test_multimodal_config.py
0 → 100644
View file @
c3a2c6ac
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
from
vllm.attention.backends.registry
import
_Backend
from
vllm.config.multimodal
import
MultiModalConfig
def
test_mm_encoder_attn_backend_str_conversion
():
config
=
MultiModalConfig
(
mm_encoder_attn_backend
=
"FLASH_ATTN"
)
assert
config
.
mm_encoder_attn_backend
==
_Backend
.
FLASH_ATTN
def
test_mm_encoder_attn_backend_invalid
():
with
pytest
.
raises
(
ValueError
):
MultiModalConfig
(
mm_encoder_attn_backend
=
"not_a_backend"
)
def
test_mm_encoder_attn_backend_hash_updates
():
base_hash
=
MultiModalConfig
().
compute_hash
()
overridden_hash
=
MultiModalConfig
(
mm_encoder_attn_backend
=
_Backend
.
FLASH_ATTN
).
compute_hash
()
assert
base_hash
!=
overridden_hash
vllm/attention/layer.py
View file @
c3a2c6ac
...
...
@@ -16,6 +16,7 @@ from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.utils.kv_sharing_utils
import
validate_kv_sharing_target
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.config.multimodal
import
MultiModalConfig
from
vllm.config.vllm
import
VllmConfig
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
...
...
@@ -443,6 +444,7 @@ class MultiHeadAttention(nn.Module):
# This has no effect, it is only here to make it easier to swap
# between Attention and MultiHeadAttention
prefix
:
str
=
""
,
multimodal_config
:
MultiModalConfig
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_heads
=
num_heads
...
...
@@ -462,7 +464,14 @@ class MultiHeadAttention(nn.Module):
dtype
=
torch
.
get_default_dtype
()
# Determine the attention backend
backend
=
get_vit_attn_backend
(
head_size
=
head_size
,
dtype
=
dtype
)
attn_backend_override
=
None
if
multimodal_config
is
not
None
:
attn_backend_override
=
multimodal_config
.
mm_encoder_attn_backend
backend
=
get_vit_attn_backend
(
head_size
=
head_size
,
dtype
=
dtype
,
attn_backend_override
=
attn_backend_override
,
)
# Some auto-selected backends can be upgraded
# to upstream flash attention if available.
...
...
vllm/config/model.py
View file @
c3a2c6ac
...
...
@@ -50,6 +50,7 @@ if TYPE_CHECKING:
import
vllm.model_executor.layers.quantization
as
me_quant
import
vllm.model_executor.models
as
me_models
from
vllm.attention.backends.registry
import
_Backend
from
vllm.config.load
import
LoadConfig
from
vllm.config.parallel
import
ParallelConfig
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
...
...
@@ -57,6 +58,7 @@ if TYPE_CHECKING:
else
:
PretrainedConfig
=
Any
_Backend
=
Any
me_quant
=
LazyLoader
(
"model_executor"
,
globals
(),
"vllm.model_executor.layers.quantization"
)
...
...
@@ -307,6 +309,7 @@ class ModelConfig:
mm_processor_cache_type
:
InitVar
[
MMCacheType
|
None
]
=
None
mm_shm_cache_max_object_size_mb
:
InitVar
[
int
|
None
]
=
None
mm_encoder_tp_mode
:
InitVar
[
MMEncoderTPMode
|
None
]
=
None
mm_encoder_attn_backend
:
InitVar
[
_Backend
|
str
|
None
]
=
None
interleave_mm_strings
:
InitVar
[
bool
|
None
]
=
None
skip_mm_profiling
:
InitVar
[
bool
|
None
]
=
None
video_pruning_rate
:
InitVar
[
float
|
None
]
=
None
...
...
@@ -424,6 +427,7 @@ class ModelConfig:
mm_processor_cache_type
:
MMCacheType
|
None
,
mm_shm_cache_max_object_size_mb
:
int
|
None
,
mm_encoder_tp_mode
:
MMEncoderTPMode
|
None
,
mm_encoder_attn_backend
:
_Backend
|
str
|
None
,
interleave_mm_strings
:
bool
|
None
,
skip_mm_profiling
:
bool
|
None
,
video_pruning_rate
:
float
|
None
,
...
...
@@ -733,6 +737,7 @@ class ModelConfig:
mm_processor_cache_type
=
mm_processor_cache_type
,
mm_shm_cache_max_object_size_mb
=
mm_shm_cache_max_object_size_mb
,
mm_encoder_tp_mode
=
mm_encoder_tp_mode
,
mm_encoder_attn_backend
=
mm_encoder_attn_backend
,
interleave_mm_strings
=
interleave_mm_strings
,
skip_mm_profiling
=
skip_mm_profiling
,
video_pruning_rate
=
video_pruning_rate
,
...
...
vllm/config/multimodal.py
View file @
c3a2c6ac
...
...
@@ -3,13 +3,18 @@
import
hashlib
from
collections.abc
import
Mapping
from
typing
import
Any
,
Literal
,
TypeAlias
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
TypeAlias
from
pydantic
import
ConfigDict
,
Field
,
field_validator
,
model_validator
from
pydantic.dataclasses
import
dataclass
from
vllm.config.utils
import
config
if
TYPE_CHECKING
:
from
vllm.attention.backends.registry
import
_Backend
else
:
_Backend
=
Any
@
dataclass
class
BaseDummyOptions
:
...
...
@@ -112,6 +117,10 @@ class MultiModalConfig:
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP."""
mm_encoder_attn_backend
:
_Backend
|
None
=
None
"""Optional override for the multi-modal encoder attention backend when
using vision transformers. Accepts any value from
`vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`)."""
interleave_mm_strings
:
bool
=
False
"""Enable fully interleaved support for multimodal prompts, while using
--chat-template-content-format=string."""
...
...
@@ -148,6 +157,29 @@ class MultiModalConfig:
value
[
k
]
=
BaseDummyOptions
(
**
v
)
return
value
@
field_validator
(
"mm_encoder_attn_backend"
,
mode
=
"before"
)
@
classmethod
def
_validate_mm_encoder_attn_backend
(
cls
,
value
:
object
)
->
_Backend
|
None
:
from
vllm.attention.backends.registry
import
(
_Backend
as
BackendEnum
,
)
from
vllm.attention.backends.registry
import
(
backend_name_to_enum
,
)
if
value
is
None
or
isinstance
(
value
,
BackendEnum
):
return
value
if
isinstance
(
value
,
str
):
candidate
=
backend_name_to_enum
(
value
.
upper
())
if
candidate
is
not
None
:
return
candidate
valid_backends
=
", "
.
join
(
sorted
(
BackendEnum
.
__members__
.
keys
()))
raise
ValueError
(
f
"Invalid mm encoder attention backend. Expected one of:
{
valid_backends
}
."
)
@
model_validator
(
mode
=
"after"
)
def
_validate_multimodal_config
(
self
):
if
self
.
mm_processor_cache_type
!=
"shm"
and
(
...
...
@@ -172,9 +204,11 @@ class MultiModalConfig:
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors
:
list
[
Any
]
=
[]
factors
:
list
[
Any
]
=
[
self
.
mm_encoder_attn_backend
.
name
if
self
.
mm_encoder_attn_backend
is
not
None
else
None
]
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()
return
hash_str
...
...
vllm/engine/arg_utils.py
View file @
c3a2c6ac
...
...
@@ -32,6 +32,7 @@ from pydantic.fields import FieldInfo
from
typing_extensions
import
TypeIs
,
deprecated
import
vllm.envs
as
envs
from
vllm.attention.backends.registry
import
_Backend
from
vllm.config
import
(
CacheConfig
,
CompilationConfig
,
...
...
@@ -451,6 +452,9 @@ class EngineArgs:
MultiModalConfig
.
mm_shm_cache_max_object_size_mb
)
mm_encoder_tp_mode
:
MMEncoderTPMode
=
MultiModalConfig
.
mm_encoder_tp_mode
mm_encoder_attn_backend
:
_Backend
|
str
|
None
=
(
MultiModalConfig
.
mm_encoder_attn_backend
)
io_processor_plugin
:
str
|
None
=
None
skip_mm_profiling
:
bool
=
MultiModalConfig
.
skip_mm_profiling
video_pruning_rate
:
float
=
MultiModalConfig
.
video_pruning_rate
...
...
@@ -914,6 +918,10 @@ class EngineArgs:
multimodal_group
.
add_argument
(
"--mm-encoder-tp-mode"
,
**
multimodal_kwargs
[
"mm_encoder_tp_mode"
]
)
multimodal_group
.
add_argument
(
"--mm-encoder-attn-backend"
,
**
multimodal_kwargs
[
"mm_encoder_attn_backend"
],
)
multimodal_group
.
add_argument
(
"--interleave-mm-strings"
,
**
multimodal_kwargs
[
"interleave_mm_strings"
]
)
...
...
@@ -1160,6 +1168,7 @@ class EngineArgs:
mm_processor_cache_type
=
self
.
mm_processor_cache_type
,
mm_shm_cache_max_object_size_mb
=
self
.
mm_shm_cache_max_object_size_mb
,
mm_encoder_tp_mode
=
self
.
mm_encoder_tp_mode
,
mm_encoder_attn_backend
=
self
.
mm_encoder_attn_backend
,
pooler_config
=
self
.
pooler_config
,
override_pooler_config
=
self
.
override_pooler_config
,
logits_processor_pattern
=
self
.
logits_processor_pattern
,
...
...
vllm/model_executor/models/dots_ocr.py
View file @
c3a2c6ac
...
...
@@ -256,6 +256,7 @@ class DotsVisionAttention(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -288,7 +289,9 @@ class DotsVisionAttention(nn.Module):
)
# Select attention backend
self
.
attn_backend
=
get_vit_attn_backend
(
self
.
hidden_size_per_attention_head
,
torch
.
get_default_dtype
()
self
.
hidden_size_per_attention_head
,
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
self
.
use_upstream_fa
=
False
...
...
@@ -510,6 +513,7 @@ class DotsVisionBlock(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
):
super
().
__init__
()
...
...
@@ -521,6 +525,7 @@ class DotsVisionBlock(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
use_data_parallel
=
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
self
.
norm1
=
RMSNorm
(
config
.
embed_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
mlp
=
DotsSwiGLUFFN
(
...
...
@@ -561,6 +566,7 @@ class DotsVisionTransformer(nn.Module):
require_post_norm
:
bool
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -571,7 +577,9 @@ class DotsVisionTransformer(nn.Module):
head_dim
=
config
.
embed_dim
//
config
.
num_attention_heads
self
.
rotary_pos_emb
=
VisionRotaryEmbedding
(
head_dim
//
2
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
()
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
check_upstream_fa_availability
(
torch
.
get_default_dtype
()
...
...
@@ -591,6 +599,7 @@ class DotsVisionTransformer(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
i
}
"
,
use_data_parallel
=
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
for
i
in
range
(
num_layers
)
]
...
...
@@ -750,11 +759,17 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
self
.
config
.
vision_config
=
vision_config
else
:
vision_config
=
self
.
config
.
vision_config
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
vision_tower
=
DotsVisionTransformer
(
vision_config
,
quant_config
=
self
.
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_tower"
),
use_data_parallel
=
self
.
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
self
.
language_model
:
Qwen2ForCausalLM
=
init_vllm_registered_model
(
vllm_config
=
vllm_config
,
...
...
vllm/model_executor/models/ernie45_vl.py
View file @
c3a2c6ac
...
...
@@ -164,6 +164,7 @@ class Ernie4_5_VisionAttention(nn.Module):
projection_size
:
int
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
# Per attention head and per partition values.
...
...
@@ -196,6 +197,7 @@ class Ernie4_5_VisionAttention(nn.Module):
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
self
.
hidden_size_per_attention_head
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
self
.
use_upstream_fa
=
False
...
...
@@ -367,6 +369,7 @@ class Ernie4_5_VisionBlock(nn.Module):
norm_layer
:
Callable
[[
int
],
nn
.
Module
]
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -382,6 +385,7 @@ class Ernie4_5_VisionBlock(nn.Module):
projection_size
=
dim
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_backend_override
=
attn_backend_override
,
)
self
.
mlp
=
Ernie4_5_VisionMLP
(
...
...
@@ -458,6 +462,7 @@ class Ernie4_5_VisionTransformer(nn.Module):
norm_eps
:
float
=
1e-6
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
patch_size
=
vision_config
.
patch_size
...
...
@@ -493,6 +498,7 @@ class Ernie4_5_VisionTransformer(nn.Module):
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
,
attn_backend_override
=
attn_backend_override
,
)
for
layer_idx
in
range
(
depth
)
]
...
...
@@ -504,7 +510,9 @@ class Ernie4_5_VisionTransformer(nn.Module):
self
.
ln
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
1e-6
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
()
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
check_upstream_fa_availability
(
torch
.
get_default_dtype
()
...
...
@@ -1327,11 +1335,17 @@ class Ernie4_5_VLMoeForConditionalGeneration(
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
vision_model
=
Ernie4_5_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
),
attn_backend_override
=
attn_backend_override
,
)
self
.
language_model
=
Ernie4_5_VLMoeForCausalLM
(
...
...
vllm/model_executor/models/glm4_1v.py
View file @
c3a2c6ac
...
...
@@ -247,6 +247,7 @@ class Glm4vVisionAttention(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
# Per attention head and per partition values.
...
...
@@ -287,6 +288,7 @@ class Glm4vVisionAttention(nn.Module):
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
self
.
hidden_size_per_attention_head
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
self
.
use_upstream_fa
=
False
...
...
@@ -417,6 +419,7 @@ class Glm4vVisionBlock(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
if
norm_layer
is
None
:
...
...
@@ -430,6 +433,7 @@ class Glm4vVisionBlock(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
use_data_parallel
=
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
self
.
mlp
=
Glm4vVisionMLP
(
dim
,
...
...
@@ -696,6 +700,7 @@ class Glm4vVisionTransformer(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -731,6 +736,7 @@ class Glm4vVisionTransformer(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
,
use_data_parallel
=
self
.
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
for
layer_idx
in
range
(
depth
)
]
...
...
@@ -759,7 +765,9 @@ class Glm4vVisionTransformer(nn.Module):
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
()
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
check_upstream_fa_availability
(
torch
.
get_default_dtype
()
...
...
@@ -1437,12 +1445,18 @@ class Glm4vForConditionalGeneration(
self
.
multimodal_config
=
multimodal_config
self
.
use_data_parallel
=
multimodal_config
.
mm_encoder_tp_mode
==
"data"
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
visual
=
Glm4vVisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-5
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
use_data_parallel
=
self
.
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
if
config
.
model_type
==
"glm4v"
:
...
...
vllm/model_executor/models/keye.py
View file @
c3a2c6ac
...
...
@@ -353,6 +353,7 @@ class KeyeSiglipAttention(nn.Module):
config
:
PretrainedConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_backend_override
:
_Backend
|
None
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -392,7 +393,9 @@ class KeyeSiglipAttention(nn.Module):
# Detect attention implementation.
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
self
.
head_dim
,
dtype
=
torch
.
get_default_dtype
()
head_size
=
self
.
head_dim
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
self
.
use_upstream_fa
=
False
...
...
@@ -521,6 +524,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
config
:
PretrainedConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_backend_override
:
_Backend
|
None
=
None
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
...
...
@@ -529,6 +533,7 @@ class KeyeSiglipEncoderLayer(nn.Module):
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
attn_backend_override
=
attn_backend_override
,
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
SiglipMLP
(
...
...
@@ -573,6 +578,7 @@ class KeyeSiglipEncoder(nn.Module):
config
:
PretrainedConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_backend_override
:
_Backend
|
None
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -585,6 +591,7 @@ class KeyeSiglipEncoder(nn.Module):
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
,
attn_backend_override
=
attn_backend_override
,
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
]
...
...
@@ -666,6 +673,7 @@ class KeyeSiglipVisionTransformer(nn.Module):
config
:
PretrainedConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_backend_override
:
_Backend
|
None
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -676,6 +684,7 @@ class KeyeSiglipVisionTransformer(nn.Module):
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.encoder"
,
attn_backend_override
=
attn_backend_override
,
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
...
...
@@ -747,6 +756,7 @@ class KeyeSiglipVisionModel(nn.Module):
config
:
PretrainedConfig
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_backend_override
:
_Backend
|
None
=
None
,
):
super
().
__init__
()
...
...
@@ -754,6 +764,7 @@ class KeyeSiglipVisionModel(nn.Module):
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.vision_model"
,
attn_backend_override
=
attn_backend_override
,
)
self
.
quant_config
=
quant_config
...
...
@@ -1296,10 +1307,16 @@ class BaseKeyeModule(nn.Module):
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
visual
=
KeyeSiglipVisionModel
(
config
.
vision_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
attn_backend_override
=
attn_backend_override
,
)
self
.
mlp_AR
=
self
.
_build_projector
(
...
...
vllm/model_executor/models/ovis2_5.py
View file @
c3a2c6ac
...
...
@@ -10,6 +10,7 @@ import torch
import
torch.nn
as
nn
from
transformers
import
BaseImageProcessor
,
BatchFeature
,
PretrainedConfig
from
vllm.attention.backends.registry
import
_Backend
from
vllm.config
import
VllmConfig
from
vllm.config.multimodal
import
BaseDummyOptions
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
...
...
@@ -105,6 +106,7 @@ class VisualTokenizer(torch.nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -113,6 +115,7 @@ class VisualTokenizer(torch.nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.vit"
,
use_data_parallel
=
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
# reserved tokens for INDICATOR_IDS
head_dim
=
visual_vocab_size
-
len
(
INDICATOR_IDS
)
...
...
@@ -132,6 +135,7 @@ class VisualTokenizer(torch.nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
):
model_type
=
config
.
model_type
if
model_type
==
"siglip2_navit"
:
...
...
@@ -140,6 +144,7 @@ class VisualTokenizer(torch.nn.Module):
quant_config
=
quant_config
,
prefix
=
prefix
,
use_data_parallel
=
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
raise
ValueError
(
f
"Unsupported visual tokenizer model_type:
{
model_type
}
"
)
...
...
@@ -457,6 +462,7 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
:
PretrainedConfig
=
config
self
.
llm
=
init_vllm_registered_model
(
...
...
@@ -464,11 +470,17 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
prefix
=
maybe_prefix
(
prefix
,
"llm"
),
)
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
visual_tokenizer
=
VisualTokenizer
(
config
=
config
.
vit_config
,
visual_vocab_size
=
config
.
visual_vocab_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.visual_tokenizer"
,
attn_backend_override
=
attn_backend_override
,
)
self
.
vte
=
VisualEmbedding
(
config
.
visual_vocab_size
,
config
.
hidden_size
)
...
...
vllm/model_executor/models/qwen2_5_vl.py
View file @
c3a2c6ac
...
...
@@ -637,6 +637,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -669,7 +670,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
use_upstream_fa
=
False
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
()
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
if
(
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
...
...
@@ -1226,12 +1229,18 @@ class Qwen2_5_VLForConditionalGeneration(
if
multimodal_config
.
get_limit_per_prompt
(
"image"
)
or
multimodal_config
.
get_limit_per_prompt
(
"video"
):
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
visual
=
Qwen2_5_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
self
.
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
use_data_parallel
=
self
.
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
else
:
self
.
visual
=
None
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
c3a2c6ac
...
...
@@ -320,6 +320,7 @@ class Qwen2VisionAttention(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
# Per attention head and per partition values.
...
...
@@ -355,6 +356,7 @@ class Qwen2VisionAttention(nn.Module):
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
self
.
hidden_size_per_attention_head
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
self
.
use_upstream_fa
=
False
...
...
@@ -497,6 +499,7 @@ class Qwen2VisionBlock(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
if
norm_layer
is
None
:
...
...
@@ -512,6 +515,7 @@ class Qwen2VisionBlock(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
use_data_parallel
=
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
self
.
mlp
=
Qwen2VisionMLP
(
dim
,
...
...
@@ -662,6 +666,7 @@ class Qwen2VisionTransformer(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
...
...
@@ -703,6 +708,7 @@ class Qwen2VisionTransformer(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
,
use_data_parallel
=
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
for
layer_idx
in
range
(
depth
)
]
...
...
@@ -716,7 +722,9 @@ class Qwen2VisionTransformer(nn.Module):
use_data_parallel
=
use_data_parallel
,
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
()
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
check_upstream_fa_availability
(
torch
.
get_default_dtype
()
...
...
@@ -1356,12 +1364,18 @@ class Qwen2VLForConditionalGeneration(
if
multimodal_config
.
get_limit_per_prompt
(
"image"
)
or
multimodal_config
.
get_limit_per_prompt
(
"video"
):
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
visual
=
Qwen2VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
use_data_parallel
=
self
.
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
else
:
self
.
visual
=
None
...
...
vllm/model_executor/models/qwen3_omni_moe_thinker.py
View file @
c3a2c6ac
...
...
@@ -296,6 +296,7 @@ class Qwen3Omni_VisionTransformer(nn.Module):
norm_eps
:
float
=
1e-6
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
vision_config
.
hidden_size
...
...
@@ -367,7 +368,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
()
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
check_upstream_fa_availability
(
torch
.
get_default_dtype
()
...
...
@@ -1144,11 +1147,17 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
self
.
audio_tower
=
Qwen3OmniMoeAudioEncoder
(
thinker_config
.
audio_config
)
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
visual
=
Qwen3Omni_VisionTransformer
(
vision_config
=
thinker_config
.
vision_config
,
norm_eps
=
getattr
(
thinker_config
.
text_config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
attn_backend_override
=
attn_backend_override
,
)
self
.
quant_config
=
quant_config
...
...
vllm/model_executor/models/qwen3_vl.py
View file @
c3a2c6ac
...
...
@@ -300,6 +300,7 @@ class Qwen3_VisionTransformer(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
vision_config
.
hidden_size
...
...
@@ -359,7 +360,9 @@ class Qwen3_VisionTransformer(nn.Module):
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
()
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
use_upstream_fa
=
False
if
(
...
...
@@ -379,7 +382,6 @@ class Qwen3_VisionTransformer(nn.Module):
raise
RuntimeError
(
f
"Qwen3-VL does not support
{
self
.
attn_backend
}
backend now."
)
self
.
blocks
=
nn
.
ModuleList
(
[
Qwen3_VisionBlock
(
...
...
@@ -1214,12 +1216,18 @@ class Qwen3VLForConditionalGeneration(
)
and
not
multimodal_config
.
get_limit_per_prompt
(
"video"
):
self
.
visual
=
None
else
:
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
visual
=
Qwen3_VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
use_data_parallel
=
self
.
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
self
.
language_model
=
Qwen3LLMForCausalLM
(
...
...
vllm/model_executor/models/siglip2navit.py
View file @
c3a2c6ac
...
...
@@ -208,6 +208,7 @@ class Siglip2Attention(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -248,7 +249,9 @@ class Siglip2Attention(nn.Module):
# Detect attention implementation.
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
self
.
head_dim
,
dtype
=
torch
.
get_default_dtype
()
head_size
=
self
.
head_dim
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
self
.
use_upstream_fa
=
False
...
...
@@ -372,6 +375,7 @@ class Siglip2EncoderLayer(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
...
...
@@ -381,6 +385,7 @@ class Siglip2EncoderLayer(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
use_data_parallel
=
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
Siglip2MLP
(
...
...
@@ -434,6 +439,7 @@ class Siglip2Encoder(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -444,6 +450,7 @@ class Siglip2Encoder(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
idx
}
"
,
use_data_parallel
=
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
for
idx
in
range
(
config
.
num_hidden_layers
)
]
...
...
@@ -618,6 +625,7 @@ class Siglip2VisionTransformer(nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -629,6 +637,7 @@ class Siglip2VisionTransformer(nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.encoder"
,
use_data_parallel
=
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
...
...
@@ -657,6 +666,7 @@ class Siglip2NavitModel(torch.nn.Module):
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
use_data_parallel
:
bool
=
False
,
attn_backend_override
:
_Backend
|
None
=
None
,
):
super
().
__init__
()
...
...
@@ -665,6 +675,7 @@ class Siglip2NavitModel(torch.nn.Module):
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.vision_model"
,
use_data_parallel
=
use_data_parallel
,
attn_backend_override
=
attn_backend_override
,
)
def
forward
(
...
...
vllm/model_executor/models/vision.py
View file @
c3a2c6ac
...
...
@@ -78,10 +78,18 @@ def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInf
raise
NotImplementedError
(
msg
)
def
get_vit_attn_backend
(
head_size
:
int
,
dtype
:
torch
.
dtype
)
->
_Backend
:
def
get_vit_attn_backend
(
head_size
:
int
,
dtype
:
torch
.
dtype
,
*
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
_Backend
:
"""
Get the available attention backend for Vision Transformer.
"""
if
attn_backend_override
is
not
None
:
return
attn_backend_override
# Lazy import to avoid circular dependency
from
vllm.attention.selector
import
get_env_variable_attn_backend
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment