Unverified Commit 14fdd21d authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[Core] More fixes to MultiModalEmbeddings type handling (#19715)


Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
parent 04fefe7c
......@@ -808,7 +808,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None and len(
multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
......
......@@ -1487,7 +1487,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
assert self.img_patch_id is not None
inputs_embeds = merge_multimodal_embeddings(
......
......@@ -515,7 +515,8 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.image_pad_token_id)
......
......@@ -364,7 +364,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index)
......
......@@ -669,7 +669,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
if multimodal_embeddings:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.image_token_id)
......
......@@ -1148,7 +1148,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.embed_tokens(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None and len(
multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID])
......
......@@ -423,7 +423,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
......
......@@ -805,7 +805,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
# TODO (ywang96): support overlapping modalitiy embeddings so that
# `use_audio_in_video` will work on V1.
......@@ -845,7 +846,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is None:
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds
for embeddings, modality in multimodal_embeddings:
......
......@@ -1046,7 +1046,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
......
......@@ -364,7 +364,8 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.audio_token_index)
......
......@@ -1289,7 +1289,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id])
......
......@@ -754,7 +754,8 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
) -> torch.Tensor:
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.transformer.visual.image_pad_id)
......
......@@ -883,7 +883,8 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
assert self.img_context_token_id is not None
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
......
......@@ -598,7 +598,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
......
......@@ -560,7 +560,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
# TODO(ywang96): remove this block after v0 is deprecated.
if not envs.VLLM_USE_V1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment