Unverified Commit b3cf368d authored by lkchen's avatar lkchen Committed by GitHub
Browse files

[V1][Molmo] Fix get_multimodal_embeddings() in molmo.py (#14161)

parent c8525f06
...@@ -602,7 +602,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -602,7 +602,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.multi_modal_projector(image_outputs, image_attn_mask) return self.multi_modal_projector(image_outputs, image_attn_mask)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
...@@ -628,7 +628,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -628,7 +628,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_projection(query_output) return self.language_projection(query_output)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
...@@ -986,7 +986,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -986,7 +986,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
data=self._validate_pixel_values(pixel_values), data=self._validate_pixel_values(pixel_values),
) )
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
...@@ -606,7 +606,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -606,7 +606,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self._pixel_values_to_embedding( return self._pixel_values_to_embedding(
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop) pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor: def get_multimodal_embeddings(
self, **kwargs: object
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
...@@ -1037,7 +1037,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1037,7 +1037,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_values = image_input["data"] pixel_values = image_input["data"]
return self._encode_image(pixel_values) return self._encode_image(pixel_values)
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor: def get_multimodal_embeddings(
self, **kwargs: object
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -327,7 +327,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -327,7 +327,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_patches_flat) image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0) return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
...@@ -595,7 +595,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -595,7 +595,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
return self.transformer.vision(pixel_values) return self.transformer.vision(pixel_values)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
...@@ -617,7 +617,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -617,7 +617,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.logits_processor = LogitsProcessor(config.text_config.vocab_size) self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.sampler = get_sampler() self.sampler = get_sampler()
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self.model._parse_and_validate_image_input(**kwargs) image_input = self.model._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
...@@ -4,6 +4,7 @@ from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, ...@@ -4,6 +4,7 @@ from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol, Type, Union, overload, runtime_checkable) Protocol, Type, Union, overload, runtime_checkable)
import torch import torch
from torch import Tensor
from typing_extensions import TypeIs, TypeVar from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -15,12 +16,11 @@ from .interfaces_base import is_pooling_model ...@@ -15,12 +16,11 @@ from .interfaces_base import is_pooling_model
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.multimodal.inputs import NestedTensors # noqa: F401
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
logger = init_logger(__name__) logger = init_logger(__name__)
T = TypeVar("T", default="NestedTensors") T = TypeVar("T", default=Union[list[Tensor], Tensor, tuple[Tensor, ...]])
@runtime_checkable @runtime_checkable
...@@ -36,7 +36,7 @@ class SupportsMultiModal(Protocol): ...@@ -36,7 +36,7 @@ class SupportsMultiModal(Protocol):
MRO of your model class. MRO of your model class.
""" """
def get_multimodal_embeddings(self, **kwargs) -> Optional[T]: def get_multimodal_embeddings(self, **kwargs) -> T:
""" """
Returns multimodal embeddings generated from multimodal kwargs Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings. to be merged with text embeddings.
...@@ -59,18 +59,18 @@ class SupportsMultiModal(Protocol): ...@@ -59,18 +59,18 @@ class SupportsMultiModal(Protocol):
@overload @overload
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: Tensor,
multimodal_embeddings: Optional[T] = None, multimodal_embeddings: Optional[T] = None,
attn_metadata: Optional["AttentionMetadata"] = None, attn_metadata: Optional["AttentionMetadata"] = None,
) -> torch.Tensor: ) -> Tensor:
... ...
@overload @overload
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: Tensor,
multimodal_embeddings: Optional[T] = None, multimodal_embeddings: Optional[T] = None,
) -> torch.Tensor: ) -> Tensor:
""" """
Returns the input embeddings merged from the text embeddings from Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal input_ids and the multimodal embeddings generated from multimodal
...@@ -210,7 +210,7 @@ class SupportsPP(Protocol): ...@@ -210,7 +210,7 @@ class SupportsPP(Protocol):
self, self,
*, *,
intermediate_tensors: Optional["IntermediateTensors"], intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]: ) -> Union[Tensor, "IntermediateTensors"]:
""" """
Accept :class:`IntermediateTensors` when PP rank > 0. Accept :class:`IntermediateTensors` when PP rank > 0.
...@@ -237,7 +237,7 @@ class _SupportsPPType(Protocol): ...@@ -237,7 +237,7 @@ class _SupportsPPType(Protocol):
self, self,
*, *,
intermediate_tensors: Optional["IntermediateTensors"], intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]: ) -> Union[Tensor, "IntermediateTensors"]:
... ...
......
...@@ -904,7 +904,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -904,7 +904,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else: else:
self.visual_token_mask = None self.visual_token_mask = None
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
...@@ -635,7 +635,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -635,7 +635,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_features = self._process_image_pixels(image_input) image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features) return self.multi_modal_projector(image_features)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
...@@ -479,7 +479,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -479,7 +479,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings) for i, patch_features_batch in enumerate(patch_embeddings)
] ]
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
...@@ -420,7 +420,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -420,7 +420,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError( raise ValueError(
f"Unsupported type of video input {type(video_pixels)}") f"Unsupported type of video input {type(video_pixels)}")
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
video_input = self._parse_and_validate_video_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is None: if video_input is None:
return None return None
......
...@@ -50,7 +50,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -50,7 +50,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptInsertion, PromptUpdate) PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, json_map_leaves from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
SupportsQuant) SupportsQuant)
...@@ -1576,14 +1576,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1576,14 +1576,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
return embeds_in_batch return embeds_in_batch
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) image_features = self._process_image_input(image_input)
return [ nested_embeds = [
self._get_mm_embeds(*args) for args in zip( self._get_mm_embeds(*args) for args in zip(
image_features, image_features,
image_input["feat_is_patch"], image_input["feat_is_patch"],
...@@ -1591,6 +1593,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1591,6 +1593,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_input["embed_is_patch"], image_input["embed_is_patch"],
) )
] ]
return flatten_2d_lists(nested_embeds)
def get_input_embeddings( def get_input_embeddings(
self, self,
......
...@@ -263,7 +263,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -263,7 +263,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.multi_modal_projector(image_features) return self.multi_modal_projector(image_features)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
...@@ -648,7 +648,9 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -648,7 +648,9 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return image_embeds return image_embeds
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
...@@ -220,7 +220,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -220,7 +220,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return get_sampler() return get_sampler()
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input, image_tokens = self._parse_and_validate_image_input( image_input, image_tokens = self._parse_and_validate_image_input(
**kwargs) **kwargs)
if image_input is None: if image_input is None:
......
...@@ -356,7 +356,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -356,7 +356,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
return torch.split(masked_audio_features, return torch.split(masked_audio_features,
audio_output_lengths.flatten().tolist()) audio_output_lengths.flatten().tolist())
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None: if audio_input is None:
return None return None
......
...@@ -740,7 +740,9 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, ...@@ -740,7 +740,9 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
return self.transformer.visual(image_input["data"]) return self.transformer.visual(image_input["data"])
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment