Unverified Commit d55244df authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Model] Add `SupportsMultiModal.get_language_model` interface (#16007)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 04149cce
......@@ -742,6 +742,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
for img in vision_embeddings_flat.split(patches_per_image, dim=0)
]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -1488,6 +1488,9 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
)
]
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -323,6 +323,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.multi_modal_projector(image_features)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -674,6 +674,9 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -1802,3 +1802,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal,
connector=["audio_projection_for_vision", "audio_projection"],
tower_model=["vision_encoder", "embed_tokens_extend"],
)
def get_language_model(self) -> torch.nn.Module:
return self.model
......@@ -396,6 +396,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -967,6 +967,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
**kwargs)
return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
......
......@@ -355,6 +355,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
return torch.split(masked_audio_features,
audio_output_lengths.flatten().tolist())
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
......
......@@ -1276,6 +1276,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
......
......@@ -740,6 +740,9 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
return self.transformer.visual(image_input["data"])
def get_language_model(self) -> torch.nn.Module:
return self.transformer
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -889,6 +889,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else:
self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -563,6 +563,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
]
return flattened_embeddings.split(embed_lens)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
......
......@@ -692,6 +692,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
)
return decoder_outputs
def get_language_model(self) -> torch.nn.Module:
return self.model.decoder
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
# TODO: This method does not obey the interface for SupportsMultiModal.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment