Unverified Commit 3d54bdcb authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Optimization] Streamline `InputPreprocessor` (#25702)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 6b0fcbbf
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
...@@ -13,6 +12,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry ...@@ -13,6 +12,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalUUIDDict) MultiModalInputs, MultiModalUUIDDict)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt, from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
...@@ -200,20 +200,6 @@ class InputPreprocessor: ...@@ -200,20 +200,6 @@ class InputPreprocessor:
return tokenizer.encode(prompt, **tokenization_kwargs) return tokenizer.encode(prompt, **tokenization_kwargs)
async def _tokenize_prompt_async(
self,
prompt: str,
tokenization_kwargs: Optional[dict[str, Any]] = None,
) -> list[int]:
"""
Async version of
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
"""
tokenizer = self.get_tokenizer()
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
return tokenizer.encode(prompt, **tokenization_kwargs)
def _get_mm_tokenizer(self) -> AnyTokenizer: def _get_mm_tokenizer(self) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input # while using also multi-modal input
...@@ -223,14 +209,17 @@ class InputPreprocessor: ...@@ -223,14 +209,17 @@ class InputPreprocessor:
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
return tokenizer return tokenizer
async def _get_mm_tokenizer_async(self) -> AnyTokenizer: def _get_mm_processor(self) -> BaseMultiModalProcessor:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer if not hasattr(self, "_mm_processor"):
# while using also multi-modal input tokenizer = self._get_mm_tokenizer()
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
tokenizer = self.get_tokenizer() self._mm_processor = self.mm_registry.create_processor(
return tokenizer self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
return self._mm_processor
def _process_multimodal( def _process_multimodal(
self, self,
...@@ -245,55 +234,7 @@ class InputPreprocessor: ...@@ -245,55 +234,7 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt, Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata. returning the corresponding token IDs and metadata.
""" """
tokenizer = self._get_mm_tokenizer() mm_processor = self._get_mm_processor()
mm_processor = self.mm_registry.create_processor(
self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
mm_input = mm_processor.apply(
prompt,
mm_data,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
mm_hashes = mm_input["mm_hashes"]
# Validate that all mm items have a string as their hash
if not contains_only_strings(mm_hashes):
raise ValueError(
f"mm_hashes must contain only strings, got: {mm_hashes}. "
"This is likely due to an incorrect custom implementation of "
"MultiModalProcessor.apply method.")
return mm_input
async def _process_multimodal_async(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalInputs:
"""
Async version of
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
"""
tokenizer = await self._get_mm_tokenizer_async()
mm_processor = self.mm_registry.create_processor(
self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
...@@ -340,12 +281,6 @@ class InputPreprocessor: ...@@ -340,12 +281,6 @@ class InputPreprocessor:
return embeds_inputs(prompt_embeds=prompt_embeds, return embeds_inputs(prompt_embeds=prompt_embeds,
cache_salt=parsed_content.get("cache_salt")) cache_salt=parsed_content.get("cache_salt"))
async def _process_embeds_async(
self,
parsed_content: EmbedsPrompt,
) -> EmbedsInputs:
return self._process_embeds(parsed_content)
def _truncate_inputs( def _truncate_inputs(
self, self,
inputs: list[int], inputs: list[int],
...@@ -389,33 +324,6 @@ class InputPreprocessor: ...@@ -389,33 +324,6 @@ class InputPreprocessor:
return inputs return inputs
async def _process_tokens_async(
self,
parsed_content: TokensPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs)
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = await self._process_multimodal_async(
prompt_token_ids,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
else:
inputs = token_inputs(prompt_token_ids=prompt_token_ids, )
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
def _process_text( def _process_text(
self, self,
parsed_content: TextPrompt, parsed_content: TextPrompt,
...@@ -449,39 +357,6 @@ class InputPreprocessor: ...@@ -449,39 +357,6 @@ class InputPreprocessor:
return inputs return inputs
async def _process_text_async(
self,
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = await self._process_multimodal_async(
prompt_text,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
else:
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
tokenization_kwargs=tokenization_kwargs,
)
inputs = token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
def _prompt_to_llm_inputs( def _prompt_to_llm_inputs(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
...@@ -524,41 +399,6 @@ class InputPreprocessor: ...@@ -524,41 +399,6 @@ class InputPreprocessor:
assert_never(parsed) assert_never(parsed)
async def _prompt_to_llm_inputs_async(
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> SingletonInputs:
"""
Async version of
[`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs].
"""
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds":
return await self._process_embeds_async(parsed["content"])
if parsed["type"] == "tokens":
return await self._process_tokens_async(
parsed["content"],
mm_uuids=mm_uuids,
)
if parsed["type"] == "text":
return await self._process_text_async(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
if parsed["type"] == "str":
return await self._process_text_async(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
assert_never(parsed)
def _build_enc_dec_llm_inputs( def _build_enc_dec_llm_inputs(
self, self,
encoder_inputs: SingletonInputs, encoder_inputs: SingletonInputs,
...@@ -735,62 +575,6 @@ class InputPreprocessor: ...@@ -735,62 +575,6 @@ class InputPreprocessor:
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
async def _process_encoder_decoder_prompt_async(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> EncoderDecoderInputs:
"""
Async version of
[`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt].
"""
encoder_inputs: SingletonInputs
decoder_inputs: Optional[SingletonInputs]
if is_explicit_encoder_decoder_prompt(prompt):
encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_inputs = await encoder_task
decoder_inputs = None
else:
decoder_task = self._prompt_to_llm_inputs_async(
decoder_input,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model:
encoder_inputs, decoder_inputs = (
self._split_enc_dec_mm_inputs(encoder_inputs,
decoder_inputs))
else:
inputs = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._split_enc_dec_mm_inputs(inputs))
else:
encoder_inputs = inputs
decoder_inputs = None
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
def _build_decoder_only_llm_inputs( def _build_decoder_only_llm_inputs(
self, self,
prompt_inputs: DecoderOnlyInputs, prompt_inputs: DecoderOnlyInputs,
...@@ -830,25 +614,6 @@ class InputPreprocessor: ...@@ -830,25 +614,6 @@ class InputPreprocessor:
return self._build_decoder_only_llm_inputs(prompt_comps) return self._build_decoder_only_llm_inputs(prompt_comps)
async def _process_decoder_only_prompt_async(
self,
prompt: SingletonPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> DecoderOnlyInputs:
"""
Async version of
[`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt].
"""
prompt_comps = await self._prompt_to_llm_inputs_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
return self._build_decoder_only_llm_inputs(prompt_comps)
def preprocess( def preprocess(
self, self,
prompt: PromptType, prompt: PromptType,
...@@ -877,37 +642,6 @@ class InputPreprocessor: ...@@ -877,37 +642,6 @@ class InputPreprocessor:
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
async def preprocess_async(
self,
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> ProcessorInputs:
"""
Async version of
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
"""
if self.model_config.is_encoder_decoder:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder.
return await self._process_encoder_decoder_prompt_async(
prompt,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
return await self._process_decoder_only_prompt_async(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
def clear_cache(self) -> None: def clear_cache(self) -> None:
if self.mm_processor_cache is not None: if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache() self.mm_processor_cache.clear_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment