Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8c851f6d
Unverified
Commit
8c851f6d
authored
Oct 15, 2025
by
Isotr0py
Committed by
GitHub
Oct 15, 2025
Browse files
[Bugfix] Fix qwen3-omni audio truncation issue (#26815)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
7cfa420f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
2 deletions
+16
-2
vllm/model_executor/models/qwen3_omni_moe_thinker.py
vllm/model_executor/models/qwen3_omni_moe_thinker.py
+16
-2
No files found.
vllm/model_executor/models/qwen3_omni_moe_thinker.py
View file @
8c851f6d
...
@@ -30,7 +30,9 @@ import numpy as np
...
@@ -30,7 +30,9 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
packaging.version
import
Version
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
transformers.feature_extraction_utils
import
BatchFeature
from
transformers.feature_extraction_utils
import
BatchFeature
from
transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe
import
(
from
transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe
import
(
Qwen3OmniMoeConfig
,
Qwen3OmniMoeConfig
,
...
@@ -711,11 +713,12 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
...
@@ -711,11 +713,12 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
return
x
return
x
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
feature_extractor
=
self
.
info
.
get_feature_extractor
()
hop_length
=
feature_extractor
.
hop_length
if
audios
:
if
audios
:
# NOTE: Qwen3-Omni processor accept "audio"
# NOTE: Qwen3-Omni processor accept "audio"
# To make sure the cache works with padding=True, we pre-padded
# To make sure the cache works with padding=True, we pre-padded
# the audio to multiple of hop_length.
# the audio to multiple of hop_length.
hop_length
=
self
.
info
.
get_feature_extractor
().
hop_length
mm_data
[
"audio"
]
=
[
mm_data
[
"audio"
]
=
[
pad_to_hop_length
(
audio
,
hop_length
)
pad_to_hop_length
(
audio
,
hop_length
)
if
isinstance
(
audio
,
np
.
ndarray
)
if
isinstance
(
audio
,
np
.
ndarray
)
...
@@ -725,6 +728,14 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
...
@@ -725,6 +728,14 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_kwargs
=
dict
(
mm_kwargs
=
dict
(
**
mm_kwargs
,
**
mm_kwargs
,
)
)
# TODO(Isotr0py): Remove this patch after upstream fix PR
# released and Transformers version update:
# https://github.com/huggingface/transformers/pull/41473
if
(
Version
(
TRANSFORMERS_VERSION
)
<
Version
(
"4.58.0"
)
and
"truncation"
not
in
mm_kwargs
):
mm_kwargs
[
"truncation"
]
=
False
hf_inputs
=
super
().
_call_hf_processor
(
hf_inputs
=
super
().
_call_hf_processor
(
prompt
=
prompt
,
prompt
=
prompt
,
...
@@ -738,7 +749,6 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
...
@@ -738,7 +749,6 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
and
"feature_attention_mask"
in
hf_inputs
and
"feature_attention_mask"
in
hf_inputs
and
(
audios
:
=
mm_data
.
get
(
"audio"
,
[]))
and
(
audios
:
=
mm_data
.
get
(
"audio"
,
[]))
):
):
hop_length
=
self
.
info
.
get_feature_extractor
().
hop_length
audio_num_frames
=
[]
audio_num_frames
=
[]
for
_
,
audio
in
enumerate
(
audios
):
for
_
,
audio
in
enumerate
(
audios
):
audio_length
=
len
(
audio
[
0
])
if
isinstance
(
audio
,
tuple
)
else
len
(
audio
)
audio_length
=
len
(
audio
[
0
])
if
isinstance
(
audio
,
tuple
)
else
len
(
audio
)
...
@@ -747,6 +757,10 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
...
@@ -747,6 +757,10 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
if
audio_length
%
hop_length
==
0
if
audio_length
%
hop_length
==
0
else
(
audio_length
//
hop_length
-
1
)
else
(
audio_length
//
hop_length
-
1
)
)
)
if
mm_kwargs
.
get
(
"truncation"
,
False
):
num_frame
=
min
(
num_frame
,
feature_extractor
.
n_samples
//
hop_length
)
audio_num_frames
.
append
(
num_frame
)
audio_num_frames
.
append
(
num_frame
)
hf_inputs
[
"feature_attention_mask"
]
=
[
hf_inputs
[
"feature_attention_mask"
]
=
[
torch
.
ones
(
num_frame
)
for
num_frame
in
audio_num_frames
torch
.
ones
(
num_frame
)
for
num_frame
in
audio_num_frames
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment