Unverified Commit b3cf368d authored by lkchen's avatar lkchen Committed by GitHub
Browse files

[V1][Molmo] Fix get_multimodal_embeddings() in molmo.py (#14161)

parent c8525f06
......@@ -602,7 +602,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.multi_modal_projector(image_outputs, image_attn_mask)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
......
......@@ -628,7 +628,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_projection(query_output)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
......
......@@ -986,7 +986,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
data=self._validate_pixel_values(pixel_values),
)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
......
......@@ -606,7 +606,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self._pixel_values_to_embedding(
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
def get_multimodal_embeddings(
self, **kwargs: object
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
......
......@@ -1037,7 +1037,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_values = image_input["data"]
return self._encode_image(pixel_values)
def get_multimodal_embeddings(self, **kwargs: object) -> torch.Tensor:
def get_multimodal_embeddings(
self, **kwargs: object
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
......
......@@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
......@@ -327,7 +327,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
......
......@@ -595,7 +595,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
return self.transformer.vision(pixel_values)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
......
......@@ -617,7 +617,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
self.sampler = get_sampler()
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self.model._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
......
......@@ -4,6 +4,7 @@ from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol, Type, Union, overload, runtime_checkable)
import torch
from torch import Tensor
from typing_extensions import TypeIs, TypeVar
from vllm.logger import init_logger
......@@ -15,12 +16,11 @@ from .interfaces_base import is_pooling_model
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.multimodal.inputs import NestedTensors # noqa: F401
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
T = TypeVar("T", default="NestedTensors")
T = TypeVar("T", default=Union[list[Tensor], Tensor, tuple[Tensor, ...]])
@runtime_checkable
......@@ -36,7 +36,7 @@ class SupportsMultiModal(Protocol):
MRO of your model class.
"""
def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
def get_multimodal_embeddings(self, **kwargs) -> T:
"""
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
......@@ -59,18 +59,18 @@ class SupportsMultiModal(Protocol):
@overload
def get_input_embeddings(
self,
input_ids: torch.Tensor,
input_ids: Tensor,
multimodal_embeddings: Optional[T] = None,
attn_metadata: Optional["AttentionMetadata"] = None,
) -> torch.Tensor:
) -> Tensor:
...
@overload
def get_input_embeddings(
self,
input_ids: torch.Tensor,
input_ids: Tensor,
multimodal_embeddings: Optional[T] = None,
) -> torch.Tensor:
) -> Tensor:
"""
Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal
......@@ -210,7 +210,7 @@ class SupportsPP(Protocol):
self,
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
) -> Union[Tensor, "IntermediateTensors"]:
"""
Accept :class:`IntermediateTensors` when PP rank > 0.
......@@ -237,7 +237,7 @@ class _SupportsPPType(Protocol):
self,
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
) -> Union[Tensor, "IntermediateTensors"]:
...
......
......@@ -904,7 +904,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else:
self.visual_token_mask = None
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
......
......@@ -635,7 +635,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
......
......@@ -479,7 +479,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings)
]
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
......
......@@ -420,7 +420,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError(
f"Unsupported type of video input {type(video_pixels)}")
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is None:
return None
......
......@@ -50,7 +50,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, json_map_leaves
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
SupportsQuant)
......@@ -1576,14 +1576,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
return embeds_in_batch
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
image_features = self._process_image_input(image_input)
return [
nested_embeds = [
self._get_mm_embeds(*args) for args in zip(
image_features,
image_input["feat_is_patch"],
......@@ -1591,6 +1593,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_input["embed_is_patch"],
)
]
return flatten_2d_lists(nested_embeds)
def get_input_embeddings(
self,
......
......@@ -263,7 +263,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.multi_modal_projector(image_features)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
......
......@@ -648,7 +648,9 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return image_embeds
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
......
......@@ -220,7 +220,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return get_sampler()
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input, image_tokens = self._parse_and_validate_image_input(
**kwargs)
if image_input is None:
......
......@@ -356,7 +356,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
return torch.split(masked_audio_features,
audio_output_lengths.flatten().tolist())
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
return None
......
......@@ -740,7 +740,9 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
return self.transformer.visual(image_input["data"])
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment