Unverified Commit a0d8d944 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Move MM Hash parsing into Renderer (#34711)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent df3f537a
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.multimodal.parse import parse_mm_uuids
from vllm.renderers.hf import HfRenderer from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config from vllm.tokenizers.registry import tokenizer_args_from_config
...@@ -45,10 +46,11 @@ def test_multi_modal_uuids_length_mismatch_raises(): ...@@ -45,10 +46,11 @@ def test_multi_modal_uuids_length_mismatch_raises():
mm_uuids = {"image": ["hash_cherry"]} mm_uuids = {"image": ["hash_cherry"]}
mm_processor = renderer.get_mm_processor() mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data) mm_data_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuid_items = parse_mm_uuids(mm_uuids)
with pytest.raises(ValueError, match="must have same length as"): with pytest.raises(ValueError, match="must have same length as"):
renderer._process_mm_uuids(mm_data, mm_items, mm_uuids, "req-1") renderer._process_mm_uuids(mm_data, mm_data_items, mm_uuid_items, "req-1")
def test_multi_modal_uuids_missing_modality_raises(): def test_multi_modal_uuids_missing_modality_raises():
...@@ -63,10 +65,11 @@ def test_multi_modal_uuids_missing_modality_raises(): ...@@ -63,10 +65,11 @@ def test_multi_modal_uuids_missing_modality_raises():
mm_uuids = {"image": ["hash_cherry"]} mm_uuids = {"image": ["hash_cherry"]}
mm_processor = renderer.get_mm_processor() mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data) mm_data_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuid_items = parse_mm_uuids(mm_uuids)
with pytest.raises(ValueError, match="is empty but .* is missing"): with pytest.raises(ValueError, match="is empty but .* is missing"):
renderer._process_mm_uuids(mm_data, mm_items, mm_uuids, "req-2") renderer._process_mm_uuids(mm_data, mm_data_items, mm_uuid_items, "req-2")
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -78,7 +81,7 @@ def test_multi_modal_uuids_missing_modality_raises(): ...@@ -78,7 +81,7 @@ def test_multi_modal_uuids_missing_modality_raises():
], ],
) )
def test_multi_modal_uuids_accepts_none_and_passes_through( def test_multi_modal_uuids_accepts_none_and_passes_through(
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool mm_cache_gb: float, enable_prefix_caching: bool
): ):
renderer = _build_renderer( renderer = _build_renderer(
mm_cache_gb=mm_cache_gb, mm_cache_gb=mm_cache_gb,
...@@ -94,9 +97,11 @@ def test_multi_modal_uuids_accepts_none_and_passes_through( ...@@ -94,9 +97,11 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
mm_uuids = {"image": [None, "hash_stop"], "video": None} mm_uuids = {"image": [None, "hash_stop"], "video": None}
mm_processor = renderer.get_mm_processor() mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data) mm_data_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuid_items = parse_mm_uuids(mm_uuids)
processed_mm_uuids = renderer._process_mm_uuids( processed_mm_uuids = renderer._process_mm_uuids(
mm_data, mm_items, mm_uuids, "req-3" mm_data, mm_data_items, mm_uuid_items, "req-3"
) )
assert processed_mm_uuids == mm_uuids assert processed_mm_uuids == mm_uuids
...@@ -111,7 +116,7 @@ def test_multi_modal_uuids_accepts_none_and_passes_through( ...@@ -111,7 +116,7 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
], ],
) )
def test_multi_modal_uuids_accepts_empty( def test_multi_modal_uuids_accepts_empty(
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool mm_cache_gb: float, enable_prefix_caching: bool
): ):
renderer = _build_renderer( renderer = _build_renderer(
mm_cache_gb=mm_cache_gb, mm_cache_gb=mm_cache_gb,
...@@ -124,15 +129,17 @@ def test_multi_modal_uuids_accepts_empty( ...@@ -124,15 +129,17 @@ def test_multi_modal_uuids_accepts_empty(
mm_uuids = {"image": [], "video": None} # type: ignore[var-annotated] mm_uuids = {"image": [], "video": None} # type: ignore[var-annotated]
mm_processor = renderer.get_mm_processor() mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data) mm_data_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuid_items = parse_mm_uuids(mm_uuids)
processed_mm_uuids = renderer._process_mm_uuids( processed_mm_uuids = renderer._process_mm_uuids(
mm_data, mm_items, mm_uuids, "req-4" mm_data, mm_data_items, mm_uuid_items, "req-4"
) )
assert processed_mm_uuids == mm_uuids assert processed_mm_uuids == mm_uuids
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): def test_multi_modal_uuids_ignored_when_caching_disabled():
# When both processor cache is 0 and prefix caching disabled, the # When both processor cache is 0 and prefix caching disabled, the
# processor builds overrides from request id instead of using user UUIDs. # processor builds overrides from request id instead of using user UUIDs.
renderer = _build_renderer(mm_cache_gb=0.0, enable_prefix_caching=False) renderer = _build_renderer(mm_cache_gb=0.0, enable_prefix_caching=False)
...@@ -145,9 +152,11 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): ...@@ -145,9 +152,11 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": ["hash_video"]} mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": ["hash_video"]}
mm_processor = renderer.get_mm_processor() mm_processor = renderer.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data) mm_data_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuid_items = parse_mm_uuids(mm_uuids)
processed_mm_uuids = renderer._process_mm_uuids( processed_mm_uuids = renderer._process_mm_uuids(
mm_data, mm_items, mm_uuids, request_id mm_data, mm_data_items, mm_uuid_items, request_id
) )
# Expect request-id-based overrides are passed through # Expect request-id-based overrides are passed through
......
...@@ -91,7 +91,7 @@ class InputPreprocessor: ...@@ -91,7 +91,7 @@ class InputPreprocessor:
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object] | None, mm_processor_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
...@@ -103,9 +103,9 @@ class InputPreprocessor: ...@@ -103,9 +103,9 @@ class InputPreprocessor:
return self.renderer._process_multimodal( return self.renderer._process_multimodal(
prompt, prompt,
mm_data, mm_data,
mm_uuids=mm_uuids,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
def _process_embeds( def _process_embeds(
...@@ -144,7 +144,7 @@ class InputPreprocessor: ...@@ -144,7 +144,7 @@ class InputPreprocessor:
inputs = self._process_multimodal( inputs = self._process_multimodal(
prompt_token_ids, prompt_token_ids,
multi_modal_data, multi_modal_data,
parsed_content.get("mm_processor_kwargs") or {}, parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=parsed_content.get("multi_modal_uuids"), mm_uuids=parsed_content.get("multi_modal_uuids"),
) )
......
...@@ -36,9 +36,13 @@ from vllm.multimodal.inputs import ( ...@@ -36,9 +36,13 @@ from vllm.multimodal.inputs import (
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalInputs, MultiModalInputs,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalUUIDDict,
) )
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
...@@ -203,10 +207,9 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]): ...@@ -203,10 +207,9 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None, tokenization_kwargs: Mapping[str, object] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if mm_items: if mm_items:
if isinstance(prompt, str): if isinstance(prompt, str):
...@@ -235,9 +238,9 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]): ...@@ -235,9 +238,9 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
return super().apply( return super().apply(
prompt=prompt, prompt=prompt,
mm_items=mm_items, mm_items=mm_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
def _hf_processor_applies_updates( def _hf_processor_applies_updates(
......
...@@ -24,13 +24,13 @@ from vllm.multimodal.inputs import ( ...@@ -24,13 +24,13 @@ from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalUUIDDict,
) )
from vllm.multimodal.parse import ( from vllm.multimodal.parse import (
ImageEmbeddingItems, ImageEmbeddingItems,
ImageProcessorItems, ImageProcessorItems,
ImageSize, ImageSize,
MultiModalDataItems, MultiModalDataItems,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing import BaseDummyInputsBuilder from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import ( from vllm.multimodal.processing.processor import (
...@@ -313,9 +313,9 @@ class DeepseekVL2MultiModalProcessor( ...@@ -313,9 +313,9 @@ class DeepseekVL2MultiModalProcessor(
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_uuids: MultiModalUUIDDict | None = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 2 vs > 2 # The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is # Since the processing cache assumes that the processor output is
...@@ -325,17 +325,17 @@ class DeepseekVL2MultiModalProcessor( ...@@ -325,17 +325,17 @@ class DeepseekVL2MultiModalProcessor(
return self._apply_hf_processor( return self._apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
return super()._cached_apply_hf_processor( return super()._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
......
...@@ -16,11 +16,12 @@ from transformers import PretrainedConfig ...@@ -16,11 +16,12 @@ from transformers import PretrainedConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargsItems, MultiModalUUIDDict from vllm.multimodal.inputs import MultiModalKwargsItems
from vllm.multimodal.parse import ( from vllm.multimodal.parse import (
ImageEmbeddingItems, ImageEmbeddingItems,
ImageProcessorItems, ImageProcessorItems,
MultiModalDataItems, MultiModalDataItems,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing.processor import ( from vllm.multimodal.processing.processor import (
MultiModalProcessingInfo, MultiModalProcessingInfo,
...@@ -491,9 +492,9 @@ class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingIn ...@@ -491,9 +492,9 @@ class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingIn
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_uuids: MultiModalUUIDDict | None = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 1 vs > 1 # The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is # Since the processing cache assumes that the processor output is
...@@ -503,17 +504,17 @@ class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingIn ...@@ -503,17 +504,17 @@ class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingIn
return self._apply_hf_processor( return self._apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
return super()._cached_apply_hf_processor( return super()._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
......
...@@ -30,7 +30,6 @@ from vllm.multimodal.inputs import ( ...@@ -30,7 +30,6 @@ from vllm.multimodal.inputs import (
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalInputs, MultiModalInputs,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalUUIDDict,
mm_inputs, mm_inputs,
) )
from vllm.multimodal.parse import ( from vllm.multimodal.parse import (
...@@ -38,6 +37,7 @@ from vllm.multimodal.parse import ( ...@@ -38,6 +37,7 @@ from vllm.multimodal.parse import (
ImageProcessorItems, ImageProcessorItems,
ImageSize, ImageSize,
MultiModalDataItems, MultiModalDataItems,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
...@@ -773,9 +773,9 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -773,9 +773,9 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None, tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
...@@ -789,9 +789,9 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -789,9 +789,9 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
result = super().apply( result = super().apply(
prompt, prompt,
mm_items, mm_items,
hf_processor_mm_kwargs, mm_uuid_items,
tokenization_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
mm_uuids=mm_uuids, tokenization_kwargs=tokenization_kwargs,
) )
mm_item_counts = mm_items.get_all_counts() mm_item_counts = mm_items.get_all_counts()
......
...@@ -16,12 +16,12 @@ from vllm.multimodal.inputs import ( ...@@ -16,12 +16,12 @@ from vllm.multimodal.inputs import (
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalInputs, MultiModalInputs,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalUUIDDict,
) )
from vllm.multimodal.parse import ( from vllm.multimodal.parse import (
ImageEmbeddingItems, ImageEmbeddingItems,
ImageProcessorItems, ImageProcessorItems,
MultiModalDataItems, MultiModalDataItems,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
...@@ -231,16 +231,16 @@ class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingIn ...@@ -231,16 +231,16 @@ class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingIn
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None, tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
mm_inputs = super().apply( mm_inputs = super().apply(
prompt, prompt,
mm_items, mm_items,
hf_processor_mm_kwargs, mm_uuid_items,
tokenization_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
mm_uuids=mm_uuids, tokenization_kwargs=tokenization_kwargs,
) )
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
......
...@@ -44,10 +44,14 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems ...@@ -44,10 +44,14 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalUUIDDict,
NestedTensors, NestedTensors,
) )
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.processing.processor import ( from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor, BaseMultiModalProcessor,
...@@ -344,16 +348,16 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]) ...@@ -344,16 +348,16 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_uuids: MultiModalUUIDDict | None = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
# NOTE: The tokens are already inserted by the chat template # NOTE: The tokens are already inserted by the chat template
......
...@@ -42,9 +42,13 @@ from vllm.multimodal.inputs import ( ...@@ -42,9 +42,13 @@ from vllm.multimodal.inputs import (
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalInputs, MultiModalInputs,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalUUIDDict,
) )
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
...@@ -189,10 +193,9 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]): ...@@ -189,10 +193,9 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None, tokenization_kwargs: Mapping[str, object] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if mm_items: if mm_items:
if isinstance(prompt, str): if isinstance(prompt, str):
...@@ -221,9 +224,9 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]): ...@@ -221,9 +224,9 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
return super().apply( return super().apply(
prompt=prompt, prompt=prompt,
mm_items=mm_items, mm_items=mm_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
def _hf_processor_applies_updates( def _hf_processor_applies_updates(
......
...@@ -46,7 +46,6 @@ from vllm.multimodal.inputs import ( ...@@ -46,7 +46,6 @@ from vllm.multimodal.inputs import (
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalInputs, MultiModalInputs,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalUUIDDict,
PlaceholderRange, PlaceholderRange,
mm_inputs, mm_inputs,
) )
...@@ -55,6 +54,7 @@ from vllm.multimodal.parse import ( ...@@ -55,6 +54,7 @@ from vllm.multimodal.parse import (
ModalityDataItems, ModalityDataItems,
MultiModalDataItems, MultiModalDataItems,
MultiModalDataParser, MultiModalDataParser,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
...@@ -196,15 +196,19 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing ...@@ -196,15 +196,19 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None, tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if hf_processor_mm_kwargs is None:
hf_processor_mm_kwargs = {}
if tokenization_kwargs is None: if tokenization_kwargs is None:
tokenization_kwargs = {} tokenization_kwargs = {}
mm_hashes = self._hash_mm_items( mm_hashes = self._hash_mm_items(
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
) )
_, passthrough_data = self._get_hf_mm_data(mm_items) _, passthrough_data = self._get_hf_mm_data(mm_items)
......
...@@ -31,11 +31,14 @@ from vllm.multimodal.inputs import ( ...@@ -31,11 +31,14 @@ from vllm.multimodal.inputs import (
MultiModalFeatureSpec, MultiModalFeatureSpec,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalInputs, MultiModalInputs,
MultiModalUUIDDict,
PlaceholderRange, PlaceholderRange,
mm_inputs, mm_inputs,
) )
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems from vllm.multimodal.parse import (
ImageProcessorItems,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
...@@ -177,9 +180,9 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -177,9 +180,9 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None, tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
...@@ -187,6 +190,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -187,6 +190,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
Apply HF Processor on prompt text and multi-modal data together, Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors. outputting token IDs and processed tensors.
""" """
if hf_processor_mm_kwargs is None:
hf_processor_mm_kwargs = {}
if tokenization_kwargs is None: if tokenization_kwargs is None:
tokenization_kwargs = {} tokenization_kwargs = {}
...@@ -258,7 +263,9 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -258,7 +263,9 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
# Use overrides if provided; fallback to data-dependent hashing. # Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = self._hash_mm_items( mm_hashes = self._hash_mm_items(
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
) )
return mm_inputs( return mm_inputs(
......
...@@ -41,13 +41,13 @@ from vllm.multimodal.inputs import ( ...@@ -41,13 +41,13 @@ from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalUUIDDict,
NestedTensors, NestedTensors,
) )
from vllm.multimodal.parse import ( from vllm.multimodal.parse import (
AudioProcessorItems, AudioProcessorItems,
MultiModalDataItems, MultiModalDataItems,
MultiModalDataParser, MultiModalDataParser,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.processing.processor import ( from vllm.multimodal.processing.processor import (
...@@ -363,16 +363,16 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]) ...@@ -363,16 +363,16 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_uuids: MultiModalUUIDDict | None = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
# NOTE: The tokens are already inserted by the chat template # NOTE: The tokens are already inserted by the chat template
......
...@@ -155,7 +155,7 @@ The built-in modalities are defined by ...@@ -155,7 +155,7 @@ The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins]. [`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
""" """
MultiModalUUIDDict: TypeAlias = Mapping[str, list[str | None] | str] MultiModalUUIDDict: TypeAlias = Mapping[str, Sequence[str | None] | str]
""" """
A dictionary containing user-provided UUIDs for items in each modality. A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and If a UUID for an item is not provided, its entry will be `None` and
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict from collections import UserDict
from collections.abc import Callable, Iterator, Mapping, Sequence from collections.abc import Callable, Iterator, Mapping, Sequence, Set
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
...@@ -33,6 +33,7 @@ from .inputs import ( ...@@ -33,6 +33,7 @@ from .inputs import (
MultiModalDataDict, MultiModalDataDict,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalUUIDDict,
VideoItem, VideoItem,
) )
from .media import MediaWithBytes from .media import MediaWithBytes
...@@ -297,14 +298,15 @@ class DictEmbeddingItems( ...@@ -297,14 +298,15 @@ class DictEmbeddingItems(
return self.data return self.data
class AudioProcessorItems(ProcessorBatchItems[HfAudioItem]): class AudioProcessorItems(ProcessorBatchItems[HfAudioItem | None]):
def __init__(self, data: Sequence[HfAudioItem] | None) -> None: def __init__(self, data: Sequence[HfAudioItem | None]) -> None:
if data is None:
data = [None]
super().__init__(data, "audio") super().__init__(data, "audio")
def get_audio_length(self, item_idx: int) -> int: def get_audio_length(self, item_idx: int) -> int:
audio = self.get(item_idx) audio = self.get(item_idx)
if audio is None:
raise ValueError(f"Cannot get length of cached audio at {item_idx}")
return len(audio) return len(audio)
...@@ -322,14 +324,14 @@ class ImageSize(NamedTuple): ...@@ -322,14 +324,14 @@ class ImageSize(NamedTuple):
height: int height: int
class ImageProcessorItems(ProcessorBatchItems[HfImageItem]): class ImageProcessorItems(ProcessorBatchItems[HfImageItem | None]):
def __init__(self, data: Sequence[HfImageItem] | None) -> None: def __init__(self, data: Sequence[HfImageItem | None]) -> None:
if data is None:
data = [None]
super().__init__(data, "image") super().__init__(data, "image")
def get_image_size(self, item_idx: int) -> ImageSize: def get_image_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx) image = self.get(item_idx)
if image is None:
raise ValueError(f"Cannot get size of cached image at {item_idx}")
if isinstance(image, PILImage.Image): if isinstance(image, PILImage.Image):
return ImageSize(*image.size) return ImageSize(*image.size)
...@@ -349,22 +351,31 @@ class ImageEmbeddingItems(EmbeddingItems): ...@@ -349,22 +351,31 @@ class ImageEmbeddingItems(EmbeddingItems):
super().__init__(data, "image", expected_hidden_size) super().__init__(data, "image", expected_hidden_size)
class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): class VideoProcessorItems(ProcessorBatchItems[HfVideoItem | None]):
def __init__( def __init__(
self, self,
data: Sequence[HfVideoItem] | None, data: Sequence[HfVideoItem | None],
metadata: dict[str, Any] | list[dict[str, Any] | None] | None = None, metadata: dict[str, Any] | list[dict[str, Any] | None] | None = None,
) -> None: ) -> None:
if data is None:
data = [None]
super().__init__(data, "video") super().__init__(data, "video")
self.metadata = metadata self.metadata = metadata
def get_num_frames(self, item_idx: int) -> int: def get_num_frames(self, item_idx: int) -> int:
return len(self.get(item_idx)) video = self.get(item_idx)
if video is None:
raise ValueError(f"Cannot get length of cached video at {item_idx}")
return len(video)
def get_frame_size(self, item_idx: int) -> ImageSize: def get_frame_size(self, item_idx: int) -> ImageSize:
image = self.get(item_idx)[0] # Assume that the video isn't empty video = self.get(item_idx)
if video is None:
raise ValueError(f"Cannot get size of cached video at {item_idx}")
if len(video) == 0:
raise ValueError(f"Cannot get size of empty video at {item_idx}")
image = video[0]
if isinstance(image, PILImage.Image): if isinstance(image, PILImage.Image):
return ImageSize(*image.size) return ImageSize(*image.size)
...@@ -400,6 +411,15 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]): ...@@ -400,6 +411,15 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
normalized such that each entry corresponds to a list. normalized such that each entry corresponds to a list.
""" """
def select(self, modalities: Set[str]):
"""
Construct a new `MultiModalDataItems` instance containing only the
selected modalities.
"""
return MultiModalDataItems(
{modality: self[modality] for modality in modalities}
)
def get_count(self, modality: str, *, strict: bool = True) -> int: def get_count(self, modality: str, *, strict: bool = True) -> int:
""" """
Get the number of data items belonging to a modality. Get the number of data items belonging to a modality.
...@@ -497,19 +517,11 @@ class MultiModalDataParser: ...@@ -497,19 +517,11 @@ class MultiModalDataParser:
) -> TypeGuard[torch.Tensor | list[torch.Tensor]]: ) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return data.ndim == 3 return data.ndim == 3
if is_list_of(data, torch.Tensor): if is_list_of(data, torch.Tensor) and len(data) > 0:
return data[0].ndim == 2 # type: ignore[index] return data[0].ndim == 2 # type: ignore[index]
return False return False
def _is_empty(self, data: object) -> TypeGuard[None]:
if isinstance(data, list):
return len(data) == 0
if isinstance(data, (np.ndarray, torch.Tensor)):
return data.size == 0
return False
def _get_audio_with_sr( def _get_audio_with_sr(
self, self,
audio: AudioItem, audio: AudioItem,
...@@ -545,12 +557,6 @@ class MultiModalDataParser: ...@@ -545,12 +557,6 @@ class MultiModalDataParser:
data: ModalityData[AudioItem], data: ModalityData[AudioItem],
) -> ModalityDataItems[Any, Any] | None: ) -> ModalityDataItems[Any, Any] | None:
if data is None: if data is None:
return AudioProcessorItems(None)
# also check single audio item with sampling rate
if self._is_empty(data) or (
isinstance(data, tuple) and self._is_empty(data[0])
):
return None return None
if self.is_embeddings(data): if self.is_embeddings(data):
...@@ -558,9 +564,8 @@ class MultiModalDataParser: ...@@ -558,9 +564,8 @@ class MultiModalDataParser:
data_items: list[AudioItem] data_items: list[AudioItem]
if ( if (
is_list_of(data, float) (is_list_of(data, float) and len(data) > 0)
or isinstance(data, (np.ndarray, torch.Tensor)) or (isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 1)
and data.ndim == 1
or isinstance(data, tuple) or isinstance(data, tuple)
): ):
data_items = [data] data_items = [data]
...@@ -591,18 +596,13 @@ class MultiModalDataParser: ...@@ -591,18 +596,13 @@ class MultiModalDataParser:
data: ModalityData[ImageItem], data: ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any] | None: ) -> ModalityDataItems[Any, Any] | None:
if data is None: if data is None:
return ImageProcessorItems(None)
if self._is_empty(data):
return None return None
if self.is_embeddings(data): if self.is_embeddings(data):
return ImageEmbeddingItems(data, self.expected_hidden_size) return ImageEmbeddingItems(data, self.expected_hidden_size)
if ( if isinstance(data, (PILImage.Image, MediaWithBytes)) or (
isinstance(data, (PILImage.Image, MediaWithBytes)) isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 3
or isinstance(data, (np.ndarray, torch.Tensor))
and data.ndim == 3
): ):
data_items = [data] data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)): elif isinstance(data, (np.ndarray, torch.Tensor)):
...@@ -617,19 +617,14 @@ class MultiModalDataParser: ...@@ -617,19 +617,14 @@ class MultiModalDataParser:
data: ModalityData[VideoItem], data: ModalityData[VideoItem],
) -> ModalityDataItems[Any, Any] | None: ) -> ModalityDataItems[Any, Any] | None:
if data is None: if data is None:
return VideoProcessorItems(None)
if self._is_empty(data):
return None return None
if self.is_embeddings(data): if self.is_embeddings(data):
return VideoEmbeddingItems(data, self.expected_hidden_size) return VideoEmbeddingItems(data, self.expected_hidden_size)
data_items: list[VideoItem] data_items: list[VideoItem]
if ( if (is_list_of(data, PILImage.Image) and len(data) > 0) or (
is_list_of(data, PILImage.Image) isinstance(data, (np.ndarray, torch.Tensor)) and data.ndim == 4
or isinstance(data, (np.ndarray, torch.Tensor))
and data.ndim == 4
): ):
data_items = [data] data_items = [data]
elif isinstance(data, (np.ndarray, torch.Tensor)): elif isinstance(data, (np.ndarray, torch.Tensor)):
...@@ -664,12 +659,15 @@ class MultiModalDataParser: ...@@ -664,12 +659,15 @@ class MultiModalDataParser:
data: ModalityData[Any], data: ModalityData[Any],
) -> ModalityDataItems[Any, Any] | None: ) -> ModalityDataItems[Any, Any] | None:
"""Parse vision chunk data (unified image and video chunks).""" """Parse vision chunk data (unified image and video chunks)."""
if data is None or self._is_empty(data): if data is None:
return None return None
if self.is_embeddings(data): if self.is_embeddings(data):
raise ValueError("Do not support embedding data for vision_chunk right now") raise ValueError("Do not support embedding data for vision_chunk right now")
if isinstance(data, dict): if isinstance(data, dict):
data = [data] data = [data]
return VisionChunkProcessorItems(data) return VisionChunkProcessorItems(data)
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]: def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
...@@ -693,3 +691,20 @@ class MultiModalDataParser: ...@@ -693,3 +691,20 @@ class MultiModalDataParser:
mm_items[k] = parsed_data mm_items[k] = parsed_data
return mm_items return mm_items
MultiModalUUIDItems: TypeAlias = dict[str, Sequence[str | None]]
"""
As [`MultiModalUUIDDict`][vllm.multimodal.inputs.MultiModalUUIDDict], but
normalized such that each entry corresponds to a list.
"""
def parse_mm_uuids(mm_uuids: MultiModalUUIDDict | None) -> MultiModalUUIDItems:
if mm_uuids is None:
return {}
return {
modality: [uuids] if isinstance(uuids, str) else uuids
for modality, uuids in mm_uuids.items()
}
...@@ -32,7 +32,6 @@ from ..inputs import ( ...@@ -32,7 +32,6 @@ from ..inputs import (
MultiModalKwargsItem, MultiModalKwargsItem,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalKwargsOptionalItems, MultiModalKwargsOptionalItems,
MultiModalUUIDDict,
PlaceholderRange, PlaceholderRange,
mm_enc_dec_inputs, mm_enc_dec_inputs,
mm_inputs, mm_inputs,
...@@ -41,6 +40,7 @@ from ..parse import ( ...@@ -41,6 +40,7 @@ from ..parse import (
DictEmbeddingItems, DictEmbeddingItems,
EmbeddingItems, EmbeddingItems,
MultiModalDataItems, MultiModalDataItems,
MultiModalUUIDItems,
) )
from .context import ( from .context import (
BaseProcessingInfo, BaseProcessingInfo,
...@@ -1014,11 +1014,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1014,11 +1014,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self, self,
prompt: str, prompt: str,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], mm_uuid_items: MultiModalUUIDItems | None = None,
*, hf_processor_mm_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
return self.apply(prompt, mm_items, hf_processor_mm_kwargs, mm_uuids=mm_uuids) return self.apply(
prompt,
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
@abstractmethod @abstractmethod
def _get_mm_fields_config( def _get_mm_fields_config(
...@@ -1174,7 +1178,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1174,7 +1178,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
In addition, return whether prompt updates have been applied. In addition, return whether prompt updates have been applied.
""" """
processor_data, passthrough_data = self._get_hf_mm_data(mm_items) valid_mm_items = mm_items.select(
{k for k, c in mm_items.get_all_counts().items() if c > 0}
)
processor_data, passthrough_data = self._get_hf_mm_data(valid_mm_items)
processed_data = self._call_hf_processor( processed_data = self._call_hf_processor(
prompt=prompt_text, prompt=prompt_text,
...@@ -1301,69 +1308,57 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1301,69 +1308,57 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _hash_mm_items( def _hash_mm_items(
self, self,
mm_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalHashes: ) -> MultiModalHashes:
"""Create MM hashes to be returned.
Note: When overrides are provided via callers of `apply`,
`_hash_mm_items` will be bypassed and the overrides will be used.
"""
model_id = self.info.model_id model_id = self.info.model_id
hashes: MultiModalHashes = {} if mm_uuid_items is None:
mm_uuids = mm_uuids or {} mm_uuid_items = {}
mm_hashes: MultiModalHashes = {}
hasher = MultiModalHasher
for modality, items in mm_items.items(): for modality, data_items in mm_data_items.items():
if modality in mm_uuids: if modality in mm_uuid_items:
mm_uuids_per_modality = mm_uuids[modality] uuid_items = mm_uuid_items[modality]
if isinstance(mm_uuids_per_modality, str):
mm_uuids_per_modality = [mm_uuids_per_modality]
# For None entries, compute a hash; otherwise, use provided ID. # For None entries, compute a hash; otherwise, use provided ID.
computed: list[str] = [] hashes: list[str] = []
for i, item in enumerate(items.get_all_items_for_hash()): for i, item in enumerate(data_items.get_all_items_for_hash()):
item_uuid = mm_uuids_per_modality[i] uuid_item = uuid_items[i]
# NOTE: Even if a item_uuid is provided, we still compute a # NOTE: Even if a uuid_item is provided, we still compute a hash
# hash if `hf_processor_mm_kwargs` or `tokenization_kwargs` # if `hf_processor_mm_kwargs` is provided.
# are provided. This is because the processed multimodal # This is because the processed multimodal inputs can be different
# inputs can be different depending on the processor kwargs. # depending on the processor kwargs.
if ( if uuid_item is None or hf_processor_mm_kwargs:
item_uuid is None
or hf_processor_mm_kwargs
or tokenization_kwargs
):
# NOTE: use provided hash string to hash with kwargs # NOTE: use provided hash string to hash with kwargs
# if available for better performance. # if available for better performance.
item = item_uuid if item_uuid is not None else item item = uuid_item if uuid_item is not None else item
computed.append( hashes.append(
MultiModalHasher.hash_kwargs( hasher.hash_kwargs(
model_id=model_id, model_id=model_id,
**{modality: item}, **{modality: item},
**hf_processor_mm_kwargs, **hf_processor_mm_kwargs,
**tokenization_kwargs,
) )
) )
else: else:
computed.append(item_uuid) hashes.append(uuid_item)
hashes[modality] = computed
mm_hashes[modality] = hashes
else: else:
hashes[modality] = [ mm_hashes[modality] = [
MultiModalHasher.hash_kwargs( hasher.hash_kwargs(
model_id=model_id, model_id=model_id,
**{modality: item}, **{modality: item},
**hf_processor_mm_kwargs, **hf_processor_mm_kwargs,
**tokenization_kwargs,
) )
for item in items for item in data_items
] ]
return hashes return mm_hashes
def _get_cache_missing_items( def _get_cache_missing_items(
self, self,
...@@ -1468,10 +1463,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1468,10 +1463,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
( (
prompt_ids, prompt_ids,
...@@ -1494,9 +1488,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1494,9 +1488,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
with timed_preprocessor_operation(self.info.ctx, "hashing"): with timed_preprocessor_operation(self.info.ctx, "hashing"):
mm_hashes = self._hash_mm_items( mm_hashes = self._hash_mm_items(
mm_data_items, mm_data_items,
hf_processor_mm_kwargs, mm_uuid_items,
tokenization_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
mm_uuids=mm_uuids,
) )
mm_prompt_updates = self._get_mm_prompt_updates( mm_prompt_updates = self._get_mm_prompt_updates(
...@@ -1517,10 +1510,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1517,10 +1510,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
""" """
Apply the HF processor on the full prompt text, Apply the HF processor on the full prompt text,
...@@ -1533,17 +1525,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1533,17 +1525,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return self._apply_hf_processor( return self._apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
with timed_preprocessor_operation(self.info.ctx, "hashing"): with timed_preprocessor_operation(self.info.ctx, "hashing"):
mm_hashes = self._hash_mm_items( mm_hashes = self._hash_mm_items(
mm_data_items, mm_data_items,
hf_processor_mm_kwargs, mm_uuid_items,
tokenization_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
mm_uuids=mm_uuids,
) )
with timed_preprocessor_operation(self.info.ctx, "cache_lookup"): with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
...@@ -1753,10 +1744,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1753,10 +1744,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None, tokenization_kwargs: Mapping[str, object] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
...@@ -1775,6 +1765,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1775,6 +1765,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
if request_id is not None: if request_id is not None:
self.info.ctx.create_timing_stats(request_id) self.info.ctx.create_timing_stats(request_id)
if hf_processor_mm_kwargs is None:
hf_processor_mm_kwargs = {}
if tokenization_kwargs is None: if tokenization_kwargs is None:
tokenization_kwargs = {} tokenization_kwargs = {}
...@@ -1785,9 +1777,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1785,9 +1777,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) = self._cached_apply_hf_processor( ) = self._cached_apply_hf_processor(
prompt, prompt,
mm_items, mm_items,
mm_uuid_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
# NOTE: tokenization_kwargs are not required to init processor # NOTE: tokenization_kwargs are not required to init processor
...@@ -1861,10 +1853,9 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1861,10 +1853,9 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None, tokenization_kwargs: Mapping[str, object] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalEncDecInputs: ) -> MultiModalEncDecInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
...@@ -1877,9 +1868,9 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1877,9 +1868,9 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
encoder_inputs = super().apply( encoder_inputs = super().apply(
encoder_prompt, encoder_prompt,
mm_items, mm_items,
hf_processor_mm_kwargs, mm_uuid_items,
tokenization_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
mm_uuids=mm_uuids, tokenization_kwargs=tokenization_kwargs,
) )
return self._get_enc_dec_inputs( return self._get_enc_dec_inputs(
......
...@@ -51,7 +51,7 @@ if TYPE_CHECKING: ...@@ -51,7 +51,7 @@ if TYPE_CHECKING:
MultiModalInputs, MultiModalInputs,
MultiModalUUIDDict, MultiModalUUIDDict,
) )
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems, MultiModalUUIDItems
from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.processing import BaseMultiModalProcessor
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -463,23 +463,25 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -463,23 +463,25 @@ class BaseRenderer(ABC, Generic[_T]):
def _validate_mm_uuids( def _validate_mm_uuids(
self, self,
mm_data: "MultiModalDataDict", mm_data: "MultiModalDataDict",
mm_items: "MultiModalDataItems", mm_data_items: "MultiModalDataItems",
mm_uuids: "MultiModalUUIDDict | None", mm_uuid_items: "MultiModalUUIDItems",
) -> None: ) -> None:
if mm_uuids is None: # NOTE: Keys corresponding to `None` in `mm_data` don't appear in
mm_uuids = {} # `mm_data_items`
modalities = mm_data.keys() | mm_uuid_items.keys()
# NOTE: Keys corresponding to `None` in `mm_data` don't appear in `mm_items`
modalities = mm_data.keys() | mm_uuids.keys()
for modality in modalities: for modality in modalities:
data_items = mm_items.get(modality) or list[Any]() data_items = mm_data_items.get(modality)
uuid_items = mm_uuid_items.get(modality)
uuid_items = mm_uuids.get(modality) or list[str | None]() if data_items is None:
if isinstance(uuid_items, str): if uuid_items is None:
uuid_items = [uuid_items] raise ValueError(
f"multi_modal_data[{modality!r}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
if len(data_items) > 0: elif uuid_items is not None:
if len(uuid_items) > 0 and len(data_items) != len(uuid_items): if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
raise ValueError( raise ValueError(
f"If given, multi_modal_uuids[{modality!r}] must have " f"If given, multi_modal_uuids[{modality!r}] must have "
...@@ -488,14 +490,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -488,14 +490,7 @@ class BaseRenderer(ABC, Generic[_T]):
) )
for i, item in enumerate(data_items): for i, item in enumerate(data_items):
if item is None: if item is None and uuid_items[i] is None:
if not uuid_items:
raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}] is missing."
)
if uuid_items[i] is None:
raise ValueError( raise ValueError(
f"multi_modal_data[{modality!r}][{i}] is empty but " f"multi_modal_data[{modality!r}][{i}] is empty but "
f"multi_modal_uuids[{modality!r}][{i}] is missing." f"multi_modal_uuids[{modality!r}][{i}] is missing."
...@@ -504,8 +499,8 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -504,8 +499,8 @@ class BaseRenderer(ABC, Generic[_T]):
def _process_mm_uuids( def _process_mm_uuids(
self, self,
mm_data: "MultiModalDataDict", mm_data: "MultiModalDataDict",
mm_items: "MultiModalDataItems", mm_data_items: "MultiModalDataItems",
mm_uuids: "MultiModalUUIDDict | None", mm_uuid_items: "MultiModalUUIDItems",
mm_req_id: str, mm_req_id: str,
): ):
model_config = self.model_config model_config = self.model_config
...@@ -520,40 +515,45 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -520,40 +515,45 @@ class BaseRenderer(ABC, Generic[_T]):
and model_config.multimodal_config.mm_processor_cache_gb == 0 and model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.config.cache_config.enable_prefix_caching and not self.config.cache_config.enable_prefix_caching
): ):
mm_uuids = { mm_uuid_items = {
modality: [f"{mm_req_id}-{modality}-{i}" for i in range(data_count)] modality: [f"{mm_req_id}-{modality}-{i}" for i in range(data_count)]
for modality, data_count in mm_items.get_all_counts().items() for modality, data_count in mm_data_items.get_all_counts().items()
} }
self._validate_mm_uuids(mm_data, mm_items, mm_uuids) self._validate_mm_uuids(mm_data, mm_data_items, mm_uuid_items)
return mm_uuids return mm_uuid_items
# TODO: Remove str and tokenization_kwargs after deprecating InputPreprocessor # TODO: Remove str and tokenization_kwargs after deprecating InputPreprocessor
def _process_multimodal( def _process_multimodal(
self, self,
prompt: list[int] | str, prompt: list[int] | str,
mm_data: "MultiModalDataDict", mm_data: "MultiModalDataDict",
mm_uuids: "MultiModalUUIDDict | None",
mm_processor_kwargs: Mapping[str, object] | None, mm_processor_kwargs: Mapping[str, object] | None,
tokenization_kwargs: dict[str, Any] | None, tokenization_kwargs: dict[str, Any] | None,
mm_uuids: "MultiModalUUIDDict | None",
) -> "MultiModalInputs": ) -> "MultiModalInputs":
from vllm.multimodal.parse import parse_mm_uuids
from vllm.multimodal.processing.context import set_request_id from vllm.multimodal.processing.context import set_request_id
mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}" mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}"
mm_processor = self.get_mm_processor() mm_processor = self.get_mm_processor()
mm_items = mm_processor.info.parse_mm_data(mm_data) mm_data_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuids = self._process_mm_uuids(mm_data, mm_items, mm_uuids, mm_req_id) mm_uuid_items = parse_mm_uuids(mm_uuids)
mm_uuids = self._process_mm_uuids(
mm_data, mm_data_items, mm_uuid_items, mm_req_id
)
with set_request_id(mm_req_id), set_default_torch_num_threads(): with set_request_id(mm_req_id), set_default_torch_num_threads():
mm_inputs = mm_processor.apply( mm_inputs = mm_processor.apply(
prompt, prompt,
mm_items, mm_data_items,
hf_processor_mm_kwargs=mm_processor_kwargs or {}, mm_uuid_items,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
) )
self.update_mm_cache_stats() self.update_mm_cache_stats()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment