Unverified Commit 601bd326 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Clean up type annotation for `SupportsMultiModal` (#14794)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 09269b31
...@@ -30,12 +30,12 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -30,12 +30,12 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import consecutive_placeholder_ranges from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (init_vllm_registered_model, maybe_prefix, from .utils import (init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
...@@ -221,8 +221,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -221,8 +221,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return get_sampler() return get_sampler()
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input, image_tokens = self._parse_and_validate_image_input( image_input, image_tokens = self._parse_and_validate_image_input(
**kwargs) **kwargs)
if image_input is None: if image_input is None:
...@@ -255,7 +254,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -255,7 +254,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
...@@ -59,7 +59,8 @@ from vllm.platforms import _Backend ...@@ -59,7 +59,8 @@ from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
apply_rotary_pos_emb_vision) apply_rotary_pos_emb_vision)
...@@ -952,7 +953,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -952,7 +953,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return modalities return modalities
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs) -> Optional[tuple[torch.Tensor, ...]]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
...@@ -978,7 +979,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -978,7 +979,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[tuple[torch.Tensor, ...]] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
...@@ -990,10 +991,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -990,10 +991,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings_v0( def get_input_embeddings_v0(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
image_input: Optional[tuple[torch.Tensor, ...]] = None, image_input: Optional[Qwen2_5_VLImageInputs] = None,
video_input: Optional[tuple[torch.Tensor, ...]] = None, video_input: Optional[Qwen2_5_VLVideoInputs] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.get_input_embeddings(input_ids)
if image_input is not None: if image_input is not None:
image_embeds = self._process_image_input(image_input) image_embeds = self._process_image_input(image_input)
......
...@@ -37,8 +37,7 @@ from vllm.config import VllmConfig ...@@ -37,8 +37,7 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
NestedTensors)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
...@@ -47,7 +46,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -47,7 +46,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
...@@ -357,8 +356,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -357,8 +356,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
audio_output_lengths.flatten().tolist()) audio_output_lengths.flatten().tolist())
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None: if audio_input is None:
return None return None
...@@ -368,7 +366,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -368,7 +366,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
...@@ -71,7 +71,8 @@ from vllm.transformers_utils.config import uses_mrope ...@@ -71,7 +71,8 @@ from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import ( from vllm.transformers_utils.processor import (
cached_image_processor_from_config) cached_image_processor_from_config)
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
...@@ -1262,7 +1263,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1262,7 +1263,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
return modalities return modalities
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs) -> Optional[tuple[torch.Tensor, ...]]: self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities: if not modalities:
...@@ -1289,7 +1290,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1289,7 +1290,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[tuple[torch.Tensor, ...]] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
...@@ -1301,10 +1302,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1301,10 +1302,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_input_embeddings_v0( def get_input_embeddings_v0(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
image_input: Optional[tuple[torch.Tensor, ...]] = None, image_input: Optional[Qwen2VLImagePixelInputs] = None,
video_input: Optional[tuple[torch.Tensor, ...]] = None, video_input: Optional[Qwen2VLVideoPixelInputs] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.get_input_embeddings(input_ids)
if image_input is not None: if image_input is not None:
image_embeds = self._process_image_input(image_input) image_embeds = self._process_image_input(image_input)
......
...@@ -32,8 +32,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -32,8 +32,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
...@@ -41,7 +40,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -41,7 +40,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .qwen import QWenBaseModel, QWenModel from .qwen import QWenBaseModel, QWenModel
from .utils import flatten_bn, merge_multimodal_embeddings from .utils import flatten_bn, merge_multimodal_embeddings
...@@ -741,8 +741,7 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, ...@@ -741,8 +741,7 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
return self.transformer.visual(image_input["data"]) return self.transformer.visual(image_input["data"])
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
...@@ -753,7 +752,7 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, ...@@ -753,7 +752,7 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.transformer.get_input_embeddings(input_ids) inputs_embeds = self.transformer.get_input_embeddings(input_ids)
......
...@@ -35,7 +35,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs ...@@ -35,7 +35,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings, merge_multimodal_embeddings,
...@@ -555,8 +556,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -555,8 +556,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
return flattened_embeddings return flattened_embeddings
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None: if audio_input is None:
return None return None
...@@ -566,7 +566,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -566,7 +566,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
......
...@@ -34,8 +34,8 @@ from vllm.multimodal.processing import (BaseProcessingInfo, ...@@ -34,8 +34,8 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .interfaces import (SupportsMultiModal, SupportsTranscription, from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsV0Only) SupportsTranscription, SupportsV0Only)
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors, from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
make_layers) make_layers)
...@@ -689,8 +689,7 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ...@@ -689,8 +689,7 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
return decoder_outputs return decoder_outputs
def get_multimodal_embeddings( def get_multimodal_embeddings(
self, **kwargs self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
# TODO: This method does not obey the interface for SupportsMultiModal. # TODO: This method does not obey the interface for SupportsMultiModal.
# Refactor this once encoder/decoder support is implemented in V1. # Refactor this once encoder/decoder support is implemented in V1.
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment