Unverified Commit 88c3e114 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Move MM data parsing outside processor (#33408)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 92924b2d
......@@ -25,7 +25,6 @@ from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from ..hasher import MultiModalHasher
from ..inputs import (
MultiModalDataDict,
MultiModalEncDecInputs,
MultiModalFieldConfig,
MultiModalHashes,
......@@ -1013,39 +1012,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def __call__(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
return self.apply(prompt, mm_data, hf_processor_mm_kwargs, mm_uuids=mm_uuids)
def _to_mm_items(
self,
mm_data: MultiModalDataDict,
) -> MultiModalDataItems:
"""
Normalize
[`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
mm_config = self.info.ctx.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
for modality, items in mm_items.items():
if isinstance(items, (EmbeddingItems, DictEmbeddingItems)):
raise ValueError(
f"You must set `--enable-mm-embeds` to input "
f"`{modality}_embeds`"
)
for modality, items in mm_items.items():
self.info.validate_num_items(modality, len(items))
return mm_items
return self.apply(prompt, mm_items, hf_processor_mm_kwargs, mm_uuids=mm_uuids)
@abstractmethod
def _get_mm_fields_config(
......@@ -1409,6 +1381,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
]
for modality, items_is_cached in mm_is_cached.items()
}
mm_missing_data = {}
for modality, idxs in mm_missing_idxs.items():
missing_modality_data = []
......@@ -1423,7 +1396,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
missing_modality_data.append(data)
mm_missing_data[modality] = missing_modality_data
return mm_is_cached, self._to_mm_items(mm_missing_data)
mm_missing_items = self.info.parse_mm_data(mm_missing_data)
return mm_is_cached, mm_missing_items
def _recompute_cached_prompt_update(
self,
......@@ -1774,7 +1749,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def apply(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object] | None = None,
*,
......@@ -1797,8 +1772,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
if request_id is not None:
self.info.ctx.create_timing_stats(request_id)
mm_items = self._to_mm_items(mm_data)
if tokenization_kwargs is None:
tokenization_kwargs = {}
......@@ -1843,7 +1816,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
def create_encoder_prompt(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
) -> str | list[int]:
"""
Create input prompt for the encoder. HF processor will be applied on
......@@ -1854,7 +1827,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
def create_decoder_prompt(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
) -> str | list[int]:
"""Create input prompt for the decoder."""
return prompt
......@@ -1862,11 +1835,11 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _get_enc_dec_inputs(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
encoder_inputs: MultiModalInputs,
):
tokenizer = self.info.get_tokenizer()
decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data)
decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_items)
if isinstance(decoder_prompt_raw, str):
decoder_prompt_ids = tokenizer.encode(
decoder_prompt_raw, add_special_tokens=False
......@@ -1884,7 +1857,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
def apply(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object] | None = None,
*,
......@@ -1897,10 +1870,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs.
"""
encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
encoder_prompt = self.create_encoder_prompt(prompt, mm_items)
encoder_inputs = super().apply(
encoder_prompt,
mm_data,
mm_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
......@@ -1908,6 +1881,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
return self._get_enc_dec_inputs(
prompt=prompt,
mm_data=mm_data,
mm_items=mm_items,
encoder_inputs=encoder_inputs,
)
......@@ -330,7 +330,7 @@ class MultiModalRegistry:
)
mm_inputs = processor.apply(
prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data,
mm_items=processor_inputs.mm_items,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
)
......
......@@ -212,7 +212,7 @@ class InputProcessor:
def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
mm_processor = self.input_preprocessor._get_mm_processor()
return mm_processor.data_parser.parse_mm_data(mm_data)
return mm_processor.info.parse_mm_data(mm_data)
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
if isinstance(prompt, str):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment