Unverified Commit d55244df authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Model] Add `SupportsMultiModal.get_language_model` interface (#16007)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 04149cce
...@@ -79,6 +79,17 @@ Further update the model as follows: ...@@ -79,6 +79,17 @@ Further update the model as follows:
return inputs_embeds return inputs_embeds
``` ```
- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model` getter to provide stable access to the underlying language model.
```python
class YourModelForImage2Seq(nn.Module):
...
def get_language_model(self) -> torch.nn.Module:
# Change `language_model` according to your implementation.
return self.language_model
```
- Once the above steps are done, update the model class with the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface. - Once the above steps are done, update the model class with the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
```diff ```diff
......
...@@ -605,6 +605,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -605,6 +605,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.multi_modal_projector(image_outputs, image_attn_mask) return self.multi_modal_projector(image_outputs, image_attn_mask)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -424,6 +424,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -424,6 +424,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
num_patches=num_patches, num_patches=num_patches,
) )
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -627,6 +627,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -627,6 +627,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return self.language_projection(query_output) return self.language_projection(query_output)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -988,6 +988,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -988,6 +988,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),
) )
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -604,6 +604,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -604,6 +604,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self._pixel_values_to_embedding( return self._pixel_values_to_embedding(
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop) pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -1050,6 +1050,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1050,6 +1050,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values = image_input["data"] pixel_values = image_input["data"]
return self._encode_image(pixel_values) return self._encode_image(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -341,6 +341,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -341,6 +341,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return vision_embeddings_flat.split(patches_per_image, dim=0) return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -591,6 +591,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -591,6 +591,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist()) e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
] ]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -596,6 +596,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -596,6 +596,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
return self.transformer.vision(pixel_values) return self.transformer.vision(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.transformer
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -710,6 +710,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -710,6 +710,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
e.flatten(0, 1) for e in image_features.split(num_patches.tolist()) e.flatten(0, 1) for e in image_features.split(num_patches.tolist())
] ]
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -56,6 +56,18 @@ class SupportsMultiModal(Protocol): ...@@ -56,6 +56,18 @@ class SupportsMultiModal(Protocol):
""" """
... ...
def get_language_model(self) -> torch.nn.Module:
"""
Returns the underlying language model used for text generation.
This is typically the `torch.nn.Module` instance responsible for
processing the merged multimodal embeddings and producing hidden states
Returns:
torch.nn.Module: The core language model component.
"""
...
# Only for models that support v0 chunked prefill # Only for models that support v0 chunked prefill
# TODO(ywang96): Remove this overload once v0 is deprecated # TODO(ywang96): Remove this overload once v0 is deprecated
@overload @overload
......
...@@ -884,6 +884,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -884,6 +884,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else: else:
self.visual_token_mask = None self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -674,6 +674,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -674,6 +674,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_embeds = torch.split(image_embeds, feature_sizes) image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -480,6 +480,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -480,6 +480,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings) for i, patch_features_batch in enumerate(patch_embeddings)
] ]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -421,6 +421,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -421,6 +421,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return [e.flatten(0, 1) for e in embeds] return [e.flatten(0, 1) for e in embeds]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
video_input = self._parse_and_validate_video_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs)
......
...@@ -852,6 +852,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -852,6 +852,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
image_feature = image_feature.view(batch_frames, -1, dim) image_feature = image_feature.view(batch_frames, -1, dim)
return image_feature return image_feature
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
......
...@@ -892,6 +892,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -892,6 +892,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
return multimodal_embeddings return multimodal_embeddings
def get_language_model(self) -> torch.nn.Module:
return self.llm
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
......
...@@ -514,6 +514,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -514,6 +514,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = (image_embeds, ) image_embeds = (image_embeds, )
return image_embeds return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -1325,6 +1325,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1325,6 +1325,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
cross_attention_states = cross_attention_states_flat cross_attention_states = cross_attention_states_flat
return cross_attention_states return cross_attention_states
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_cross_attention_states( def get_cross_attention_states(
self, self,
image_inputs: MllamaImagePixelInputs, image_inputs: MllamaImagePixelInputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment