Unverified Commit 14fdd21d authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[Core] More fixes to MultiModalEmbeddings type handling (#19715)


Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
parent 04fefe7c
...@@ -620,7 +620,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -620,7 +620,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index) self.config.image_token_index)
......
...@@ -430,7 +430,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -430,7 +430,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
......
...@@ -641,7 +641,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -641,7 +641,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
_IMAGE_TOKEN_ID) _IMAGE_TOKEN_ID)
......
...@@ -1005,7 +1005,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1005,7 +1005,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids) inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
self.model.vocabulary_mapping.image_token_id) self.model.vocabulary_mapping.image_token_id)
......
...@@ -600,7 +600,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -600,7 +600,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
self.image_token_id) self.image_token_id)
......
...@@ -1046,7 +1046,8 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1046,7 +1046,8 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
self.pad_token_id) self.pad_token_id)
......
...@@ -345,7 +345,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -345,7 +345,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
......
...@@ -592,7 +592,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -592,7 +592,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
......
...@@ -609,7 +609,8 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -609,7 +609,8 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.transformer.get_input_embeddings(input_ids) inputs_embeds = self.transformer.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
......
...@@ -721,7 +721,8 @@ class GraniteSpeechForConditionalGeneration( ...@@ -721,7 +721,8 @@ class GraniteSpeechForConditionalGeneration(
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Compute the merged LLM / audio embeddings.""" """Compute the merged LLM / audio embeddings."""
if multimodal_embeddings is None: if multimodal_embeddings is None \
or len(multimodal_embeddings) == 0:
return self.language_model.get_input_embeddings(input_ids) return self.language_model.get_input_embeddings(input_ids)
inputs_embeds = embed_multimodal( inputs_embeds = embed_multimodal(
......
...@@ -720,7 +720,8 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -720,7 +720,8 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids) inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
......
...@@ -1336,7 +1336,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -1336,7 +1336,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
context_token_ids = [ context_token_ids = [
token_id for token_id in (self.img_context_token_id, token_id for token_id in (self.img_context_token_id,
self.video_context_token_id) self.video_context_token_id)
......
...@@ -393,7 +393,8 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -393,7 +393,8 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
# model as one of the requirements of basic vLLM model implementation. # model as one of the requirements of basic vLLM model implementation.
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None and len(
multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
......
...@@ -683,7 +683,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -683,7 +683,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
......
...@@ -502,7 +502,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -502,7 +502,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if not multimodal_embeddings: if multimodal_embeddings is None \
or len(multimodal_embeddings) == 0:
return self.language_model.get_input_embeddings(input_ids) return self.language_model.get_input_embeddings(input_ids)
inputs_embeds = embed_multimodal( inputs_embeds = embed_multimodal(
......
...@@ -426,7 +426,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -426,7 +426,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
self.config.video_token_index) self.config.video_token_index)
......
...@@ -881,7 +881,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -881,7 +881,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings, input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_index, self.config.video_token_index]) [self.config.image_token_index, self.config.video_token_index])
......
...@@ -892,7 +892,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -892,7 +892,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids) inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
assert len(self.mm_token_ids) > 0 assert len(self.mm_token_ids) > 0
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
......
...@@ -201,7 +201,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -201,7 +201,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
......
...@@ -521,7 +521,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, ...@@ -521,7 +521,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment