"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "d45cbe70f5bf25bb2f490f4152c256e9acb2a62b"
Unverified Commit 853a8eb5 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix Qwen Omni audio inference (#27920)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 758ea2e9
...@@ -130,6 +130,8 @@ class Qwen2_5OmniAudioFeatureInputs(TensorSchema): ...@@ -130,6 +130,8 @@ class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
TensorShape("nmb", "tsl", dynamic_dims={"tsl"}), TensorShape("nmb", "tsl", dynamic_dims={"tsl"}),
] ]
audio_feature_lengths: Annotated[torch.Tensor, TensorShape("na")]
feature_attention_mask: Annotated[ feature_attention_mask: Annotated[
torch.Tensor | list[torch.Tensor], torch.Tensor | list[torch.Tensor],
TensorShape("na", "msl", dynamic_dims={"msl"}), TensorShape("na", "msl", dynamic_dims={"msl"}),
...@@ -732,13 +734,6 @@ class Qwen2_5OmniConditionalGenerationMixin: ...@@ -732,13 +734,6 @@ class Qwen2_5OmniConditionalGenerationMixin:
input_features = audio_input["input_features"] input_features = audio_input["input_features"]
audio_feature_lengths = audio_input["audio_feature_lengths"] audio_feature_lengths = audio_input["audio_feature_lengths"]
if audio_feature_lengths.shape[0] == 1:
audio_feature_lengths = audio_feature_lengths.squeeze(0)
elif audio_feature_lengths.shape[1] == 1:
audio_feature_lengths = audio_feature_lengths.squeeze(1)
else:
raise AssertionError(audio_feature_lengths.shape)
audio_feat_lengths, audio_output_lengths = ( audio_feat_lengths, audio_output_lengths = (
self.audio_tower._get_feat_extract_output_lengths(audio_feature_lengths) self.audio_tower._get_feat_extract_output_lengths(audio_feature_lengths)
) )
......
...@@ -99,7 +99,6 @@ from .utils import ( ...@@ -99,7 +99,6 @@ from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
_merge_multimodal_embeddings, _merge_multimodal_embeddings,
flatten_bn,
maybe_prefix, maybe_prefix,
) )
from .vision import ( from .vision import (
...@@ -1065,8 +1064,6 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix ...@@ -1065,8 +1064,6 @@ class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMix
input_features = audio_input["input_features"] input_features = audio_input["input_features"]
audio_feature_lengths = audio_input["audio_feature_lengths"] audio_feature_lengths = audio_input["audio_feature_lengths"]
audio_feature_lengths = flatten_bn(audio_feature_lengths, concat=True)
audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths( audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(
audio_feature_lengths audio_feature_lengths
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment