Unverified Commit d8cf819a authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[Core] [Bugfix] [Multimodal] Fix multimodal profiling and generation for SFT/PTQed models (#20058)


Signed-off-by: default avatarKyle Sayers <kylesayrs@gmail.com>
parent 551ef163
......@@ -538,11 +538,13 @@ return a schema of the tensors outputted by the HF processor that are related to
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
image_patches = processed_outputs.get("image_patches")
......@@ -566,6 +568,11 @@ return a schema of the tensors outputted by the HF processor that are related to
Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) has special handling
for text-only inputs to prevent unnecessary warnings from HF processor.
!!! note
The `_call_hf_processor` method specifies both `mm_kwargs` and `tok_kwargs` for
processing. `mm_kwargs` is used to both initialize and call the huggingface
processor, whereas `tok_kwargs` is only used to call the huggingface processor.
This lets us override [_get_mm_fields_config][vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config] as follows:
```python
......
......@@ -1086,6 +1086,7 @@ def test_hf_processor_kwargs(model_id, call_kwargs, expected_kwargs):
prompt="",
mm_data={},
mm_kwargs=call_kwargs,
tok_kwargs={},
)
assert out_kwargs == expected_kwargs
......@@ -481,6 +481,13 @@ class LLM:
# Use default sampling params.
sampling_params = self.get_default_sampling_params()
tokenization_kwargs: dict[str, Any] = {}
truncate_prompt_tokens = None
if isinstance(sampling_params, SamplingParams):
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs)
self._validate_and_add_requests(
prompts=parsed_prompts,
params=sampling_params,
......@@ -488,6 +495,7 @@ class LLM:
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
guided_options=guided_options_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
......
......@@ -171,6 +171,10 @@ def _validate_truncation_size(
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
else:
if tokenization_kwargs is not None:
tokenization_kwargs["truncation"] = False
return truncate_prompt_tokens
......
......@@ -265,7 +265,8 @@ class InputPreprocessor:
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]],
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> MultiModalInputs:
"""
......@@ -280,15 +281,19 @@ class InputPreprocessor:
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
return_mm_hashes)
return mm_processor.apply(prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes)
async def _process_multimodal_async(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]],
lora_request: Optional[LoRARequest],
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> MultiModalInputs:
"""
......@@ -302,8 +307,11 @@ class InputPreprocessor:
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
return_mm_hashes)
return mm_processor.apply(prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes)
def _process_embeds(
self,
......@@ -338,6 +346,7 @@ class InputPreprocessor:
def _process_tokens(
self,
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> Union[TokenInputs, MultiModalInputs]:
......@@ -350,6 +359,7 @@ class InputPreprocessor:
prompt_token_ids,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
......@@ -367,6 +377,7 @@ class InputPreprocessor:
async def _process_tokens_async(
self,
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> Union[TokenInputs, MultiModalInputs]:
......@@ -379,6 +390,7 @@ class InputPreprocessor:
prompt_token_ids,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
......@@ -408,6 +420,7 @@ class InputPreprocessor:
prompt_text,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
......@@ -442,6 +455,7 @@ class InputPreprocessor:
prompt_text,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
......@@ -860,7 +874,8 @@ class InputPreprocessor:
"returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return self._process_encoder_decoder_prompt(prompt)
return self._process_encoder_decoder_prompt(
prompt, tokenization_kwargs)
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt "
......
......@@ -185,11 +185,13 @@ class AyaVisionMultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt,
mm_data,
mm_kwargs,
tok_kwargs,
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_processor = hf_processor.image_processor
......
......@@ -454,6 +454,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
# HF processor always adds placeholders even when there's no image
......@@ -465,6 +466,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
def _get_mm_fields_config(
......
......@@ -107,6 +107,7 @@ class ChameleonMultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
prompt_ids = self.info.get_tokenizer().encode(prompt)
......@@ -117,6 +118,7 @@ class ChameleonMultiModalProcessor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
def _apply_hf_processor_tokens_only(
......
......@@ -204,12 +204,13 @@ class DeepseekVL2MultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
processed_outputs = self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(prompt=prompt, **mm_data),
mm_kwargs,
dict(**mm_kwargs, **tok_kwargs),
)
pixel_values = processed_outputs["pixel_values"]
# split pixel values into patches corresponding to each image
......@@ -278,6 +279,7 @@ class DeepseekVL2MultiModalProcessor(
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
......@@ -290,6 +292,7 @@ class DeepseekVL2MultiModalProcessor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes,
)
......@@ -297,6 +300,7 @@ class DeepseekVL2MultiModalProcessor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes,
)
......
......@@ -794,6 +794,7 @@ class Florence2MultiModalProcessor(
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False
......@@ -828,10 +829,11 @@ class Florence2MultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs)
prompt, mm_data, mm_kwargs, tok_kwargs)
else:
hf_processor = self.info.get_hf_processor()
tokenizer = hf_processor.tokenizer
......
......@@ -153,6 +153,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
# Avoid warning from HF logger for text-only input
......@@ -164,6 +165,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
image_patches = processed_outputs.get("image_patches")
......
......@@ -259,11 +259,13 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt,
mm_data,
mm_kwargs,
tok_kwargs,
)
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
......
......@@ -481,6 +481,7 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False
......
......@@ -141,6 +141,7 @@ class GraniteSpeechMultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])
......@@ -153,6 +154,7 @@ class GraniteSpeechMultiModalProcessor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
if "audio" in mm_data:
......
......@@ -490,6 +490,7 @@ class H2OVLMultiModalProcessor(
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
......@@ -502,6 +503,7 @@ class H2OVLMultiModalProcessor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes,
)
......@@ -509,6 +511,7 @@ class H2OVLMultiModalProcessor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes,
)
......
......@@ -326,6 +326,7 @@ class Idefics3MultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Text-only input not supported in composite processor
if not (images := mm_data.get("images", [])):
......@@ -337,6 +338,7 @@ class Idefics3MultiModalProcessor(
prompt,
mm_data,
mm_kwargs,
tok_kwargs,
)
parsed_images = (self._get_data_parser().parse_mm_data({
......
......@@ -758,11 +758,13 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
......@@ -941,9 +943,10 @@ class InternVLMultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
processed_outputs = super()._call_hf_processor(prompt, mm_data,
mm_kwargs)
mm_kwargs, tok_kwargs)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
if self.info.supports_video and (
......
......@@ -296,11 +296,13 @@ class PixtralHFMultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
pixel_values = processed_outputs.get("pixel_values")
......@@ -797,6 +799,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
return_mm_hashes: bool = False,
) -> MultiModalInputs:
hf_config = self.info.get_hf_config()
......@@ -809,7 +812,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
)
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes)
tokenization_kwargs, return_mm_hashes)
mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts()
......
......@@ -286,6 +286,7 @@ class LlavaOnevisionMultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_data = dict(mm_data)
videos = mm_data.pop("videos", [])
......@@ -296,6 +297,7 @@ class LlavaOnevisionMultiModalProcessor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
# LLaVA-OneVision processor doesn't support multiple videos
......@@ -310,6 +312,7 @@ class LlavaOnevisionMultiModalProcessor(
prompt=prompt,
mm_data={},
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
images = mm_data.pop("images", [])
......@@ -319,6 +322,7 @@ class LlavaOnevisionMultiModalProcessor(
prompt=image_token * len(images),
mm_data={"images": images},
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
image_outputs = {
k: v
......@@ -334,6 +338,7 @@ class LlavaOnevisionMultiModalProcessor(
prompt=video_token,
mm_data={"videos": video},
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
pixel_values_videos.append(item_outputs["pixel_values_videos"][0])
......@@ -352,11 +357,13 @@ class LlavaOnevisionMultiModalProcessor(
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
base_result = super()._hf_processor_applies_updates(
prompt_text=prompt_text,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return base_result and mm_items.get_count("video", strict=False) == 0
......
......@@ -260,6 +260,7 @@ class MiniCPMOMultiModalProcessor(
self,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
if (audios := mm_data.get("audios")) is None:
return {}
......@@ -276,9 +277,9 @@ class MiniCPMOMultiModalProcessor(
prompts=[self.info.audio_pattern] * len(parsed_audios),
mm_data={"audios": [[audio] for audio in parsed_audios]},
mm_kwargs={
**mm_kwargs,
"chunk_input": True,
**mm_kwargs, "chunk_input": True
},
tok_kwargs=tok_kwargs,
out_keys={"audio_features", "audio_feature_lens"},
)
......@@ -302,10 +303,11 @@ class MiniCPMOMultiModalProcessor(
self,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
return {
**super().process_mm_inputs(mm_data, mm_kwargs),
**self.process_audios(mm_data, mm_kwargs),
**super().process_mm_inputs(mm_data, mm_kwargs, tok_kwargs),
**self.process_audios(mm_data, mm_kwargs, tok_kwargs),
}
def _get_prompt_updates(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment