Unverified Commit 8d9b6721 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Abstract out multi-modal data parsing in merged processor (#11620)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent b12e87f9
...@@ -356,7 +356,7 @@ steps: ...@@ -356,7 +356,7 @@ steps:
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
- pytest -v -s models/embedding/language -m 'not core_model' - pytest -v -s models/embedding/language -m 'not core_model'
- label: Multi-Modal Models Test (Standard) # 28min - label: Multi-Modal Models Test (Standard) # 40min
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
...@@ -372,7 +372,7 @@ steps: ...@@ -372,7 +372,7 @@ steps:
- pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/language -m core_model
- pytest -v -s models/encoder_decoder/vision_language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model
- label: Multi-Modal Models Test (Extended) 1 # 1h16m - label: Multi-Modal Models Test (Extended) 1 # 48m
optional: true optional: true
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
......
...@@ -33,7 +33,7 @@ from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel ...@@ -33,7 +33,7 @@ from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs, from vllm.multimodal.inputs import (ModalityData, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
...@@ -54,7 +54,7 @@ def calculate_image_placeholder(vision_config): ...@@ -54,7 +54,7 @@ def calculate_image_placeholder(vision_config):
def mm_input_mapper_for_glmv( def mm_input_mapper_for_glmv(
ctx: InputContext, ctx: InputContext,
data: MultiModalData[object], data: ModalityData[object],
) -> Dict: ) -> Dict:
model_config = ctx.model_config model_config = ctx.model_config
tokenizer = cached_get_tokenizer( tokenizer = cached_get_tokenizer(
......
...@@ -20,11 +20,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -20,11 +20,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalFieldConfig, MultiModalInputsV2, MultiModalInputsV2, MultiModalKwargs,
MultiModalKwargs, NestedTensors) NestedTensors)
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement, MultiModalDataItems, ProcessorInputs,
PromptReplacement,
full_groupby_modality) full_groupby_modality)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -179,7 +181,9 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor): ...@@ -179,7 +181,9 @@ class LlavaMultiModalProcessor(BaseMultiModalProcessor):
assert isinstance(vision_config, PixtralVisionConfig) assert isinstance(vision_config, PixtralVisionConfig)
def get_replacement_pixtral(item_idx: int): def get_replacement_pixtral(item_idx: int):
image_size = mm_items.get_image_size(item_idx) images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
( (
num_width_tokens, num_width_tokens,
num_height_tokens, num_height_tokens,
...@@ -591,8 +595,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -591,8 +595,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
mm_items = self._get_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_item_counts() mm_item_counts = mm_items.get_all_counts()
mm_kwargs = result["mm_kwargs"] mm_kwargs = result["mm_kwargs"]
# We reimplement the functionality of MLlavaProcessor from # We reimplement the functionality of MLlavaProcessor from
......
...@@ -32,12 +32,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -32,12 +32,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalFieldConfig, MultiModalInputsV2, MultiModalInputsV2, MultiModalKwargs,
MultiModalKwargs, NestedTensors, NestedTensors, PlaceholderRange)
PlaceholderRange) from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement, MultiModalDataItems, ProcessorInputs,
PromptReplacement,
_BoundPromptReplacement, _BoundPromptReplacement,
_PlaceholderInfo) _PlaceholderInfo)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -381,7 +382,9 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor): ...@@ -381,7 +382,9 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
assert isinstance(bos_token_id, int) assert isinstance(bos_token_id, int)
def get_replacement_phi3v(item_idx: int): def get_replacement_phi3v(item_idx: int):
image_size = mm_items.get_image_size(item_idx) images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
num_tokens = image_processor.calc_num_image_tokens_from_image_size( num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=image_size.width, width=image_size.width,
height=image_size.height, height=image_size.height,
...@@ -389,12 +392,14 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor): ...@@ -389,12 +392,14 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id] return [_IMAGE_TOKEN_ID] * num_tokens + [bos_token_id]
num_images = mm_items.get_count("image", strict=False)
return [ return [
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=image_token, target=image_token,
replacement=get_replacement_phi3v, replacement=get_replacement_phi3v,
) for image_token in image_tokens[:len(mm_items.images)] ) for image_token in image_tokens[:num_images]
] ]
def _apply_prompt_replacements( def _apply_prompt_replacements(
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from functools import cached_property from functools import cached_property
from typing import (Any, Iterable, List, Mapping, Optional, Set, Tuple, from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
TypedDict, Union) Union)
import numpy as np import numpy as np
import torch import torch
...@@ -38,10 +38,12 @@ from vllm.inputs import InputContext ...@@ -38,10 +38,12 @@ from vllm.inputs import InputContext
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
MultiModalKwargs, NestedTensors) NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement) MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
...@@ -99,15 +101,9 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): ...@@ -99,15 +101,9 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor):
def _get_feature_extractor(self) -> WhisperFeatureExtractor: def _get_feature_extractor(self) -> WhisperFeatureExtractor:
return self._get_hf_processor().feature_extractor # type: ignore return self._get_hf_processor().feature_extractor # type: ignore
def _get_hf_mm_data( def _get_data_parser(self) -> MultiModalDataParser:
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate) return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
return super()._get_hf_mm_data(mm_items)
def _call_hf_processor( def _call_hf_processor(
self, self,
......
...@@ -25,7 +25,6 @@ from functools import cached_property, partial ...@@ -25,7 +25,6 @@ from functools import cached_property, partial
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Set, Tuple, Type, TypedDict, Union) Set, Tuple, Type, TypedDict, Union)
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -55,15 +54,16 @@ from vllm.model_executor.layers.quantization.gptq_marlin import ( ...@@ -55,15 +54,16 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalDataItems, from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs, MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors, VideoItem)
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement) MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.platforms import _Backend from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.utils import is_list_of
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend, from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
...@@ -719,61 +719,81 @@ get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens, ...@@ -719,61 +719,81 @@ get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens,
data_type_key="video") data_type_key="video")
class Qwen2VLMultiModalDataItems(MultiModalDataItems): class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
dict[str, torch.Tensor]]):
@staticmethod def __init__(self, data: dict, modality: str) -> None:
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems": super().__init__(data)
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
"""
multi_data = Qwen2VLMultiModalDataItems()
for k, v in data.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = ( # type: ignore[index]
v if (
isinstance(v, (dict, torch.Tensor)) # type: ignore[assignment]
or is_list_of(v, list)
or isinstance(v[0], (np.ndarray, torch.Tensor))
and v[0].ndim == 4
) else [v]
)
elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index]
v if isinstance(v, (dict, torch.Tensor, list)) else [v]
)
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
return multi_data self.modality = modality
def get_item_counts(self) -> Mapping[str, int]: grid_thw = data[f"{modality}_grid_thw"]
return { slice_idxs = [0] + grid_thw.prod(-1).cumsum_(0).tolist()
m: ( self._slices = [
len(items[f"{m}_grid_thw"]) # type: ignore slice(slice_idxs[i], slice_idxs[i + 1])
if isinstance(items, dict) else len(items)) for i in range(len(grid_thw))
for m, items in self.items() ]
}
def has_embedding_inputs(self) -> bool: def __repr__(self) -> str:
return any( return (f"{type(self).__name__}(modality={self.modality!r})")
isinstance(items, dict) or any(
isinstance(item, torch.Tensor) for item in items)
for items in self.values())
def get_count(self) -> int:
return len(self.data[f"{self.modality}_grid_thw"])
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): def get(self, index: int) -> dict[str, torch.Tensor]:
out = {}
for k, v in self.data.items():
if v != f"{self.modality}_grid_thw":
v = v[self._slices[index]]
out[k] = v
return out
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return self.data
class Qwen2ImageEmbeddingItems(Qwen2EmbeddingItems):
def __init__(self, data: dict) -> None:
super().__init__(data, "image")
class Qwen2VideoEmbeddingItems(Qwen2EmbeddingItems):
def _get_mm_items( def __init__(self, data: dict) -> None:
super().__init__(data, "video")
class Qwen2MultiModalDataParser(MultiModalDataParser):
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return Qwen2EmbeddingItems(data, modality="image")
return super()._parse_image_data(data)
def _parse_video_data(
self, self,
mm_data: MultiModalDataDict, data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> MultiModalDataItems: ) -> ModalityDataItems[Any, Any]:
return Qwen2VLMultiModalDataItems.from_dict(mm_data) if isinstance(data, dict):
return Qwen2EmbeddingItems(data, modality="video")
return super()._parse_video_data(data)
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2MultiModalDataParser()
def _get_hf_processor( def _get_hf_processor(
self, self,
...@@ -796,35 +816,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): ...@@ -796,35 +816,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor):
return hf_processor return hf_processor
def _get_hf_mm_data(
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()
for k, v in mm_items.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
if k in ("image", "video", "audio"):
if isinstance(v, dict):
# Pass through embedding inputs (dict)
passthrough_data.update(v)
elif isinstance(v, torch.Tensor) and v.ndim == 3:
# Pass through embedding inputs (single)
passthrough_data[f"{k}_embeds"] = [v]
elif (is_list_of(v, torch.Tensor) and len(v) > 0
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
elif len(v) > 0:
# Map keys to plural form, e.g.: image -> images
processor_data[f"{k}s"] = v
else:
processor_data[k] = v
return processor_data, passthrough_data
def _get_prompt_replacements( def _get_prompt_replacements(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import math import math
from functools import cached_property, lru_cache from functools import cached_property, lru_cache
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
Tuple, TypedDict, Union) TypedDict, Union)
import numpy as np import numpy as np
import torch import torch
...@@ -24,10 +24,12 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -24,10 +24,12 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataItems, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
MultiModalKwargs, NestedTensors) NestedTensors)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
ProcessorInputs, PromptReplacement) MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of from vllm.utils import is_list_of
...@@ -85,15 +87,9 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor): ...@@ -85,15 +87,9 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor):
hf_processor = self._get_hf_processor() hf_processor = self._get_hf_processor()
return hf_processor.audio_processor.feature_extractor # type: ignore return hf_processor.audio_processor.feature_extractor # type: ignore
def _get_hf_mm_data( def _get_data_parser(self) -> MultiModalDataParser:
self,
mm_items: MultiModalDataItems,
) -> tuple[dict[str, Any], dict[str, Any]]:
# resample audio to the model's sampling rate
feature_extractor = self._get_feature_extractor() feature_extractor = self._get_feature_extractor()
mm_items.resample_audios(feature_extractor.sampling_rate) return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
return super()._get_hf_mm_data(mm_items)
def _call_hf_processor( def _call_hf_processor(
self, self,
......
from .base import MultiModalPlaceholderMap, MultiModalPlugin from .base import MultiModalPlaceholderMap, MultiModalPlugin
from .inputs import (BatchedTensorInputs, MultiModalData, from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
MultiModalDataBuiltins, MultiModalDataDict, MultiModalDataDict, MultiModalKwargs,
MultiModalKwargs, MultiModalPlaceholderDict, MultiModalPlaceholderDict, NestedTensors)
NestedTensors)
from .registry import MultiModalRegistry from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry() MULTIMODAL_REGISTRY = MultiModalRegistry()
...@@ -16,7 +15,7 @@ See also: ...@@ -16,7 +15,7 @@ See also:
__all__ = [ __all__ = [
"BatchedTensorInputs", "BatchedTensorInputs",
"MultiModalData", "ModalityData",
"MultiModalDataBuiltins", "MultiModalDataBuiltins",
"MultiModalDataDict", "MultiModalDataDict",
"MultiModalKwargs", "MultiModalKwargs",
......
...@@ -9,7 +9,7 @@ from vllm.inputs.registry import InputContext ...@@ -9,7 +9,7 @@ from vllm.inputs.registry import InputContext
from vllm.utils import PlaceholderModule from vllm.utils import PlaceholderModule
from .base import MediaIO, MultiModalPlugin from .base import MediaIO, MultiModalPlugin
from .inputs import AudioItem, MultiModalData, MultiModalKwargs from .inputs import AudioItem, ModalityData, MultiModalKwargs
try: try:
import librosa import librosa
...@@ -31,7 +31,7 @@ class AudioPlugin(MultiModalPlugin): ...@@ -31,7 +31,7 @@ class AudioPlugin(MultiModalPlugin):
def _default_input_mapper( def _default_input_mapper(
self, self,
ctx: InputContext, ctx: InputContext,
data: MultiModalData[AudioItem], data: ModalityData[AudioItem],
**mm_processor_kwargs, **mm_processor_kwargs,
) -> MultiModalKwargs: ) -> MultiModalKwargs:
raise NotImplementedError("There is no default audio input mapper") raise NotImplementedError("There is no default audio input mapper")
......
...@@ -15,12 +15,12 @@ if TYPE_CHECKING: ...@@ -15,12 +15,12 @@ if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import SequenceGroupMetadata
from .inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs, from .inputs import (ModalityData, MultiModalDataDict, MultiModalKwargs,
PlaceholderRange) PlaceholderRange)
logger = init_logger(__name__) logger = init_logger(__name__)
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]], MultiModalInputMapper = Callable[[InputContext, ModalityData[object]],
MultiModalKwargs] MultiModalKwargs]
""" """
Return a dictionary to be passed as keyword arguments to Return a dictionary to be passed as keyword arguments to
...@@ -69,7 +69,7 @@ class MultiModalPlugin(ABC): ...@@ -69,7 +69,7 @@ class MultiModalPlugin(ABC):
def _default_input_mapper( def _default_input_mapper(
self, self,
ctx: InputContext, ctx: InputContext,
data: MultiModalData[Any], data: ModalityData[Any],
**mm_processor_kwargs, **mm_processor_kwargs,
) -> MultiModalKwargs: ) -> MultiModalKwargs:
""" """
...@@ -118,7 +118,7 @@ class MultiModalPlugin(ABC): ...@@ -118,7 +118,7 @@ class MultiModalPlugin(ABC):
def map_input( def map_input(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
data: MultiModalData[Any], data: ModalityData[Any],
mm_processor_kwargs: Optional[dict[str, Any]], mm_processor_kwargs: Optional[dict[str, Any]],
) -> MultiModalKwargs: ) -> MultiModalKwargs:
""" """
......
...@@ -13,7 +13,7 @@ from vllm.transformers_utils.processor import get_image_processor ...@@ -13,7 +13,7 @@ from vllm.transformers_utils.processor import get_image_processor
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .base import MediaIO, MultiModalPlugin from .base import MediaIO, MultiModalPlugin
from .inputs import ImageItem, MultiModalData, MultiModalKwargs from .inputs import ImageItem, ModalityData, MultiModalKwargs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -44,7 +44,7 @@ class ImagePlugin(MultiModalPlugin): ...@@ -44,7 +44,7 @@ class ImagePlugin(MultiModalPlugin):
def _default_input_mapper( def _default_input_mapper(
self, self,
ctx: InputContext, ctx: InputContext,
data: MultiModalData[ImageItem], data: ModalityData[ImageItem],
**mm_processor_kwargs, **mm_processor_kwargs,
) -> MultiModalKwargs: ) -> MultiModalKwargs:
model_config = ctx.model_config model_config = ctx.model_config
......
...@@ -2,53 +2,74 @@ from abc import ABC, abstractmethod ...@@ -2,53 +2,74 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import (Any, Literal, NamedTuple, TypedDict, TypeVar, Union, cast, from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final
final)
import numpy as np import numpy as np
import torch import torch
import torch.types import torch.types
from PIL.Image import Image from PIL.Image import Image
from transformers import BatchFeature from transformers import BatchFeature
from typing_extensions import NotRequired, TypeAlias, assert_never from typing_extensions import NotRequired, TypeAlias
from vllm.utils import JSONTree, is_list_of, json_map_leaves from vllm.utils import JSONTree, is_list_of, json_map_leaves
_T = TypeVar("_T") _T = TypeVar("_T")
# yapf: disable HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
ImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
""" """
A :class:`transformers.image_utils.ImageInput` representing a single image A :class:`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace :code:`ImageProcessor`. item, which can be passed to a HuggingFace :code:`ImageProcessor`.
""" """
VideoItem: TypeAlias = Union[ HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor,
list[Image], list[np.ndarray], list[torch.Tensor]]
np.ndarray,
torch.Tensor,
list[np.ndarray],
list[torch.Tensor],
]
""" """
A :class:`transformers.image_utils.VideoInput` representing a single video A :class:`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace :code:`VideoProcessor`. item, which can be passed to a HuggingFace :code:`VideoProcessor`.
""" """
AudioItem: TypeAlias = Union[ HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor]
np.ndarray,
list[float],
# `(audio, sampling_rate)`: If the audio's sampling rate is different
# from that expected by the model, we need to resample it.
tuple[np.ndarray, float],
]
""" """
Represents a single audio Represents a single audio
item, which can be passed to a HuggingFace :code:`AudioProcessor`. item, which can be passed to a HuggingFace :code:`AudioProcessor`.
""" """
# yapf: enable
MultiModalData: TypeAlias = Union[_T, list[_T]] ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor]
"""
A :class:`transformers.image_utils.ImageInput` representing a single image
item, which can be passed to a HuggingFace :code:`ImageProcessor`.
Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as image embeddings;
these are directly passed to the model without HF processing.
"""
VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor]
"""
A :class:`transformers.image_utils.VideoInput` representing a single video
item, which can be passed to a HuggingFace :code:`VideoProcessor`.
Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as video embeddings;
these are directly passed to the model without HF processing.
"""
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
torch.Tensor]
"""
Represents a single audio
item, which can be passed to a HuggingFace :code:`AudioProcessor`.
Alternatively, a tuple `(audio, sampling_rate)`, where the sampling rate
is different from that expected by the model;
these are resampled to the model's sampling rate before being processed by HF.
Alternatively, a 3-D tensor or batch of 2-D tensors,
which are treated as audio embeddings;
these are directly passed to the model without HF processing.
"""
ModalityData: TypeAlias = Union[_T, list[_T]]
""" """
Either a single data item, or a list of data items. Either a single data item, or a list of data items.
...@@ -61,17 +82,17 @@ The number of data items allowed per modality is restricted by ...@@ -61,17 +82,17 @@ The number of data items allowed per modality is restricted by
class MultiModalDataBuiltins(TypedDict, total=False): class MultiModalDataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM.""" """Type annotations for modality types predefined by vLLM."""
image: MultiModalData[ImageItem] image: ModalityData[ImageItem]
"""The input image(s).""" """The input image(s)."""
video: MultiModalData[VideoItem] video: ModalityData[VideoItem]
"""The input video(s).""" """The input video(s)."""
audio: MultiModalData[AudioItem] audio: ModalityData[AudioItem]
"""The input audio(s).""" """The input audio(s)."""
MultiModalDataDict: TypeAlias = Mapping[str, MultiModalData[Any]] MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
""" """
A dictionary containing an entry for each modality type to input. A dictionary containing an entry for each modality type to input.
...@@ -83,123 +104,6 @@ Note: ...@@ -83,123 +104,6 @@ Note:
""" """
class ImageSize(NamedTuple):
width: int
height: int
class MultiModalDataItems(UserDict[str, list[Any]]):
"""
As :class:`MultiModalDataDict`, but normalized such that each entry
corresponds to a list.
"""
@staticmethod
def from_dict(data: MultiModalDataDict) -> "MultiModalDataItems":
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
"""
multi_data = MultiModalDataItems()
for k, v in data.items():
# TODO: Make a separate modality for embedding inputs
# to avoid confusion
# yapf: disable
if k == "video":
# Special case since even a single item can be a list
multi_data[k] = ( # type: ignore[index]
v if (
isinstance(v, torch.Tensor)
or is_list_of(v, list)
or isinstance(v[0], (np.ndarray, torch.Tensor))
and v[0].ndim == 4
) else [v]
)
elif k in ("image", "audio"):
multi_data[k] = ( # type: ignore[index]
v if isinstance(v, (torch.Tensor, list)) else [v]
)
else:
multi_data[k] = v if isinstance(v, list) else [v] # type: ignore[index]
# yapf: enable
return multi_data
# NOTE: When a field (e.g. `images`) doesn't exist, directly appending to
# `self.images` doesn't update this dictionary, which may be confusing
# We annotate the getter methods as `Sequence` to prevent others from
# trying to update the list in this way
@property
def images(self) -> Sequence[ImageItem]:
return self.get("image", [])
@property
def videos(self) -> Sequence[VideoItem]:
return self.get("video", [])
@property
def audios(self) -> Sequence[AudioItem]:
return self.get("audio", [])
def get_item_counts(self) -> Mapping[str, int]:
return {m: len(items) for m, items in self.items()}
def has_embedding_inputs(self) -> bool:
return any(
any(isinstance(item, torch.Tensor) for item in items)
for items in self.values())
def get_image_size(self, item_idx: int) -> ImageSize:
image = self.images[item_idx]
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
def get_audio_with_sr(
self,
item_idx: int,
*,
default_sr: float,
) -> tuple[np.ndarray, float]:
audio = self.audios[item_idx]
if isinstance(audio, tuple):
return audio
if isinstance(audio, list):
return np.array(audio), default_sr
if isinstance(audio, np.ndarray):
return audio, default_sr
assert_never(audio)
def resample_audios(self, new_sr: float, *, drop_sr: bool = True) -> None:
"""
If :code:`drop_sr=True`, the audio items in this dictionary are updated
to be NumPy arrays which implicitly means that their sampling rate is
the same as the model's expected sampling rate; otherwise, they remain
as :code:`(audio, new_sr)` tuples.
"""
# Avoid circular import
from .audio import resample_audio
if not self.audios:
return
new_audios = []
for item_idx in range(len(self.audios)):
audio, sr = self.get_audio_with_sr(item_idx, default_sr=new_sr)
audio = resample_audio(audio, orig_sr=sr, target_sr=new_sr)
new_audios.append(audio if drop_sr else (audio, new_sr))
self["audio"] = new_audios
class PlaceholderRange(TypedDict): class PlaceholderRange(TypedDict):
""" """
Placeholder location information for multi-modal data. Placeholder location information for multi-modal data.
...@@ -436,7 +340,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ...@@ -436,7 +340,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
) -> "MultiModalKwargs": ) -> "MultiModalKwargs":
data = { data = {
key: items[0].field.reduce(items).data key: items[0].field.reduce(items).data
for key, items in items_by_key.items() for key, items in items_by_key.items() if len(items) > 0
} }
return MultiModalKwargs(data, return MultiModalKwargs(data,
...@@ -567,6 +471,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]): ...@@ -567,6 +471,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
Get the keyword arguments corresponding to an item identified by Get the keyword arguments corresponding to an item identified by
its modality and index. its modality and index.
""" """
if modality not in self._keys_by_modality:
available_modalities = set(self._keys_by_modality.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")
keys_to_gather = self._keys_by_modality[modality] keys_to_gather = self._keys_by_modality[modality]
return { return {
......
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, TypeVar
import numpy as np
import torch
from PIL.Image import Image
from typing_extensions import TypeAlias, TypeGuard, assert_never
from vllm.utils import is_list_of
from .audio import resample_audio
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
ImageItem, ModalityData, MultiModalDataDict,
NestedTensors, VideoItem)
_T = TypeVar("_T")
_I = TypeVar("_I")
class ModalityDataItems(ABC, Generic[_T, _I]):
def __init__(self, data: _T) -> None:
super().__init__()
self.data = data
def __len__(self) -> int:
return self.get_count()
def __getitem__(self, index: int) -> _I:
return self.get(index)
if TYPE_CHECKING:
# Auto-generated
def __iter__(self) -> Iterator[_I]:
...
@abstractmethod
def get_count(self) -> int:
"""Get the number of data items."""
raise NotImplementedError
@abstractmethod
def get(self, index: int) -> _I:
"""Get a data item by its index."""
raise NotImplementedError
def get_all(self) -> list[_I]:
"""Get all data items."""
return [self.get(idx) for idx in range(self.get_count())]
@abstractmethod
def get_processor_data(self) -> Mapping[str, object]:
"""Get the data to pass to the HF processor."""
raise NotImplementedError
@abstractmethod
def get_passthrough_data(self) -> Mapping[str, object]:
"""Get the data to pass directly to the model."""
raise NotImplementedError
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
def __init__(self, data: Sequence[_T], modality: str) -> None:
super().__init__(data)
self.modality = modality
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r})")
def get_count(self) -> int:
return len(self.data)
def get(self, index: int) -> _T:
return self.data[index]
def get_processor_data(self) -> Mapping[str, object]:
return {f"{self.modality}s": self.data}
def get_passthrough_data(self) -> Mapping[str, object]:
return {}
class EmbeddingItems(ModalityDataItems[NestedTensors, torch.Tensor]):
def __init__(self, data: NestedTensors, modality: str) -> None:
super().__init__(data)
self.modality = modality
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r})")
def get_count(self) -> int:
return len(self.data)
def get(self, index: int) -> object:
return self.data[index]
def get_processor_data(self) -> Mapping[str, object]:
return {}
def get_passthrough_data(self) -> Mapping[str, object]:
return {f"{self.modality}_embeds": self.data}
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]):
def __init__(self, data: Sequence[HfAudioItem]) -> None:
super().__init__(data, "audio")
class AudioEmbeddingItems(EmbeddingItems):
def __init__(self, data: NestedTensors) -> None:
super().__init__(data, "audio")
class ImageSize(NamedTuple):
width: int
height: int
class ImageProcessorItems(ProcessorBatchItems[HfImageItem]):
def __init__(self, data: Sequence[HfImageItem]) -> None:
super().__init__(data, "image")
def get_image_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
class ImageEmbeddingItems(EmbeddingItems):
def __init__(self, data: NestedTensors) -> None:
super().__init__(data, "image")
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]):
def __init__(self, data: Sequence[HfVideoItem]) -> None:
super().__init__(data, "video")
class VideoEmbeddingItems(EmbeddingItems):
def __init__(self, data: NestedTensors) -> None:
super().__init__(data, "video")
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
"""
As :class:`MultiModalDataDict`, but normalized such that each entry
corresponds to a list.
"""
def get_count(self, modality: str, *, strict: bool = True) -> int:
"""
Get the number of data items belonging to a modality.
If `strict=False`, return `0` instead of raising :exc:`KeyError`
even if the modality is not found.
"""
if modality not in self:
if strict:
available_modalities = set(self.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")
return 0
return self[modality].get_count()
def get_all_counts(self) -> Mapping[str, int]:
"""Get the number of items belonging to each modality."""
return {m: items.get_count() for m, items in self.items()}
def get_items(
self,
modality: str,
typ: type[_D],
) -> _D:
"""
Get the data items belonging to a modality,
requiring that they belong to a certain type.
"""
if modality not in self:
available_modalities = set(self.keys())
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")
items = self[modality]
if not isinstance(items, typ):
raise TypeError(f"Invalid type of data items for {modality=}. "
f"Expected type: {typ}, but "
f"found type: {type(items)}")
return items
ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],
ModalityDataItems[Any, Any]]
class MultiModalDataParser:
"""
Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`.
"""
def __init__(self, *, target_sr: Optional[float] = None) -> None:
super().__init__()
self.target_sr = target_sr
def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]:
if isinstance(data, torch.Tensor):
return data.ndim == 3
if is_list_of(data, torch.Tensor):
return len(data) == 0 or data[0].ndim == 2
return False
def _get_audio_with_sr(
self,
audio: AudioItem,
) -> tuple[np.ndarray, Optional[float]]:
if isinstance(audio, tuple):
return audio
if isinstance(audio, list):
return np.array(audio), None
if isinstance(audio, np.ndarray):
return audio, None
if isinstance(audio, torch.Tensor):
return audio.numpy(), None
assert_never(audio)
def _parse_audio_data(
self,
data: ModalityData[AudioItem],
) -> ModalityDataItems[Any, Any]:
if self._is_embeddings(data):
return AudioEmbeddingItems(data)
if (is_list_of(data, float)
or isinstance(data,
(np.ndarray, torch.Tensor)) and data.ndim == 1
or isinstance(data, tuple)):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data
new_audios = list[np.ndarray]()
for data_item in data_items:
audio, orig_sr = self._get_audio_with_sr(data_item)
if orig_sr is None:
new_audio = audio
else:
target_sr = self.target_sr
if target_sr is None:
raise RuntimeError(
"Audio resampling is not supported when "
"`target_sr` is not provided")
new_audio = resample_audio(audio,
orig_sr=orig_sr,
target_sr=target_sr)
new_audios.append(new_audio)
return AudioProcessorItems(new_audios)
def _parse_image_data(
self,
data: ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any]:
if self._is_embeddings(data):
return ImageEmbeddingItems(data)
if (isinstance(data, Image)
or isinstance(data,
(np.ndarray, torch.Tensor)) and data.ndim == 3):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data
return ImageProcessorItems(data_items)
def _parse_video_data(
self,
data: ModalityData[VideoItem],
) -> ModalityDataItems[Any, Any]:
if self._is_embeddings(data):
return VideoEmbeddingItems(data)
if (is_list_of(data, Image)
or isinstance(data,
(np.ndarray, torch.Tensor)) and data.ndim == 4):
data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data
return VideoProcessorItems(data_items)
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
return {
"audio": self._parse_audio_data,
"image": self._parse_image_data,
"video": self._parse_video_data,
}
def parse_mm_data(self,
mm_data: MultiModalDataDict) -> MultiModalDataItems:
subparsers = self._get_subparsers()
mm_items = MultiModalDataItems()
for k, v in mm_data.items():
if k not in subparsers:
raise ValueError(f"Unsupported modality: {k}")
mm_items[k] = subparsers[k](v)
return mm_items
...@@ -15,11 +15,12 @@ from transformers import BatchFeature, ProcessorMixin ...@@ -15,11 +15,12 @@ from transformers import BatchFeature, ProcessorMixin
from vllm.inputs import DummyData, InputProcessingContext from vllm.inputs import DummyData, InputProcessingContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby, is_list_of from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from .inputs import (MultiModalDataDict, MultiModalDataItems, from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalFieldConfig, MultiModalFieldItem, MultiModalFieldItem, MultiModalInputsV2, MultiModalKwargs,
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange) PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -621,6 +622,16 @@ class BaseMultiModalProcessor(ABC): ...@@ -621,6 +622,16 @@ class BaseMultiModalProcessor(ABC):
) -> MultiModalInputsV2: ) -> MultiModalInputsV2:
return self.apply(prompt, mm_data, hf_processor_mm_kwargs) return self.apply(prompt, mm_data, hf_processor_mm_kwargs)
def _get_data_parser(self) -> MultiModalDataParser:
"""
Construct a data parser to preprocess multi-modal data items
before passing them to :meth:`_get_hf_mm_data`.
You can support additional modalities by creating a subclass
of :class:`MultiModalDataParser` that has additional subparsers.
"""
return MultiModalDataParser()
def _get_hf_processor(self) -> ProcessorMixin: def _get_hf_processor(self) -> ProcessorMixin:
""" """
Subclasses can add keyword arguments to this method to accept Subclasses can add keyword arguments to this method to accept
...@@ -631,11 +642,16 @@ class BaseMultiModalProcessor(ABC): ...@@ -631,11 +642,16 @@ class BaseMultiModalProcessor(ABC):
def _get_tokenizer(self) -> AnyTokenizer: def _get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer return self.ctx.tokenizer
def _get_mm_items( def _to_mm_items(
self, self,
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
) -> MultiModalDataItems: ) -> MultiModalDataItems:
return MultiModalDataItems.from_dict(mm_data) """
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`
before passing them to :meth:`_get_hf_mm_data`.
"""
parser = self._get_data_parser()
return parser.parse_mm_data(mm_data)
@abstractmethod @abstractmethod
def _get_mm_fields_config( def _get_mm_fields_config(
...@@ -680,22 +696,9 @@ class BaseMultiModalProcessor(ABC): ...@@ -680,22 +696,9 @@ class BaseMultiModalProcessor(ABC):
processor_data = dict[str, Any]() processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]() passthrough_data = dict[str, Any]()
for k, v in mm_items.items(): for items in mm_items.values():
# TODO: Make a separate modality for embedding inputs processor_data.update(items.get_processor_data())
# to avoid confusion passthrough_data.update(items.get_passthrough_data())
if k in ("image", "video", "audio"):
if isinstance(v, torch.Tensor) and v.ndim == 3:
# Pass through embedding inputs (single)
passthrough_data[f"{k}_embeds"] = [v]
elif (is_list_of(v, torch.Tensor) and len(v) > 0
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
elif len(v) > 0:
# Map keys to plural form, e.g.: image -> images
processor_data[f"{k}s"] = v
else:
processor_data[k] = v
return processor_data, passthrough_data return processor_data, passthrough_data
...@@ -756,7 +759,7 @@ class BaseMultiModalProcessor(ABC): ...@@ -756,7 +759,7 @@ class BaseMultiModalProcessor(ABC):
cached items; instead, we rely on our own prompt replacement logic cached items; instead, we rely on our own prompt replacement logic
for the full text. for the full text.
""" """
mm_missing_counts = mm_missing_data_items.get_item_counts() mm_missing_counts = mm_missing_data_items.get_all_counts()
prompt_ids, _ = self._apply_hf_processor( prompt_ids, _ = self._apply_hf_processor(
prompt_text=prompt_text, prompt_text=prompt_text,
...@@ -789,7 +792,8 @@ class BaseMultiModalProcessor(ABC): ...@@ -789,7 +792,8 @@ class BaseMultiModalProcessor(ABC):
cache = self.cache cache = self.cache
model_id = self.ctx.model_config.model model_id = self.ctx.model_config.model
if cache is None or mm_data_items.has_embedding_inputs(): _, passthrough_data = self._get_hf_mm_data(mm_data_items)
if cache is None or passthrough_data:
return self._apply_hf_processor( return self._apply_hf_processor(
prompt_text=prompt_text, prompt_text=prompt_text,
mm_items=mm_data_items, mm_items=mm_data_items,
...@@ -812,7 +816,7 @@ class BaseMultiModalProcessor(ABC): ...@@ -812,7 +816,7 @@ class BaseMultiModalProcessor(ABC):
modality: [mm_data_items[modality][idx] for idx in idxs] modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items() for modality, idxs in mm_missing_idxs.items()
} }
mm_missing_data_items = self._get_mm_items(mm_missing_data) mm_missing_data_items = self._to_mm_items(mm_missing_data)
prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing( prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing(
prompt_text=prompt_text, prompt_text=prompt_text,
...@@ -852,7 +856,7 @@ class BaseMultiModalProcessor(ABC): ...@@ -852,7 +856,7 @@ class BaseMultiModalProcessor(ABC):
mm_merged_field_items[modality] = merged_modal_items_lst mm_merged_field_items[modality] = merged_modal_items_lst
if self.enable_sanity_checks: if self.enable_sanity_checks:
mm_missing_counts = mm_missing_data_items.get_item_counts() mm_missing_counts = mm_missing_data_items.get_all_counts()
assert all( assert all(
item_count == mm_missing_counts[modality] item_count == mm_missing_counts[modality]
for modality, item_count in mm_missing_next_idx.items()), dict( for modality, item_count in mm_missing_next_idx.items()), dict(
...@@ -865,7 +869,7 @@ class BaseMultiModalProcessor(ABC): ...@@ -865,7 +869,7 @@ class BaseMultiModalProcessor(ABC):
) )
if self.enable_sanity_checks: if self.enable_sanity_checks:
mm_item_counts = mm_data_items.get_item_counts() mm_item_counts = mm_data_items.get_all_counts()
for modality, item_count in mm_item_counts.items(): for modality, item_count in mm_item_counts.items():
for item_idx in range(item_count): for item_idx in range(item_count):
...@@ -958,7 +962,7 @@ class BaseMultiModalProcessor(ABC): ...@@ -958,7 +962,7 @@ class BaseMultiModalProcessor(ABC):
3. Extract information about the placeholder tokens from the 3. Extract information about the placeholder tokens from the
processed token IDs. processed token IDs.
""" """
mm_items = self._get_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
prompt_ids, mm_kwargs = self._cached_apply_hf_processor( prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
prompt_text, prompt_text,
...@@ -975,7 +979,7 @@ class BaseMultiModalProcessor(ABC): ...@@ -975,7 +979,7 @@ class BaseMultiModalProcessor(ABC):
# If HF processor already inserts placeholder tokens, # If HF processor already inserts placeholder tokens,
# there is no need for us to insert them # there is no need for us to insert them
mm_item_counts = mm_items.get_item_counts() mm_item_counts = mm_items.get_all_counts()
all_placeholders = self._find_placeholders(prompt_repls, prompt_ids, all_placeholders = self._find_placeholders(prompt_repls, prompt_ids,
mm_item_counts) mm_item_counts)
......
...@@ -15,7 +15,7 @@ from vllm.transformers_utils.processor import get_video_processor ...@@ -15,7 +15,7 @@ from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import PlaceholderModule, is_list_of from vllm.utils import PlaceholderModule, is_list_of
from .base import MediaIO, MultiModalData from .base import MediaIO, ModalityData
from .image import ImageMediaIO, ImagePlugin from .image import ImageMediaIO, ImagePlugin
from .inputs import MultiModalKwargs, VideoItem from .inputs import MultiModalKwargs, VideoItem
...@@ -54,7 +54,7 @@ class VideoPlugin(ImagePlugin): ...@@ -54,7 +54,7 @@ class VideoPlugin(ImagePlugin):
def _default_input_mapper( def _default_input_mapper(
self, self,
ctx: InputContext, ctx: InputContext,
data: MultiModalData[VideoItem], data: ModalityData[VideoItem],
**mm_processor_kwargs, **mm_processor_kwargs,
) -> MultiModalKwargs: ) -> MultiModalKwargs:
model_config = ctx.model_config model_config = ctx.model_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment