Unverified Commit d8cf819a authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[Core] [Bugfix] [Multimodal] Fix multimodal profiling and generation for SFT/PTQed models (#20058)


Signed-off-by: default avatarKyle Sayers <kylesayrs@gmail.com>
parent 551ef163
...@@ -534,6 +534,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -534,6 +534,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
self, self,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]: ) -> Mapping[str, NestedTensors]:
if (images := mm_data.get("images")) is None: if (images := mm_data.get("images")) is None:
return {} return {}
...@@ -550,6 +551,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -550,6 +551,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompts=[self.info.image_pattern] * len(parsed_images), prompts=[self.info.image_pattern] * len(parsed_images),
mm_data={"images": [[image] for image in parsed_images]}, mm_data={"images": [[image] for image in parsed_images]},
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
) )
...@@ -563,6 +565,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -563,6 +565,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
self, self,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]: ) -> Mapping[str, NestedTensors]:
if (videos := mm_data.get("videos")) is None: if (videos := mm_data.get("videos")) is None:
return {} return {}
...@@ -586,6 +589,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -586,6 +589,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
"max_slice_nums": "max_slice_nums":
self.info.get_video_max_slice_num(), self.info.get_video_max_slice_num(),
}, },
tok_kwargs=tok_kwargs,
out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
) )
...@@ -601,10 +605,11 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -601,10 +605,11 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
self, self,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]: ) -> Mapping[str, NestedTensors]:
return { return {
**self.process_images(mm_data, mm_kwargs), **self.process_images(mm_data, mm_kwargs, tok_kwargs),
**self.process_videos(mm_data, mm_kwargs), **self.process_videos(mm_data, mm_kwargs, tok_kwargs),
} }
def _base_call_hf_processor( def _base_call_hf_processor(
...@@ -612,6 +617,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -612,6 +617,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompts: list[str], prompts: list[str],
mm_data: Mapping[str, Sequence[object]], mm_data: Mapping[str, Sequence[object]],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
*, *,
out_keys: set[str], out_keys: set[str],
) -> dict[str, NestedTensors]: ) -> dict[str, NestedTensors]:
...@@ -621,6 +627,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -621,6 +627,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt=prompts, # type: ignore prompt=prompts, # type: ignore
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
) )
else: else:
inputs = defaultdict[str, list[torch.Tensor]](list) inputs = defaultdict[str, list[torch.Tensor]](list)
...@@ -633,6 +640,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -633,6 +640,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
for k, v in mm_data.items() for k, v in mm_data.items()
}, },
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
) )
for k, v in inputs_one.items(): for k, v in inputs_one.items():
...@@ -646,11 +654,12 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -646,11 +654,12 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
input_ids = torch.tensor([tokenizer.encode(prompt)]) input_ids = torch.tensor([tokenizer.encode(prompt, **tok_kwargs)])
mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs) mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs, tok_kwargs)
return BatchFeature({ return BatchFeature({
"input_ids": input_ids, "input_ids": input_ids,
...@@ -662,6 +671,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -662,6 +671,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt_text: str, prompt_text: str,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool: ) -> bool:
return False return False
......
...@@ -113,11 +113,13 @@ class MiniMaxVL01MultiModalProcessor( ...@@ -113,11 +113,13 @@ class MiniMaxVL01MultiModalProcessor(
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
processed_outputs = super()._call_hf_processor( processed_outputs = super()._call_hf_processor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
) )
pixel_values = processed_outputs.get("pixel_values") pixel_values = processed_outputs.get("pixel_values")
......
...@@ -228,11 +228,13 @@ class Mistral3MultiModalProcessor( ...@@ -228,11 +228,13 @@ class Mistral3MultiModalProcessor(
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
processed_outputs = super()._call_hf_processor( processed_outputs = super()._call_hf_processor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
) )
pixel_values = processed_outputs.get("pixel_values") pixel_values = processed_outputs.get("pixel_values")
......
...@@ -166,10 +166,11 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] ...@@ -166,10 +166,11 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> MultiModalEncDecInputs: ) -> MultiModalEncDecInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes) tokenization_kwargs, return_mm_hashes)
image_token_id = self.info.get_hf_config().image_token_index image_token_id = self.info.get_hf_config().image_token_index
# Check that the number of image tokens in the decoder prompt matches # Check that the number of image tokens in the decoder prompt matches
...@@ -239,6 +240,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] ...@@ -239,6 +240,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
if mm_data: if mm_data:
...@@ -247,7 +249,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] ...@@ -247,7 +249,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
for img in mm_data["images"] for img in mm_data["images"]
] ]
processed_outputs = super()._call_hf_processor( processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs) prompt, mm_data, mm_kwargs, tok_kwargs)
processed_outputs["num_tiles"] = torch.tensor(num_tiles) processed_outputs["num_tiles"] = torch.tensor(num_tiles)
for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"): for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"):
processed_outputs[k] = processed_outputs[k].squeeze(0) processed_outputs[k] = processed_outputs[k].squeeze(0)
......
...@@ -574,6 +574,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] ...@@ -574,6 +574,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
...@@ -583,6 +584,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo] ...@@ -583,6 +584,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
) )
processor = self.info.get_hf_processor(**mm_kwargs) processor = self.info.get_hf_processor(**mm_kwargs)
......
...@@ -335,6 +335,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): ...@@ -335,6 +335,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
if not mm_data: if not mm_data:
# Avoid warning from HF logger for text-only input # Avoid warning from HF logger for text-only input
...@@ -346,6 +347,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]): ...@@ -346,6 +347,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
) )
hf_processor = self.info.get_hf_processor() hf_processor = self.info.get_hf_processor()
......
...@@ -121,6 +121,7 @@ class PaliGemmaMultiModalProcessor( ...@@ -121,6 +121,7 @@ class PaliGemmaMultiModalProcessor(
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
if not mm_data: if not mm_data:
...@@ -131,6 +132,7 @@ class PaliGemmaMultiModalProcessor( ...@@ -131,6 +132,7 @@ class PaliGemmaMultiModalProcessor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
) )
def _get_mm_fields_config( def _get_mm_fields_config(
...@@ -191,10 +193,11 @@ class PaliGemmaMultiModalProcessor( ...@@ -191,10 +193,11 @@ class PaliGemmaMultiModalProcessor(
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> MultiModalInputs: ) -> MultiModalInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs, mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes) tokenization_kwargs, return_mm_hashes)
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
......
...@@ -376,11 +376,13 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -376,11 +376,13 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
processed_outputs = super()._call_hf_processor( processed_outputs = super()._call_hf_processor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
) )
input_ids = processed_outputs["input_ids"] input_ids = processed_outputs["input_ids"]
......
...@@ -762,6 +762,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ...@@ -762,6 +762,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
if not mm_data: if not mm_data:
prompt_ids = self.info.get_tokenizer().encode(prompt) prompt_ids = self.info.get_tokenizer().encode(prompt)
...@@ -773,7 +774,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ...@@ -773,7 +774,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
mm_data['audios'] = [(data, sr) for data in audio_data] mm_data['audios'] = [(data, sr) for data in audio_data]
processed_outputs = super()._call_hf_processor(prompt, mm_data, processed_outputs = super()._call_hf_processor(prompt, mm_data,
mm_kwargs) mm_kwargs, tok_kwargs)
num_img_tokens = [ num_img_tokens = [
self.info.get_num_image_tokens(image_width=img_size[0], self.info.get_num_image_tokens(image_width=img_size[0],
......
...@@ -237,6 +237,7 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): ...@@ -237,6 +237,7 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
dummy_text = self.get_dummy_text(mm_counts) dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
dummy_images = dummy_mm_data.get("image", []) dummy_images = dummy_mm_data.get("image", [])
tokenization_kwargs = {"truncation": False}
request = ChatCompletionRequest(messages=[ request = ChatCompletionRequest(messages=[
UserMessage(content=[ UserMessage(content=[
...@@ -247,7 +248,9 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): ...@@ -247,7 +248,9 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
res = tokenizer.mistral.encode_chat_completion(request) res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens dummy_tokens = res.tokens
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data) return ProcessorInputs(prompt=dummy_tokens,
mm_data=dummy_mm_data,
tokenization_kwargs=tokenization_kwargs)
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
...@@ -297,6 +300,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] ...@@ -297,6 +300,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*, *,
return_mm_hashes: bool, return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
...@@ -309,6 +313,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] ...@@ -309,6 +313,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
......
...@@ -92,6 +92,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): ...@@ -92,6 +92,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> MultiModalInputs: ) -> MultiModalInputs:
mm_kwargs = {} mm_kwargs = {}
......
...@@ -244,6 +244,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ...@@ -244,6 +244,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
mm_data = dict(mm_data) mm_data = dict(mm_data)
audios = mm_data.pop("audios", []) audios = mm_data.pop("audios", [])
...@@ -258,6 +259,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ...@@ -258,6 +259,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
) )
input_features = hf_inputs.pop('input_features', None) input_features = hf_inputs.pop('input_features', None)
...@@ -453,6 +455,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ...@@ -453,6 +455,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*, *,
enable_hf_prompt_update: bool, enable_hf_prompt_update: bool,
) -> tuple[list[int], MultiModalKwargs, bool]: ) -> tuple[list[int], MultiModalKwargs, bool]:
...@@ -465,6 +468,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ...@@ -465,6 +468,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
prompt_text=prompt, prompt_text=prompt,
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
) )
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
prompt_ids = encode_tokens(tokenizer, prompt) prompt_ids = encode_tokens(tokenizer, prompt)
...@@ -474,6 +478,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ...@@ -474,6 +478,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
mm_kwargs = self._apply_hf_processor_mm_only( mm_kwargs = self._apply_hf_processor_mm_only(
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
) )
return prompt_ids, mm_kwargs, False return prompt_ids, mm_kwargs, False
...@@ -482,6 +487,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ...@@ -482,6 +487,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> MultiModalKwargs: ) -> MultiModalKwargs:
""" """
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`. Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
...@@ -498,6 +504,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor( ...@@ -498,6 +504,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
) )
return mm_kwargs return mm_kwargs
......
...@@ -150,6 +150,7 @@ class Qwen2AudioMultiModalProcessor( ...@@ -150,6 +150,7 @@ class Qwen2AudioMultiModalProcessor(
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, Any], mm_kwargs: Mapping[str, Any],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
# NOTE - we rename audios -> audio in mm data because transformers has # NOTE - we rename audios -> audio in mm data because transformers has
# deprecated audios for the qwen2audio processor and will remove # deprecated audios for the qwen2audio processor and will remove
...@@ -174,6 +175,7 @@ class Qwen2AudioMultiModalProcessor( ...@@ -174,6 +175,7 @@ class Qwen2AudioMultiModalProcessor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
) )
def _get_mm_fields_config( def _get_mm_fields_config(
......
...@@ -1027,11 +1027,13 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ...@@ -1027,11 +1027,13 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
mm_kwargs = self.info._get_image_processor_kwargs(**mm_kwargs)
return self.info.ctx.call_hf_processor( return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs), self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data), dict(text=prompt, **mm_data),
self.info._get_image_processor_kwargs(**mm_kwargs), dict(**mm_kwargs, **tok_kwargs),
) )
def _get_prompt_updates( def _get_prompt_updates(
......
...@@ -580,6 +580,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): ...@@ -580,6 +580,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
# Drops anything between <img>/</img> tags; encoding with the tokenizer # Drops anything between <img>/</img> tags; encoding with the tokenizer
# will automatically add the image pads for the context. # will automatically add the image pads for the context.
...@@ -600,6 +601,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): ...@@ -600,6 +601,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
) )
def _hf_processor_applies_updates( def _hf_processor_applies_updates(
...@@ -607,6 +609,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): ...@@ -607,6 +609,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
prompt_text: str, prompt_text: str,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool: ) -> bool:
return False return False
......
...@@ -534,11 +534,13 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -534,11 +534,13 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]: ) -> Mapping[str, NestedTensors]:
processed_outputs = super()._call_hf_processor( processed_outputs = super()._call_hf_processor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
) )
hf_processor = self.info.get_hf_processor(**mm_kwargs) hf_processor = self.info.get_hf_processor(**mm_kwargs)
......
...@@ -144,6 +144,7 @@ class UltravoxMultiModalProcessor( ...@@ -144,6 +144,7 @@ class UltravoxMultiModalProcessor(
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
# Text-only input not supported in composite processor # Text-only input not supported in composite processor
if not mm_data.get("audios", []): if not mm_data.get("audios", []):
...@@ -165,10 +166,15 @@ class UltravoxMultiModalProcessor( ...@@ -165,10 +166,15 @@ class UltravoxMultiModalProcessor(
item_processor_data = dict(**mm_data, audios=audios) item_processor_data = dict(**mm_data, audios=audios)
# some tokenizer kwargs are incompatible with UltravoxProcessor
tok_kwargs.pop("padding", None)
tok_kwargs.pop("truncation", None)
output = super()._call_hf_processor( output = super()._call_hf_processor(
prompt=prompt, prompt=prompt,
mm_data=item_processor_data, mm_data=item_processor_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
) )
output['audio_features'] = output.pop('audio_values') output['audio_features'] = output.pop('audio_values')
......
...@@ -700,9 +700,10 @@ class WhisperMultiModalProcessor( ...@@ -700,9 +700,10 @@ class WhisperMultiModalProcessor(
prompt: str, prompt: str,
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
if mm_data: if mm_data:
feature_extractor = self.info.get_feature_extractor(**mm_kwargs) feature_extractor = self.info.get_feature_extractor()
mm_data = dict(audio=mm_data.pop("audios")) mm_data = dict(audio=mm_data.pop("audios"))
mm_kwargs = dict( mm_kwargs = dict(
**mm_kwargs, **mm_kwargs,
...@@ -712,6 +713,7 @@ class WhisperMultiModalProcessor( ...@@ -712,6 +713,7 @@ class WhisperMultiModalProcessor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
) )
if "labels" in processed_outputs: if "labels" in processed_outputs:
processed_outputs["input_ids"] = processed_outputs.pop("labels") processed_outputs["input_ids"] = processed_outputs.pop("labels")
......
...@@ -1267,6 +1267,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1267,6 +1267,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# This refers to the data to be passed to HF processor. # This refers to the data to be passed to HF processor.
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> "BatchFeature": ) -> "BatchFeature":
""" """
Call the HF processor on the prompt text and Call the HF processor on the prompt text and
...@@ -1275,7 +1276,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1275,7 +1276,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return self.info.ctx.call_hf_processor( return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs), self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data), dict(text=prompt, **mm_data),
mm_kwargs, dict(**mm_kwargs, **tok_kwargs),
) )
def _hf_processor_applies_updates( def _hf_processor_applies_updates(
...@@ -1283,6 +1284,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1283,6 +1284,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_text: str, prompt_text: str,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool: ) -> bool:
""" """
Return whether the HF processor applies prompt updates. Return whether the HF processor applies prompt updates.
...@@ -1300,6 +1302,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1300,6 +1302,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_text: str, prompt_text: str,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]: ) -> tuple[list[int], MultiModalKwargs, bool]:
""" """
Apply the HF processor on the prompt text and multi-modal data Apply the HF processor on the prompt text and multi-modal data
...@@ -1313,6 +1316,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1313,6 +1316,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt=prompt_text, prompt=prompt_text,
mm_data=processor_data, mm_data=processor_data,
mm_kwargs=hf_processor_mm_kwargs, mm_kwargs=hf_processor_mm_kwargs,
tok_kwargs=tokenization_kwargs,
) )
processed_data.update(passthrough_data) processed_data.update(passthrough_data)
...@@ -1327,11 +1331,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1327,11 +1331,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_text=prompt_text, prompt_text=prompt_text,
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
) )
return prompt_ids, mm_kwargs, is_update_applied return prompt_ids, mm_kwargs, is_update_applied
def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]: def _apply_hf_processor_text_only(
self, prompt_text: str,
tokenization_kwargs: Mapping[str, object]) -> list[int]:
""" """
Apply the HF processor on the prompt text only. Apply the HF processor on the prompt text only.
...@@ -1343,6 +1350,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1343,6 +1350,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_text=prompt_text, prompt_text=prompt_text,
mm_items=MultiModalDataItems({}), mm_items=MultiModalDataItems({}),
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
tokenization_kwargs=tokenization_kwargs,
) )
return prompt_ids return prompt_ids
...@@ -1368,6 +1376,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1368,6 +1376,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> MultiModalKwargs: ) -> MultiModalKwargs:
""" """
Apply the HF processor on the multi-modal data only. Apply the HF processor on the multi-modal data only.
...@@ -1383,6 +1392,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1383,6 +1392,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts), prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
) )
return mm_kwargs return mm_kwargs
...@@ -1392,6 +1402,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1392,6 +1402,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*, *,
enable_hf_prompt_update: bool, enable_hf_prompt_update: bool,
) -> tuple[list[int], MultiModalKwargs, bool]: ) -> tuple[list[int], MultiModalKwargs, bool]:
...@@ -1412,15 +1423,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1412,15 +1423,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_text=prompt, prompt_text=prompt,
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
) )
prompt_ids = self._apply_hf_processor_text_only(prompt) prompt_ids = self._apply_hf_processor_text_only(
prompt, tokenization_kwargs)
else: else:
prompt_ids = self._apply_hf_processor_tokens_only(prompt) prompt_ids = self._apply_hf_processor_tokens_only(prompt)
mm_kwargs = self._apply_hf_processor_mm_only( mm_kwargs = self._apply_hf_processor_mm_only(
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
) )
return prompt_ids, mm_kwargs, False return prompt_ids, mm_kwargs, False
...@@ -1430,14 +1444,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1430,14 +1444,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
cache: ProcessingCache, cache: ProcessingCache,
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[ ) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[
str, list[object]]]: str, list[object]]]:
model_id = self.info.model_id model_id = self.info.model_id
mm_cache_items = { mm_cache_items = {
modality: [ modality: [
cache.get_item(model_id, modality, item, cache.get_item(
hf_processor_mm_kwargs) for item in items model_id, modality, item,
dict(**hf_processor_mm_kwargs, **tokenization_kwargs))
for item in items
] ]
for modality, items in mm_data_items.items() for modality, items in mm_data_items.items()
} }
...@@ -1457,10 +1474,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1457,10 +1474,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return mm_cache_items, mm_missing_data return mm_cache_items, mm_missing_data
def _hash_mm_items( def _hash_mm_items(
self, self, mm_items: MultiModalDataItems,
mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object]) -> MultiModalHashes:
) -> MultiModalHashes:
"""Create MM hashes to be returned (only used in V1).""" """Create MM hashes to be returned (only used in V1)."""
model_id = self.info.model_id model_id = self.info.model_id
...@@ -1468,7 +1484,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1468,7 +1484,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
modality: [ modality: [
MultiModalHasher.hash_kwargs(model_id=model_id, MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: item}, **{modality: item},
**hf_processor_mm_kwargs) **hf_processor_mm_kwargs,
**tokenization_kwargs)
for item in items for item in items
] ]
for modality, items in mm_items.items() for modality, items in mm_items.items()
...@@ -1513,6 +1530,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1513,6 +1530,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*, *,
return_mm_hashes: bool, return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
...@@ -1524,10 +1542,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1524,10 +1542,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt=prompt, prompt=prompt,
mm_items=mm_data_items, mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
enable_hf_prompt_update=True, enable_hf_prompt_update=True,
) )
mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs) mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs)
if return_mm_hashes else None) if return_mm_hashes else None)
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
...@@ -1537,6 +1557,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1537,6 +1557,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*, *,
return_mm_hashes: bool, return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
...@@ -1552,6 +1573,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1552,6 +1573,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
...@@ -1562,6 +1584,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1562,6 +1584,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
cache=cache, cache=cache,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
) )
# NOTE: `prompt` does not correspond to `mm_missing_data_items`, # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
...@@ -1575,6 +1598,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1575,6 +1598,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt=prompt, prompt=prompt,
mm_items=self._to_mm_items(mm_missing_data), mm_items=self._to_mm_items(mm_missing_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
enable_hf_prompt_update=False, enable_hf_prompt_update=False,
) )
...@@ -1783,6 +1807,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1783,6 +1807,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
...@@ -1800,6 +1825,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1800,6 +1825,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
""" """
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
if tokenization_kwargs is None:
tokenization_kwargs = {}
( (
prompt_ids, prompt_ids,
mm_kwargs, mm_kwargs,
...@@ -1809,9 +1837,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1809,9 +1837,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt, prompt,
mm_items, mm_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
# NOTE: tokenization_kwargs are not required to init processor
prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates( prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
...@@ -1892,6 +1922,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1892,6 +1922,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> MultiModalEncDecInputs: ) -> MultiModalEncDecInputs:
""" """
...@@ -1906,6 +1937,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1906,6 +1937,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
encoder_prompt, encoder_prompt,
mm_data, mm_data,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs,
return_mm_hashes, return_mm_hashes,
) )
......
...@@ -30,6 +30,7 @@ class ProcessorInputs: ...@@ -30,6 +30,7 @@ class ProcessorInputs:
prompt: Union[str, list[int]] prompt: Union[str, list[int]]
mm_data: MultiModalDataDict mm_data: MultiModalDataDict
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
class DummyEncoderData(NamedTuple): class DummyEncoderData(NamedTuple):
...@@ -90,8 +91,11 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -90,8 +91,11 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
""" """
dummy_text = self.get_dummy_text(mm_counts) dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
tokenization_kwargs = {"truncation": False}
return ProcessorInputs(prompt=dummy_text, mm_data=dummy_mm_data) return ProcessorInputs(prompt=dummy_text,
mm_data=dummy_mm_data,
tokenization_kwargs=tokenization_kwargs)
def _get_dummy_audios( def _get_dummy_audios(
self, self,
...@@ -170,6 +174,7 @@ class MultiModalProfiler(Generic[_I]): ...@@ -170,6 +174,7 @@ class MultiModalProfiler(Generic[_I]):
prompt=processor_inputs.prompt, prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data, mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
) )
def _get_mm_num_tokens( def _get_mm_num_tokens(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment