Unverified Commit 8bf6266a authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Multimodal] Generate mm_hash based on request metadata when caching is turned off (#23690)


Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
parent 0585a9e7
...@@ -257,6 +257,8 @@ class InputPreprocessor: ...@@ -257,6 +257,8 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Mapping[str, object]], mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Apply the model's multi-modal processor to a multi-modal prompt, Apply the model's multi-modal processor to a multi-modal prompt,
...@@ -273,10 +275,13 @@ class InputPreprocessor: ...@@ -273,10 +275,13 @@ class InputPreprocessor:
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
return mm_processor.apply(prompt, return mm_processor.apply(
prompt,
mm_data, mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs) tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
async def _process_multimodal_async( async def _process_multimodal_async(
self, self,
...@@ -285,6 +290,8 @@ class InputPreprocessor: ...@@ -285,6 +290,8 @@ class InputPreprocessor:
mm_processor_kwargs: Optional[Mapping[str, object]], mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Async version of Async version of
...@@ -301,10 +308,13 @@ class InputPreprocessor: ...@@ -301,10 +308,13 @@ class InputPreprocessor:
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
return mm_processor.apply(prompt, return mm_processor.apply(
prompt,
mm_data, mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs) tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
)
def _process_embeds( def _process_embeds(
self, self,
...@@ -341,6 +351,8 @@ class InputPreprocessor: ...@@ -341,6 +351,8 @@ class InputPreprocessor:
parsed_content: TokensPrompt, parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"] prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids") token_type_ids = parsed_content.get("token_type_ids")
...@@ -353,6 +365,7 @@ class InputPreprocessor: ...@@ -353,6 +365,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
else: else:
inputs = token_inputs( inputs = token_inputs(
...@@ -370,6 +383,8 @@ class InputPreprocessor: ...@@ -370,6 +383,8 @@ class InputPreprocessor:
parsed_content: TokensPrompt, parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"] prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids") token_type_ids = parsed_content.get("token_type_ids")
...@@ -382,6 +397,7 @@ class InputPreprocessor: ...@@ -382,6 +397,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
else: else:
inputs = token_inputs( inputs = token_inputs(
...@@ -399,6 +415,8 @@ class InputPreprocessor: ...@@ -399,6 +415,8 @@ class InputPreprocessor:
parsed_content: TextPrompt, parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"] prompt_text = parsed_content["prompt"]
...@@ -410,6 +428,7 @@ class InputPreprocessor: ...@@ -410,6 +428,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
else: else:
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
...@@ -432,6 +451,8 @@ class InputPreprocessor: ...@@ -432,6 +451,8 @@ class InputPreprocessor:
parsed_content: TextPrompt, parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]: ) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"] prompt_text = parsed_content["prompt"]
...@@ -443,6 +464,7 @@ class InputPreprocessor: ...@@ -443,6 +464,7 @@ class InputPreprocessor:
parsed_content.get("mm_processor_kwargs"), parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
else: else:
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
...@@ -465,6 +487,8 @@ class InputPreprocessor: ...@@ -465,6 +487,8 @@ class InputPreprocessor:
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> SingletonInputs: ) -> SingletonInputs:
""" """
Extract the singleton inputs from a prompt. Extract the singleton inputs from a prompt.
...@@ -486,18 +510,21 @@ class InputPreprocessor: ...@@ -486,18 +510,21 @@ class InputPreprocessor:
return self._process_tokens( return self._process_tokens(
parsed["content"], parsed["content"],
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
if parsed["type"] == "text": if parsed["type"] == "text":
return self._process_text( return self._process_text(
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return self._process_text( return self._process_text(
TextPrompt(prompt=parsed["content"]), TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
assert_never(parsed) assert_never(parsed)
...@@ -507,6 +534,8 @@ class InputPreprocessor: ...@@ -507,6 +534,8 @@ class InputPreprocessor:
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> SingletonInputs: ) -> SingletonInputs:
""" """
Async version of Async version of
...@@ -520,18 +549,21 @@ class InputPreprocessor: ...@@ -520,18 +549,21 @@ class InputPreprocessor:
return await self._process_tokens_async( return await self._process_tokens_async(
parsed["content"], parsed["content"],
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
if parsed["type"] == "text": if parsed["type"] == "text":
return await self._process_text_async( return await self._process_text_async(
parsed["content"], parsed["content"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
if parsed["type"] == "str": if parsed["type"] == "str":
return await self._process_text_async( return await self._process_text_async(
TextPrompt(prompt=parsed["content"]), TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
assert_never(parsed) assert_never(parsed)
...@@ -641,6 +673,8 @@ class InputPreprocessor: ...@@ -641,6 +673,8 @@ class InputPreprocessor:
self, self,
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
For encoder/decoder models only: For encoder/decoder models only:
...@@ -682,6 +716,7 @@ class InputPreprocessor: ...@@ -682,6 +716,7 @@ class InputPreprocessor:
encoder_inputs = self._prompt_to_llm_inputs( encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"], prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None decoder_inputs = None
...@@ -697,6 +732,7 @@ class InputPreprocessor: ...@@ -697,6 +732,7 @@ class InputPreprocessor:
inputs = self._prompt_to_llm_inputs( inputs = self._prompt_to_llm_inputs(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
...@@ -712,6 +748,8 @@ class InputPreprocessor: ...@@ -712,6 +748,8 @@ class InputPreprocessor:
self, self,
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
Async version of Async version of
...@@ -724,6 +762,7 @@ class InputPreprocessor: ...@@ -724,6 +762,7 @@ class InputPreprocessor:
encoder_task = self._prompt_to_llm_inputs_async( encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"], prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
...@@ -733,6 +772,7 @@ class InputPreprocessor: ...@@ -733,6 +772,7 @@ class InputPreprocessor:
decoder_task = self._prompt_to_llm_inputs_async( decoder_task = self._prompt_to_llm_inputs_async(
decoder_input, decoder_input,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
encoder_inputs, decoder_inputs = await asyncio.gather( encoder_inputs, decoder_inputs = await asyncio.gather(
...@@ -748,6 +788,7 @@ class InputPreprocessor: ...@@ -748,6 +788,7 @@ class InputPreprocessor:
inputs = await self._prompt_to_llm_inputs_async( inputs = await self._prompt_to_llm_inputs_async(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
...@@ -774,6 +815,8 @@ class InputPreprocessor: ...@@ -774,6 +815,8 @@ class InputPreprocessor:
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
For decoder-only models: For decoder-only models:
...@@ -794,6 +837,7 @@ class InputPreprocessor: ...@@ -794,6 +837,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
return self._build_decoder_only_llm_inputs(prompt_comps) return self._build_decoder_only_llm_inputs(prompt_comps)
...@@ -803,6 +847,8 @@ class InputPreprocessor: ...@@ -803,6 +847,8 @@ class InputPreprocessor:
prompt: SingletonPrompt, prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
""" """
Async version of Async version of
...@@ -812,6 +858,7 @@ class InputPreprocessor: ...@@ -812,6 +858,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
return self._build_decoder_only_llm_inputs(prompt_comps) return self._build_decoder_only_llm_inputs(prompt_comps)
...@@ -821,6 +868,8 @@ class InputPreprocessor: ...@@ -821,6 +868,8 @@ class InputPreprocessor:
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
"""Preprocess the input prompt.""" """Preprocess the input prompt."""
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
...@@ -829,6 +878,7 @@ class InputPreprocessor: ...@@ -829,6 +878,7 @@ class InputPreprocessor:
return self._process_encoder_decoder_prompt( return self._process_encoder_decoder_prompt(
prompt, prompt,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
...@@ -840,6 +890,7 @@ class InputPreprocessor: ...@@ -840,6 +890,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
async def preprocess_async( async def preprocess_async(
...@@ -847,6 +898,8 @@ class InputPreprocessor: ...@@ -847,6 +898,8 @@ class InputPreprocessor:
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None, tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
Async version of Async version of
...@@ -858,6 +911,7 @@ class InputPreprocessor: ...@@ -858,6 +911,7 @@ class InputPreprocessor:
return await self._process_encoder_decoder_prompt_async( return await self._process_encoder_decoder_prompt_async(
prompt, prompt,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
...@@ -869,6 +923,7 @@ class InputPreprocessor: ...@@ -869,6 +923,7 @@ class InputPreprocessor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
def clear_cache(self) -> None: def clear_cache(self) -> None:
......
...@@ -290,6 +290,7 @@ class DeepseekVL2MultiModalProcessor( ...@@ -290,6 +290,7 @@ class DeepseekVL2MultiModalProcessor(
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 2 vs > 2 # The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is # Since the processing cache assumes that the processor output is
...@@ -301,6 +302,7 @@ class DeepseekVL2MultiModalProcessor( ...@@ -301,6 +302,7 @@ class DeepseekVL2MultiModalProcessor(
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
return super()._cached_apply_hf_processor( return super()._cached_apply_hf_processor(
...@@ -308,6 +310,7 @@ class DeepseekVL2MultiModalProcessor( ...@@ -308,6 +310,7 @@ class DeepseekVL2MultiModalProcessor(
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
......
...@@ -479,6 +479,7 @@ class H2OVLMultiModalProcessor( ...@@ -479,6 +479,7 @@ class H2OVLMultiModalProcessor(
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 1 vs > 1 # The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is # Since the processing cache assumes that the processor output is
...@@ -490,6 +491,7 @@ class H2OVLMultiModalProcessor( ...@@ -490,6 +491,7 @@ class H2OVLMultiModalProcessor(
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
return super()._cached_apply_hf_processor( return super()._cached_apply_hf_processor(
...@@ -497,6 +499,7 @@ class H2OVLMultiModalProcessor( ...@@ -497,6 +499,7 @@ class H2OVLMultiModalProcessor(
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
......
...@@ -795,6 +795,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -795,6 +795,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
...@@ -805,8 +806,11 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -805,8 +806,11 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
image_height=-1, image_height=-1,
) )
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs, result = super().apply(prompt,
tokenization_kwargs) mm_data,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides)
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts() mm_item_counts = mm_items.get_all_counts()
......
...@@ -184,9 +184,13 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] ...@@ -184,9 +184,13 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalEncDecInputs: ) -> MultiModalEncDecInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, mm_inputs = super().apply(prompt,
tokenization_kwargs) mm_data,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides)
image_token_id = self.info.get_hf_config().image_token_index image_token_id = self.info.get_hf_config().image_token_index
# Check that the number of image tokens in the decoder prompt matches # Check that the number of image tokens in the decoder prompt matches
......
...@@ -203,9 +203,13 @@ class PaliGemmaMultiModalProcessor( ...@@ -203,9 +203,13 @@ class PaliGemmaMultiModalProcessor(
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, mm_inputs = super().apply(prompt,
tokenization_kwargs) mm_data,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides)
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
......
...@@ -314,12 +314,14 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] ...@@ -314,12 +314,14 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
# NOTE: The tokens are already inserted by the chat template # NOTE: The tokens are already inserted by the chat template
......
...@@ -138,6 +138,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): ...@@ -138,6 +138,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if "image" in mm_data: if "image" in mm_data:
image_data = mm_data["image"] image_data = mm_data["image"]
...@@ -146,8 +147,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): ...@@ -146,8 +147,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
mm_data = {"image": mm_data} mm_data = {"image": mm_data}
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs, tokenization_kwargs = tokenization_kwargs or {}
tokenization_kwargs or {}) mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
tokenization_kwargs))
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_processed_data = BatchFeature(image_data) mm_processed_data = BatchFeature(image_data)
......
...@@ -327,6 +327,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -327,6 +327,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
...@@ -393,9 +394,11 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -393,9 +394,11 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs, self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
num_image_patches), num_image_patches),
) )
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
tokenization_kwargs))
mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs,
tokenization_kwargs)
return MultiModalInputs( return MultiModalInputs(
type="multimodal", type="multimodal",
prompt=prompt, prompt=prompt,
......
...@@ -288,12 +288,14 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] ...@@ -288,12 +288,14 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
# NOTE: The tokens are already inserted by the chat template # NOTE: The tokens are already inserted by the chat template
......
...@@ -1020,8 +1020,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1020,8 +1020,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt: str, prompt: str,
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
return self.apply(prompt, mm_data, hf_processor_mm_kwargs) return self.apply(prompt,
mm_data,
hf_processor_mm_kwargs,
mm_hash_overrides=mm_hash_overrides)
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
""" """
...@@ -1357,7 +1362,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1357,7 +1362,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
) -> MultiModalHashes: ) -> MultiModalHashes:
"""Create MM hashes to be returned (only used in V1).""" """Create MM hashes to be returned (only used in V1).
Note: When overrides are provided via callers of `apply`,
`_hash_mm_items` will be bypassed and the overrides will be used.
"""
model_id = self.info.model_id model_id = self.info.model_id
return { return {
...@@ -1464,6 +1473,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1464,6 +1473,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
( (
prompt_ids, prompt_ids,
...@@ -1483,8 +1494,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1483,8 +1494,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs), hf_processor_mm_kwargs),
) )
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, # Use overrides if provided; fallback to data-dependent hashing.
tokenization_kwargs) mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs))
mm_prompt_updates = self._get_mm_prompt_updates( mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items, mm_data_items,
...@@ -1506,6 +1519,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1506,6 +1519,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object],
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
""" """
Apply the HF processor on the full prompt text, Apply the HF processor on the full prompt text,
...@@ -1520,10 +1535,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1520,10 +1535,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, # Use overrides if provided; fallback to data-dependent hashing.
tokenization_kwargs) mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs))
mm_missing_data_items = self._get_cache_missing_items( mm_missing_data_items = self._get_cache_missing_items(
cache=cache, cache=cache,
...@@ -1723,6 +1741,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1723,6 +1741,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
...@@ -1751,6 +1771,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1751,6 +1771,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items, mm_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
# NOTE: tokenization_kwargs are not required to init processor # NOTE: tokenization_kwargs are not required to init processor
...@@ -1835,6 +1856,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1835,6 +1856,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None, tokenization_kwargs: Optional[Mapping[str, object]] = None,
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
) -> MultiModalEncDecInputs: ) -> MultiModalEncDecInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
...@@ -1849,6 +1872,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1849,6 +1872,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_data, mm_data,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs, tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides,
) )
return self._get_enc_dec_inputs( return self._get_enc_dec_inputs(
......
...@@ -225,6 +225,41 @@ class Processor: ...@@ -225,6 +225,41 @@ class Processor:
# Remember that this backend was set automatically # Remember that this backend was set automatically
params.guided_decoding.backend_was_auto = True params.guided_decoding.backend_was_auto = True
def _maybe_build_mm_hash_overrides(
self,
request_id: str,
prompt: PromptType,
) -> Optional[dict[str, list[str]]]:
"""Build per-item multimodal hash overrides when enabled. In this case,
multimodal data items are identified by their request id, modality and
index rather than their content.
Returns a dictionary of modality -> list[str] of overrides, or None if
disabled or no multimodal data is present.
"""
def _extract_mm_data(p: PromptType):
if isinstance(p, dict) and "encoder_prompt" in p:
enc = p.get("encoder_prompt")
if isinstance(enc, dict):
return enc.get("multi_modal_data")
return None
if isinstance(p, dict):
return p.get("multi_modal_data")
return None
mm_data = _extract_mm_data(prompt)
if not mm_data:
return None
overrides: dict[str, list[str]] = {}
for modality, data in mm_data.items():
n = len(data) if isinstance(data, list) else 1
overrides[modality] = [
f"{request_id}-{modality}-{i}" for i in range(n)
]
return overrides
def process_inputs( def process_inputs(
self, self,
request_id: str, request_id: str,
...@@ -254,6 +289,18 @@ class Processor: ...@@ -254,6 +289,18 @@ class Processor:
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
# Optionally generate multimodal hash overrides based on request id.
# NOTE: when users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore hashing is no longer necessary.
if (self.model_config.multimodal_config and
self.model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.cache_config.enable_prefix_caching):
mm_hash_overrides = self._maybe_build_mm_hash_overrides(
request_id, prompt)
else:
mm_hash_overrides = None
# Process inputs, which includes: # Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists. # 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess # 2. For multimodal models with a merged preprocessor, preprocess
...@@ -262,6 +309,7 @@ class Processor: ...@@ -262,6 +309,7 @@ class Processor:
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
mm_hash_overrides=mm_hash_overrides,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
current_platform.validate_request( current_platform.validate_request(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment