Unverified Commit d2f816d6 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Standardize merging multimodal embeddings (#26771)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 577d4982
...@@ -1645,12 +1645,12 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1645,12 +1645,12 @@ class Ernie4_5_VLMoeForConditionalGeneration(
for modality in modalities: for modality in modalities:
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input) image_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += tuple(image_embeddings)
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input) video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
......
...@@ -1608,11 +1608,11 @@ class Glm4vForConditionalGeneration( ...@@ -1608,11 +1608,11 @@ class Glm4vForConditionalGeneration(
for modality in mm_input_by_modality: for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality] multimodal_input = mm_input_by_modality[modality]
if modality == "image": if modality == "image":
vision_embeddings = self._process_image_input(multimodal_input) image_embeddings = self._process_image_input(multimodal_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += tuple(image_embeddings)
if modality == "video": if modality == "video":
video_embeddings = self._process_video_input(multimodal_input) video_embeddings = self._process_video_input(multimodal_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
def forward( def forward(
......
...@@ -749,12 +749,12 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -749,12 +749,12 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
for modality in modalities: for modality in modalities:
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input) image_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += tuple(image_embeddings)
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input) video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
......
...@@ -753,12 +753,12 @@ class InternS1ForConditionalGeneration( ...@@ -753,12 +753,12 @@ class InternS1ForConditionalGeneration(
for modality in modalities: for modality in modalities:
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_vision_input(image_input) image_embeddings = self._process_vision_input(image_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += tuple(image_embeddings)
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_embeddings = self._process_vision_input(video_input) video_embeddings = self._process_vision_input(video_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
......
...@@ -1358,12 +1358,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) ...@@ -1358,12 +1358,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
for modality in modalities: for modality in modalities:
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_vision_input(image_input) image_embeddings = self._process_vision_input(image_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += tuple(image_embeddings)
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_embeddings = self._process_vision_input(video_input) video_embeddings = self._process_vision_input(video_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
......
...@@ -1459,12 +1459,12 @@ class BaseKeyeModule(nn.Module): ...@@ -1459,12 +1459,12 @@ class BaseKeyeModule(nn.Module):
for modality in modalities: for modality in modalities:
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input) image_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += tuple(image_embeddings)
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input) video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
def forward( def forward(
......
...@@ -881,8 +881,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp ...@@ -881,8 +881,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
for modality in mm_input_by_modality: for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality] multimodal_input = mm_input_by_modality[modality]
if modality == "image": if modality == "image":
vision_embeddings = self._process_image_input(multimodal_input) image_embeddings = self._process_image_input(multimodal_input)
multimodal_embeddings += tuple(vision_embeddings) multimodal_embeddings += tuple(image_embeddings)
if modality == "video": if modality == "video":
video_embeddings = self._process_video_pixels(multimodal_input) video_embeddings = self._process_video_pixels(multimodal_input)
multimodal_embeddings += tuple(video_embeddings) multimodal_embeddings += tuple(video_embeddings)
......
...@@ -762,7 +762,7 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -762,7 +762,7 @@ class MiniCPMO(MiniCPMV2_6):
for modality in modalities: for modality in modalities:
if modality == "audios": if modality == "audios":
audio_input = modalities["audios"] audio_input = modalities["audios"]
audio_features = self._process_audio_input(audio_input) audio_embeddings = self._process_audio_input(audio_input)
multimodal_embeddings += tuple(audio_features) multimodal_embeddings += tuple(audio_embeddings)
return multimodal_embeddings return multimodal_embeddings
...@@ -1129,12 +1129,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -1129,12 +1129,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
for modality in modalities: for modality in modalities:
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
image_features = self._process_vision_input(image_input) image_embeddings = self._process_vision_input(image_input)
multimodal_embeddings += tuple(image_features) multimodal_embeddings += tuple(image_embeddings)
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_features = self._process_vision_input(video_input) video_embeddings = self._process_vision_input(video_input)
multimodal_embeddings += tuple(video_features) multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
......
...@@ -1263,12 +1263,12 @@ class NemotronH_Nano_VL_V2( ...@@ -1263,12 +1263,12 @@ class NemotronH_Nano_VL_V2(
for modality in modalities: for modality in modalities:
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input) image_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += tuple(image_embeddings)
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input) video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
......
...@@ -575,8 +575,8 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor ...@@ -575,8 +575,8 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
for modality in modalities: for modality in modalities:
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input) image_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += tuple(image_embeddings)
return multimodal_embeddings return multimodal_embeddings
......
...@@ -616,12 +616,12 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -616,12 +616,12 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
for modality in modalities: for modality in modalities:
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_visual_input(image_input) image_embeddings = self._process_visual_input(image_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += tuple(image_embeddings)
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_embeddings = self._process_visual_input(video_input) video_embeddings = self._process_visual_input(video_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
......
...@@ -1430,8 +1430,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1430,8 +1430,8 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
if modality == "images": if modality == "images":
audio_projection_mode = "vision" audio_projection_mode = "vision"
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input) image_embeddings = self._process_image_input(image_input)
multimodal_embeddings += tuple(vision_embeddings) multimodal_embeddings += tuple(image_embeddings)
if modality == "audios": if modality == "audios":
audio_input = modalities["audios"] audio_input = modalities["audios"]
audio_embeddings = self._process_audio_input( audio_embeddings = self._process_audio_input(
......
...@@ -1248,8 +1248,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1248,8 +1248,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
if modality == "images": if modality == "images":
audio_projection_mode = "vision" audio_projection_mode = "vision"
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input) image_embeddings = self._process_image_input(image_input)
multimodal_embeddings += tuple(vision_embeddings) multimodal_embeddings += tuple(image_embeddings)
if modality == "audios": if modality == "audios":
audio_input = modalities["audios"] audio_input = modalities["audios"]
audio_embeddings = self._process_audio_input( audio_embeddings = self._process_audio_input(
......
...@@ -1210,14 +1210,14 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -1210,14 +1210,14 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
for modality in mm_input_by_modality: for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality] multimodal_input = mm_input_by_modality[modality]
if modality == "image": if modality == "image":
vision_embeddings = self._process_image_input(multimodal_input) image_embeddings = self._process_image_input(multimodal_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += tuple(image_embeddings)
if modality == "video": if modality == "video":
video_embeddings = self._process_video_input(multimodal_input) video_embeddings = self._process_video_input(multimodal_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += tuple(video_embeddings)
if modality == "audio": if modality == "audio":
audio_embeddings = self._process_audio_input(multimodal_input) audio_embeddings = self._process_audio_input(multimodal_input)
multimodal_embeddings += audio_embeddings multimodal_embeddings += tuple(audio_embeddings)
return multimodal_embeddings return multimodal_embeddings
# TODO (ywang96): support overlapping modality embeddings so that # TODO (ywang96): support overlapping modality embeddings so that
......
...@@ -1586,19 +1586,19 @@ class Qwen2_5_VLForConditionalGeneration( ...@@ -1586,19 +1586,19 @@ class Qwen2_5_VLForConditionalGeneration(
for modality in mm_input_by_modality: for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality] multimodal_input = mm_input_by_modality[modality]
if modality == "image": if modality == "image":
vision_embeddings = self._process_image_input(multimodal_input) image_embeddings = self._process_image_input(multimodal_input)
if self.is_multimodal_pruning_enabled: if self.is_multimodal_pruning_enabled:
vision_embeddings = self._postprocess_image_embeds_evs( image_embeddings = self._postprocess_image_embeds_evs(
vision_embeddings, multimodal_input image_embeddings, multimodal_input
) )
multimodal_embeddings += vision_embeddings multimodal_embeddings += tuple(image_embeddings)
if modality == "video": if modality == "video":
video_embeddings = self._process_video_input(multimodal_input) video_embeddings = self._process_video_input(multimodal_input)
if self.is_multimodal_pruning_enabled: if self.is_multimodal_pruning_enabled:
video_embeddings = self._postprocess_video_embeds_evs( video_embeddings = self._postprocess_video_embeds_evs(
video_embeddings, multimodal_input video_embeddings, multimodal_input
) )
multimodal_embeddings += video_embeddings multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
def forward( def forward(
......
...@@ -1561,12 +1561,12 @@ class Qwen2VLForConditionalGeneration( ...@@ -1561,12 +1561,12 @@ class Qwen2VLForConditionalGeneration(
for modality in modalities: for modality in modalities:
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input) image_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += tuple(image_embeddings)
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_embeddings = self._process_video_input(video_input) video_embeddings = self._process_video_input(video_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
......
...@@ -1260,14 +1260,14 @@ class Qwen3OmniMoeThinkerForConditionalGeneration( ...@@ -1260,14 +1260,14 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
for modality in mm_input_by_modality: for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality] multimodal_input = mm_input_by_modality[modality]
if modality == "image": if modality == "image":
vision_embeddings = self._process_image_input(multimodal_input) image_embeddings = self._process_image_input(multimodal_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += tuple(image_embeddings)
if modality == "video": if modality == "video":
video_embeddings = self._process_video_input(multimodal_input) video_embeddings = self._process_video_input(multimodal_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += tuple(video_embeddings)
if modality == "audio": if modality == "audio":
audio_embeddings = self._process_audio_input(multimodal_input) audio_embeddings = self._process_audio_input(multimodal_input)
multimodal_embeddings += audio_embeddings multimodal_embeddings += tuple(audio_embeddings)
return multimodal_embeddings return multimodal_embeddings
def get_input_embeddings( def get_input_embeddings(
......
...@@ -1601,11 +1601,11 @@ class Qwen3VLForConditionalGeneration( ...@@ -1601,11 +1601,11 @@ class Qwen3VLForConditionalGeneration(
for modality in mm_input_by_modality: for modality in mm_input_by_modality:
multimodal_input = mm_input_by_modality[modality] multimodal_input = mm_input_by_modality[modality]
if modality == "image": if modality == "image":
vision_embeddings = self._process_image_input(multimodal_input) image_embeddings = self._process_image_input(multimodal_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += tuple(image_embeddings)
if modality == "video": if modality == "video":
video_embeddings = self._process_video_input(multimodal_input) video_embeddings = self._process_video_input(multimodal_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
def _compute_deepstack_embeds( def _compute_deepstack_embeds(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment