Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e9cb2241
Commit
e9cb2241
authored
Oct 28, 2025
by
zhuwenwen
Browse files
Add Qwen3-Omni moe thinker
parent
f007cd06
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1827 additions
and
17 deletions
+1827
-17
docs/models/supported_models.md
docs/models/supported_models.md
+2
-1
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+1
-0
tests/models/registry.py
tests/models/registry.py
+3
-0
vllm/model_executor/layers/rotary_embedding/mrope.py
vllm/model_executor/layers/rotary_embedding/mrope.py
+14
-14
vllm/model_executor/models/qwen3_omni_moe_thinker.py
vllm/model_executor/models/qwen3_omni_moe_thinker.py
+1743
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+62
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-1
No files found.
docs/models/supported_models.md
View file @
e9cb2241
...
@@ -689,6 +689,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
...
@@ -689,6 +689,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
`Qwen2_5OmniThinkerForConditionalGeneration`
| Qwen2.5-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen2.5-Omni-3B`
,
`Qwen/Qwen2.5-Omni-7B`
| ✅︎ | ✅︎ | ✅︎ |
|
`Qwen2_5OmniThinkerForConditionalGeneration`
| Qwen2.5-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen2.5-Omni-3B`
,
`Qwen/Qwen2.5-Omni-7B`
| ✅︎ | ✅︎ | ✅︎ |
|
`Qwen3VLForConditionalGeneration`
| Qwen3-VL | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3-VL-4B-Instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Qwen3VLForConditionalGeneration`
| Qwen3-VL | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3-VL-4B-Instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Qwen3VLMoeForConditionalGeneration`
| Qwen3-VL-MOE | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3-VL-30B-A3B-Instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Qwen3VLMoeForConditionalGeneration`
| Qwen3-VL-MOE | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3-VL-30B-A3B-Instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Qwen3OmniMoeThinkerForConditionalGeneration`
| Qwen3-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen3-Omni-30B-A3B-Instruct`
,
`Qwen/Qwen3-Omni-30B-A3B-Thinking`
| ✅︎ | ✅︎ | ✅︎ |
|
`RForConditionalGeneration`
| R-VL-4B | T + I
<sup>
E+
</sup>
|
`YannQi/R-4B`
| | ✅︎ | ✅︎ |
|
`RForConditionalGeneration`
| R-VL-4B | T + I
<sup>
E+
</sup>
|
`YannQi/R-4B`
| | ✅︎ | ✅︎ |
|
`SkyworkR1VChatModel`
| Skywork-R1V-38B | T + I |
`Skywork/Skywork-R1V-38B`
| | ✅︎ | ✅︎ |
|
`SkyworkR1VChatModel`
| Skywork-R1V-38B | T + I |
`Skywork/Skywork-R1V-38B`
| | ✅︎ | ✅︎ |
|
`SmolVLMForConditionalGeneration`
| SmolVLM2 | T + I |
`SmolVLM2-2.2B-Instruct`
| ✅︎ | | ✅︎ |
|
`SmolVLMForConditionalGeneration`
| SmolVLM2 | T + I |
`SmolVLM2-2.2B-Instruct`
| ✅︎ | | ✅︎ |
...
@@ -779,7 +780,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th
...
@@ -779,7 +780,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
!!! note
!!! note
For Qwen2.5-Omni, reading audio from video pre-processing (
`--mm-processor-kwargs '{"use_audio_in_video": true}'`
)
For Qwen2.5-Omni
and Qwen3-Omni
, reading audio from video pre-processing (
`--mm-processor-kwargs '{"use_audio_in_video": true}'`
)
is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
#### Transcription
#### Transcription
...
...
tests/models/multimodal/processing/test_common.py
View file @
e9cb2241
...
@@ -362,6 +362,7 @@ def _test_processing_correctness_one(
...
@@ -362,6 +362,7 @@ def _test_processing_correctness_one(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2.5-Omni-3B"
),
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2.5-Omni-3B"
),
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-VL-4B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-VL-4B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-VL-30B-A3B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-VL-30B-A3B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-Omni-30B-A3B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"YannQi/R-4B"
),
os
.
path
.
join
(
models_path_prefix
,
"YannQi/R-4B"
),
os
.
path
.
join
(
models_path_prefix
,
"Skywork/Skywork-R1V-38B"
),
os
.
path
.
join
(
models_path_prefix
,
"Skywork/Skywork-R1V-38B"
),
os
.
path
.
join
(
models_path_prefix
,
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
),
...
...
tests/models/registry.py
View file @
e9cb2241
...
@@ -580,6 +580,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -580,6 +580,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_model_len
=
4096
,
max_model_len
=
4096
,
min_transformers_version
=
"4.57"
,
min_transformers_version
=
"4.57"
,
is_available_online
=
False
),
is_available_online
=
False
),
"Qwen3OmniMoeForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-Omni-30B-A3B-Instruct"
),
max_model_len
=
4096
,
min_transformers_version
=
"4.57"
),
"RForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"YannQi/R-4B"
),
"RForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"YannQi/R-4B"
),
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"SkyworkR1VChatModel"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Skywork/Skywork-R1V-38B"
),
"SkyworkR1VChatModel"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Skywork/Skywork-R1V-38B"
),
...
...
vllm/model_executor/layers/rotary_embedding/mrope.py
View file @
e9cb2241
...
@@ -429,7 +429,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -429,7 +429,7 @@ class MRotaryEmbedding(RotaryEmbedding):
use_audio_in_video
:
bool
=
False
,
use_audio_in_video
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
)
->
tuple
[
torch
.
Tensor
,
int
]:
from
vllm.transformers_utils.config
import
thinker_uses_mrope
from
vllm.transformers_utils.config
import
thinker_uses_mrope
if
thinker_uses_mrope
(
hf_config
):
if
thinker_uses_mrope
(
hf_config
)
and
hf_config
.
model_type
==
"qwen2_5_omni"
:
return
cls
.
_omni_get_input_positions_tensor
(
return
cls
.
_omni_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
hf_config
=
hf_config
,
...
@@ -899,7 +899,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -899,7 +899,7 @@ class MRotaryEmbedding(RotaryEmbedding):
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
(
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
t_index
=
(
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)).
long
().
flatten
()
-
1
,
llm_grid_h
*
llm_grid_w
)).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
...
@@ -1003,8 +1003,8 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1003,8 +1003,8 @@ class MRotaryEmbedding(RotaryEmbedding):
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
(
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
t_index
=
(
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
*
video_second_per_grid_t
*
-
1
,
llm_grid_h
*
llm_grid_w
)
*
video_second_per_grid_t
*
tokens_per_second
).
long
().
flatten
()
tokens_per_second
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
...
@@ -1027,7 +1027,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1027,7 +1027,7 @@ class MRotaryEmbedding(RotaryEmbedding):
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
llm_positions
=
llm_positions
[:,
context_len
:
seq_len
]
return
llm_positions
,
mrope_position_delta
return
llm_positions
,
mrope_position_delta
@
classmethod
@
classmethod
def
_omni_get_input_positions_tensor
(
def
_omni_get_input_positions_tensor
(
cls
,
cls
,
...
@@ -1061,6 +1061,11 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1061,6 +1061,11 @@ class MRotaryEmbedding(RotaryEmbedding):
# _vl_get_input_positions_tensor.
# _vl_get_input_positions_tensor.
thinker_config
=
hf_config
.
thinker_config
thinker_config
=
hf_config
.
thinker_config
if
isinstance
(
image_grid_thw
,
list
):
image_grid_thw
=
torch
.
tensor
(
image_grid_thw
)
if
isinstance
(
video_grid_thw
,
list
):
video_grid_thw
=
torch
.
tensor
(
video_grid_thw
)
audio_token_id
=
thinker_config
.
audio_token_index
audio_token_id
=
thinker_config
.
audio_token_index
image_token_id
=
thinker_config
.
image_token_index
image_token_id
=
thinker_config
.
image_token_index
video_token_id
=
thinker_config
.
video_token_index
video_token_id
=
thinker_config
.
video_token_index
...
@@ -1073,11 +1078,6 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1073,11 +1078,6 @@ class MRotaryEmbedding(RotaryEmbedding):
tokens_per_second
=
getattr
(
thinker_config
.
vision_config
,
tokens_per_second
=
getattr
(
thinker_config
.
vision_config
,
"tokens_per_second"
,
25
)
"tokens_per_second"
,
25
)
if
isinstance
(
image_grid_thw
,
list
):
image_grid_thw
=
torch
.
tensor
(
image_grid_thw
)
if
isinstance
(
video_grid_thw
,
list
):
video_grid_thw
=
torch
.
tensor
(
video_grid_thw
)
src_item
=
input_tokens
src_item
=
input_tokens
audio_seqlens
=
audio_feature_lengths
audio_seqlens
=
audio_feature_lengths
if
not
second_per_grid_ts
:
if
not
second_per_grid_ts
:
...
@@ -1121,7 +1121,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1121,7 +1121,7 @@ class MRotaryEmbedding(RotaryEmbedding):
grid_t
=
image_grid_thw
[
image_idx
][
0
]
grid_t
=
image_grid_thw
[
image_idx
][
0
]
grid_hs
=
image_grid_thw
[:,
1
]
grid_hs
=
image_grid_thw
[:,
1
]
grid_ws
=
image_grid_thw
[:,
2
]
grid_ws
=
image_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
1
*
tokens_per_second
).
long
()
t_index
=
torch
.
arange
(
grid_t
)
*
1
*
tokens_per_second
llm_pos_ids
=
cls
.
_get_llm_pos_ids_for_vision
(
llm_pos_ids
=
cls
.
_get_llm_pos_ids_for_vision
(
start_idx
,
image_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
start_idx
,
image_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
grid_ws
)
...
@@ -1136,7 +1136,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1136,7 +1136,7 @@ class MRotaryEmbedding(RotaryEmbedding):
grid_ws
=
video_grid_thw
[:,
2
]
grid_ws
=
video_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
t_index
=
(
torch
.
arange
(
grid_t
)
*
second_per_grid_ts
[
video_idx
]
*
second_per_grid_ts
[
video_idx
]
*
tokens_per_second
)
.
long
()
tokens_per_second
)
llm_pos_ids
=
cls
.
_get_llm_pos_ids_for_vision
(
llm_pos_ids
=
cls
.
_get_llm_pos_ids_for_vision
(
start_idx
,
video_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
start_idx
,
video_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
grid_ws
)
...
@@ -1159,7 +1159,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1159,7 +1159,7 @@ class MRotaryEmbedding(RotaryEmbedding):
t_ntoken_per_chunk
=
int
(
tokens_per_second
*
seconds_per_chunk
)
t_ntoken_per_chunk
=
int
(
tokens_per_second
*
seconds_per_chunk
)
t_index
=
(
torch
.
arange
(
grid_t
)
*
t_index
=
(
torch
.
arange
(
grid_t
)
*
second_per_grid_ts
[
video_idx
]
*
second_per_grid_ts
[
video_idx
]
*
tokens_per_second
)
.
long
()
tokens_per_second
)
t_index_split_chunk
=
cls
.
_split_list_into_ranges
(
t_index_split_chunk
=
cls
.
_split_list_into_ranges
(
t_index
,
t_ntoken_per_chunk
)
t_index
,
t_ntoken_per_chunk
)
place_num
=
(((
audio_seqlen
-
1
)
//
2
+
1
-
2
)
//
2
+
1
)
+
2
place_num
=
(((
audio_seqlen
-
1
)
//
2
+
1
-
2
)
//
2
+
1
)
+
2
...
@@ -1299,7 +1299,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1299,7 +1299,7 @@ class MRotaryEmbedding(RotaryEmbedding):
grid_w
=
video_grid_thw
[
2
]
grid_w
=
video_grid_thw
[
2
]
t_ntoken_per_chunk
=
int
(
tokens_per_second
*
seconds_per_chunk
)
t_ntoken_per_chunk
=
int
(
tokens_per_second
*
seconds_per_chunk
)
t_index
=
(
torch
.
arange
(
grid_t
)
*
video_second_per_grid_t
*
t_index
=
(
torch
.
arange
(
grid_t
)
*
video_second_per_grid_t
*
tokens_per_second
)
.
long
()
tokens_per_second
)
t_index_split_chunk
=
cls
.
_split_list_into_ranges
(
t_index_split_chunk
=
cls
.
_split_list_into_ranges
(
t_index
,
t_ntoken_per_chunk
)
t_index
,
t_ntoken_per_chunk
)
...
...
vllm/model_executor/models/qwen3_omni_moe_thinker.py
0 → 100644
View file @
e9cb2241
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part)."""
import
os
import
math
from
collections.abc
import
Callable
,
Iterable
,
Mapping
,
Sequence
from
functools
import
partial
from
typing
import
Any
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
packaging.version
import
Version
from
transformers
import
PretrainedConfig
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
transformers.feature_extraction_utils
import
BatchFeature
from
transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe
import
(
Qwen3OmniMoeConfig
,
Qwen3OmniMoeThinkerConfig
,
)
from
transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe
import
(
Qwen3OmniMoeAudioEncoder
,
)
from
transformers.models.qwen3_omni_moe.processing_qwen3_omni_moe
import
(
Qwen3OmniMoeProcessor
,
)
from
transformers.models.whisper
import
WhisperFeatureExtractor
from
vllm.attention.backends.registry
import
_Backend
from
vllm.attention.layer
import
check_upstream_fa_availability
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_pp_group
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
_ACTIVATION_REGISTRY
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.qwen2_audio
import
(
Qwen2AudioFeatureInputs
,
Qwen2AudioProcessingInfo
,
)
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargsItems
from
vllm.multimodal.parse
import
AudioProcessorItems
,
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
MultiModalPromptUpdates
,
PlaceholderFeaturesInfo
,
PromptReplacement
,
PromptUpdate
,
)
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMRoPE
,
SupportsMultiModal
,
SupportsPP
,
)
from
.qwen2_5_omni_thinker
import
(
Qwen2_5OmniConditionalGenerationMixin
,
Qwen2_5OmniThinkerDummyInputsBuilder
,
Qwen2_5OmniThinkerMultiModalProcessor
,
Qwen2_5OmniThinkerProcessingInfo
,
)
from
.qwen2_5_vl
import
(
Qwen2_5_VisionAttention
,
Qwen2_5_VisionRotaryEmbedding
,
Qwen2_5_VLProcessingInfo
,
)
from
.qwen3_moe
import
Qwen3MoeForCausalLM
,
Qwen3MoeModel
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
_merge_multimodal_embeddings
,
maybe_prefix
,
)
from
.vision
import
(
conv3d_to_linear_weight
,
get_llm_pos_ids_for_vision
,
get_vit_attn_backend
,
)
try
:
import
flash_attn
except
(
ImportError
,
ModuleNotFoundError
):
flash_attn
=
None
logger
=
init_logger
(
__name__
)
def
_get_feat_extract_output_lengths
(
input_lengths
:
torch
.
Tensor
):
input_lengths_leave
=
input_lengths
%
100
feat_lengths
=
(
input_lengths_leave
-
1
)
//
2
+
1
output_lengths
=
(
((
feat_lengths
-
1
)
//
2
+
1
-
1
)
//
2
+
1
+
(
input_lengths
//
100
)
*
13
)
return
feat_lengths
,
output_lengths
class
Qwen3_VisionPatchEmbed
(
nn
.
Module
):
def
__init__
(
self
,
patch_size
:
int
=
14
,
temporal_patch_size
:
int
=
2
,
in_channels
:
int
=
3
,
hidden_size
:
int
=
1152
,
)
->
None
:
super
().
__init__
()
self
.
patch_size
=
patch_size
self
.
temporal_patch_size
=
temporal_patch_size
self
.
hidden_size
=
hidden_size
kernel_size
=
(
temporal_patch_size
,
patch_size
,
patch_size
)
self
.
proj
=
ReplicatedLinear
(
in_channels
*
math
.
prod
(
kernel_size
),
hidden_size
,
bias
=
True
,
return_bias
=
False
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
L
,
C
=
x
.
shape
if
os
.
environ
.
get
(
'PYTORCH_MIOPEN_SUGGEST_NDHWC'
)
==
'1'
:
x
=
x
.
to
(
memory_format
=
torch
.
channels_last_3d
)
x
=
self
.
proj
(
x
)
return
x
class
Qwen3_VisionMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
:
int
,
hidden_features
:
int
,
bias
:
bool
=
False
,
act_fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
=
F
.
silu
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
linear_fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
bias
=
bias
,
quant_config
=
quant_config
,
return_bias
=
False
,
prefix
=
f
"
{
prefix
}
.linear_fc1"
,
)
self
.
linear_fc2
=
RowParallelLinear
(
hidden_features
,
in_features
,
bias
=
bias
,
quant_config
=
quant_config
,
return_bias
=
False
,
prefix
=
f
"
{
prefix
}
.linear_fc2"
,
)
self
.
act_fn
=
act_fn
def
forward
(
self
,
x
:
torch
.
Tensor
):
mlp_output
=
self
.
linear_fc2
(
self
.
act_fn
(
self
.
linear_fc1
(
x
)))
return
mlp_output
class
Qwen3_VisionBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
num_heads
:
int
,
mlp_hidden_dim
:
int
,
act_fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
=
F
.
silu
,
norm_layer
:
Callable
[[
int
],
nn
.
Module
]
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
if
norm_layer
is
None
:
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
self
.
norm1
=
norm_layer
(
dim
)
self
.
norm2
=
norm_layer
(
dim
)
self
.
attn
=
Qwen2_5_VisionAttention
(
embed_dim
=
dim
,
num_heads
=
num_heads
,
projection_size
=
dim
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
self
.
mlp
=
Qwen3_VisionMLP
(
dim
,
mlp_hidden_dim
,
act_fn
=
act_fn
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
,
max_seqlen
:
int
|
None
=
None
,
# Only used for Flash Attention
seqlens
:
list
[
int
]
|
None
=
None
,
# Only used for xFormers
)
->
torch
.
Tensor
:
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
x
=
x
+
self
.
mlp
(
self
.
norm2
(
x
))
return
x
class
Qwen3_VisionPatchMerger
(
nn
.
Module
):
def
__init__
(
self
,
d_model
:
int
,
context_dim
:
int
,
norm_layer
:
Callable
[[
int
],
nn
.
Module
]
|
None
=
None
,
spatial_merge_size
:
int
=
2
,
use_postshuffle_norm
:
bool
=
False
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
context_dim
*
(
spatial_merge_size
**
2
)
self
.
use_postshuffle_norm
=
use_postshuffle_norm
if
self
.
use_postshuffle_norm
:
context_dim
=
self
.
hidden_size
if
norm_layer
is
None
:
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
self
.
use_postshuffle_norm
=
use_postshuffle_norm
self
.
ln_q
=
norm_layer
(
self
.
hidden_size
if
use_postshuffle_norm
else
context_dim
)
self
.
mlp
=
nn
.
ModuleList
(
[
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp.0"
,
),
nn
.
GELU
(),
RowParallelLinear
(
self
.
hidden_size
,
d_model
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp.2"
,
),
]
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
use_postshuffle_norm
:
x
=
self
.
ln_q
(
x
.
view
(
-
1
,
self
.
hidden_size
))
else
:
x
=
self
.
ln_q
(
x
).
view
(
-
1
,
self
.
hidden_size
)
mlp_fc1
,
mlp_act
,
mlp_fc2
=
self
.
mlp
x_parallel
,
_
=
mlp_fc1
(
x
)
x_parallel
=
mlp_act
(
x_parallel
)
out
,
_
=
mlp_fc2
(
x_parallel
)
return
out
class
Qwen3Omni_VisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
vision_config
,
norm_eps
:
float
=
1e-6
,
quant_config
:
QuantizationConfig
|
None
=
None
,
prefix
:
str
=
""
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
vision_config
.
hidden_size
self
.
num_heads
=
vision_config
.
num_heads
self
.
image_size
=
vision_config
.
image_size
self
.
patch_size
=
vision_config
.
patch_size
self
.
spatial_merge_size
=
vision_config
.
spatial_merge_size
self
.
spatial_merge_unit
=
self
.
spatial_merge_size
**
2
self
.
temporal_patch_size
=
vision_config
.
temporal_patch_size
self
.
num_grid_per_side
=
self
.
image_size
//
self
.
patch_size
self
.
apply_vit_abs_pos_embed
=
vision_config
.
apply_vit_abs_pos_embed
self
.
deepstack_visual_indexes
=
vision_config
.
deepstack_visual_indexes
self
.
patch_embed
=
Qwen3_VisionPatchEmbed
(
patch_size
=
self
.
patch_size
,
temporal_patch_size
=
self
.
temporal_patch_size
,
in_channels
=
vision_config
.
in_channels
,
hidden_size
=
self
.
hidden_size
,
)
# vit pos embeding, TODO: spatial_patch_size vs patch_size
if
self
.
apply_vit_abs_pos_embed
:
self
.
pos_embed
=
nn
.
Embedding
(
self
.
num_grid_per_side
**
2
,
self
.
hidden_size
)
else
:
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
empty
([
1
,
self
.
num_grid_per_side
**
2
,
self
.
hidden_size
])
)
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
norm_eps
)
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_pos_emb
=
Qwen2_5_VisionRotaryEmbedding
(
head_dim
//
2
)
self
.
blocks
=
nn
.
ModuleList
(
[
Qwen3_VisionBlock
(
dim
=
self
.
hidden_size
,
num_heads
=
self
.
num_heads
,
mlp_hidden_dim
=
vision_config
.
intermediate_size
,
act_fn
=
_ACTIVATION_REGISTRY
[
vision_config
.
hidden_act
],
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
layer_idx
}
"
,
)
for
layer_idx
in
range
(
vision_config
.
depth
)
]
)
self
.
merger
=
Qwen3_VisionPatchMerger
(
d_model
=
vision_config
.
out_hidden_size
,
context_dim
=
self
.
hidden_size
,
norm_layer
=
norm_layer
,
spatial_merge_size
=
self
.
spatial_merge_size
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.merger"
,
)
if
self
.
deepstack_visual_indexes
is
not
None
:
self
.
merger_list
=
nn
.
ModuleList
(
[
Qwen3_VisionPatchMerger
(
d_model
=
vision_config
.
out_hidden_size
,
context_dim
=
self
.
hidden_size
,
spatial_merge_size
=
self
.
spatial_merge_size
,
use_postshuffle_norm
=
True
,
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.merger_list.
{
layer_idx
}
"
,
)
for
layer_idx
in
range
(
len
(
self
.
deepstack_visual_indexes
))
]
)
self
.
attn_backend
=
get_vit_attn_backend
(
head_size
=
head_dim
,
dtype
=
torch
.
get_default_dtype
(),
attn_backend_override
=
attn_backend_override
,
)
if
self
.
attn_backend
!=
_Backend
.
FLASH_ATTN
and
check_upstream_fa_availability
(
torch
.
get_default_dtype
()
):
self
.
attn_backend
=
_Backend
.
FLASH_ATTN
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
patch_embed
.
proj
.
weight
.
dtype
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
patch_embed
.
proj
.
weight
.
device
def
rot_pos_emb
(
self
,
grid_thw
):
pos_ids
=
[]
for
t
,
h
,
w
in
grid_thw
:
hpos_ids
=
torch
.
arange
(
h
).
unsqueeze
(
1
).
expand
(
-
1
,
w
)
hpos_ids
=
hpos_ids
.
reshape
(
h
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
w
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
)
hpos_ids
=
hpos_ids
.
permute
(
0
,
2
,
1
,
3
)
hpos_ids
=
hpos_ids
.
flatten
()
wpos_ids
=
torch
.
arange
(
w
).
unsqueeze
(
0
).
expand
(
h
,
-
1
)
wpos_ids
=
wpos_ids
.
reshape
(
h
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
w
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
)
wpos_ids
=
wpos_ids
.
permute
(
0
,
2
,
1
,
3
)
wpos_ids
=
wpos_ids
.
flatten
()
pos_ids
.
append
(
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
max_grid_size
=
grid_thw
[:,
1
:].
max
()
rotary_pos_emb_full
=
self
.
rotary_pos_emb
(
max_grid_size
)
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
return
rotary_pos_emb
def
fast_pos_embed_interpolate
(
self
,
grid_thw
:
list
[
list
[
int
]])
->
torch
.
Tensor
:
num_grid_per_side
=
self
.
num_grid_per_side
m_size
=
self
.
spatial_merge_size
hidden_dim
=
self
.
pos_embed
.
embedding_dim
outputs
=
[]
for
t
,
h
,
w
in
grid_thw
:
h_idxs
=
torch
.
linspace
(
0
,
num_grid_per_side
-
1
,
h
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
w_idxs
=
torch
.
linspace
(
0
,
num_grid_per_side
-
1
,
w
,
dtype
=
torch
.
float32
,
device
=
self
.
device
)
h_floor
=
h_idxs
.
to
(
torch
.
long
)
w_floor
=
w_idxs
.
to
(
torch
.
long
)
h_ceil
=
torch
.
clamp
(
h_floor
+
1
,
max
=
num_grid_per_side
-
1
)
w_ceil
=
torch
.
clamp
(
w_floor
+
1
,
max
=
num_grid_per_side
-
1
)
dh
=
h_idxs
-
h_floor
dw
=
w_idxs
-
w_floor
# Create meshgrid view for all h, w vars
dh_grid
,
dw_grid
=
torch
.
meshgrid
(
dh
,
dw
,
indexing
=
"ij"
)
h_floor_grid
,
w_floor_grid
=
torch
.
meshgrid
(
h_floor
,
w_floor
,
indexing
=
"ij"
)
h_ceil_grid
,
w_ceil_grid
=
torch
.
meshgrid
(
h_ceil
,
w_ceil
,
indexing
=
"ij"
)
h_floor_grid_idx
=
h_floor_grid
*
num_grid_per_side
h_ceil_grid_idx
=
h_ceil_grid
*
num_grid_per_side
# original computation of weights
# w00 = (1 - dh_grid) * (1 - dw_grid)
# w01 = (1 - dh_grid) * dw_grid
# w10 = dh_grid * (1 - dw_grid)
# w11 = dh_grid * dw_grid
# we reuse w11 here to avoid duplicate
# dh_grid * dw_grid computation
w11
=
dh_grid
*
dw_grid
w10
=
dh_grid
-
w11
w01
=
dw_grid
-
w11
w00
=
1
-
dh_grid
-
dw_grid
+
w11
idx00
=
h_floor_grid_idx
+
w_floor_grid
idx01
=
h_floor_grid_idx
+
w_ceil_grid
idx10
=
h_ceil_grid_idx
+
w_floor_grid
idx11
=
h_ceil_grid_idx
+
w_ceil_grid
indices
=
torch
.
stack
([
idx00
,
idx01
,
idx10
,
idx11
],
dim
=
0
).
reshape
(
4
,
-
1
)
weights
=
torch
.
stack
([
w00
,
w01
,
w10
,
w11
],
dim
=
0
).
reshape
(
4
,
-
1
,
1
)
weights
=
weights
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
embeds
=
self
.
pos_embed
(
indices
)
weighted_embeds
=
embeds
*
weights
p0
,
p1
,
p2
,
p3
=
weighted_embeds
.
unbind
(
dim
=
0
)
combined
=
p0
+
p1
+
p2
+
p3
combined
=
combined
.
view
(
h
*
w
,
hidden_dim
)
repeated
=
combined
.
unsqueeze
(
0
).
expand
(
t
,
-
1
,
-
1
).
contiguous
()
repeated
=
repeated
.
view
(
t
,
h
//
m_size
,
m_size
,
w
//
m_size
,
m_size
,
hidden_dim
)
repeated
=
repeated
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
reshape
(
-
1
,
hidden_dim
)
outputs
.
append
(
repeated
)
return
torch
.
cat
(
outputs
,
dim
=
0
)
def
compute_attn_mask_seqlen
(
self
,
cu_seqlens
:
torch
.
Tensor
,
)
->
tuple
[
int
|
None
,
list
[
int
]
|
None
]:
max_seqlen
,
seqlens
=
None
,
None
if
self
.
attn_backend
==
_Backend
.
FLASH_ATTN
:
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
return
max_seqlen
,
seqlens
def
forward
(
self
,
x
:
torch
.
Tensor
,
grid_thw
:
list
[
list
[
int
]],
)
->
torch
.
Tensor
:
hidden_states
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
hidden_states
=
self
.
patch_embed
(
hidden_states
)
if
self
.
apply_vit_abs_pos_embed
:
pos_embeds
=
self
.
fast_pos_embed_interpolate
(
grid_thw
)
hidden_states
=
hidden_states
+
pos_embeds
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw
)
cu_seqlens
=
torch
.
repeat_interleave
(
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]
).
cumsum
(
dim
=
0
,
dtype
=
grid_thw
.
dtype
if
torch
.
jit
.
is_tracing
()
else
torch
.
int32
,
)
cu_seqlens
=
F
.
pad
(
cu_seqlens
,
(
1
,
0
),
value
=
0
)
hidden_states
=
hidden_states
.
unsqueeze
(
1
)
rotary_pos_emb
=
rotary_pos_emb
.
to
(
hidden_states
.
device
)
max_seqlen
,
seqlens
=
self
.
compute_attn_mask_seqlen
(
cu_seqlens
)
hidden_states_list
=
[]
deepstack_visual_indexes
=
self
.
deepstack_visual_indexes
for
layer_num
,
blk
in
enumerate
(
self
.
blocks
):
hidden_states
=
blk
(
hidden_states
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
,
max_seqlen
=
max_seqlen
,
seqlens
=
seqlens
,
)
if
(
deepstack_visual_indexes
is
not
None
and
layer_num
in
deepstack_visual_indexes
):
hidden_states_list
.
append
(
hidden_states
)
hidden_states
=
self
.
merger
(
hidden_states
)
# processing deepstack
if
deepstack_visual_indexes
is
not
None
:
processed_hidden_states_list
=
[
hidden_states
]
for
idx
,
x
in
enumerate
(
hidden_states_list
):
x
=
self
.
merger_list
[
idx
](
x
)
processed_hidden_states_list
.
append
(
x
)
# we cat the original visual features and deepstack features
# along the feature dim
hidden_states
=
torch
.
cat
(
processed_hidden_states_list
,
dim
=
1
)
# [seq_len, hidden_size * (1 + depth_of_deepstack)]
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"attn.qkv."
,
"attn.q."
,
"q"
),
(
"attn.qkv."
,
"attn.k."
,
"k"
),
(
"attn.qkv."
,
"attn.v."
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
name
.
endswith
(
"patch_embed.proj.weight"
):
loaded_weight
=
conv3d_to_linear_weight
(
loaded_weight
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
-
1
,
"intermediate_tensors"
:
0
,
"inputs_embeds"
:
0
,
"deepstack_input_embeds"
:
0
,
}
)
class
Qwen3MoeLLMModel
(
Qwen3MoeModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
self
.
deepstack_multiscale_layer_start
=
1
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
deepstack_input_embeds
:
IntermediateTensors
|
None
=
None
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
layer_idx
,
layer
in
enumerate
(
self
.
layers
[
self
.
start_layer
:
self
.
end_layer
]
):
layer_idx
=
layer_idx
+
self
.
start_layer
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
residual
,
)
if
deepstack_input_embeds
is
not
None
and
layer_idx
in
range
(
0
,
len
(
deepstack_input_embeds
)
):
hidden_states
=
(
hidden_states
+
deepstack_input_embeds
[
f
"deepstack_input_embeds_
{
layer_idx
}
"
]
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
}
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
Qwen3MoeLLMForCausalLM
(
Qwen3MoeForCausalLM
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
(
Qwen3MoeForCausalLM
,
self
).
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen3MoeLLMModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
embed_tokens
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
class
Qwen3OmniMoeThinkerProcessingInfo
(
Qwen2AudioProcessingInfo
,
Qwen2_5_VLProcessingInfo
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
(
Qwen3OmniMoeConfig
).
thinker_config
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
Qwen3OmniMoeProcessor
:
processor
=
self
.
ctx
.
get_hf_processor
(
Qwen3OmniMoeProcessor
,
use_fast
=
kwargs
.
pop
(
"use_fast"
,
True
),
**
kwargs
,
)
if
not
hasattr
(
processor
,
"audio_token"
):
processor
.
audio_token
=
"<|audio_pad|>"
if
not
hasattr
(
processor
,
"image_token"
):
processor
.
image_token
=
"<|image_pad|>"
if
not
hasattr
(
processor
,
"video_token"
):
processor
.
video_token
=
"<|video_pad|>"
return
processor
def
get_feature_extractor
(
self
,
**
kwargs
:
object
):
hf_processor
=
self
.
get_hf_processor
(
**
kwargs
)
feature_extractor
=
hf_processor
.
feature_extractor
# type: ignore
assert
isinstance
(
feature_extractor
,
WhisperFeatureExtractor
)
return
feature_extractor
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
return
{
"audio"
:
None
,
"image"
:
None
,
"video"
:
None
}
Qwen3OmniMoeThinkerDummyInputsBuilder
=
Qwen2_5OmniThinkerDummyInputsBuilder
class
Qwen3OmniMoeThinkerMultiModalProcessor
(
Qwen2_5OmniThinkerMultiModalProcessor
,
):
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
mm_data
=
dict
(
mm_data
)
audios
=
mm_data
.
pop
(
"audios"
,
[])
def
pad_to_hop_length
(
x
:
np
.
ndarray
,
hop_length
:
int
)
->
np
.
ndarray
:
length
=
x
.
shape
[
-
1
]
if
length
%
hop_length
!=
0
:
pad_length
=
hop_length
-
(
length
%
hop_length
)
x
=
np
.
pad
(
x
,
(
0
,
pad_length
),
mode
=
"constant"
,
constant_values
=
0
)
return
x
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
feature_extractor
=
self
.
info
.
get_feature_extractor
()
hop_length
=
feature_extractor
.
hop_length
if
audios
:
# NOTE: Qwen3-Omni processor accept "audio"
# To make sure the cache works with padding=True, we pre-padded
# the audio to multiple of hop_length.
mm_data
[
"audio"
]
=
[
pad_to_hop_length
(
audio
,
hop_length
)
if
isinstance
(
audio
,
np
.
ndarray
)
else
(
pad_to_hop_length
(
audio
[
0
],
hop_length
),
audio
[
1
])
for
audio
in
audios
]
mm_kwargs
=
dict
(
**
mm_kwargs
,
)
# TODO(Isotr0py): Remove this patch after upstream fix PR
# released and Transformers version update:
# https://github.com/huggingface/transformers/pull/41473
if
(
Version
(
TRANSFORMERS_VERSION
)
<
Version
(
"4.58.0"
)
and
"truncation"
not
in
mm_kwargs
):
mm_kwargs
[
"truncation"
]
=
False
hf_inputs
=
super
().
_call_hf_processor
(
prompt
=
prompt
,
mm_data
=
mm_data
,
mm_kwargs
=
mm_kwargs
,
tok_kwargs
=
tok_kwargs
,
)
if
(
"audio_feature_lengths"
in
hf_inputs
and
"feature_attention_mask"
in
hf_inputs
and
(
audios
:
=
mm_data
.
get
(
"audio"
,
[]))
):
audio_num_frames
=
[]
for
_
,
audio
in
enumerate
(
audios
):
audio_length
=
len
(
audio
[
0
])
if
isinstance
(
audio
,
tuple
)
else
len
(
audio
)
num_frame
=
(
(
audio_length
//
hop_length
)
if
audio_length
%
hop_length
==
0
else
(
audio_length
//
hop_length
-
1
)
)
if
mm_kwargs
.
get
(
"truncation"
,
False
):
num_frame
=
min
(
num_frame
,
feature_extractor
.
n_samples
//
hop_length
)
audio_num_frames
.
append
(
num_frame
)
hf_inputs
[
"feature_attention_mask"
]
=
[
torch
.
ones
(
num_frame
)
for
num_frame
in
audio_num_frames
]
hf_inputs
[
"audio_feature_lengths"
]
=
torch
.
tensor
(
audio_num_frames
)
return
hf_inputs
def
_maybe_apply_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
prompt_ids
:
list
[
int
],
mm_kwargs
:
MultiModalKwargsItems
,
mm_prompt_updates
:
MultiModalPromptUpdates
,
is_update_applied
:
bool
,
)
->
tuple
[
list
[
int
],
str
,
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]]:
"""
Qwen3-Omni reimplements this function to handle `use_audio_in_video`.
"""
mm_item_counts
=
mm_items
.
get_all_counts
()
self
.
_validate_mm_kwargs
(
mm_kwargs
,
mm_item_counts
)
use_audio_in_video
=
False
if
"video"
in
mm_kwargs
:
for
item
in
mm_kwargs
[
"video"
]:
if
item
and
item
[
"use_audio_in_video"
].
data
:
use_audio_in_video
=
True
else
:
use_audio_in_video
=
False
if
use_audio_in_video
and
"video"
in
mm_item_counts
:
assert
"audio"
in
mm_item_counts
mm_item_counts
[
"audio"
]
-=
mm_item_counts
[
"video"
]
# Special case with `use_audio_in_video=True`
if
use_audio_in_video
:
if
is_update_applied
:
prompt_ids
=
self
.
_get_raw_input_ids
(
prompt_ids
,
use_audio_in_video
)
(
prompt_ids
,
mm_placeholders
,
)
=
self
.
_apply_prompt_updates
(
prompt_ids
,
mm_prompt_updates
,
)
self
.
_validate_mm_placeholders
(
mm_placeholders
,
mm_item_counts
)
# normal case with `use_audio_in_video=False`
elif
is_update_applied
:
mm_placeholders
=
self
.
_find_mm_placeholders
(
prompt_ids
,
mm_prompt_updates
,
)
self
.
_validate_mm_placeholders
(
mm_placeholders
,
mm_item_counts
,
)
else
:
prompt_ids
,
mm_placeholders
=
self
.
_apply_prompt_updates
(
prompt_ids
,
mm_prompt_updates
,
)
self
.
_validate_mm_placeholders
(
mm_placeholders
,
mm_item_counts
,
)
return
prompt_ids
,
mm_placeholders
def
get_updates_use_audio_in_video
(
self
,
thinker_config
:
PretrainedConfig
,
audio_len
:
int
,
video_grid_thw
:
list
[
int
]
|
torch
.
Tensor
,
video_second_per_grid_t
:
float
,
)
->
list
[
int
]:
shift
=
0
audio_token_id
=
thinker_config
.
audio_token_id
video_token_id
=
thinker_config
.
video_token_id
audio_start_token_id
=
thinker_config
.
audio_start_token_id
audio_end_token_id
=
thinker_config
.
audio_end_token_id
spatial_merge_size
=
thinker_config
.
vision_config
.
spatial_merge_size
position_id_per_seconds
=
thinker_config
.
position_id_per_seconds
audio_token_indices
=
np
.
arange
(
next
(
iter
([
audio_len
])))
curr_video_grid_thw
=
next
(
iter
([
video_grid_thw
]))
height
=
curr_video_grid_thw
[
1
]
//
spatial_merge_size
width
=
curr_video_grid_thw
[
2
]
//
spatial_merge_size
video_token_indices
=
np
.
arange
(
curr_video_grid_thw
[
0
]).
reshape
(
-
1
,
1
,
1
)
video_token_indices
=
np
.
broadcast_to
(
video_token_indices
,
(
video_token_indices
.
shape
[
0
],
height
,
width
)
).
reshape
(
-
1
)
video_token_indices
=
(
(
video_token_indices
+
shift
)
*
next
(
iter
([
video_second_per_grid_t
]))
*
position_id_per_seconds
)
video_data_index
,
audio_data_index
=
0
,
0
updates
=
[
audio_start_token_id
]
while
video_data_index
<
len
(
video_token_indices
)
and
audio_data_index
<
len
(
audio_token_indices
):
if
(
video_token_indices
[
video_data_index
]
<=
audio_token_indices
[
audio_data_index
]
):
updates
+=
[
video_token_id
]
video_data_index
+=
1
else
:
updates
+=
[
audio_token_id
]
audio_data_index
+=
1
if
video_data_index
<
len
(
video_token_indices
):
updates
+=
[
video_token_id
]
*
(
len
(
video_token_indices
)
-
video_data_index
)
if
audio_data_index
<
len
(
audio_token_indices
):
updates
+=
[
audio_token_id
]
*
(
len
(
audio_token_indices
)
-
audio_data_index
)
updates
+=
[
audio_end_token_id
]
return
updates
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
Any
],
out_mm_kwargs
:
MultiModalKwargsItems
,
)
->
Sequence
[
PromptUpdate
]:
processor
=
self
.
info
.
get_hf_processor
(
**
hf_processor_mm_kwargs
)
tokenizer
=
self
.
info
.
get_tokenizer
()
image_processor
=
self
.
info
.
get_image_processor
(
**
hf_processor_mm_kwargs
)
vocab
=
tokenizer
.
get_vocab
()
audio_token
=
processor
.
audio_token
image_token
=
processor
.
image_token
video_token
=
processor
.
video_token
audio_token_id
=
vocab
[
audio_token
]
image_token_id
=
vocab
[
image_token
]
video_token_id
=
vocab
[
video_token
]
out_mm_data
=
out_mm_kwargs
.
get_data
()
audio_feature_lengths
=
out_mm_data
.
get
(
"audio_feature_lengths"
)
feature_attention_mask
=
out_mm_data
.
get
(
"feature_attention_mask"
)
if
audio_feature_lengths
is
None
and
feature_attention_mask
is
None
:
audio_output_lengths
=
[]
elif
audio_feature_lengths
is
not
None
:
_
,
audio_output_lens
=
_get_feat_extract_output_lengths
(
audio_feature_lengths
)
audio_output_lengths
=
audio_output_lens
.
tolist
()
elif
feature_attention_mask
is
not
None
:
assert
isinstance
(
feature_attention_mask
,
torch
.
Tensor
)
_
,
audio_output_lens
=
_get_feat_extract_output_lengths
(
feature_attention_mask
.
sum
(
-
1
)
)
audio_output_lengths
=
audio_output_lens
.
tolist
()
# number of audios read from video.
audio_in_video_item_idx
=
0
audio_item_idx
=
0
def
get_replacement_qwen2_audio
(
item_idx
:
int
):
nonlocal
audio_item_idx
item_idx
+=
audio_in_video_item_idx
audio_item_idx
+=
1
num_features
=
audio_output_lengths
[
item_idx
]
if
num_features
==
0
:
audios
=
mm_items
.
get_items
(
"audio"
,
AudioProcessorItems
)
audio
=
audios
.
get
(
item_idx
)
raise
ValueError
(
f
"The audio
{
audio
}
(len=
{
len
(
audio
)
}
) is too short "
"to be represented inside the model"
)
return
[
audio_token_id
]
*
num_features
def
get_replacement_qwen2_vision
(
item_idx
:
int
,
modality
:
str
):
grid_thw
=
out_mm_data
[
f
"
{
modality
}
_grid_thw"
][
item_idx
]
assert
isinstance
(
grid_thw
,
torch
.
Tensor
)
merge_length
=
image_processor
.
merge_size
**
2
token_id
=
image_token_id
if
modality
==
"image"
else
video_token_id
return
[
token_id
]
*
(
int
(
grid_thw
.
prod
())
//
merge_length
)
use_audio_in_video
=
hf_processor_mm_kwargs
.
get
(
"use_audio_in_video"
,
False
)
thinker_config
=
self
.
info
.
get_hf_config
()
def
get_replacement_qwen2_use_audio_in_video
(
item_idx
:
int
):
nonlocal
audio_in_video_item_idx
audio_num_features
=
audio_output_lengths
[
audio_item_idx
+
item_idx
]
video_grid_thw
=
out_mm_data
[
"video_grid_thw"
][
item_idx
]
audio_in_video_item_idx
+=
1
second_per_grid_ts
=
hf_processor_mm_kwargs
.
get
(
"second_per_grid_ts"
,
None
)
if
second_per_grid_ts
:
video_second_per_grid_t
=
second_per_grid_ts
[
item_idx
]
else
:
video_second_per_grid_t
=
1.0
return
self
.
get_updates_use_audio_in_video
(
thinker_config
=
thinker_config
,
audio_len
=
audio_num_features
,
video_grid_thw
=
video_grid_thw
,
video_second_per_grid_t
=
video_second_per_grid_t
,
)
video_replacement_fn
=
(
get_replacement_qwen2_use_audio_in_video
if
use_audio_in_video
else
partial
(
get_replacement_qwen2_vision
,
modality
=
"video"
)
)
return
[
PromptReplacement
(
modality
=
"audio"
,
target
=
audio_token
,
replacement
=
get_replacement_qwen2_audio
,
),
PromptReplacement
(
modality
=
"image"
,
target
=
image_token
,
replacement
=
partial
(
get_replacement_qwen2_vision
,
modality
=
"image"
),
),
PromptReplacement
(
modality
=
"video"
,
target
=
video_token
,
replacement
=
video_replacement_fn
,
),
]
def
_validate_mm_placeholders
(
self
,
mm_placeholders
:
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
None
:
BaseMultiModalProcessor
[
Qwen2_5OmniThinkerProcessingInfo
].
_validate_mm_placeholders
(
self
,
mm_placeholders
,
mm_item_counts
)
def
_get_raw_input_ids
(
self
,
token_ids
:
list
[
int
],
use_audio_in_video
:
bool
=
False
,
)
->
list
[
int
]:
tokenizer
=
self
.
info
.
get_tokenizer
()
vision_bos_token
=
tokenizer
.
encode
(
tokenizer
.
vision_bos_token
)[
0
]
vision_eos_token
=
tokenizer
.
encode
(
tokenizer
.
vision_eos_token
)[
0
]
audio_bos_token
=
tokenizer
.
encode
(
tokenizer
.
audio_bos_token
)[
0
]
audio_eos_token
=
tokenizer
.
encode
(
tokenizer
.
audio_eos_token
)[
0
]
audio_token
=
tokenizer
.
encode
(
"<|audio_pad|>"
)[
0
]
image_token
=
tokenizer
.
encode
(
"<|image_pad|>"
)[
0
]
video_token
=
tokenizer
.
encode
(
"<|video_pad|>"
)[
0
]
result
=
token_ids
[:]
if
use_audio_in_video
:
while
True
:
start
=
None
for
i
in
range
(
len
(
result
)
-
1
):
if
result
[
i
:
i
+
2
]
==
[
vision_bos_token
,
audio_bos_token
]:
start
=
i
break
if
start
is
not
None
:
end
=
None
for
i
in
range
(
start
+
2
,
len
(
result
)
-
1
):
if
result
[
i
:
i
+
2
]
==
[
audio_eos_token
,
vision_eos_token
]:
end
=
i
break
if
end
is
not
None
:
result
=
(
result
[:
start
]
+
[
vision_bos_token
,
video_token
,
vision_eos_token
]
+
result
[
end
+
2
:]
)
else
:
break
for
mm_token
in
[
audio_token
,
image_token
,
video_token
]:
compressed
=
[]
for
x
in
result
:
if
x
!=
mm_token
or
(
not
compressed
or
compressed
[
-
1
]
!=
mm_token
):
compressed
.
append
(
x
)
result
=
compressed
return
result
class
Qwen3OmniMoeConditionalGenerationMixin
(
Qwen2_5OmniConditionalGenerationMixin
):
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
name
:
str
,
dim
:
int
=
0
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. Got type:
{
type
(
mm_input
)
}
"
)
if
name
==
"feature_attention_mask"
:
dim
=
-
1
if
isinstance
(
mm_input
,
torch
.
Tensor
):
return
torch
.
concat
(
list
(
mm_input
),
dim
=
dim
)
else
:
if
isinstance
(
mm_input
[
0
],
list
):
return
torch
.
concat
(
[
torch
.
concat
(
mm_input
[
i
],
dim
=
dim
)
for
i
in
range
(
len
(
mm_input
))],
dim
=
dim
,
)
else
:
return
torch
.
concat
(
mm_input
,
dim
=
dim
)
def
_process_audio_input
(
self
,
audio_input
:
Qwen2AudioFeatureInputs
,
audio_hashes
:
list
[
str
]
=
None
,
cached_audio_features
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
input_features
=
audio_input
[
"input_features"
]
audio_feature_lengths
=
audio_input
[
"audio_feature_lengths"
]
if
input_features
.
ndim
==
3
:
assert
input_features
.
shape
[
0
]
==
1
input_features
=
input_features
.
squeeze
(
0
)
if
not
isinstance
(
audio_feature_lengths
,
torch
.
Tensor
):
audio_feature_lengths
=
torch
.
cat
(
audio_feature_lengths
)
if
audio_feature_lengths
.
ndim
==
2
:
audio_feature_lengths
=
audio_feature_lengths
.
reshape
(
-
1
)
audio_feat_lengths
,
audio_output_lengths
=
(
_get_feat_extract_output_lengths
(
audio_feature_lengths
)
)
audio_outputs
=
self
.
audio_tower
(
input_features
.
to
(
self
.
audio_tower
.
dtype
),
feature_lens
=
audio_feature_lengths
,
aftercnn_lens
=
audio_feat_lengths
,
)
audio_features
=
audio_outputs
.
last_hidden_state
return
audio_features
.
split
(
audio_output_lengths
.
tolist
())
@
MULTIMODAL_REGISTRY
.
register_processor
(
Qwen3OmniMoeThinkerMultiModalProcessor
,
info
=
Qwen3OmniMoeThinkerProcessingInfo
,
dummy_inputs
=
Qwen3OmniMoeThinkerDummyInputsBuilder
,
)
class
Qwen3OmniMoeThinkerForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsMRoPE
,
Qwen3OmniMoeConditionalGenerationMixin
,
):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"thinker.lm_head."
:
"language_model.lm_head."
,
"thinker.model."
:
"language_model.model."
,
"thinker."
:
""
,
}
)
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
str
|
None
:
if
modality
.
startswith
(
"image"
):
return
"<|vision_start|><|image_pad|><|vision_end|>"
if
modality
.
startswith
(
"video"
):
return
"<|vision_start|><|video_pad|><|vision_end|>"
if
modality
.
startswith
(
"audio"
):
return
"<|audio_start|><|audio_pad|><|audio_end|>"
raise
ValueError
(
"Only image, video or audio modality is supported"
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
thinker_config
:
Qwen3OmniMoeThinkerConfig
=
(
vllm_config
.
model_config
.
hf_config
.
thinker_config
)
quant_config
=
vllm_config
.
quant_config
multimodal_config
=
vllm_config
.
model_config
.
multimodal_config
self
.
config
=
thinker_config
self
.
multimodal_config
=
multimodal_config
# force "use_flash_attention_2=True" to audio tower to align
# the results.
if
flash_attn
is
not
None
:
audio_config
=
thinker_config
.
audio_config
audio_config
.
_attn_implementation_autoset
=
True
audio_config
.
_attn_implementation
=
"flash_attention_2"
else
:
logger
.
warning
(
"flash_attn is not available, the model may not yield the "
"exactly same result as the transformers implementation "
"in the audio tower part."
)
self
.
audio_tower
=
Qwen3OmniMoeAudioEncoder
(
thinker_config
.
audio_config
)
attn_backend_override
=
(
multimodal_config
.
mm_encoder_attn_backend
if
multimodal_config
is
not
None
else
None
)
self
.
visual
=
Qwen3Omni_VisionTransformer
(
vision_config
=
thinker_config
.
vision_config
,
norm_eps
=
getattr
(
thinker_config
.
text_config
,
"rms_norm_eps"
,
1e-6
),
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"visual"
),
attn_backend_override
=
attn_backend_override
,
)
self
.
quant_config
=
quant_config
self
.
language_model
=
Qwen3MoeLLMForCausalLM
(
vllm_config
=
vllm_config
.
with_hf_config
(
thinker_config
.
text_config
,
architectures
=
[
"Qwen3MoeForCausalLM"
]
),
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
self
.
use_deepstack
=
hasattr
(
thinker_config
.
vision_config
,
"deepstack_visual_indexes"
)
self
.
deepstack_num_level
=
(
len
(
thinker_config
.
vision_config
.
deepstack_visual_indexes
)
if
self
.
use_deepstack
else
0
)
# register buffer for deepstack
self
.
deepstack_input_embeds
=
(
[
torch
.
zeros
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
,
thinker_config
.
text_config
.
hidden_size
,
)
for
_
in
range
(
self
.
deepstack_num_level
)
]
if
self
.
use_deepstack
else
None
)
self
.
visual_dim
=
thinker_config
.
vision_config
.
out_hidden_size
self
.
multiscale_dim
=
self
.
visual_dim
*
self
.
deepstack_num_level
def
_get_deepstack_input_embeds
(
self
,
num_tokens
:
int
)
->
IntermediateTensors
:
# get deepstack_input_embeds from buffer, and clear the buffer
return
IntermediateTensors
(
{
f
"deepstack_input_embeds_
{
idx
}
"
:
self
.
deepstack_input_embeds
[
idx
][
:
num_tokens
]
for
idx
in
range
(
self
.
deepstack_num_level
)
}
)
def
_set_deepstack_input_embeds
(
self
,
deepstack_input_embeds
:
torch
.
Tensor
)
->
None
:
# set deepstack_input_embeds to buffer
num_tokens
=
deepstack_input_embeds
.
size
(
1
)
if
num_tokens
>
self
.
deepstack_input_embeds
[
0
].
size
(
0
):
self
.
deepstack_input_embeds
=
[
torch
.
zeros
(
num_tokens
,
self
.
config
.
text_config
.
hidden_size
,
device
=
self
.
deepstack_input_embeds
[
0
].
device
,
dtype
=
self
.
deepstack_input_embeds
[
0
].
dtype
,
)
for
_
in
range
(
self
.
deepstack_num_level
)
]
for
idx
in
range
(
self
.
deepstack_num_level
):
self
.
deepstack_input_embeds
[
idx
][:
num_tokens
].
copy_
(
deepstack_input_embeds
[
idx
]
)
def
_clear_deepstack_input_embeds
(
self
,
num_tokens
:
int
)
->
None
:
# clear deepstack_input_embeds in buffer
if
num_tokens
>
0
:
for
idx
in
range
(
self
.
deepstack_num_level
):
self
.
deepstack_input_embeds
[
idx
][:
num_tokens
].
zero_
()
def
_parse_and_validate_multimodal_inputs
(
self
,
**
kwargs
:
object
)
->
dict
:
mm_input_by_modality
=
{}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for
input_key
in
kwargs
:
if
(
input_key
in
(
"pixel_values"
,
"image_embeds"
)
and
"image"
not
in
mm_input_by_modality
):
mm_input_by_modality
[
"image"
]
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
(
input_key
in
(
"pixel_values_videos"
,
"video_embeds"
)
and
"video"
not
in
mm_input_by_modality
):
mm_input_by_modality
[
"video"
]
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
if
(
input_key
in
(
"input_audio_features"
)
and
"audio"
not
in
mm_input_by_modality
):
mm_input_by_modality
[
"audio"
]
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
return
mm_input_by_modality
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
|
None
:
mm_input_by_modality
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
mm_input_by_modality
:
return
[]
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings
:
tuple
[
torch
.
Tensor
,
...]
=
()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for
modality
in
mm_input_by_modality
:
multimodal_input
=
mm_input_by_modality
[
modality
]
if
modality
==
"image"
:
image_embeddings
=
self
.
_process_image_input
(
multimodal_input
)
multimodal_embeddings
+=
tuple
(
image_embeddings
)
if
modality
==
"video"
:
video_embeddings
=
self
.
_process_video_input
(
multimodal_input
)
multimodal_embeddings
+=
tuple
(
video_embeddings
)
if
modality
==
"audio"
:
audio_embeddings
=
self
.
_process_audio_input
(
multimodal_input
)
multimodal_embeddings
+=
tuple
(
audio_embeddings
)
return
multimodal_embeddings
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
MultiModalEmbeddings
|
None
=
None
,
*
,
is_multimodal
:
torch
.
Tensor
|
None
=
None
,
handle_oov_mm_token
:
bool
=
False
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
_get_text_embeddings
(
input_ids
,
self
.
language_model
.
get_input_embeddings
,
is_multimodal
=
is_multimodal
,
handle_oov_mm_token
=
handle_oov_mm_token
,
)
if
multimodal_embeddings
is
None
or
len
(
multimodal_embeddings
)
==
0
:
return
inputs_embeds
deepstack_input_embeds
=
None
# TODO (ywang96): support overlapping modalitiy embeddings so that
# `use_audio_in_video` will work on V1.
# split the feat dim to obtain multi-scale visual feature
has_vision_embeddings
=
[
embeddings
.
shape
[
-
1
]
!=
self
.
config
.
text_config
.
hidden_size
for
embeddings
in
multimodal_embeddings
]
if
self
.
visual
.
deepstack_visual_indexes
is
not
None
and
any
(
has_vision_embeddings
):
multiscale_len
=
len
(
self
.
visual
.
deepstack_visual_indexes
)
multimodal_embeddings_multiscale
=
[]
is_vision
=
torch
.
zeros_like
(
is_multimodal
)
mm_positions
=
torch
.
nonzero
(
is_multimodal
,
as_tuple
=
True
)[
0
]
mm_position_idx
=
0
for
index
,
embeddings
in
enumerate
(
multimodal_embeddings
):
num_tokens
=
embeddings
.
shape
[
0
]
current_positions
=
mm_positions
[
mm_position_idx
:
mm_position_idx
+
num_tokens
]
# Vision embeddings
if
embeddings
.
shape
[
-
1
]
!=
self
.
config
.
text_config
.
hidden_size
:
visual_dim
=
embeddings
.
shape
[
-
1
]
//
(
multiscale_len
+
1
)
multi_dim
=
visual_dim
*
multiscale_len
embeddings_main
,
embeddings_multiscale
=
torch
.
split
(
embeddings
,
[
visual_dim
,
multi_dim
],
dim
=-
1
)
multimodal_embeddings
[
index
]
=
embeddings_main
multimodal_embeddings_multiscale
.
append
(
embeddings_multiscale
)
is_vision
[
current_positions
]
=
True
# Audio embeddings
else
:
is_vision
[
current_positions
]
=
False
mm_position_idx
+=
num_tokens
deepstack_input_embeds
=
inputs_embeds
.
new_zeros
(
inputs_embeds
.
size
(
0
),
multiscale_len
*
inputs_embeds
.
size
(
1
)
)
deepstack_input_embeds
=
_merge_multimodal_embeddings
(
inputs_embeds
=
deepstack_input_embeds
,
multimodal_embeddings
=
multimodal_embeddings_multiscale
,
is_multimodal
=
is_vision
,
)
deepstack_input_embeds
=
(
deepstack_input_embeds
.
view
(
inputs_embeds
.
shape
[
0
],
multiscale_len
,
visual_dim
)
.
permute
(
1
,
0
,
2
)
.
contiguous
()
)
self
.
_set_deepstack_input_embeds
(
deepstack_input_embeds
)
inputs_embeds
=
_merge_multimodal_embeddings
(
inputs_embeds
=
inputs_embeds
,
multimodal_embeddings
=
multimodal_embeddings
,
is_multimodal
=
is_multimodal
,
)
return
inputs_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
**
kwargs
:
object
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
if
intermediate_tensors
is
not
None
:
inputs_embeds
=
None
if
(
self
.
use_deepstack
and
inputs_embeds
is
not
None
and
get_pp_group
().
is_first_rank
):
deepstack_input_embeds
=
self
.
_get_deepstack_input_embeds
(
inputs_embeds
.
size
(
0
)
)
else
:
deepstack_input_embeds
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
# args for deepstack
deepstack_input_embeds
=
deepstack_input_embeds
,
)
if
inputs_embeds
is
not
None
and
get_pp_group
().
is_first_rank
:
self
.
_clear_deepstack_input_embeds
(
inputs_embeds
.
size
(
0
))
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
return
self
.
language_model
.
compute_logits
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
[
"talker."
,
"code2wav."
],
)
loaded_weights
=
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
return
loaded_weights
def
get_mrope_input_positions
(
self
,
input_tokens
:
list
[
int
],
hf_config
:
PretrainedConfig
,
image_grid_thw
:
list
[
list
[
int
]]
|
torch
.
Tensor
|
None
,
video_grid_thw
:
list
[
list
[
int
]]
|
torch
.
Tensor
|
None
,
second_per_grid_ts
:
list
[
float
]
|
None
=
None
,
context_len
:
int
=
0
,
seq_len
:
int
|
None
=
None
,
audio_feature_lengths
:
torch
.
Tensor
|
None
=
None
,
use_audio_in_video
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
config
=
hf_config
.
thinker_config
if
isinstance
(
image_grid_thw
,
list
):
image_grid_thw
=
torch
.
tensor
(
image_grid_thw
)
if
isinstance
(
video_grid_thw
,
list
):
video_grid_thw
=
torch
.
tensor
(
video_grid_thw
)
input_ids
=
torch
.
tensor
(
input_tokens
)
if
input_ids
is
None
or
input_ids
.
ndim
!=
1
:
raise
ValueError
(
"_omni3_get_input_positions_tensor expects 1D input_ids"
)
seq_len
=
input_ids
.
shape
[
0
]
if
audio_feature_lengths
is
not
None
and
not
isinstance
(
audio_feature_lengths
,
torch
.
Tensor
):
audio_feature_lengths
=
torch
.
as_tensor
(
audio_feature_lengths
,
dtype
=
torch
.
long
)
if
second_per_grid_ts
is
None
:
if
video_grid_thw
is
not
None
and
video_grid_thw
.
numel
()
>
0
:
second_per_grids
=
torch
.
ones
(
video_grid_thw
.
shape
[
0
],
dtype
=
torch
.
float32
)
else
:
second_per_grids
=
torch
.
tensor
([],
dtype
=
torch
.
float32
)
else
:
second_per_grids
=
torch
.
tensor
(
second_per_grid_ts
,
dtype
=
torch
.
float32
)
spatial_merge_size
=
config
.
vision_config
.
spatial_merge_size
image_token_id
=
config
.
image_token_id
video_token_id
=
config
.
video_token_id
audio_token_id
=
config
.
audio_token_id
vision_start_token_id
=
config
.
vision_start_token_id
audio_start_token_id
=
config
.
audio_start_token_id
position_id_per_seconds
=
config
.
position_id_per_seconds
vision_start_indices
=
torch
.
argwhere
(
input_ids
==
vision_start_token_id
).
squeeze
(
1
)
if
vision_start_indices
.
numel
()
>
0
:
vision_tokens
=
input_ids
[
vision_start_indices
+
1
]
else
:
vision_tokens
=
input_ids
.
new_empty
((
0
,),
dtype
=
input_ids
.
dtype
)
audio_nums
=
torch
.
sum
(
input_ids
==
audio_start_token_id
)
image_nums
=
(
vision_tokens
==
image_token_id
).
sum
()
video_nums
=
(
(
vision_tokens
==
audio_start_token_id
).
sum
()
if
use_audio_in_video
else
(
vision_tokens
==
video_token_id
).
sum
()
)
llm_pos_ids_list
:
list
[
torch
.
Tensor
]
=
[]
st
=
0
image_idx
=
0
video_idx
=
0
audio_idx
=
0
remain_images
,
remain_videos
,
remain_audios
=
image_nums
,
video_nums
,
audio_nums
# noqa: E501
multimodal_nums
=
(
image_nums
+
audio_nums
if
use_audio_in_video
else
image_nums
+
video_nums
+
audio_nums
)
# noqa: E501
for
_
in
range
(
multimodal_nums
):
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
if
(
image_token_id
in
input_tokens
or
video_token_id
in
input_tokens
)
and
(
remain_videos
>
0
or
remain_images
>
0
):
ed_vision_start
=
input_tokens
.
index
(
vision_start_token_id
,
st
)
else
:
ed_vision_start
=
len
(
input_tokens
)
+
1
if
audio_token_id
in
input_tokens
and
remain_audios
>
0
:
ed_audio_start
=
input_tokens
.
index
(
audio_start_token_id
,
st
)
else
:
ed_audio_start
=
len
(
input_tokens
)
+
1
min_ed
=
min
(
ed_vision_start
,
ed_audio_start
)
if
min_ed
==
ed_audio_start
:
text_len
=
min_ed
-
st
if
text_len
!=
0
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
,
dtype
=
torch
.
long
)
.
view
(
1
,
-
1
)
.
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
bos_len
=
1
llm_pos_ids_list
.
append
(
torch
.
arange
(
bos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
_
,
audio_len
=
_get_feat_extract_output_lengths
(
audio_feature_lengths
[
audio_idx
]
)
llm_pos_ids
=
(
torch
.
arange
(
audio_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
eos_len
=
1
llm_pos_ids_list
.
append
(
torch
.
arange
(
eos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
st
+=
text_len
+
bos_len
+
audio_len
+
eos_len
audio_idx
+=
1
remain_audios
-=
1
elif
(
min_ed
==
ed_vision_start
and
input_ids
[
ed_vision_start
+
1
]
==
image_token_id
):
text_len
=
min_ed
-
st
if
text_len
!=
0
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
,
dtype
=
torch
.
long
)
.
view
(
1
,
-
1
)
.
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
bos_len
=
1
llm_pos_ids_list
.
append
(
torch
.
arange
(
bos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
grid_t
=
image_grid_thw
[
image_idx
][
0
]
grid_hs
=
image_grid_thw
[:,
1
]
grid_ws
=
image_grid_thw
[:,
2
]
t_index
=
torch
.
arange
(
grid_t
)
*
position_id_per_seconds
llm_pos_ids
=
get_llm_pos_ids_for_vision
(
st_idx
,
image_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
image_len
=
image_grid_thw
[
image_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
eos_len
=
1
llm_pos_ids_list
.
append
(
torch
.
arange
(
eos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
st
+=
text_len
+
bos_len
+
image_len
+
eos_len
image_idx
+=
1
remain_images
-=
1
elif
(
min_ed
==
ed_vision_start
and
input_ids
[
ed_vision_start
+
1
]
==
video_token_id
and
not
use_audio_in_video
):
text_len
=
min_ed
-
st
if
text_len
!=
0
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
,
dtype
=
torch
.
long
)
.
view
(
1
,
-
1
)
.
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
bos_len
=
1
llm_pos_ids_list
.
append
(
torch
.
arange
(
bos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
grid_t
=
video_grid_thw
[
video_idx
][
0
]
grid_hs
=
video_grid_thw
[:,
1
]
grid_ws
=
video_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
float
(
second_per_grids
[
video_idx
].
item
())
*
position_id_per_seconds
)
llm_pos_ids
=
get_llm_pos_ids_for_vision
(
st_idx
,
video_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
video_len
=
video_grid_thw
[
video_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
llm_pos_ids_list
.
append
(
llm_pos_ids
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
eos_len
=
1
llm_pos_ids_list
.
append
(
torch
.
arange
(
eos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
st
+=
text_len
+
bos_len
+
video_len
+
eos_len
video_idx
+=
1
remain_videos
-=
1
elif
(
min_ed
==
ed_vision_start
and
ed_vision_start
+
1
==
ed_audio_start
and
use_audio_in_video
):
text_len
=
min_ed
-
st
if
text_len
!=
0
:
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
,
dtype
=
torch
.
long
)
.
view
(
1
,
-
1
)
.
expand
(
3
,
-
1
)
+
st_idx
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
bos_len
=
1
bos_block
=
(
torch
.
arange
(
bos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_pos_ids_list
.
append
(
bos_block
)
llm_pos_ids_list
.
append
(
bos_block
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
_
,
audio_len
=
_get_feat_extract_output_lengths
(
audio_feature_lengths
[
audio_idx
]
)
audio_llm_pos_ids
=
(
torch
.
arange
(
audio_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
grid_t
=
video_grid_thw
[
video_idx
][
0
]
grid_hs
=
video_grid_thw
[:,
1
]
grid_ws
=
video_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
float
(
second_per_grids
[
video_idx
].
item
())
*
position_id_per_seconds
)
video_llm_pos_ids
=
get_llm_pos_ids_for_vision
(
st_idx
,
video_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
video_data_index
,
audio_data_index
=
0
,
0
while
(
video_data_index
<
video_llm_pos_ids
.
shape
[
-
1
]
and
audio_data_index
<
audio_llm_pos_ids
.
shape
[
-
1
]
):
if
(
video_llm_pos_ids
[
0
][
video_data_index
]
<=
audio_llm_pos_ids
[
0
][
audio_data_index
]
):
llm_pos_ids_list
.
append
(
video_llm_pos_ids
[
:,
video_data_index
:
video_data_index
+
1
]
)
video_data_index
+=
1
else
:
llm_pos_ids_list
.
append
(
audio_llm_pos_ids
[
:,
audio_data_index
:
audio_data_index
+
1
]
)
audio_data_index
+=
1
if
video_data_index
<
video_llm_pos_ids
.
shape
[
-
1
]:
llm_pos_ids_list
.
append
(
video_llm_pos_ids
[
:,
video_data_index
:
video_llm_pos_ids
.
shape
[
-
1
]
]
)
if
audio_data_index
<
audio_llm_pos_ids
.
shape
[
-
1
]:
llm_pos_ids_list
.
append
(
audio_llm_pos_ids
[
:,
audio_data_index
:
audio_llm_pos_ids
.
shape
[
-
1
]
]
)
video_len
=
video_grid_thw
[
video_idx
].
prod
()
//
(
spatial_merge_size
**
2
)
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
eos_len
=
1
eos_block
=
(
torch
.
arange
(
eos_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_pos_ids_list
.
append
(
eos_block
)
llm_pos_ids_list
.
append
(
eos_block
)
st
+=
text_len
+
bos_len
*
2
+
audio_len
+
video_len
+
eos_len
*
2
# noqa: E501
audio_idx
+=
1
video_idx
+=
1
remain_videos
-=
1
remain_audios
-=
1
if
st
<
len
(
input_tokens
):
st_idx
=
llm_pos_ids_list
[
-
1
].
max
()
+
1
if
llm_pos_ids_list
else
0
text_len
=
len
(
input_tokens
)
-
st
llm_pos_ids_list
.
append
(
torch
.
arange
(
text_len
,
dtype
=
torch
.
long
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
llm_positions
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
).
reshape
(
3
,
-
1
)
if
llm_positions
.
shape
[
1
]
!=
seq_len
:
raise
RuntimeError
(
"Position ids length mismatch with input ids length"
)
mrope_position_delta
=
llm_positions
.
max
()
+
1
-
seq_len
return
llm_positions
,
mrope_position_delta
\ No newline at end of file
vllm/model_executor/models/registry.py
View file @
e9cb2241
...
@@ -270,6 +270,7 @@ _MULTIMODAL_MODELS = {
...
@@ -270,6 +270,7 @@ _MULTIMODAL_MODELS = {
"Qwen2AudioForConditionalGeneration"
:
(
"qwen2_audio"
,
"Qwen2AudioForConditionalGeneration"
),
# noqa: E501
"Qwen2AudioForConditionalGeneration"
:
(
"qwen2_audio"
,
"Qwen2AudioForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniModel"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniModel"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniForConditionalGeneration"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniForConditionalGeneration"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen3OmniMoeForConditionalGeneration"
:
(
"qwen3_omni_moe_thinker"
,
"Qwen3OmniMoeThinkerForConditionalGeneration"
),
"Qwen3VLForConditionalGeneration"
:
(
"qwen3_vl"
,
"Qwen3VLForConditionalGeneration"
),
# noqa: E501
"Qwen3VLForConditionalGeneration"
:
(
"qwen3_vl"
,
"Qwen3VLForConditionalGeneration"
),
# noqa: E501
"Qwen3VLMoeForConditionalGeneration"
:
(
"qwen3_vl_moe"
,
"Qwen3VLMoeForConditionalGeneration"
),
# noqa: E501
"Qwen3VLMoeForConditionalGeneration"
:
(
"qwen3_vl_moe"
,
"Qwen3VLMoeForConditionalGeneration"
),
# noqa: E501
"SkyworkR1VChatModel"
:
(
"skyworkr1v"
,
"SkyworkR1VChatModel"
),
"SkyworkR1VChatModel"
:
(
"skyworkr1v"
,
"SkyworkR1VChatModel"
),
...
...
vllm/model_executor/models/vision.py
View file @
e9cb2241
...
@@ -72,10 +72,18 @@ def get_vision_encoder_info(
...
@@ -72,10 +72,18 @@ def get_vision_encoder_info(
raise
NotImplementedError
(
msg
)
raise
NotImplementedError
(
msg
)
def
get_vit_attn_backend
(
head_size
:
int
,
dtype
:
torch
.
dtype
)
->
_Backend
:
def
get_vit_attn_backend
(
head_size
:
int
,
dtype
:
torch
.
dtype
,
*
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
_Backend
:
"""
"""
Get the available attention backend for Vision Transformer.
Get the available attention backend for Vision Transformer.
"""
"""
if
attn_backend_override
is
not
None
:
return
attn_backend_override
# Lazy import to avoid circular dependency
# Lazy import to avoid circular dependency
from
vllm.attention.selector
import
get_env_variable_attn_backend
from
vllm.attention.selector
import
get_env_variable_attn_backend
...
@@ -402,3 +410,56 @@ def run_dp_sharded_mrope_vision_model(
...
@@ -402,3 +410,56 @@ def run_dp_sharded_mrope_vision_model(
assert
len
(
out_embeddings
)
==
len
(
assert
len
(
out_embeddings
)
==
len
(
original_order_embeddings
),
"Found unassigned embeddings"
original_order_embeddings
),
"Found unassigned embeddings"
return
out_embeddings
return
out_embeddings
def
get_llm_pos_ids_for_vision
(
start_idx
:
int
,
vision_idx
:
int
,
spatial_merge_size
:
int
,
t_index
:
list
[
int
],
grid_hs
:
torch
.
Tensor
,
grid_ws
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
llm_pos_ids_list
=
[]
llm_grid_h
=
grid_hs
[
vision_idx
]
//
spatial_merge_size
llm_grid_w
=
grid_ws
[
vision_idx
]
//
spatial_merge_size
h_index
=
(
torch
.
arange
(
llm_grid_h
)
.
view
(
1
,
-
1
,
1
)
.
expand
(
len
(
t_index
),
-
1
,
llm_grid_w
)
.
flatten
()
)
w_index
=
(
torch
.
arange
(
llm_grid_w
)
.
view
(
1
,
1
,
-
1
)
.
expand
(
len
(
t_index
),
llm_grid_h
,
-
1
)
.
flatten
()
)
t_index_tensor
=
(
torch
.
Tensor
(
t_index
)
.
to
(
llm_grid_h
.
device
)
.
view
(
-
1
,
1
)
.
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
.
long
()
.
flatten
()
)
_llm_pos_ids
=
torch
.
stack
([
t_index_tensor
,
h_index
,
w_index
])
llm_pos_ids_list
.
append
(
_llm_pos_ids
+
start_idx
)
llm_pos_ids
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
)
return
llm_pos_ids
# Due to a performance regression with Conv3D in PyTorch2.9, we reshape
# Conv3D weights to Linear weights for better performance.
# See: https://github.com/vllm-project/vllm/issues/27406
# and https://github.com/pytorch/pytorch/issues/166122
# FIXME(Isotr0py): Revert the PR introduces this workaround
# (https://github.com/vllm-project/vllm/pull/27418),
# once the performance issue is resolved in PyTorch.
def
conv3d_to_linear_weight
(
conv3d_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride.
"""
out_channels
,
in_channels
,
kt
,
kh
,
kw
=
conv3d_weight
.
shape
linear_weight
=
conv3d_weight
.
reshape
(
out_channels
,
in_channels
*
kt
*
kh
*
kw
)
return
linear_weight
vllm/v1/worker/gpu_model_runner.py
View file @
e9cb2241
...
@@ -752,7 +752,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -752,7 +752,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
mm_input
.
get
(
"use_audio_in_video"
)
is
True
:
if
mm_input
.
get
(
"use_audio_in_video"
)
is
True
:
use_audio_in_video
=
True
use_audio_in_video
=
True
if
supports_mrope
(
self
.
model
):
if
supports_mrope
(
self
.
get_
model
()
):
req_state
.
mrope_positions
,
req_state
.
mrope_position_delta
=
\
req_state
.
mrope_positions
,
req_state
.
mrope_position_delta
=
\
self
.
model
.
get_mrope_input_positions
(
self
.
model
.
get_mrope_input_positions
(
req_state
.
prompt_token_ids
,
req_state
.
prompt_token_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment