Unverified Commit 14fdd21d authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[Core] More fixes to MultiModalEmbeddings type handling (#19715)


Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
parent 04fefe7c
......@@ -620,7 +620,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index)
......
......@@ -430,7 +430,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
......
......@@ -641,7 +641,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
_IMAGE_TOKEN_ID)
......
......@@ -1005,7 +1005,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.model.vocabulary_mapping.image_token_id)
......
......@@ -600,7 +600,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.image_token_id)
......
......@@ -1046,7 +1046,8 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.pad_token_id)
......
......@@ -345,7 +345,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
......
......@@ -592,7 +592,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
......
......@@ -609,7 +609,8 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
) -> torch.Tensor:
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
......
......@@ -721,7 +721,8 @@ class GraniteSpeechForConditionalGeneration(
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
"""Compute the merged LLM / audio embeddings."""
if multimodal_embeddings is None:
if multimodal_embeddings is None \
or len(multimodal_embeddings) == 0:
return self.language_model.get_input_embeddings(input_ids)
inputs_embeds = embed_multimodal(
......
......@@ -720,7 +720,8 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
......
......@@ -1336,7 +1336,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
context_token_ids = [
token_id for token_id in (self.img_context_token_id,
self.video_context_token_id)
......
......@@ -393,7 +393,8 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal):
# model as one of the requirements of basic vLLM model implementation.
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None and len(
multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
......
......@@ -683,7 +683,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
......
......@@ -502,7 +502,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
if not multimodal_embeddings:
if multimodal_embeddings is None \
or len(multimodal_embeddings) == 0:
return self.language_model.get_input_embeddings(input_ids)
inputs_embeds = embed_multimodal(
......
......@@ -426,7 +426,8 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.video_token_index)
......
......@@ -881,7 +881,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
[self.config.image_token_index, self.config.video_token_index])
......
......@@ -892,7 +892,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.llm.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
assert len(self.mm_token_ids) > 0
inputs_embeds = merge_multimodal_embeddings(
input_ids,
......
......@@ -201,7 +201,8 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
......
......@@ -521,7 +521,8 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment