Unverified Commit 14fdd21d authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[Core] More fixes to MultiModalEmbeddings type handling (#19715)


Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
parent 04fefe7c
...@@ -808,7 +808,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -808,7 +808,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None and len(
multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
......
...@@ -1487,7 +1487,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1487,7 +1487,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids) inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
assert self.img_patch_id is not None assert self.img_patch_id is not None
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
......
...@@ -515,7 +515,8 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -515,7 +515,8 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids) inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
self.image_pad_token_id) self.image_pad_token_id)
......
...@@ -364,7 +364,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -364,7 +364,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index) self.config.image_token_index)
......
...@@ -669,7 +669,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -669,7 +669,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if multimodal_embeddings: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
self.image_token_id) self.image_token_id)
......
...@@ -1148,7 +1148,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): ...@@ -1148,7 +1148,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.model.embed_tokens(input_ids) inputs_embeds = self.model.embed_tokens(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None and len(
multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
[_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID]) [_IMAGE_PLACEHOLDER_TOKEN_ID, _AUDIO_PLACEHOLDER_TOKEN_ID])
......
...@@ -423,7 +423,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -423,7 +423,8 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
......
...@@ -805,7 +805,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -805,7 +805,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
# TODO (ywang96): support overlapping modalitiy embeddings so that # TODO (ywang96): support overlapping modalitiy embeddings so that
# `use_audio_in_video` will work on V1. # `use_audio_in_video` will work on V1.
...@@ -845,7 +846,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration( ...@@ -845,7 +846,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is None: if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds return inputs_embeds
for embeddings, modality in multimodal_embeddings: for embeddings, modality in multimodal_embeddings:
......
...@@ -1046,7 +1046,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1046,7 +1046,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id]) [self.config.image_token_id, self.config.video_token_id])
......
...@@ -364,7 +364,8 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -364,7 +364,8 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
self.config.audio_token_index) self.config.audio_token_index)
......
...@@ -1289,7 +1289,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1289,7 +1289,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_id, self.config.video_token_id]) [self.config.image_token_id, self.config.video_token_id])
......
...@@ -754,7 +754,8 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, ...@@ -754,7 +754,8 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.transformer.get_input_embeddings(input_ids) inputs_embeds = self.transformer.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
self.transformer.visual.image_pad_id) self.transformer.visual.image_pad_id)
......
...@@ -883,7 +883,8 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -883,7 +883,8 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
assert self.img_context_token_id is not None assert self.img_context_token_id is not None
self._set_visual_token_mask(input_ids) self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
......
...@@ -598,7 +598,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -598,7 +598,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
......
...@@ -560,7 +560,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -560,7 +560,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
# TODO(ywang96): remove this block after v0 is deprecated. # TODO(ywang96): remove this block after v0 is deprecated.
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment