Unverified Commit 601bd326 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Clean up type annotation for `SupportsMultiModal` (#14794)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 09269b31
...@@ -34,7 +34,8 @@ Further update the model as follows: ...@@ -34,7 +34,8 @@ Further update the model as follows:
image_features = self.vision_encoder(image_input) image_features = self.vision_encoder(image_input)
return self.multi_modal_projector(image_features) return self.multi_modal_projector(image_features)
def get_multimodal_embeddings(self, **kwargs: object) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
# Validate the multimodal input keyword arguments # Validate the multimodal input keyword arguments
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
...@@ -61,7 +62,7 @@ Further update the model as follows: ...@@ -61,7 +62,7 @@ Further update the model as follows:
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# `get_input_embeddings` should already be implemented for the language # `get_input_embeddings` should already be implemented for the language
......
...@@ -214,7 +214,7 @@ MULTIMODAL_MODELS = { ...@@ -214,7 +214,7 @@ MULTIMODAL_MODELS = {
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(), "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(),
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(), "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(), "allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
"microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(), "microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"), "mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(), "Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
...@@ -237,7 +237,7 @@ TEST_MODELS = [ ...@@ -237,7 +237,7 @@ TEST_MODELS = [
"BAAI/bge-multilingual-gemma2", "BAAI/bge-multilingual-gemma2",
# [MULTIMODAL GENERATION] # [MULTIMODAL GENERATION]
"OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL2-1B",
"microsoft/Phi-3-vision-128k-instruct", "microsoft/Phi-3.5-vision-instruct",
"fixie-ai/ultravox-v0_5-llama-3_2-1b", "fixie-ai/ultravox-v0_5-llama-3_2-1b",
# [LANGUAGE GENERATION - HYBRID ARCH] # [LANGUAGE GENERATION - HYBRID ARCH]
"ai21labs/Jamba-tiny-dev", "ai21labs/Jamba-tiny-dev",
......
...@@ -21,8 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead ...@@ -21,8 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
...@@ -35,7 +34,7 @@ from .idefics2_vision_model import Idefics2VisionConfig ...@@ -35,7 +34,7 @@ from .idefics2_vision_model import Idefics2VisionConfig
from .idefics2_vision_model import ( from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer) Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable # yapf: enable
from .interfaces import SupportsMultiModal, SupportsQuant from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter, maybe_prefix, is_pp_missing_parameter, maybe_prefix,
...@@ -607,8 +606,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -607,8 +606,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.multi_modal_projector(image_outputs, image_attn_mask) return self.multi_modal_projector(image_outputs, image_attn_mask)
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -618,7 +616,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -618,7 +616,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
...@@ -15,8 +15,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -15,8 +15,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets, BaseProcessingInfo, PromptIndexTargets,
...@@ -25,7 +24,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs ...@@ -25,7 +24,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .blip import BlipVisionModel from .blip import BlipVisionModel
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
...@@ -629,8 +628,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -629,8 +628,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_projection(query_output) return self.language_projection(query_output)
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -640,7 +638,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -640,7 +638,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
...@@ -30,8 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -30,8 +30,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
...@@ -39,7 +38,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -39,7 +38,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
...@@ -986,8 +985,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -986,8 +985,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
) )
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -1000,7 +998,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1000,7 +998,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids) inputs_embeds = self.model.get_input_embeddings(input_ids)
......
...@@ -36,7 +36,7 @@ from vllm.transformers_utils.processors.deepseek_vl2 import ( ...@@ -36,7 +36,7 @@ from vllm.transformers_utils.processors.deepseek_vl2 import (
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
...@@ -605,8 +605,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -605,8 +605,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values=pixel_values, images_spatial_crop=images_spatial_crop) pixel_values=pixel_values, images_spatial_crop=images_spatial_crop)
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -616,7 +615,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -616,7 +615,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
...@@ -20,7 +20,7 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, ...@@ -20,7 +20,7 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
BartParallelLMHead, BartParallelLMHead,
BartScaledWordEmbedding) BartScaledWordEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems
from vllm.multimodal.processing import (BaseProcessingInfo, from vllm.multimodal.processing import (BaseProcessingInfo,
...@@ -30,7 +30,8 @@ from vllm.multimodal.processing import (BaseProcessingInfo, ...@@ -30,7 +30,8 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsV0Only from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsV0Only)
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
...@@ -1037,8 +1038,7 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1037,8 +1038,7 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal):
return self._encode_image(pixel_values) return self._encode_image(pixel_values)
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs: object self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -1048,7 +1048,7 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1048,7 +1048,7 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal):
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union from typing import List, Literal, Optional, Set, Tuple, TypedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
...@@ -327,8 +327,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -327,8 +327,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return vision_embeddings_flat.split(patches_per_image, dim=0) return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -338,7 +337,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -338,7 +337,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
...@@ -14,8 +14,7 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm ...@@ -14,8 +14,7 @@ from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
NestedTensors)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
...@@ -24,7 +23,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -24,7 +23,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
...@@ -481,7 +480,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -481,7 +480,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
) )
return self.multi_modal_projector(vision_outputs) return self.multi_modal_projector(vision_outputs)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -491,7 +491,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -491,7 +491,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if multimodal_embeddings is None: if multimodal_embeddings is None:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
......
...@@ -28,7 +28,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -28,7 +28,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BatchFeature, BaseProcessingInfo, BatchFeature,
...@@ -39,7 +39,8 @@ from vllm.sequence import IntermediateTensors ...@@ -39,7 +39,8 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
from .chatglm import ChatGLMBaseModel, ChatGLMModel from .chatglm import ChatGLMBaseModel, ChatGLMModel
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import flatten_bn, merge_multimodal_embeddings from .utils import flatten_bn, merge_multimodal_embeddings
...@@ -596,8 +597,7 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -596,8 +597,7 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
return self.transformer.vision(pixel_values) return self.transformer.vision(pixel_values)
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -608,7 +608,7 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, ...@@ -608,7 +608,7 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.transformer.get_input_embeddings(input_ids) inputs_embeds = self.transformer.get_input_embeddings(input_ids)
......
...@@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors ...@@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors
from .idefics2_vision_model import ( from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer) Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable # yapf: enable
from .interfaces import SupportsLoRA, SupportsMultiModal from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
...@@ -617,8 +617,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -617,8 +617,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.sampler = get_sampler() self.sampler = get_sampler()
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self.model._parse_and_validate_image_input(**kwargs) image_input = self.model._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -628,7 +627,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -628,7 +627,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids) inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
...@@ -5,7 +5,7 @@ from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, ...@@ -5,7 +5,7 @@ from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
import torch import torch
from torch import Tensor from torch import Tensor
from typing_extensions import TypeIs, TypeVar from typing_extensions import TypeIs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -20,7 +20,14 @@ if TYPE_CHECKING: ...@@ -20,7 +20,14 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
T = TypeVar("T", default=Union[list[Tensor], Tensor, tuple[Tensor, ...]]) MultiModalEmbeddings = Union[list[Tensor], Tensor, tuple[Tensor, ...]]
"""
The output embeddings must be one of the following formats:
- A list or tuple of 2D tensors, where each tensor corresponds to
each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
"""
@runtime_checkable @runtime_checkable
...@@ -36,17 +43,12 @@ class SupportsMultiModal(Protocol): ...@@ -36,17 +43,12 @@ class SupportsMultiModal(Protocol):
MRO of your model class. MRO of your model class.
""" """
def get_multimodal_embeddings(self, **kwargs) -> T: def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
""" """
Returns multimodal embeddings generated from multimodal kwargs Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings. to be merged with text embeddings.
The output embeddings must be one of the following formats:
- A list or tuple of 2D tensors, where each tensor corresponds to
each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
Note: Note:
The returned multimodal embeddings must be in the same order as The returned multimodal embeddings must be in the same order as
the appearances of their corresponding multimodal data item in the the appearances of their corresponding multimodal data item in the
...@@ -60,7 +62,7 @@ class SupportsMultiModal(Protocol): ...@@ -60,7 +62,7 @@ class SupportsMultiModal(Protocol):
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: Tensor, input_ids: Tensor,
multimodal_embeddings: Optional[T] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
attn_metadata: Optional["AttentionMetadata"] = None, attn_metadata: Optional["AttentionMetadata"] = None,
) -> Tensor: ) -> Tensor:
... ...
...@@ -69,7 +71,7 @@ class SupportsMultiModal(Protocol): ...@@ -69,7 +71,7 @@ class SupportsMultiModal(Protocol):
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: Tensor, input_ids: Tensor,
multimodal_embeddings: Optional[T] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> Tensor: ) -> Tensor:
""" """
Returns the input embeddings merged from the text embeddings from Returns the input embeddings merged from the text embeddings from
......
...@@ -37,7 +37,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs ...@@ -37,7 +37,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
...@@ -905,8 +905,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -905,8 +905,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.visual_token_mask = None self.visual_token_mask = None
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -916,7 +915,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -916,7 +915,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
...@@ -38,7 +38,7 @@ from vllm.sequence import IntermediateTensors ...@@ -38,7 +38,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel, from .pixtral import (PixtralHFVisionModel,
get_pixtral_hf_image_feature_grid_size) get_pixtral_hf_image_feature_grid_size)
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
...@@ -778,7 +778,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -778,7 +778,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return embeds_in_batch return embeds_in_batch
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -800,7 +801,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -800,7 +801,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
...@@ -16,12 +16,12 @@ from vllm.config import VllmConfig ...@@ -16,12 +16,12 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import ImageSize from vllm.multimodal.parse import ImageSize
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo, from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
LlavaDummyInputsBuilder, LlavaLikeConfig, LlavaDummyInputsBuilder, LlavaLikeConfig,
LlavaMultiModalProjector, init_vision_tower_for_llava) LlavaMultiModalProjector, init_vision_tower_for_llava)
...@@ -480,8 +480,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -480,8 +480,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
] ]
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -491,7 +490,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -491,7 +490,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if multimodal_embeddings is None: if multimodal_embeddings is None:
......
...@@ -16,8 +16,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -16,8 +16,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
NestedTensors)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems) VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
...@@ -27,7 +26,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs ...@@ -27,7 +26,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava from .llava import init_vision_tower_for_llava
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
...@@ -421,8 +420,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -421,8 +420,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
f"Unsupported type of video input {type(video_pixels)}") f"Unsupported type of video input {type(video_pixels)}")
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
video_input = self._parse_and_validate_video_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is None: if video_input is None:
return None return None
...@@ -432,7 +430,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -432,7 +430,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
...@@ -19,8 +19,7 @@ from vllm.model_executor.layers.activation import get_act_fn ...@@ -19,8 +19,7 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
NestedTensors)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems) VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import PromptReplacement, PromptUpdate
...@@ -29,7 +28,7 @@ from vllm.sequence import IntermediateTensors ...@@ -29,7 +28,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava from .llava import LlavaDummyInputsBuilder, init_vision_tower_for_llava
from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig, from .llava_next import (BaseLlavaNextMultiModalProcessor, LlavaNextLikeConfig,
LlavaNextProcessingInfo) LlavaNextProcessingInfo)
...@@ -856,7 +855,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -856,7 +855,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return image_feature return image_feature
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs) -> Optional[tuple[torch.Tensor, ...]]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
return None return None
...@@ -882,7 +881,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -882,7 +881,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[tuple[torch.Tensor, ...]] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
...@@ -894,10 +893,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -894,10 +893,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings_v0( def get_input_embeddings_v0(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
image_input: Optional[NestedTensors] = None, image_input: Optional[LlavaOnevisionImagePixelInputs] = None,
video_input: Optional[NestedTensors] = None, video_input: Optional[LlavaOnevisionVideoPixelInputs] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.get_input_embeddings(input_ids)
if image_input is not None: if image_input is not None:
image_embeds = self._process_image_input(image_input) image_embeds = self._process_image_input(image_input)
......
...@@ -52,8 +52,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs ...@@ -52,8 +52,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsQuant) SupportsMultiModal, SupportsPP, SupportsQuant)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
...@@ -1577,8 +1577,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1577,8 +1577,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
return embeds_in_batch return embeds_in_batch
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -1598,7 +1597,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, ...@@ -1598,7 +1597,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids) inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
...@@ -13,8 +13,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput ...@@ -13,8 +13,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs, MultiModalInputs, MultiModalKwargs)
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets, BaseProcessingInfo, PromptIndexTargets,
...@@ -23,7 +22,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -23,7 +22,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
...@@ -328,8 +327,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -328,8 +327,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.multi_modal_projector(image_features) return self.multi_modal_projector(image_features)
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -341,7 +339,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -341,7 +339,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
...@@ -31,8 +31,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -31,8 +31,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
...@@ -48,7 +47,8 @@ from vllm.sequence import IntermediateTensors ...@@ -48,7 +47,8 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP, SupportsQuant from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
...@@ -649,8 +649,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -649,8 +649,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return image_embeds return image_embeds
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -660,7 +659,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -660,7 +659,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment