Unverified Commit 009d689b authored by Chenheli Hua's avatar Chenheli Hua Committed by GitHub
Browse files

[Core] Simplify and unify mm uuid handling & auto-generated mm hash overrides processing. (#24271)


Signed-off-by: default avatarChenheli Hua <huachenheli@outlook.com>
parent 0efdb5c3
...@@ -152,8 +152,8 @@ def test_multi_modal_uuids_accepts_none_and_passes_through( ...@@ -152,8 +152,8 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
*, *,
tokenization_kwargs=None, tokenization_kwargs=None,
lora_request=None, lora_request=None,
mm_hash_overrides=None): mm_uuids=None):
captured["mm_hash_overrides"] = mm_hash_overrides captured["mm_uuids"] = mm_uuids
# Minimal processed inputs for decoder-only flow # Minimal processed inputs for decoder-only flow
return {"type": "token", "prompt_token_ids": [1]} return {"type": "token", "prompt_token_ids": [1]}
...@@ -180,7 +180,7 @@ def test_multi_modal_uuids_accepts_none_and_passes_through( ...@@ -180,7 +180,7 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
params=SamplingParams(), params=SamplingParams(),
) )
assert captured["mm_hash_overrides"] == mm_uuids assert captured["mm_uuids"] == mm_uuids
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
...@@ -196,8 +196,8 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): ...@@ -196,8 +196,8 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
*, *,
tokenization_kwargs=None, tokenization_kwargs=None,
lora_request=None, lora_request=None,
mm_hash_overrides=None): mm_uuids=None):
captured["mm_hash_overrides"] = mm_hash_overrides captured["mm_uuids"] = mm_uuids
return {"type": "token", "prompt_token_ids": [1]} return {"type": "token", "prompt_token_ids": [1]}
monkeypatch.setattr(processor.input_preprocessor, monkeypatch.setattr(processor.input_preprocessor,
...@@ -223,7 +223,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): ...@@ -223,7 +223,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
) )
# Expect request-id-based overrides are passed through # Expect request-id-based overrides are passed through
assert captured["mm_hash_overrides"] == { assert captured["mm_uuids"] == {
"image": [f"{request_id}-image-0", f"{request_id}-image-1"], "image": [f"{request_id}-image-0", f"{request_id}-image-1"],
"video": [f"{request_id}-video-0"], "video": [f"{request_id}-video-0"],
} }
...@@ -258,8 +258,7 @@ class InputPreprocessor: ...@@ -258,8 +258,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Apply the model's multi-modal processor to a multi-modal prompt, Apply the model's multi-modal processor to a multi-modal prompt,
...@@ -281,7 +280,7 @@ class InputPreprocessor: ...@@ -281,7 +280,7 @@ class InputPreprocessor:
mm_data, mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
mm_hashes = mm_input["mm_hashes"] mm_hashes = mm_input["mm_hashes"]
...@@ -302,8 +301,7 @@ class InputPreprocessor: ...@@ -302,8 +301,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Async version of Async version of
...@@ -325,7 +323,7 @@ class InputPreprocessor: ...@@ -325,7 +323,7 @@ class InputPreprocessor:
mm_data, mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
mm_hashes = mm_input["mm_hashes"] mm_hashes = mm_input["mm_hashes"]
...@@ -390,8 +388,7 @@ class InputPreprocessor: ...@@ -390,8 +388,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = self._truncate_inputs( prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs) parsed_content["prompt_token_ids"], tokenization_kwargs)
...@@ -404,7 +401,7 @@ class InputPreprocessor: ...@@ -404,7 +401,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
else: else:
inputs = token_inputs(prompt_token_ids=prompt_token_ids) inputs = token_inputs(prompt_token_ids=prompt_token_ids)
...@@ -420,8 +417,7 @@ class InputPreprocessor: ...@@ -420,8 +417,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = self._truncate_inputs( prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs) parsed_content["prompt_token_ids"], tokenization_kwargs)
...@@ -434,7 +430,7 @@ class InputPreprocessor: ...@@ -434,7 +430,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
else: else:
inputs = token_inputs(prompt_token_ids=prompt_token_ids, ) inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
...@@ -450,8 +446,7 @@ class InputPreprocessor: ...@@ -450,8 +446,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"] prompt_text = parsed_content["prompt"]
...@@ -463,7 +458,7 @@ class InputPreprocessor: ...@@ -463,7 +458,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
else: else:
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
...@@ -487,8 +482,7 @@ class InputPreprocessor: ...@@ -487,8 +482,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"] prompt_text = parsed_content["prompt"]
...@@ -500,7 +494,7 @@ class InputPreprocessor: ...@@ -500,7 +494,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
else: else:
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
...@@ -524,8 +518,7 @@ class InputPreprocessor: ...@@ -524,8 +518,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> SingletonInputs: ) -> SingletonInputs:
""" """
Extract the singleton inputs from a prompt. Extract the singleton inputs from a prompt.
...@@ -547,21 +540,21 @@ class InputPreprocessor: ...@@ -547,21 +540,21 @@ class InputPreprocessor:
return self._process_tokens( return self._process_tokens(
parsed["content"], parsed["content"],
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if parsed["type"] == "text": if parsed["type"] == "text":
return self._process_text( return self._process_text(
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return self._process_text( return self._process_text(
TextPrompt(prompt=parsed["content"]), TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
assert_never(parsed) assert_never(parsed)
...@@ -572,8 +565,7 @@ class InputPreprocessor: ...@@ -572,8 +565,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> SingletonInputs: ) -> SingletonInputs:
""" """
Async version of Async version of
...@@ -587,21 +579,21 @@ class InputPreprocessor: ...@@ -587,21 +579,21 @@ class InputPreprocessor:
return await self._process_tokens_async( return await self._process_tokens_async(
parsed["content"], parsed["content"],
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if parsed["type"] == "text": if parsed["type"] == "text":
return await self._process_text_async( return await self._process_text_async(
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return await self._process_text_async( return await self._process_text_async(
TextPrompt(prompt=parsed["content"]), TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
assert_never(parsed) assert_never(parsed)
...@@ -712,8 +704,7 @@ class InputPreprocessor: ...@@ -712,8 +704,7 @@ class InputPreprocessor:
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
For encoder/decoder models only: For encoder/decoder models only:
...@@ -755,7 +746,7 @@ class InputPreprocessor: ...@@ -755,7 +746,7 @@ class InputPreprocessor:
encoder_inputs = self._prompt_to_llm_inputs( encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"], prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None decoder_inputs = None
...@@ -771,7 +762,7 @@ class InputPreprocessor: ...@@ -771,7 +762,7 @@ class InputPreprocessor:
inputs = self._prompt_to_llm_inputs( inputs = self._prompt_to_llm_inputs(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
...@@ -788,8 +779,7 @@ class InputPreprocessor: ...@@ -788,8 +779,7 @@ class InputPreprocessor:
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
Async version of Async version of
...@@ -802,7 +792,7 @@ class InputPreprocessor: ...@@ -802,7 +792,7 @@ class InputPreprocessor:
encoder_task = self._prompt_to_llm_inputs_async( encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"], prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
...@@ -812,7 +802,7 @@ class InputPreprocessor: ...@@ -812,7 +802,7 @@ class InputPreprocessor:
decoder_task = self._prompt_to_llm_inputs_async( decoder_task = self._prompt_to_llm_inputs_async(
decoder_input, decoder_input,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
encoder_inputs, decoder_inputs = await asyncio.gather( encoder_inputs, decoder_inputs = await asyncio.gather(
...@@ -828,7 +818,7 @@ class InputPreprocessor: ...@@ -828,7 +818,7 @@ class InputPreprocessor:
inputs = await self._prompt_to_llm_inputs_async( inputs = await self._prompt_to_llm_inputs_async(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
...@@ -856,8 +846,7 @@ class InputPreprocessor: ...@@ -856,8 +846,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
For decoder-only models: For decoder-only models:
...@@ -878,7 +867,7 @@ class InputPreprocessor: ...@@ -878,7 +867,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
return self._build_decoder_only_llm_inputs(prompt_comps) return self._build_decoder_only_llm_inputs(prompt_comps)
...@@ -889,8 +878,7 @@ class InputPreprocessor: ...@@ -889,8 +878,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
Async version of Async version of
...@@ -900,7 +888,7 @@ class InputPreprocessor: ...@@ -900,7 +888,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
return self._build_decoder_only_llm_inputs(prompt_comps) return self._build_decoder_only_llm_inputs(prompt_comps)
...@@ -911,8 +899,7 @@ class InputPreprocessor: ...@@ -911,8 +899,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
"""Preprocess the input prompt.""" """Preprocess the input prompt."""
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
...@@ -921,7 +908,7 @@ class InputPreprocessor: ...@@ -921,7 +908,7 @@ class InputPreprocessor:
return self._process_encoder_decoder_prompt( return self._process_encoder_decoder_prompt(
prompt, prompt,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
...@@ -933,7 +920,7 @@ class InputPreprocessor: ...@@ -933,7 +920,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
async def preprocess_async( async def preprocess_async(
...@@ -942,8 +929,7 @@ class InputPreprocessor: ...@@ -942,8 +929,7 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
Async version of Async version of
...@@ -955,7 +941,7 @@ class InputPreprocessor: ...@@ -955,7 +941,7 @@ class InputPreprocessor:
return await self._process_encoder_decoder_prompt_async( return await self._process_encoder_decoder_prompt_async(
prompt, prompt,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
...@@ -967,7 +953,7 @@ class InputPreprocessor: ...@@ -967,7 +953,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
def clear_cache(self) -> None: def clear_cache(self) -> None:
......
...@@ -21,7 +21,8 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype ...@@ -21,7 +21,8 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.transformers import replace_linear_class from vllm.model_executor.models.transformers import replace_linear_class
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors) MultiModalKwargsItems, MultiModalUUIDDict,
NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
...@@ -290,7 +291,7 @@ class DeepseekVL2MultiModalProcessor( ...@@ -290,7 +291,7 @@ class DeepseekVL2MultiModalProcessor(
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 2 vs > 2 # The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is # Since the processing cache assumes that the processor output is
...@@ -302,7 +303,7 @@ class DeepseekVL2MultiModalProcessor( ...@@ -302,7 +303,7 @@ class DeepseekVL2MultiModalProcessor(
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
return super()._cached_apply_hf_processor( return super()._cached_apply_hf_processor(
...@@ -310,7 +311,7 @@ class DeepseekVL2MultiModalProcessor( ...@@ -310,7 +311,7 @@ class DeepseekVL2MultiModalProcessor(
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
......
...@@ -17,7 +17,7 @@ from transformers import PretrainedConfig ...@@ -17,7 +17,7 @@ from transformers import PretrainedConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalKwargsItems, MultiModalUUIDDict
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (MultiModalProcessingInfo, from vllm.multimodal.processing import (MultiModalProcessingInfo,
...@@ -479,7 +479,7 @@ class H2OVLMultiModalProcessor( ...@@ -479,7 +479,7 @@ class H2OVLMultiModalProcessor(
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 1 vs > 1 # The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is # Since the processing cache assumes that the processor output is
...@@ -491,7 +491,7 @@ class H2OVLMultiModalProcessor( ...@@ -491,7 +491,7 @@ class H2OVLMultiModalProcessor(
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
return super()._cached_apply_hf_processor( return super()._cached_apply_hf_processor(
...@@ -499,7 +499,7 @@ class H2OVLMultiModalProcessor( ...@@ -499,7 +499,7 @@ class H2OVLMultiModalProcessor(
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
......
...@@ -24,7 +24,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -24,7 +24,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargsItems) MultiModalInputs, MultiModalKwargsItems,
MultiModalUUIDDict)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
...@@ -795,7 +796,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -795,7 +796,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
...@@ -810,7 +811,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -810,7 +811,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
mm_data, mm_data,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides) mm_uuids=mm_uuids)
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts() mm_item_counts = mm_items.get_all_counts()
......
...@@ -57,7 +57,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -57,7 +57,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalKwargsItems) MultiModalKwargsItems, MultiModalUUIDDict)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseProcessingInfo, from vllm.multimodal.processing import (BaseProcessingInfo,
...@@ -184,13 +184,13 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] ...@@ -184,13 +184,13 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalEncDecInputs: ) -> MultiModalEncDecInputs:
mm_inputs = super().apply(prompt, mm_inputs = super().apply(prompt,
mm_data, mm_data,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides) mm_uuids=mm_uuids)
image_token_id = self.info.get_hf_config().image_token_index image_token_id = self.info.get_hf_config().image_token_index
# Check that the number of image tokens in the decoder prompt matches # Check that the number of image tokens in the decoder prompt matches
......
...@@ -12,7 +12,8 @@ from vllm.logger import init_logger ...@@ -12,7 +12,8 @@ from vllm.logger import init_logger
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargsItems) MultiModalInputs, MultiModalKwargsItems,
MultiModalUUIDDict)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
...@@ -203,13 +204,13 @@ class PaliGemmaMultiModalProcessor( ...@@ -203,13 +204,13 @@ class PaliGemmaMultiModalProcessor(
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
mm_inputs = super().apply(prompt, mm_inputs = super().apply(prompt,
mm_data, mm_data,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides) mm_uuids=mm_uuids)
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
......
...@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors) MultiModalUUIDDict, NestedTensors)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
...@@ -316,14 +316,14 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] ...@@ -316,14 +316,14 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
# NOTE: The tokens are already inserted by the chat template # NOTE: The tokens are already inserted by the chat template
......
...@@ -36,7 +36,7 @@ from vllm.multimodal.cache import MultiModalProcessorOnlyCache ...@@ -36,7 +36,7 @@ from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import (ImageItem, ModalityData, from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig, MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargsItems, MultiModalInputs, MultiModalKwargsItems,
PlaceholderRange) MultiModalUUIDDict, PlaceholderRange)
from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems, from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems,
MultiModalDataItems, MultiModalDataParser) MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
...@@ -164,7 +164,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor): ...@@ -164,7 +164,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if "image" in mm_data: if "image" in mm_data:
image_data = mm_data["image"] image_data = mm_data["image"]
...@@ -177,7 +177,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor): ...@@ -177,7 +177,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
mm_hashes = self._hash_mm_items(mm_items, mm_hashes = self._hash_mm_items(mm_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides) mm_uuids=mm_uuids)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_processed_data = BatchFeature(image_data) mm_processed_data = BatchFeature(image_data)
......
...@@ -44,7 +44,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -44,7 +44,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, PlaceholderRange) MultiModalInputs, MultiModalUUIDDict,
PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo) BaseProcessingInfo)
...@@ -347,7 +348,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -347,7 +348,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
...@@ -415,9 +416,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -415,9 +416,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
num_image_patches), num_image_patches),
) )
# Use overrides if provided; fallback to data-dependent hashing. # Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else mm_hashes = (mm_uuids if mm_uuids is not None else self._hash_mm_items(
self._hash_mm_items(mm_items, hf_processor_mm_kwargs, mm_items, hf_processor_mm_kwargs, tokenization_kwargs))
tokenization_kwargs))
return MultiModalInputs( return MultiModalInputs(
type="multimodal", type="multimodal",
......
...@@ -31,7 +31,8 @@ from vllm.model_executor.models.whisper import WhisperEncoder ...@@ -31,7 +31,8 @@ from vllm.model_executor.models.whisper import WhisperEncoder
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors) MultiModalKwargsItems, MultiModalUUIDDict,
NestedTensors)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
...@@ -290,14 +291,14 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] ...@@ -290,14 +291,14 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None, mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
# NOTE: The tokens are already inserted by the chat template # NOTE: The tokens are already inserted by the chat template
......
...@@ -1022,13 +1022,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1022,13 +1022,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
return self.apply(prompt, return self.apply(prompt,
mm_data, mm_data,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
mm_hash_overrides=mm_hash_overrides) mm_uuids=mm_uuids)
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
""" """
...@@ -1364,8 +1363,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1364,8 +1363,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> MultiModalHashes: ) -> MultiModalHashes:
"""Create MM hashes to be returned (only used in V1). """Create MM hashes to be returned (only used in V1).
...@@ -1376,30 +1374,30 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1376,30 +1374,30 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
model_id = self.info.model_id model_id = self.info.model_id
hashes: MultiModalHashes = {} hashes: MultiModalHashes = {}
mm_hash_overrides = mm_hash_overrides or {} mm_uuids = mm_uuids or {}
for modality, items in mm_items.items(): for modality, items in mm_items.items():
if modality in mm_hash_overrides: if modality in mm_uuids:
mm_hashes = mm_hash_overrides[modality] mm_uuids_per_modality = mm_uuids[modality]
if isinstance(mm_hashes, str): if isinstance(mm_uuids_per_modality, str):
mm_hashes = [mm_hashes] mm_uuids_per_modality = [mm_uuids_per_modality]
# For None entries, compute a hash; otherwise, use provided ID. # For None entries, compute a hash; otherwise, use provided ID.
computed: list[str] = [] computed: list[str] = []
for i, item in enumerate(items): for i, item in enumerate(items):
mm_hash = mm_hashes[i] item_uuid = mm_uuids_per_modality[i]
# NOTE: Even if a mm_hash is provided, we still compute a # NOTE: Even if a item_uuid is provided, we still compute a
# hash if `hf_processor_mm_kwargs` or `tokenization_kwargs` # hash if `hf_processor_mm_kwargs` or `tokenization_kwargs`
# are provided. This is because the processed multimodal # are provided. This is because the processed multimodal
# inputs can be different depending on the processor kwargs. # inputs can be different depending on the processor kwargs.
if mm_hash is None or \ if item_uuid is None or \
hf_processor_mm_kwargs or \ hf_processor_mm_kwargs or \
tokenization_kwargs: tokenization_kwargs:
# NOTE: use provided hash string to hash with kwargs # NOTE: use provided hash string to hash with kwargs
# if available for better performance. # if available for better performance.
item = mm_hash if mm_hash is not None else item item = item_uuid if item_uuid is not None else item
computed.append( computed.append(
MultiModalHasher.hash_kwargs( MultiModalHasher.hash_kwargs(
model_id=model_id, model_id=model_id,
...@@ -1407,7 +1405,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1407,7 +1405,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
**hf_processor_mm_kwargs, **hf_processor_mm_kwargs,
**tokenization_kwargs)) **tokenization_kwargs))
else: else:
computed.append(mm_hash) computed.append(item_uuid)
hashes[modality] = computed hashes[modality] = computed
else: else:
hashes[modality] = [ hashes[modality] = [
...@@ -1514,8 +1512,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1514,8 +1512,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
( (
prompt_ids, prompt_ids,
...@@ -1539,7 +1536,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1539,7 +1536,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_hashes = self._hash_mm_items(mm_data_items, mm_hashes = self._hash_mm_items(mm_data_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides) mm_uuids=mm_uuids)
mm_prompt_updates = self._get_mm_prompt_updates( mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items, mm_data_items,
...@@ -1562,8 +1559,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1562,8 +1559,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
""" """
Apply the HF processor on the full prompt text, Apply the HF processor on the full prompt text,
...@@ -1578,13 +1574,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1578,13 +1574,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
mm_hashes = self._hash_mm_items(mm_data_items, mm_hashes = self._hash_mm_items(mm_data_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides) mm_uuids=mm_uuids)
mm_missing_data_items = self._get_cache_missing_items( mm_missing_data_items = self._get_cache_missing_items(
cache=cache, cache=cache,
...@@ -1785,8 +1781,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1785,8 +1781,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
...@@ -1815,7 +1810,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1815,7 +1810,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items, mm_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
# NOTE: tokenization_kwargs are not required to init processor # NOTE: tokenization_kwargs are not required to init processor
...@@ -1901,8 +1896,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1901,8 +1896,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
*, *,
mm_hash_overrides: Optional[Union[dict[str, list[str]], mm_uuids: Optional[MultiModalUUIDDict] = None,
MultiModalUUIDDict]] = None,
) -> MultiModalEncDecInputs: ) -> MultiModalEncDecInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
...@@ -1917,7 +1911,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1917,7 +1911,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_data, mm_data,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
return self._get_enc_dec_inputs( return self._get_enc_dec_inputs(
......
...@@ -12,7 +12,7 @@ from vllm.inputs.preprocess import InputPreprocessor ...@@ -12,7 +12,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -276,11 +276,11 @@ class Processor: ...@@ -276,11 +276,11 @@ class Processor:
# Remember that this backend was set automatically # Remember that this backend was set automatically
params.guided_decoding.backend_was_auto = True params.guided_decoding.backend_was_auto = True
def _maybe_build_mm_hash_overrides( def _maybe_build_mm_uuids(
self, self,
request_id: str, request_id: str,
prompt: PromptType, prompt: PromptType,
) -> Optional[dict[str, list[str]]]: ) -> Optional[MultiModalUUIDDict]:
"""Build per-item multimodal hash overrides when enabled. In this case, """Build per-item multimodal hash overrides when enabled. In this case,
multimodal data items are identified by their request id, modality and multimodal data items are identified by their request id, modality and
index rather than their content. index rather than their content.
...@@ -303,13 +303,13 @@ class Processor: ...@@ -303,13 +303,13 @@ class Processor:
if not mm_data: if not mm_data:
return None return None
overrides: dict[str, list[str]] = {} mm_uuids: MultiModalUUIDDict = {}
for modality, data in mm_data.items(): for modality, data in mm_data.items():
n = len(data) if isinstance(data, list) else 1 n = len(data) if isinstance(data, list) else 1
overrides[modality] = [ mm_uuids[modality] = [
f"{request_id}-{modality}-{i}" for i in range(n) f"{request_id}-{modality}-{i}" for i in range(n)
] ]
return overrides return mm_uuids
def process_inputs( def process_inputs(
self, self,
...@@ -351,16 +351,15 @@ class Processor: ...@@ -351,16 +351,15 @@ class Processor:
if (self.model_config.multimodal_config and if (self.model_config.multimodal_config and
self.model_config.multimodal_config.mm_processor_cache_gb == 0 self.model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.cache_config.enable_prefix_caching): and not self.cache_config.enable_prefix_caching):
mm_hash_overrides = self._maybe_build_mm_hash_overrides( mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
request_id, prompt)
else: else:
# Otherwise, use user-provided uuids as multimodal hash overrides # Otherwise, use user-provided uuids as multimodal hash overrides
# if provided. # if provided.
self._validate_multi_modal_uuids(prompt) self._validate_multi_modal_uuids(prompt)
if isinstance(prompt, dict): if isinstance(prompt, dict):
mm_hash_overrides = prompt.get("multi_modal_uuids") mm_uuids = prompt.get("multi_modal_uuids")
else: else:
mm_hash_overrides = None mm_uuids = None
# Process inputs, which includes: # Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists. # 1. Tokenize text prompt, with LoRA request if one exists.
...@@ -370,7 +369,7 @@ class Processor: ...@@ -370,7 +369,7 @@ class Processor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides, mm_uuids=mm_uuids,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
current_platform.validate_request( current_platform.validate_request(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment