Unverified Commit d55244df authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Model] Add `SupportsMultiModal.get_language_model` interface (#16007)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 04149cce
......@@ -79,6 +79,17 @@ Further update the model as follows:
return inputs_embeds
```
- Implement {meth}`~vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model` getter to provide stable access to the underlying language model.
```python
class YourModelForImage2Seq(nn.Module):
...
def get_language_model(self) -> torch.nn.Module:
# Change `language_model` according to your implementation.
return self.language_model
```
- Once the above steps are done, update the model class with the {class}`~vllm.model_executor.models.interfaces.SupportsMultiModal` interface.
```diff
......
......@@ -605,6 +605,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.multi_modal_projector(image_outputs, image_attn_mask)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -424,6 +424,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
num_patches=num_patches,
)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -627,6 +627,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return self.language_projection(query_output)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -988,6 +988,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
data=self._validate_pixel_values(pixel_values),
)
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -604,6 +604,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self._pixel_values_to_embedding(
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -1050,6 +1050,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values = image_input["data"]
return self._encode_image(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -341,6 +341,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -591,6 +591,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
e.flatten(0, 1) for e in image_embeds.split(num_patches.tolist())
]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -596,6 +596,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
return self.transformer.vision(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.transformer
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -710,6 +710,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
e.flatten(0, 1) for e in image_features.split(num_patches.tolist())
]
def get_language_model(self) -> torch.nn.Module:
return self.model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -56,6 +56,18 @@ class SupportsMultiModal(Protocol):
"""
...
def get_language_model(self) -> torch.nn.Module:
"""
Returns the underlying language model used for text generation.
This is typically the `torch.nn.Module` instance responsible for
processing the merged multimodal embeddings and producing hidden states
Returns:
torch.nn.Module: The core language model component.
"""
...
# Only for models that support v0 chunked prefill
# TODO(ywang96): Remove this overload once v0 is deprecated
@overload
......
......@@ -884,6 +884,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else:
self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -674,6 +674,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -480,6 +480,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings)
]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -421,6 +421,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return [e.flatten(0, 1) for e in embeds]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
video_input = self._parse_and_validate_video_input(**kwargs)
......
......@@ -852,6 +852,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
image_feature = image_feature.view(batch_frames, -1, dim)
return image_feature
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
......
......@@ -892,6 +892,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
return multimodal_embeddings
def get_language_model(self) -> torch.nn.Module:
return self.llm
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
......
......@@ -514,6 +514,9 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = (image_embeds, )
return image_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
......
......@@ -1325,6 +1325,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
cross_attention_states = cross_attention_states_flat
return cross_attention_states
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_cross_attention_states(
self,
image_inputs: MllamaImagePixelInputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment