Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e9cb2241
Commit
e9cb2241
authored
Oct 28, 2025
by
zhuwenwen
Browse files
Add Qwen3-Omni moe thinker
parent
f007cd06
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1827 additions
and
17 deletions
+1827
-17
docs/models/supported_models.md
docs/models/supported_models.md
+2
-1
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+1
-0
tests/models/registry.py
tests/models/registry.py
+3
-0
vllm/model_executor/layers/rotary_embedding/mrope.py
vllm/model_executor/layers/rotary_embedding/mrope.py
+14
-14
vllm/model_executor/models/qwen3_omni_moe_thinker.py
vllm/model_executor/models/qwen3_omni_moe_thinker.py
+1743
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+62
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-1
No files found.
docs/models/supported_models.md
View file @
e9cb2241
...
@@ -689,6 +689,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
...
@@ -689,6 +689,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
`Qwen2_5OmniThinkerForConditionalGeneration`
| Qwen2.5-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen2.5-Omni-3B`
,
`Qwen/Qwen2.5-Omni-7B`
| ✅︎ | ✅︎ | ✅︎ |
|
`Qwen2_5OmniThinkerForConditionalGeneration`
| Qwen2.5-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen2.5-Omni-3B`
,
`Qwen/Qwen2.5-Omni-7B`
| ✅︎ | ✅︎ | ✅︎ |
|
`Qwen3VLForConditionalGeneration`
| Qwen3-VL | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3-VL-4B-Instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Qwen3VLForConditionalGeneration`
| Qwen3-VL | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3-VL-4B-Instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Qwen3VLMoeForConditionalGeneration`
| Qwen3-VL-MOE | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3-VL-30B-A3B-Instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Qwen3VLMoeForConditionalGeneration`
| Qwen3-VL-MOE | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
|
`Qwen/Qwen3-VL-30B-A3B-Instruct`
, etc. | ✅︎ | ✅︎ | ✅︎ |
|
`Qwen3OmniMoeThinkerForConditionalGeneration`
| Qwen3-Omni | T + I
<sup>
E+
</sup>
+ V
<sup>
E+
</sup>
+ A
<sup>
+
</sup>
|
`Qwen/Qwen3-Omni-30B-A3B-Instruct`
,
`Qwen/Qwen3-Omni-30B-A3B-Thinking`
| ✅︎ | ✅︎ | ✅︎ |
|
`RForConditionalGeneration`
| R-VL-4B | T + I
<sup>
E+
</sup>
|
`YannQi/R-4B`
| | ✅︎ | ✅︎ |
|
`RForConditionalGeneration`
| R-VL-4B | T + I
<sup>
E+
</sup>
|
`YannQi/R-4B`
| | ✅︎ | ✅︎ |
|
`SkyworkR1VChatModel`
| Skywork-R1V-38B | T + I |
`Skywork/Skywork-R1V-38B`
| | ✅︎ | ✅︎ |
|
`SkyworkR1VChatModel`
| Skywork-R1V-38B | T + I |
`Skywork/Skywork-R1V-38B`
| | ✅︎ | ✅︎ |
|
`SmolVLMForConditionalGeneration`
| SmolVLM2 | T + I |
`SmolVLM2-2.2B-Instruct`
| ✅︎ | | ✅︎ |
|
`SmolVLMForConditionalGeneration`
| SmolVLM2 | T + I |
`SmolVLM2-2.2B-Instruct`
| ✅︎ | | ✅︎ |
...
@@ -779,7 +780,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th
...
@@ -779,7 +780,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
!!! note
!!! note
For Qwen2.5-Omni, reading audio from video pre-processing (
`--mm-processor-kwargs '{"use_audio_in_video": true}'`
)
For Qwen2.5-Omni
and Qwen3-Omni
, reading audio from video pre-processing (
`--mm-processor-kwargs '{"use_audio_in_video": true}'`
)
is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
#### Transcription
#### Transcription
...
...
tests/models/multimodal/processing/test_common.py
View file @
e9cb2241
...
@@ -362,6 +362,7 @@ def _test_processing_correctness_one(
...
@@ -362,6 +362,7 @@ def _test_processing_correctness_one(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2.5-Omni-3B"
),
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2.5-Omni-3B"
),
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-VL-4B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-VL-4B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-VL-30B-A3B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-VL-30B-A3B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-Omni-30B-A3B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"YannQi/R-4B"
),
os
.
path
.
join
(
models_path_prefix
,
"YannQi/R-4B"
),
os
.
path
.
join
(
models_path_prefix
,
"Skywork/Skywork-R1V-38B"
),
os
.
path
.
join
(
models_path_prefix
,
"Skywork/Skywork-R1V-38B"
),
os
.
path
.
join
(
models_path_prefix
,
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
),
os
.
path
.
join
(
models_path_prefix
,
"HuggingFaceTB/SmolVLM2-2.2B-Instruct"
),
...
...
tests/models/registry.py
View file @
e9cb2241
...
@@ -580,6 +580,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -580,6 +580,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_model_len
=
4096
,
max_model_len
=
4096
,
min_transformers_version
=
"4.57"
,
min_transformers_version
=
"4.57"
,
is_available_online
=
False
),
is_available_online
=
False
),
"Qwen3OmniMoeForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen3-Omni-30B-A3B-Instruct"
),
max_model_len
=
4096
,
min_transformers_version
=
"4.57"
),
"RForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"YannQi/R-4B"
),
"RForConditionalGeneration"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"YannQi/R-4B"
),
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"SkyworkR1VChatModel"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Skywork/Skywork-R1V-38B"
),
"SkyworkR1VChatModel"
:
_HfExamplesInfo
(
os
.
path
.
join
(
models_path_prefix
,
"Skywork/Skywork-R1V-38B"
),
...
...
vllm/model_executor/layers/rotary_embedding/mrope.py
View file @
e9cb2241
...
@@ -429,7 +429,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -429,7 +429,7 @@ class MRotaryEmbedding(RotaryEmbedding):
use_audio_in_video
:
bool
=
False
,
use_audio_in_video
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
int
]:
)
->
tuple
[
torch
.
Tensor
,
int
]:
from
vllm.transformers_utils.config
import
thinker_uses_mrope
from
vllm.transformers_utils.config
import
thinker_uses_mrope
if
thinker_uses_mrope
(
hf_config
):
if
thinker_uses_mrope
(
hf_config
)
and
hf_config
.
model_type
==
"qwen2_5_omni"
:
return
cls
.
_omni_get_input_positions_tensor
(
return
cls
.
_omni_get_input_positions_tensor
(
input_tokens
=
input_tokens
,
input_tokens
=
input_tokens
,
hf_config
=
hf_config
,
hf_config
=
hf_config
,
...
@@ -899,7 +899,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -899,7 +899,7 @@ class MRotaryEmbedding(RotaryEmbedding):
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
(
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
t_index
=
(
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)).
long
().
flatten
()
-
1
,
llm_grid_h
*
llm_grid_w
)).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
...
@@ -1003,8 +1003,8 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1003,8 +1003,8 @@ class MRotaryEmbedding(RotaryEmbedding):
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
torch
.
arange
(
text_len
).
view
(
1
,
-
1
).
expand
(
3
,
-
1
)
+
st_idx
)
t_index
=
(
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
t_index
=
(
torch
.
arange
(
llm_grid_t
).
view
(
-
1
,
1
).
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
*
video_second_per_grid_t
*
-
1
,
llm_grid_h
*
llm_grid_w
)
*
video_second_per_grid_t
*
tokens_per_second
).
long
().
flatten
()
tokens_per_second
).
flatten
()
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
h_index
=
torch
.
arange
(
llm_grid_h
).
view
(
1
,
-
1
,
1
).
expand
(
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
llm_grid_t
,
-
1
,
llm_grid_w
).
flatten
()
...
@@ -1061,6 +1061,11 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1061,6 +1061,11 @@ class MRotaryEmbedding(RotaryEmbedding):
# _vl_get_input_positions_tensor.
# _vl_get_input_positions_tensor.
thinker_config
=
hf_config
.
thinker_config
thinker_config
=
hf_config
.
thinker_config
if
isinstance
(
image_grid_thw
,
list
):
image_grid_thw
=
torch
.
tensor
(
image_grid_thw
)
if
isinstance
(
video_grid_thw
,
list
):
video_grid_thw
=
torch
.
tensor
(
video_grid_thw
)
audio_token_id
=
thinker_config
.
audio_token_index
audio_token_id
=
thinker_config
.
audio_token_index
image_token_id
=
thinker_config
.
image_token_index
image_token_id
=
thinker_config
.
image_token_index
video_token_id
=
thinker_config
.
video_token_index
video_token_id
=
thinker_config
.
video_token_index
...
@@ -1073,11 +1078,6 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1073,11 +1078,6 @@ class MRotaryEmbedding(RotaryEmbedding):
tokens_per_second
=
getattr
(
thinker_config
.
vision_config
,
tokens_per_second
=
getattr
(
thinker_config
.
vision_config
,
"tokens_per_second"
,
25
)
"tokens_per_second"
,
25
)
if
isinstance
(
image_grid_thw
,
list
):
image_grid_thw
=
torch
.
tensor
(
image_grid_thw
)
if
isinstance
(
video_grid_thw
,
list
):
video_grid_thw
=
torch
.
tensor
(
video_grid_thw
)
src_item
=
input_tokens
src_item
=
input_tokens
audio_seqlens
=
audio_feature_lengths
audio_seqlens
=
audio_feature_lengths
if
not
second_per_grid_ts
:
if
not
second_per_grid_ts
:
...
@@ -1121,7 +1121,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1121,7 +1121,7 @@ class MRotaryEmbedding(RotaryEmbedding):
grid_t
=
image_grid_thw
[
image_idx
][
0
]
grid_t
=
image_grid_thw
[
image_idx
][
0
]
grid_hs
=
image_grid_thw
[:,
1
]
grid_hs
=
image_grid_thw
[:,
1
]
grid_ws
=
image_grid_thw
[:,
2
]
grid_ws
=
image_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
1
*
tokens_per_second
).
long
()
t_index
=
torch
.
arange
(
grid_t
)
*
1
*
tokens_per_second
llm_pos_ids
=
cls
.
_get_llm_pos_ids_for_vision
(
llm_pos_ids
=
cls
.
_get_llm_pos_ids_for_vision
(
start_idx
,
image_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
start_idx
,
image_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
grid_ws
)
...
@@ -1136,7 +1136,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1136,7 +1136,7 @@ class MRotaryEmbedding(RotaryEmbedding):
grid_ws
=
video_grid_thw
[:,
2
]
grid_ws
=
video_grid_thw
[:,
2
]
t_index
=
(
torch
.
arange
(
grid_t
)
*
t_index
=
(
torch
.
arange
(
grid_t
)
*
second_per_grid_ts
[
video_idx
]
*
second_per_grid_ts
[
video_idx
]
*
tokens_per_second
)
.
long
()
tokens_per_second
)
llm_pos_ids
=
cls
.
_get_llm_pos_ids_for_vision
(
llm_pos_ids
=
cls
.
_get_llm_pos_ids_for_vision
(
start_idx
,
video_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
start_idx
,
video_idx
,
spatial_merge_size
,
t_index
,
grid_hs
,
grid_ws
)
grid_ws
)
...
@@ -1159,7 +1159,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1159,7 +1159,7 @@ class MRotaryEmbedding(RotaryEmbedding):
t_ntoken_per_chunk
=
int
(
tokens_per_second
*
seconds_per_chunk
)
t_ntoken_per_chunk
=
int
(
tokens_per_second
*
seconds_per_chunk
)
t_index
=
(
torch
.
arange
(
grid_t
)
*
t_index
=
(
torch
.
arange
(
grid_t
)
*
second_per_grid_ts
[
video_idx
]
*
second_per_grid_ts
[
video_idx
]
*
tokens_per_second
)
.
long
()
tokens_per_second
)
t_index_split_chunk
=
cls
.
_split_list_into_ranges
(
t_index_split_chunk
=
cls
.
_split_list_into_ranges
(
t_index
,
t_ntoken_per_chunk
)
t_index
,
t_ntoken_per_chunk
)
place_num
=
(((
audio_seqlen
-
1
)
//
2
+
1
-
2
)
//
2
+
1
)
+
2
place_num
=
(((
audio_seqlen
-
1
)
//
2
+
1
-
2
)
//
2
+
1
)
+
2
...
@@ -1299,7 +1299,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -1299,7 +1299,7 @@ class MRotaryEmbedding(RotaryEmbedding):
grid_w
=
video_grid_thw
[
2
]
grid_w
=
video_grid_thw
[
2
]
t_ntoken_per_chunk
=
int
(
tokens_per_second
*
seconds_per_chunk
)
t_ntoken_per_chunk
=
int
(
tokens_per_second
*
seconds_per_chunk
)
t_index
=
(
torch
.
arange
(
grid_t
)
*
video_second_per_grid_t
*
t_index
=
(
torch
.
arange
(
grid_t
)
*
video_second_per_grid_t
*
tokens_per_second
)
.
long
()
tokens_per_second
)
t_index_split_chunk
=
cls
.
_split_list_into_ranges
(
t_index_split_chunk
=
cls
.
_split_list_into_ranges
(
t_index
,
t_ntoken_per_chunk
)
t_index
,
t_ntoken_per_chunk
)
...
...
vllm/model_executor/models/qwen3_omni_moe_thinker.py
0 → 100644
View file @
e9cb2241
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/registry.py
View file @
e9cb2241
...
@@ -270,6 +270,7 @@ _MULTIMODAL_MODELS = {
...
@@ -270,6 +270,7 @@ _MULTIMODAL_MODELS = {
"Qwen2AudioForConditionalGeneration"
:
(
"qwen2_audio"
,
"Qwen2AudioForConditionalGeneration"
),
# noqa: E501
"Qwen2AudioForConditionalGeneration"
:
(
"qwen2_audio"
,
"Qwen2AudioForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniModel"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniModel"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniForConditionalGeneration"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniForConditionalGeneration"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen3OmniMoeForConditionalGeneration"
:
(
"qwen3_omni_moe_thinker"
,
"Qwen3OmniMoeThinkerForConditionalGeneration"
),
"Qwen3VLForConditionalGeneration"
:
(
"qwen3_vl"
,
"Qwen3VLForConditionalGeneration"
),
# noqa: E501
"Qwen3VLForConditionalGeneration"
:
(
"qwen3_vl"
,
"Qwen3VLForConditionalGeneration"
),
# noqa: E501
"Qwen3VLMoeForConditionalGeneration"
:
(
"qwen3_vl_moe"
,
"Qwen3VLMoeForConditionalGeneration"
),
# noqa: E501
"Qwen3VLMoeForConditionalGeneration"
:
(
"qwen3_vl_moe"
,
"Qwen3VLMoeForConditionalGeneration"
),
# noqa: E501
"SkyworkR1VChatModel"
:
(
"skyworkr1v"
,
"SkyworkR1VChatModel"
),
"SkyworkR1VChatModel"
:
(
"skyworkr1v"
,
"SkyworkR1VChatModel"
),
...
...
vllm/model_executor/models/vision.py
View file @
e9cb2241
...
@@ -72,10 +72,18 @@ def get_vision_encoder_info(
...
@@ -72,10 +72,18 @@ def get_vision_encoder_info(
raise
NotImplementedError
(
msg
)
raise
NotImplementedError
(
msg
)
def
get_vit_attn_backend
(
head_size
:
int
,
dtype
:
torch
.
dtype
)
->
_Backend
:
def
get_vit_attn_backend
(
head_size
:
int
,
dtype
:
torch
.
dtype
,
*
,
attn_backend_override
:
_Backend
|
None
=
None
,
)
->
_Backend
:
"""
"""
Get the available attention backend for Vision Transformer.
Get the available attention backend for Vision Transformer.
"""
"""
if
attn_backend_override
is
not
None
:
return
attn_backend_override
# Lazy import to avoid circular dependency
# Lazy import to avoid circular dependency
from
vllm.attention.selector
import
get_env_variable_attn_backend
from
vllm.attention.selector
import
get_env_variable_attn_backend
...
@@ -402,3 +410,56 @@ def run_dp_sharded_mrope_vision_model(
...
@@ -402,3 +410,56 @@ def run_dp_sharded_mrope_vision_model(
assert
len
(
out_embeddings
)
==
len
(
assert
len
(
out_embeddings
)
==
len
(
original_order_embeddings
),
"Found unassigned embeddings"
original_order_embeddings
),
"Found unassigned embeddings"
return
out_embeddings
return
out_embeddings
def
get_llm_pos_ids_for_vision
(
start_idx
:
int
,
vision_idx
:
int
,
spatial_merge_size
:
int
,
t_index
:
list
[
int
],
grid_hs
:
torch
.
Tensor
,
grid_ws
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
llm_pos_ids_list
=
[]
llm_grid_h
=
grid_hs
[
vision_idx
]
//
spatial_merge_size
llm_grid_w
=
grid_ws
[
vision_idx
]
//
spatial_merge_size
h_index
=
(
torch
.
arange
(
llm_grid_h
)
.
view
(
1
,
-
1
,
1
)
.
expand
(
len
(
t_index
),
-
1
,
llm_grid_w
)
.
flatten
()
)
w_index
=
(
torch
.
arange
(
llm_grid_w
)
.
view
(
1
,
1
,
-
1
)
.
expand
(
len
(
t_index
),
llm_grid_h
,
-
1
)
.
flatten
()
)
t_index_tensor
=
(
torch
.
Tensor
(
t_index
)
.
to
(
llm_grid_h
.
device
)
.
view
(
-
1
,
1
)
.
expand
(
-
1
,
llm_grid_h
*
llm_grid_w
)
.
long
()
.
flatten
()
)
_llm_pos_ids
=
torch
.
stack
([
t_index_tensor
,
h_index
,
w_index
])
llm_pos_ids_list
.
append
(
_llm_pos_ids
+
start_idx
)
llm_pos_ids
=
torch
.
cat
(
llm_pos_ids_list
,
dim
=
1
)
return
llm_pos_ids
# Due to a performance regression with Conv3D in PyTorch2.9, we reshape
# Conv3D weights to Linear weights for better performance.
# See: https://github.com/vllm-project/vllm/issues/27406
# and https://github.com/pytorch/pytorch/issues/166122
# FIXME(Isotr0py): Revert the PR introduces this workaround
# (https://github.com/vllm-project/vllm/pull/27418),
# once the performance issue is resolved in PyTorch.
def
conv3d_to_linear_weight
(
conv3d_weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride.
"""
out_channels
,
in_channels
,
kt
,
kh
,
kw
=
conv3d_weight
.
shape
linear_weight
=
conv3d_weight
.
reshape
(
out_channels
,
in_channels
*
kt
*
kh
*
kw
)
return
linear_weight
vllm/v1/worker/gpu_model_runner.py
View file @
e9cb2241
...
@@ -752,7 +752,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -752,7 +752,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
mm_input
.
get
(
"use_audio_in_video"
)
is
True
:
if
mm_input
.
get
(
"use_audio_in_video"
)
is
True
:
use_audio_in_video
=
True
use_audio_in_video
=
True
if
supports_mrope
(
self
.
model
):
if
supports_mrope
(
self
.
get_
model
()
):
req_state
.
mrope_positions
,
req_state
.
mrope_position_delta
=
\
req_state
.
mrope_positions
,
req_state
.
mrope_position_delta
=
\
self
.
model
.
get_mrope_input_positions
(
self
.
model
.
get_mrope_input_positions
(
req_state
.
prompt_token_ids
,
req_state
.
prompt_token_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment