Unverified Commit f8516a1a authored by Yueqian Lin's avatar Yueqian Lin Committed by GitHub
Browse files

[Bugfix][Model] Fix audio-in-video support for Qwen2.5-Omni and Qwen3-Omni (#33605)


Signed-off-by: default avatarlinyueqian <linyueqian@outlook.com>
Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 82405807
......@@ -113,6 +113,95 @@ except (ImportError, ModuleNotFoundError):
logger = init_logger(__name__)
def check_interleaved_audio_video(
is_video: torch.Tensor,
is_audio: torch.Tensor,
num_video: int,
num_audio: int,
) -> bool:
"""
Check if video and audio positions are interleaved in the multimodal region.
Returns:
True if video and audio tokens are interleaved, False otherwise.
"""
if num_video == 0 or num_audio == 0:
return False
video_pos = is_video.nonzero(as_tuple=True)[0]
audio_pos = is_audio.nonzero(as_tuple=True)[0]
return (
video_pos[0].item() < audio_pos[-1].item()
and audio_pos[0].item() < video_pos[-1].item()
)
def merge_interleaved_embeddings(
inputs_embeds: torch.Tensor,
multimodal_embeddings: "MultiModalEmbeddings",
is_video: torch.Tensor,
is_audio: torch.Tensor,
is_multimodal: torch.Tensor,
num_video: int,
num_audio: int,
) -> torch.Tensor:
"""
Merge embeddings for interleaved audio-in-video sequences.
When use_audio_in_video=True, video and audio tokens are interleaved in
the token sequence, but embeddings are provided as separate contiguous
tensors (video first, then audio). This function reorders video and audio
embeddings to match sequence position order and scatters them efficiently.
Args:
inputs_embeds: The input embeddings tensor to merge into.
multimodal_embeddings: List of embedding tensors (video, audio, other).
is_video: Boolean mask for video token positions.
is_audio: Boolean mask for audio token positions.
is_multimodal: Boolean mask for all multimodal token positions.
num_video: Total count of video tokens.
num_audio: Total count of audio tokens.
Returns:
The merged inputs_embeds tensor with multimodal embeddings scattered
to their correct positions.
"""
# Categorize embeddings by modality based on token counts.
# Embeddings come grouped by modality but order varies (e.g., image, video, audio
# or video, audio depending on input kwargs order).
video_embeds: list[torch.Tensor] = []
audio_embeds: list[torch.Tensor] = []
other_embeds: list[torch.Tensor] = []
video_remaining = num_video
audio_remaining = num_audio
for emb in multimodal_embeddings:
n = emb.shape[0]
if video_remaining > 0 and n <= video_remaining:
video_embeds.append(emb)
video_remaining -= n
elif audio_remaining > 0 and n <= audio_remaining:
audio_embeds.append(emb)
audio_remaining -= n
else:
other_embeds.append(emb)
# Scatter each modality to its positions
if video_embeds:
video_positions = is_video.nonzero(as_tuple=True)[0]
inputs_embeds[video_positions] = torch.cat(video_embeds, dim=0)
if audio_embeds:
audio_positions = is_audio.nonzero(as_tuple=True)[0]
inputs_embeds[audio_positions] = torch.cat(audio_embeds, dim=0)
if other_embeds:
other_mask = is_multimodal & ~is_video & ~is_audio
other_positions = other_mask.nonzero(as_tuple=True)[0]
inputs_embeds[other_positions] = torch.cat(other_embeds, dim=0)
return inputs_embeds
class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
"""
Dimensions:
......@@ -1286,17 +1375,48 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
from .utils import _merge_multimodal_embeddings
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)
return super().embed_input_ids(
inputs_embeds = self._embed_text_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
self.get_language_model().embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
if len(multimodal_embeddings) == 0:
return inputs_embeds
# Check for audio-in-video: interleaved video and audio tokens
# in the multimodal region.
video_token_id = self.config.video_token_index
audio_token_id = self.config.audio_token_index
is_video = is_multimodal & (input_ids == video_token_id)
is_audio = is_multimodal & (input_ids == audio_token_id)
num_video = is_video.sum().item()
num_audio = is_audio.sum().item()
if check_interleaved_audio_video(is_video, is_audio, num_video, num_audio):
return merge_interleaved_embeddings(
inputs_embeds,
multimodal_embeddings,
is_video,
is_audio,
is_multimodal,
num_video,
num_audio,
)
# Default: standard merge (no interleaving)
return _merge_multimodal_embeddings(
inputs_embeds, multimodal_embeddings, is_multimodal
)
def forward(
self,
input_ids: torch.Tensor | None,
......
......@@ -92,6 +92,8 @@ from .qwen2_5_omni_thinker import (
Qwen2_5OmniConditionalGenerationMixin,
Qwen2_5OmniThinkerDummyInputsBuilder,
Qwen2_5OmniThinkerMultiModalProcessor,
check_interleaved_audio_video,
merge_interleaved_embeddings,
)
from .qwen2_5_vl import (
Qwen2_5_VisionAttention,
......@@ -1780,6 +1782,19 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
# Detect interleaved audio-in-video early, since it affects
# both the deepstack path and the final embedding merge.
video_token_id = self.config.video_token_id
audio_token_id = self.config.audio_token_id
is_video = is_multimodal & (input_ids == video_token_id)
is_audio = is_multimodal & (input_ids == audio_token_id)
num_video = is_video.sum().item()
num_audio = is_audio.sum().item()
is_interleaved = check_interleaved_audio_video(
is_video, is_audio, num_video, num_audio
)
deepstack_input_embeds = None
# split the feat dim to obtain multi-scale visual feature
has_vision_embeddings = [
......@@ -1791,14 +1806,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
):
multiscale_len = len(self.visual.deepstack_visual_indexes)
multimodal_embeddings_multiscale = []
if is_interleaved:
# Use input_ids-based mask for correct vision positions
# when audio and video tokens are interleaved.
is_vision = is_video.clone()
else:
is_vision = torch.zeros_like(is_multimodal)
mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0]
mm_position_idx = 0
for index, embeddings in enumerate(multimodal_embeddings):
num_tokens = embeddings.shape[0]
current_positions = mm_positions[
mm_position_idx : mm_position_idx + num_tokens
]
# Vision embeddings
if embeddings.shape[-1] != self.config.text_config.hidden_size:
......@@ -1809,12 +1828,21 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
multimodal_embeddings[index] = embeddings_main
multimodal_embeddings_multiscale.append(embeddings_multiscale)
if not is_interleaved:
current_positions = mm_positions[
mm_position_idx : mm_position_idx + num_tokens
]
is_vision[current_positions] = True
# Audio embeddings
else:
if not is_interleaved:
current_positions = mm_positions[
mm_position_idx : mm_position_idx + num_tokens
]
is_vision[current_positions] = False
if not is_interleaved:
mm_position_idx += num_tokens
deepstack_input_embeds = inputs_embeds.new_zeros(
......@@ -1834,6 +1862,18 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
self._set_deepstack_input_embeds(deepstack_input_embeds)
if is_interleaved:
return merge_interleaved_embeddings(
inputs_embeds,
multimodal_embeddings,
is_video,
is_audio,
is_multimodal,
num_video,
num_audio,
)
# Default: standard merge (no interleaving)
inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment