"vscode:/vscode.git/clone" did not exist on "904b87ba0564080205c51ca2cb6ba0dc0b5d3bc3"
Unverified Commit 88c3e114 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Move MM data parsing outside processor (#33408)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 92924b2d
...@@ -25,7 +25,6 @@ from vllm.utils.collection_utils import flatten_2d_lists, full_groupby ...@@ -25,7 +25,6 @@ from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from ..hasher import MultiModalHasher from ..hasher import MultiModalHasher
from ..inputs import ( from ..inputs import (
MultiModalDataDict,
MultiModalEncDecInputs, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalFieldConfig,
MultiModalHashes, MultiModalHashes,
...@@ -1013,39 +1012,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1013,39 +1012,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def __call__( def __call__(
self, self,
prompt: str, prompt: str,
mm_data: MultiModalDataDict, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
return self.apply(prompt, mm_data, hf_processor_mm_kwargs, mm_uuids=mm_uuids) return self.apply(prompt, mm_items, hf_processor_mm_kwargs, mm_uuids=mm_uuids)
def _to_mm_items(
self,
mm_data: MultiModalDataDict,
) -> MultiModalDataItems:
"""
Normalize
[`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
mm_config = self.info.ctx.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
for modality, items in mm_items.items():
if isinstance(items, (EmbeddingItems, DictEmbeddingItems)):
raise ValueError(
f"You must set `--enable-mm-embeds` to input "
f"`{modality}_embeds`"
)
for modality, items in mm_items.items():
self.info.validate_num_items(modality, len(items))
return mm_items
@abstractmethod @abstractmethod
def _get_mm_fields_config( def _get_mm_fields_config(
...@@ -1409,6 +1381,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1409,6 +1381,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
] ]
for modality, items_is_cached in mm_is_cached.items() for modality, items_is_cached in mm_is_cached.items()
} }
mm_missing_data = {} mm_missing_data = {}
for modality, idxs in mm_missing_idxs.items(): for modality, idxs in mm_missing_idxs.items():
missing_modality_data = [] missing_modality_data = []
...@@ -1423,7 +1396,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1423,7 +1396,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
missing_modality_data.append(data) missing_modality_data.append(data)
mm_missing_data[modality] = missing_modality_data mm_missing_data[modality] = missing_modality_data
return mm_is_cached, self._to_mm_items(mm_missing_data) mm_missing_items = self.info.parse_mm_data(mm_missing_data)
return mm_is_cached, mm_missing_items
def _recompute_cached_prompt_update( def _recompute_cached_prompt_update(
self, self,
...@@ -1774,7 +1749,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1774,7 +1749,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def apply( def apply(
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_data: MultiModalDataDict, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object] | None = None, tokenization_kwargs: Mapping[str, object] | None = None,
*, *,
...@@ -1797,8 +1772,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1797,8 +1772,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
if request_id is not None: if request_id is not None:
self.info.ctx.create_timing_stats(request_id) self.info.ctx.create_timing_stats(request_id)
mm_items = self._to_mm_items(mm_data)
if tokenization_kwargs is None: if tokenization_kwargs is None:
tokenization_kwargs = {} tokenization_kwargs = {}
...@@ -1843,7 +1816,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1843,7 +1816,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
def create_encoder_prompt( def create_encoder_prompt(
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_data: MultiModalDataDict, mm_items: MultiModalDataItems,
) -> str | list[int]: ) -> str | list[int]:
""" """
Create input prompt for the encoder. HF processor will be applied on Create input prompt for the encoder. HF processor will be applied on
...@@ -1854,7 +1827,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1854,7 +1827,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
def create_decoder_prompt( def create_decoder_prompt(
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_data: MultiModalDataDict, mm_items: MultiModalDataItems,
) -> str | list[int]: ) -> str | list[int]:
"""Create input prompt for the decoder.""" """Create input prompt for the decoder."""
return prompt return prompt
...@@ -1862,11 +1835,11 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1862,11 +1835,11 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _get_enc_dec_inputs( def _get_enc_dec_inputs(
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_data: MultiModalDataDict, mm_items: MultiModalDataItems,
encoder_inputs: MultiModalInputs, encoder_inputs: MultiModalInputs,
): ):
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data) decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_items)
if isinstance(decoder_prompt_raw, str): if isinstance(decoder_prompt_raw, str):
decoder_prompt_ids = tokenizer.encode( decoder_prompt_ids = tokenizer.encode(
decoder_prompt_raw, add_special_tokens=False decoder_prompt_raw, add_special_tokens=False
...@@ -1884,7 +1857,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1884,7 +1857,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
def apply( def apply(
self, self,
prompt: str | list[int], prompt: str | list[int],
mm_data: MultiModalDataDict, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object] | None = None, tokenization_kwargs: Mapping[str, object] | None = None,
*, *,
...@@ -1897,10 +1870,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1897,10 +1870,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
2. Apply the HF processor on encoder prompt. 2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs. 3. Copy the input prompt text as decoder prompt inputs.
""" """
encoder_prompt = self.create_encoder_prompt(prompt, mm_data) encoder_prompt = self.create_encoder_prompt(prompt, mm_items)
encoder_inputs = super().apply( encoder_inputs = super().apply(
encoder_prompt, encoder_prompt,
mm_data, mm_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
tokenization_kwargs, tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
...@@ -1908,6 +1881,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1908,6 +1881,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
return self._get_enc_dec_inputs( return self._get_enc_dec_inputs(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_items=mm_items,
encoder_inputs=encoder_inputs, encoder_inputs=encoder_inputs,
) )
...@@ -330,7 +330,7 @@ class MultiModalRegistry: ...@@ -330,7 +330,7 @@ class MultiModalRegistry:
) )
mm_inputs = processor.apply( mm_inputs = processor.apply(
prompt=processor_inputs.prompt, prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data, mm_items=processor_inputs.mm_items,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs, tokenization_kwargs=processor_inputs.tokenization_kwargs,
) )
......
...@@ -212,7 +212,7 @@ class InputProcessor: ...@@ -212,7 +212,7 @@ class InputProcessor:
def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems: def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
mm_processor = self.input_preprocessor._get_mm_processor() mm_processor = self.input_preprocessor._get_mm_processor()
return mm_processor.data_parser.parse_mm_data(mm_data) return mm_processor.info.parse_mm_data(mm_data)
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None: def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
if isinstance(prompt, str): if isinstance(prompt, str):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment