Unverified Commit d8cf819a authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[Core] [Bugfix] [Multimodal] Fix multimodal profiling and generation for SFT/PTQed models (#20058)


Signed-off-by: default avatarKyle Sayers <kylesayrs@gmail.com>
parent 551ef163
......@@ -534,6 +534,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
self,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
if (images := mm_data.get("images")) is None:
return {}
......@@ -550,6 +551,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompts=[self.info.image_pattern] * len(parsed_images),
mm_data={"images": [[image] for image in parsed_images]},
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
......@@ -563,6 +565,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
self,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
if (videos := mm_data.get("videos")) is None:
return {}
......@@ -586,6 +589,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
"max_slice_nums":
self.info.get_video_max_slice_num(),
},
tok_kwargs=tok_kwargs,
out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
)
......@@ -601,10 +605,11 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
self,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
return {
**self.process_images(mm_data, mm_kwargs),
**self.process_videos(mm_data, mm_kwargs),
**self.process_images(mm_data, mm_kwargs, tok_kwargs),
**self.process_videos(mm_data, mm_kwargs, tok_kwargs),
}
def _base_call_hf_processor(
......@@ -612,6 +617,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompts: list[str],
mm_data: Mapping[str, Sequence[object]],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
*,
out_keys: set[str],
) -> dict[str, NestedTensors]:
......@@ -621,6 +627,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt=prompts, # type: ignore
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
else:
inputs = defaultdict[str, list[torch.Tensor]](list)
......@@ -633,6 +640,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
for k, v in mm_data.items()
},
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
for k, v in inputs_one.items():
......@@ -646,11 +654,12 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
input_ids = torch.tensor([tokenizer.encode(prompt)])
mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs)
input_ids = torch.tensor([tokenizer.encode(prompt, **tok_kwargs)])
mm_inputs = self.process_mm_inputs(mm_data, mm_kwargs, tok_kwargs)
return BatchFeature({
"input_ids": input_ids,
......@@ -662,6 +671,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False
......
......@@ -113,11 +113,13 @@ class MiniMaxVL01MultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
pixel_values = processed_outputs.get("pixel_values")
......
......@@ -228,11 +228,13 @@ class Mistral3MultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
pixel_values = processed_outputs.get("pixel_values")
......
......@@ -166,10 +166,11 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
return_mm_hashes: bool = False,
) -> MultiModalEncDecInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes)
tokenization_kwargs, return_mm_hashes)
image_token_id = self.info.get_hf_config().image_token_index
# Check that the number of image tokens in the decoder prompt matches
......@@ -239,6 +240,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
if mm_data:
......@@ -247,7 +249,7 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
for img in mm_data["images"]
]
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs)
prompt, mm_data, mm_kwargs, tok_kwargs)
processed_outputs["num_tiles"] = torch.tensor(num_tiles)
for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"):
processed_outputs[k] = processed_outputs[k].squeeze(0)
......
......@@ -574,6 +574,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
......@@ -583,6 +584,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
processor = self.info.get_hf_processor(**mm_kwargs)
......
......@@ -335,6 +335,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
# Avoid warning from HF logger for text-only input
......@@ -346,6 +347,7 @@ class OvisMultiModalProcessor(BaseMultiModalProcessor[OvisProcessingInfo]):
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
hf_processor = self.info.get_hf_processor()
......
......@@ -121,6 +121,7 @@ class PaliGemmaMultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
if not mm_data:
......@@ -131,6 +132,7 @@ class PaliGemmaMultiModalProcessor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
def _get_mm_fields_config(
......@@ -191,10 +193,11 @@ class PaliGemmaMultiModalProcessor(
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
return_mm_hashes: bool = False,
) -> MultiModalInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
return_mm_hashes)
tokenization_kwargs, return_mm_hashes)
prompt_token_ids = mm_inputs["prompt_token_ids"]
tokenizer = self.info.get_tokenizer()
......
......@@ -376,11 +376,13 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
input_ids = processed_outputs["input_ids"]
......
......@@ -762,6 +762,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
prompt_ids = self.info.get_tokenizer().encode(prompt)
......@@ -773,7 +774,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
mm_data['audios'] = [(data, sr) for data in audio_data]
processed_outputs = super()._call_hf_processor(prompt, mm_data,
mm_kwargs)
mm_kwargs, tok_kwargs)
num_img_tokens = [
self.info.get_num_image_tokens(image_width=img_size[0],
......
......@@ -237,6 +237,7 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
dummy_images = dummy_mm_data.get("image", [])
tokenization_kwargs = {"truncation": False}
request = ChatCompletionRequest(messages=[
UserMessage(content=[
......@@ -247,7 +248,9 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
return ProcessorInputs(prompt=dummy_tokens,
mm_data=dummy_mm_data,
tokenization_kwargs=tokenization_kwargs)
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
......@@ -297,6 +300,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
......@@ -309,6 +313,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes,
)
......
......@@ -92,6 +92,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
return_mm_hashes: bool = False,
) -> MultiModalInputs:
mm_kwargs = {}
......
......@@ -244,6 +244,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])
......@@ -258,6 +259,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
input_features = hf_inputs.pop('input_features', None)
......@@ -453,6 +455,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
prompt: Union[str, list[int]],
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*,
enable_hf_prompt_update: bool,
) -> tuple[list[int], MultiModalKwargs, bool]:
......@@ -465,6 +468,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
tokenizer = self.info.get_tokenizer()
prompt_ids = encode_tokens(tokenizer, prompt)
......@@ -474,6 +478,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
mm_kwargs = self._apply_hf_processor_mm_only(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return prompt_ids, mm_kwargs, False
......@@ -482,6 +487,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> MultiModalKwargs:
"""
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
......@@ -498,6 +504,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return mm_kwargs
......
......@@ -150,6 +150,7 @@ class Qwen2AudioMultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, Any],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# NOTE - we rename audios -> audio in mm data because transformers has
# deprecated audios for the qwen2audio processor and will remove
......@@ -174,6 +175,7 @@ class Qwen2AudioMultiModalProcessor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
def _get_mm_fields_config(
......
......@@ -1027,11 +1027,13 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_kwargs = self.info._get_image_processor_kwargs(**mm_kwargs)
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
self.info._get_image_processor_kwargs(**mm_kwargs),
dict(**mm_kwargs, **tok_kwargs),
)
def _get_prompt_updates(
......
......@@ -580,6 +580,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Drops anything between <img>/</img> tags; encoding with the tokenizer
# will automatically add the image pads for the context.
......@@ -600,6 +601,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
def _hf_processor_applies_updates(
......@@ -607,6 +609,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False
......
......@@ -534,11 +534,13 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
......
......@@ -144,6 +144,7 @@ class UltravoxMultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Text-only input not supported in composite processor
if not mm_data.get("audios", []):
......@@ -165,10 +166,15 @@ class UltravoxMultiModalProcessor(
item_processor_data = dict(**mm_data, audios=audios)
# some tokenizer kwargs are incompatible with UltravoxProcessor
tok_kwargs.pop("padding", None)
tok_kwargs.pop("truncation", None)
output = super()._call_hf_processor(
prompt=prompt,
mm_data=item_processor_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
output['audio_features'] = output.pop('audio_values')
......
......@@ -700,9 +700,10 @@ class WhisperMultiModalProcessor(
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
feature_extractor = self.info.get_feature_extractor()
mm_data = dict(audio=mm_data.pop("audios"))
mm_kwargs = dict(
**mm_kwargs,
......@@ -712,6 +713,7 @@ class WhisperMultiModalProcessor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
if "labels" in processed_outputs:
processed_outputs["input_ids"] = processed_outputs.pop("labels")
......
......@@ -1267,6 +1267,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# This refers to the data to be passed to HF processor.
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> "BatchFeature":
"""
Call the HF processor on the prompt text and
......@@ -1275,7 +1276,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
mm_kwargs,
dict(**mm_kwargs, **tok_kwargs),
)
def _hf_processor_applies_updates(
......@@ -1283,6 +1284,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
"""
Return whether the HF processor applies prompt updates.
......@@ -1300,6 +1302,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
"""
Apply the HF processor on the prompt text and multi-modal data
......@@ -1313,6 +1316,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt=prompt_text,
mm_data=processor_data,
mm_kwargs=hf_processor_mm_kwargs,
tok_kwargs=tokenization_kwargs,
)
processed_data.update(passthrough_data)
......@@ -1327,11 +1331,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_text=prompt_text,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return prompt_ids, mm_kwargs, is_update_applied
def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
def _apply_hf_processor_text_only(
self, prompt_text: str,
tokenization_kwargs: Mapping[str, object]) -> list[int]:
"""
Apply the HF processor on the prompt text only.
......@@ -1343,6 +1350,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_text=prompt_text,
mm_items=MultiModalDataItems({}),
hf_processor_mm_kwargs={},
tokenization_kwargs=tokenization_kwargs,
)
return prompt_ids
......@@ -1368,6 +1376,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> MultiModalKwargs:
"""
Apply the HF processor on the multi-modal data only.
......@@ -1383,6 +1392,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return mm_kwargs
......@@ -1392,6 +1402,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt: Union[str, list[int]],
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*,
enable_hf_prompt_update: bool,
) -> tuple[list[int], MultiModalKwargs, bool]:
......@@ -1412,15 +1423,18 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
prompt_ids = self._apply_hf_processor_text_only(prompt)
prompt_ids = self._apply_hf_processor_text_only(
prompt, tokenization_kwargs)
else:
prompt_ids = self._apply_hf_processor_tokens_only(prompt)
mm_kwargs = self._apply_hf_processor_mm_only(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return prompt_ids, mm_kwargs, False
......@@ -1430,14 +1444,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
cache: ProcessingCache,
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[
str, list[object]]]:
model_id = self.info.model_id
mm_cache_items = {
modality: [
cache.get_item(model_id, modality, item,
hf_processor_mm_kwargs) for item in items
cache.get_item(
model_id, modality, item,
dict(**hf_processor_mm_kwargs, **tokenization_kwargs))
for item in items
]
for modality, items in mm_data_items.items()
}
......@@ -1457,10 +1474,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return mm_cache_items, mm_missing_data
def _hash_mm_items(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalHashes:
self, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object]) -> MultiModalHashes:
"""Create MM hashes to be returned (only used in V1)."""
model_id = self.info.model_id
......@@ -1468,7 +1484,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
modality: [
MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs)
**hf_processor_mm_kwargs,
**tokenization_kwargs)
for item in items
]
for modality, items in mm_items.items()
......@@ -1513,6 +1530,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
......@@ -1524,10 +1542,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
enable_hf_prompt_update=True,
)
mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs)
mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs)
if return_mm_hashes else None)
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
......@@ -1537,6 +1557,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
......@@ -1552,6 +1573,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes,
)
......@@ -1562,6 +1584,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
cache=cache,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
......@@ -1575,6 +1598,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt=prompt,
mm_items=self._to_mm_items(mm_missing_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
enable_hf_prompt_update=False,
)
......@@ -1783,6 +1807,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
return_mm_hashes: bool = False,
) -> MultiModalInputs:
"""
......@@ -1800,6 +1825,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
mm_items = self._to_mm_items(mm_data)
if tokenization_kwargs is None:
tokenization_kwargs = {}
(
prompt_ids,
mm_kwargs,
......@@ -1809,9 +1837,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt,
mm_items,
hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
return_mm_hashes=return_mm_hashes,
)
# NOTE: tokenization_kwargs are not required to init processor
prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......@@ -1892,6 +1922,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
return_mm_hashes: bool = False,
) -> MultiModalEncDecInputs:
"""
......@@ -1906,6 +1937,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
encoder_prompt,
mm_data,
hf_processor_mm_kwargs,
tokenization_kwargs,
return_mm_hashes,
)
......
......@@ -30,6 +30,7 @@ class ProcessorInputs:
prompt: Union[str, list[int]]
mm_data: MultiModalDataDict
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
class DummyEncoderData(NamedTuple):
......@@ -90,8 +91,11 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
"""
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
tokenization_kwargs = {"truncation": False}
return ProcessorInputs(prompt=dummy_text, mm_data=dummy_mm_data)
return ProcessorInputs(prompt=dummy_text,
mm_data=dummy_mm_data,
tokenization_kwargs=tokenization_kwargs)
def _get_dummy_audios(
self,
......@@ -170,6 +174,7 @@ class MultiModalProfiler(Generic[_I]):
prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
)
def _get_mm_num_tokens(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment