Unverified Commit b844b99a authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Enable tokenized inputs for merged multi-modal processor (#11900)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent c3cf54dd
...@@ -649,7 +649,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): ...@@ -649,7 +649,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
) )
def _test_processing_cache_correctness( def _test_processing_correctness(
model_id: str, model_id: str,
modalities: dict[str, bool], modalities: dict[str, bool],
hit_rate: float, hit_rate: float,
...@@ -691,6 +691,7 @@ def _test_processing_cache_correctness( ...@@ -691,6 +691,7 @@ def _test_processing_cache_correctness(
baseline_processor = factories.build_processor(ctx, cache=None) baseline_processor = factories.build_processor(ctx, cache=None)
cached_processor = factories.build_processor(ctx, cache=cache) cached_processor = factories.build_processor(ctx, cache=cache)
dummy_inputs = baseline_processor.dummy_inputs dummy_inputs = baseline_processor.dummy_inputs
tokenizer = baseline_processor.info.get_tokenizer()
rng = np.random.RandomState(0) rng = np.random.RandomState(0)
...@@ -747,7 +748,25 @@ def _test_processing_cache_correctness( ...@@ -747,7 +748,25 @@ def _test_processing_cache_correctness(
) )
assert baseline_result == cached_result, ( assert baseline_result == cached_result, (
f"Failed ({batch_idx=}, {mm_data=})") f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
baseline_tokenized_result = baseline_processor.apply(
tokenizer.encode(prompt),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert baseline_result == baseline_tokenized_result, (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
cached_tokenized_result = cached_processor.apply(
tokenizer.encode(prompt),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert cached_result == cached_tokenized_result, (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
# yapf: disable # yapf: disable
...@@ -771,14 +790,14 @@ def _test_processing_cache_correctness( ...@@ -771,14 +790,14 @@ def _test_processing_cache_correctness(
@pytest.mark.parametrize("num_batches", [32]) @pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0]) @pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable # yapf: enable
def test_processing_cache_correctness( def test_processing_correctness(
model_id: str, model_id: str,
modalities: dict[str, bool], modalities: dict[str, bool],
hit_rate: float, hit_rate: float,
num_batches: int, num_batches: int,
simplify_rate: float, simplify_rate: float,
): ):
_test_processing_cache_correctness( _test_processing_correctness(
model_id, model_id,
modalities, modalities,
hit_rate=hit_rate, hit_rate=hit_rate,
...@@ -795,7 +814,7 @@ def test_processing_cache_correctness( ...@@ -795,7 +814,7 @@ def test_processing_cache_correctness(
@pytest.mark.parametrize("num_batches", [32]) @pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0]) @pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable # yapf: enable
def test_processing_cache_correctness_phi3v( def test_processing_correctness_phi3v(
model_id: str, model_id: str,
modalities: dict[str, bool], modalities: dict[str, bool],
hit_rate: float, hit_rate: float,
...@@ -809,7 +828,7 @@ def test_processing_cache_correctness_phi3v( ...@@ -809,7 +828,7 @@ def test_processing_cache_correctness_phi3v(
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True) AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
_test_processing_cache_correctness( _test_processing_correctness(
model_id, model_id,
modalities, modalities,
hit_rate=hit_rate, hit_rate=hit_rate,
......
...@@ -44,13 +44,13 @@ class TokensPrompt(TypedDict): ...@@ -44,13 +44,13 @@ class TokensPrompt(TypedDict):
multi_modal_data: NotRequired["MultiModalDataDict"] multi_modal_data: NotRequired["MultiModalDataDict"]
""" """
DEPRECATED: Optional multi-modal data to pass to the model, Optional multi-modal data to pass to the model,
if the model supports it. if the model supports it.
""" """
mm_processor_kwargs: NotRequired[Dict[str, Any]] mm_processor_kwargs: NotRequired[Dict[str, Any]]
""" """
DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them. to pass the mm_processor_kwargs to each of them.
......
...@@ -279,10 +279,6 @@ class InputPreprocessor: ...@@ -279,10 +279,6 @@ class InputPreprocessor:
mm_processor = self.mm_registry.create_processor( mm_processor = self.mm_registry.create_processor(
self.model_config, tokenizer) self.model_config, tokenizer)
if isinstance(prompt, list):
logger.warning("Passing `multi_modal_data` in TokensPrompt is"
"deprecated and will be removed in a future update")
prompt = tokenizer.decode(prompt)
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
......
...@@ -441,6 +441,24 @@ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]): ...@@ -441,6 +441,24 @@ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
# HF processor always adds placeholders even when there's no image
tokenizer = self.info.get_tokenizer()
prompt_ids = tokenizer.encode(prompt)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
...@@ -469,11 +487,11 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): ...@@ -469,11 +487,11 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
def apply( def apply(
self, self,
prompt_text: str, prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2: ) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <image> tokens should be considered as placeholders, # Only <image> tokens should be considered as placeholders,
# so we ignore the trailing bos_token # so we ignore the trailing bos_token
......
...@@ -99,6 +99,34 @@ class ChameleonDummyInputsBuilder( ...@@ -99,6 +99,34 @@ class ChameleonDummyInputsBuilder(
class ChameleonMultiModalProcessor( class ChameleonMultiModalProcessor(
BaseMultiModalProcessor[ChameleonProcessingInfo]): BaseMultiModalProcessor[ChameleonProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if not mm_data:
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
# HF processor adds sep token for chat mode
tokenizer = self.info.get_tokenizer()
sep_token_id: int = \
tokenizer.vocab[tokenizer.sep_token] # type: ignore
return prompt_tokens + [sep_token_id]
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
...@@ -128,11 +156,11 @@ class ChameleonMultiModalProcessor( ...@@ -128,11 +156,11 @@ class ChameleonMultiModalProcessor(
def apply( def apply(
self, self,
prompt_text: str, prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2: ) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <image> tokens should be considered as placeholders, # Only <image> tokens should be considered as placeholders,
# so we ignore the image_start_token and image_end_token # so we ignore the image_start_token and image_end_token
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict) TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -149,14 +149,10 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ...@@ -149,14 +149,10 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
if not mm_data: if not mm_data:
# Avoid warning from HF logger for text-only input # Avoid warning from HF logger for text-only input
# Input_ids format: bos_token_id + prompt_token_ids + boa_token_id prompt_ids = self.info.get_tokenizer().encode(prompt)
# Tokenizer won't add boa_token_id by default, we add it manually. prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
tokenizer = self.info.get_tokenizer()
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
processed_outputs = super()._call_hf_processor( processed_outputs = super()._call_hf_processor(
...@@ -181,6 +177,16 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ...@@ -181,6 +177,16 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
return processed_outputs return processed_outputs
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
# HF processor adds boa_token_id
tokenizer = self.info.get_tokenizer()
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
return prompt_tokens + [boa_token_id]
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
...@@ -223,11 +229,11 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ...@@ -223,11 +229,11 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
def apply( def apply(
self, self,
prompt_text: str, prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2: ) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only |SPEAKER| (image) tokens should be considered as placeholders, # Only |SPEAKER| (image) tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id # so we ignore the trailing bos_token_id
......
...@@ -39,13 +39,13 @@ class SupportsMultiModal(Protocol): ...@@ -39,13 +39,13 @@ class SupportsMultiModal(Protocol):
The output embeddings must be one of the following formats: The output embeddings must be one of the following formats:
- A list or tuple of 2D tensors, where each tensor corresponds to - A list or tuple of 2D tensors, where each tensor corresponds to
each input multimodal data item (e.g, image). each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors. - A single 3D tensor, with the batch dimension grouping the 2D tensors.
Note: Note:
The returned multimodal embeddings must be in the same order as The returned multimodal embeddings must be in the same order as
the appearances of their corresponding multimodal data item in the the appearances of their corresponding multimodal data item in the
input prompt. input prompt.
""" """
... ...
......
...@@ -724,7 +724,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -724,7 +724,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def apply( def apply(
self, self,
prompt_text: str, prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2: ) -> MultiModalInputsV2:
...@@ -737,7 +737,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -737,7 +737,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
image_height=-1, image_height=-1,
) )
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts() mm_item_counts = mm_items.get_all_counts()
...@@ -760,7 +760,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -760,7 +760,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
) )
]) ])
prompt_ids, prompt_text, _ = self._apply_prompt_replacements( prompt_ids, prompt, _ = self._apply_prompt_replacements(
result["prompt_token_ids"], result["prompt_token_ids"],
mantis_mm_repls, mantis_mm_repls,
mm_item_counts, mm_item_counts,
...@@ -788,7 +788,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -788,7 +788,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
return MultiModalInputsV2( return MultiModalInputsV2(
type="multimodal", type="multimodal",
prompt=prompt_text, prompt=prompt,
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholder_ranges, mm_placeholders=mm_placeholder_ranges,
......
...@@ -481,11 +481,11 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -481,11 +481,11 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
def apply( def apply(
self, self,
prompt_text: str, prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2: ) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs) result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <|image|> tokens should be considered as placeholders, # Only <|image|> tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id # so we ignore the trailing bos_token_id
......
...@@ -138,12 +138,8 @@ class UltravoxMultiModalProcessor( ...@@ -138,12 +138,8 @@ class UltravoxMultiModalProcessor(
) -> BatchFeature: ) -> BatchFeature:
# Text-only input not supported in composite processor # Text-only input not supported in composite processor
if not mm_data: if not mm_data:
tokenizer = self.info.get_tokenizer() prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
prompt_ids = tokenizer.encode(
prompt,
add_special_tokens=False, # type: ignore
)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
mm_data = dict(mm_data) mm_data = dict(mm_data)
...@@ -188,6 +184,16 @@ class UltravoxMultiModalProcessor( ...@@ -188,6 +184,16 @@ class UltravoxMultiModalProcessor(
) )
return BatchFeature(combined_outputs) return BatchFeature(combined_outputs)
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
# HF processor omits bos_token_id by setting add_special_tokens=False
tokenizer = self.info.get_tokenizer()
assert prompt_tokens[0] == tokenizer.bos_token_id
return prompt_tokens[1:]
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
......
...@@ -725,15 +725,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -725,15 +725,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs, mm_kwargs,
) )
def _apply_hf_processor( def _apply_hf_processor_text_mm(
self, self,
prompt_text: str, prompt_text: str,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]: ) -> tuple[list[int], MultiModalKwargs]:
""" """
Wrapper of :meth:`_call_hf_processor` that applies Apply the HF processor on the prompt text and multi-modal data
additional pre-processing and post-processing. together.
""" """
processor_data, passthrough_data = self._get_hf_mm_data(mm_items) processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
...@@ -753,40 +753,93 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -753,40 +753,93 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return prompt_ids, mm_kwargs return prompt_ids, mm_kwargs
def _apply_hf_processor_missing( def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
self,
prompt_text: str,
mm_missing_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
):
""" """
Apply the HF processor on the full prompt text, but only on the Apply the HF processor on the prompt text only.
multi-modal data that are missing from the cache.
Note: Since HF processor requires that text and multi-modal items
We pass prompt text and multi-modal data into the HF processor correspond to each other, we create dummy multi-modal items
in separate calls to avoid HF prompt replacement being done for to go along with the text.
cached items; instead, we rely on our own prompt replacement logic
(:meth:`_get_prompt_replacements`) for the full text.
""" """
mm_missing_counts = mm_missing_data_items.get_all_counts() prompt_ids, _ = self._apply_hf_processor_text_mm(
prompt_ids, _ = self._apply_hf_processor(
prompt_text=prompt_text, prompt_text=prompt_text,
mm_items=MultiModalDataItems({}), mm_items=MultiModalDataItems({}),
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
# Some HF processors (e.g. Qwen2-VL) expect corresponding return prompt_ids
# multi-modal tokens to be in the prompt text
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
"""
Apply the HF processor on the prompt tokens only.
Most HF processors accept prompt text but not prompt tokens.
If the HF processor adds or removes tokens that are not related to
multi-modal data, you should override this method so it is consistent
with the output of :meth:`_apply_hf_processor_text_only` on the
corresponding text.
"""
return prompt_tokens
def _apply_hf_processor_mm_only(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalKwargs:
"""
Apply the HF processor on the multi-modal data only.
Since HF processor requires that text and multi-modal items
correspond to each other, we generate dummy text using
:class:`DummyInputsBuilder` to go along with the multi-modal data.
"""
mm_counts = mm_items.get_all_counts()
dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs( dummy_inputs = self.dummy_inputs.get_dummy_processor_inputs(
self.info.ctx.model_config.max_model_len, self.info.ctx.model_config.max_model_len,
mm_missing_counts, mm_counts,
) )
_, mm_missing_kwargs = self._apply_hf_processor( _, mm_kwargs = self._apply_hf_processor_text_mm(
prompt_text=dummy_inputs.prompt_text, prompt_text=dummy_inputs.prompt_text,
mm_items=mm_missing_data_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
return mm_kwargs
def _apply_hf_processor_main(
self,
prompt: Union[str, list[int]],
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
*,
enable_hf_prompt_replacement: bool,
) -> tuple[list[int], MultiModalKwargs]:
"""
Apply the HF processor on the prompt text and multi-modal data.
Note:
If :code:`enable_hf_prompt_replacement=False`, the prompt should
correspond to the multi-modal items.
"""
if isinstance(prompt, str):
if enable_hf_prompt_replacement:
return self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
prompt_ids = self._apply_hf_processor_text_only(prompt)
else:
prompt_ids = self._apply_hf_processor_tokens_only(prompt)
mm_missing_kwargs = self._apply_hf_processor_mm_only(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
) )
...@@ -794,7 +847,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -794,7 +847,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _cached_apply_hf_processor( def _cached_apply_hf_processor(
self, self,
prompt_text: str, prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs]: ) -> tuple[list[int], MultiModalKwargs]:
...@@ -807,10 +860,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -807,10 +860,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
_, passthrough_data = self._get_hf_mm_data(mm_data_items) _, passthrough_data = self._get_hf_mm_data(mm_data_items)
if cache is None or passthrough_data: if cache is None or passthrough_data:
return self._apply_hf_processor( return self._apply_hf_processor_main(
prompt_text=prompt_text, prompt=prompt,
mm_items=mm_data_items, mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_replacement=True,
) )
mm_maybe_cached_kw_items = { mm_maybe_cached_kw_items = {
...@@ -832,10 +886,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -832,10 +886,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
} }
mm_missing_data_items = self._to_mm_items(mm_missing_data) mm_missing_data_items = self._to_mm_items(mm_missing_data)
prompt_ids, mm_missing_kwargs = self._apply_hf_processor_missing( # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
prompt_text=prompt_text, # so we need to pass `enable_hf_prompt_replacement=False`
mm_missing_data_items=mm_missing_data_items, prompt_ids, mm_missing_kwargs = self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_replacement=False,
) )
mm_missing_next_idx = { mm_missing_next_idx = {
...@@ -1018,7 +1075,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1018,7 +1075,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def apply( def apply(
self, self,
prompt_text: str, prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2: ) -> MultiModalInputsV2:
...@@ -1056,7 +1113,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1056,7 +1113,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_hashes = None mm_hashes = None
prompt_ids, mm_kwargs = self._cached_apply_hf_processor( prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
prompt_text, prompt,
mm_items, mm_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
) )
...@@ -1101,12 +1158,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1101,12 +1158,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# there is no need for us to insert them # there is no need for us to insert them
if all(len(repls) == 0 for repls in mm_missing_repls.items()): if all(len(repls) == 0 for repls in mm_missing_repls.items()):
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
prompt_text = decode_tokens(tokenizer, prompt_ids) prompt = decode_tokens(tokenizer, prompt_ids)
mm_placeholders = hf_mm_placeholders mm_placeholders = hf_mm_placeholders
else: else:
( (
prompt_ids, prompt_ids,
prompt_text, prompt,
missing_mm_placeholders, missing_mm_placeholders,
) = self._apply_prompt_replacements( ) = self._apply_prompt_replacements(
prompt_ids, prompt_ids,
...@@ -1125,7 +1182,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1125,7 +1182,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return MultiModalInputsV2( return MultiModalInputsV2(
type="multimodal", type="multimodal",
prompt=prompt_text, prompt=prompt,
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes, mm_hashes=mm_hashes,
......
...@@ -137,7 +137,7 @@ class MultiModalProfiler(Generic[_I]): ...@@ -137,7 +137,7 @@ class MultiModalProfiler(Generic[_I]):
seq_len, mm_counts) seq_len, mm_counts)
return self.processor.apply( return self.processor.apply(
prompt_text=processor_inputs.prompt_text, prompt=processor_inputs.prompt_text,
mm_data=processor_inputs.mm_data, mm_data=processor_inputs.mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment