"vscode:/vscode.git/clone" did not exist on "80433e225ee0f91a7f6a082bf4e136df75ab4746"
Unverified Commit 90f9c2eb authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[V1] Change return type on get_multimodal_embeddings() (#19446)


Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
parent 387bdf0a
...@@ -601,11 +601,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -601,11 +601,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
multimodal_embeddings = self._process_image_input(image_input) multimodal_embeddings = self._process_image_input(image_input)
return multimodal_embeddings return multimodal_embeddings
......
...@@ -406,11 +406,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -406,11 +406,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
return self._process_image_input(image_input, **kwargs) return self._process_image_input(image_input, **kwargs)
......
...@@ -627,11 +627,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -627,11 +627,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
......
...@@ -987,11 +987,11 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -987,11 +987,11 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.model return self.model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
assert self.model.vqmodel is not None assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(image_input["data"].to( image_tokens = self.model.get_image_tokens(image_input["data"].to(
self.config.torch_dtype)) self.config.torch_dtype))
......
...@@ -586,11 +586,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -586,11 +586,11 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
......
...@@ -1032,11 +1032,11 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1032,11 +1032,11 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
......
...@@ -324,11 +324,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -324,11 +324,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
return self._process_image_input(image_input) return self._process_image_input(image_input)
......
...@@ -568,11 +568,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -568,11 +568,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
return self._process_image_input(image_input) return self._process_image_input(image_input)
......
...@@ -593,11 +593,11 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -593,11 +593,11 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.transformer return self.transformer
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
......
...@@ -706,10 +706,11 @@ class GraniteSpeechForConditionalGeneration( ...@@ -706,10 +706,11 @@ class GraniteSpeechForConditionalGeneration(
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, self,
**kwargs: object, **kwargs: object,
) -> Optional[MultiModalEmbeddings]: ) -> MultiModalEmbeddings:
"""Compute the audio embeddings if audio inputs are present.""" """Compute the audio embeddings if audio inputs are present."""
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None: if audio_input is None:
return []
return None return None
audio_features = self._process_audio_input(audio_input) audio_features = self._process_audio_input(audio_input)
return audio_features return audio_features
......
...@@ -706,11 +706,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -706,11 +706,11 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.model return self.model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
return self._process_image_input(image_input) return self._process_image_input(image_input)
......
...@@ -44,8 +44,8 @@ class SupportsMultiModal(Protocol): ...@@ -44,8 +44,8 @@ class SupportsMultiModal(Protocol):
MRO of your model class. MRO of your model class.
""" """
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
""" """
Returns multimodal embeddings generated from multimodal kwargs Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings. to be merged with text embeddings.
......
...@@ -1304,11 +1304,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -1304,11 +1304,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
return []
return None return None
# The result multimodal_embeddings is tuple of tensors, with each # The result multimodal_embeddings is tuple of tensors, with each
......
...@@ -659,11 +659,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -659,11 +659,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
return self._process_image_input(image_input) return self._process_image_input(image_input)
......
...@@ -478,11 +478,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -478,11 +478,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
return vision_embeddings return vision_embeddings
...@@ -492,7 +492,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -492,7 +492,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if multimodal_embeddings is None: if not multimodal_embeddings:
return self.language_model.get_input_embeddings(input_ids) return self.language_model.get_input_embeddings(input_ids)
inputs_embeds = embed_multimodal( inputs_embeds = embed_multimodal(
......
...@@ -401,11 +401,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -401,11 +401,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
video_input = self._parse_and_validate_video_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is None: if video_input is None:
return None return []
vision_embeddings = self._process_video_pixels(video_input) vision_embeddings = self._process_video_pixels(video_input)
return vision_embeddings return vision_embeddings
......
...@@ -839,11 +839,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -839,11 +839,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs( mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
**kwargs) **kwargs)
if not mm_input_by_modality: if not mm_input_by_modality:
return []
return None return None
# The result multimodal_embeddings is tuple of tensors, with each # The result multimodal_embeddings is tuple of tensors, with each
......
...@@ -878,11 +878,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -878,11 +878,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.llm return self.llm
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
return None return []
return self._process_multimodal_inputs(modalities) return self._process_multimodal_inputs(modalities)
......
...@@ -318,11 +318,11 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -318,11 +318,11 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
return self._process_image_input(image_input) return self._process_image_input(image_input)
......
...@@ -495,11 +495,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA, ...@@ -495,11 +495,11 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsLoRA,
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(self,
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return []
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment