Unverified Commit 7bd42e60 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Clean up input preprocessing (#33687)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent a2522839
...@@ -12,11 +12,13 @@ from vllm.sampling_params import SamplingParams ...@@ -12,11 +12,13 @@ from vllm.sampling_params import SamplingParams
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
MultiModalEncDecInputs,
MultiModalInputs, MultiModalInputs,
MultiModalUUIDDict, MultiModalUUIDDict,
) )
else: else:
MultiModalDataDict = object MultiModalDataDict = object
MultiModalEncDecInputs = object
MultiModalInputs = object MultiModalInputs = object
MultiModalUUIDDict = object MultiModalUUIDDict = object
...@@ -241,7 +243,7 @@ class EncoderDecoderInputs(TypedDict): ...@@ -241,7 +243,7 @@ class EncoderDecoderInputs(TypedDict):
This specifies the required data for encoder-decoder models. This specifies the required data for encoder-decoder models.
""" """
encoder: TokenInputs | MultiModalInputs encoder: TokenInputs | MultiModalEncDecInputs
"""The inputs for the encoder portion.""" """The inputs for the encoder portion."""
decoder: TokenInputs | MultiModalInputs decoder: TokenInputs | MultiModalInputs
......
...@@ -69,6 +69,22 @@ def is_explicit_encoder_decoder_prompt( ...@@ -69,6 +69,22 @@ def is_explicit_encoder_decoder_prompt(
return isinstance(prompt, dict) and "encoder_prompt" in prompt return isinstance(prompt, dict) and "encoder_prompt" in prompt
def split_enc_dec_prompt(
prompt: PromptType,
) -> tuple[SingletonPrompt, SingletonPrompt | None]:
if isinstance(prompt, str):
return prompt, None
if "encoder_prompt" in prompt and "decoder_prompt" in prompt:
# NOTE: This passes pyright but not mypy
return (
prompt["encoder_prompt"], # type: ignore[typeddict-item]
prompt["decoder_prompt"], # type: ignore[typeddict-item]
)
return prompt, None
def split_enc_dec_inputs( def split_enc_dec_inputs(
inputs: ProcessorInputs, inputs: ProcessorInputs,
) -> tuple[SingletonInputs | None, SingletonInputs]: ) -> tuple[SingletonInputs | None, SingletonInputs]:
......
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, cast from typing import Any
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.config import ModelConfig, ObservabilityConfig from vllm.config import ModelConfig, ObservabilityConfig
from vllm.inputs.parse import split_enc_dec_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.cache import BaseMultiModalProcessorCache
...@@ -27,7 +28,6 @@ from .data import ( ...@@ -27,7 +28,6 @@ from .data import (
EmbedsInputs, EmbedsInputs,
EmbedsPrompt, EmbedsPrompt,
EncoderDecoderInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt,
ProcessorInputs, ProcessorInputs,
PromptType, PromptType,
SingletonInputs, SingletonInputs,
...@@ -86,30 +86,15 @@ class InputPreprocessor: ...@@ -86,30 +86,15 @@ class InputPreprocessor:
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
def get_decoder_start_token_id(self) -> int | None: def get_decoder_start_token_id(self) -> int:
""" """
Obtain the decoder start token id employed by an encoder/decoder Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the model. Raises an error if it is not available.
model config is unavailable.
""" """
if not self.model_config.is_encoder_decoder:
logger.warning_once(
"Using None for decoder start token id because "
"this is not an encoder/decoder model."
)
return None
if self.model_config is None or self.model_config.hf_config is None:
logger.warning_once(
"Using None for decoder start token id because "
"model config is not available."
)
return None
dec_start_token_id = getattr( dec_start_token_id = getattr(
self.model_config.hf_config, "decoder_start_token_id", None self.model_config.hf_config, "decoder_start_token_id", None
) )
if dec_start_token_id is None: if dec_start_token_id is None:
logger.warning_once( logger.warning_once(
"Falling back on <BOS> for decoder start token " "Falling back on <BOS> for decoder start token "
...@@ -118,48 +103,12 @@ class InputPreprocessor: ...@@ -118,48 +103,12 @@ class InputPreprocessor:
) )
dec_start_token_id = self.get_bos_token_id() dec_start_token_id = self.get_bos_token_id()
return dec_start_token_id if dec_start_token_id is None:
raise RuntimeError("Cannot find decoder start token id or <BOS>")
def _get_default_enc_dec_decoder_prompt(self) -> list[int]:
"""
Specifically for encoder/decoder models:
generate a default decoder prompt for when
the user specifies only the encoder prompt.
Encoder/decoder models utilize the decoder
prompt in different ways; as new models are
added, it is intended that this function
will be extended to produce differing
default decoder prompts, depending on the
model variety.
Absent a special case, the default behavior
of this method is to mirror the behavior of
the HuggingFace (HF) GenerationMixin for a None
decoder prompt, which is to employ a logit processor
setting to force the first decoded token to be <BOS>.
Here, this behavior is approximated by having the
"default" decoder prompt be <BOS>.
However, it is possible that in the future
other models may have different or more
complex logic for the default decoder prompt.
This motivates having a special helper method
for default decoder prompts.
Returns:
* prompt_token_ids
"""
bos_token_id = self.get_bos_token_id() return dec_start_token_id
assert bos_token_id is not None
return [bos_token_id]
def _prepare_decoder_input_ids_for_generation( def _prepare_decoder_input_ids(self, decoder_input_ids: list[int]) -> list[int]:
self,
decoder_input_ids: list[int] | None,
) -> list[int]:
""" """
Prepares `decoder_input_ids` for generation with encoder-decoder models. Prepares `decoder_input_ids` for generation with encoder-decoder models.
...@@ -176,14 +125,7 @@ class InputPreprocessor: ...@@ -176,14 +125,7 @@ class InputPreprocessor:
* Processed token list * Processed token list
""" """
decoder_start_token_id = self.get_decoder_start_token_id() decoder_start_token_id = self.get_decoder_start_token_id()
assert decoder_start_token_id is not None
if decoder_input_ids is None:
# no decoder prompt input ->
# use decoder_start_token_id as decoder_input_ids
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if ( if (
len(decoder_input_ids) == 0 len(decoder_input_ids) == 0
...@@ -428,111 +370,70 @@ class InputPreprocessor: ...@@ -428,111 +370,70 @@ class InputPreprocessor:
assert_never(parsed) assert_never(parsed)
def _build_enc_dec_llm_inputs( def _validate_enc_inputs(
self, self,
encoder_inputs: SingletonInputs, inputs: SingletonInputs,
decoder_inputs: SingletonInputs | None, ) -> TokenInputs | MultiModalEncDecInputs:
) -> EncoderDecoderInputs: if inputs["type"] == "embeds":
if (
encoder_inputs["type"] == "embeds"
or decoder_inputs
and decoder_inputs["type"] == "embeds"
):
raise ValueError( raise ValueError(
"Embedding inputs are not supported for encoder-decoder models" "Embedding inputs are not supported for encoder-decoder models"
) )
# Needed for mypy if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs:
encoder_inputs = cast(TokenInputs | MultiModalInputs, encoder_inputs) raise RuntimeError(
decoder_inputs = cast(TokenInputs | MultiModalInputs | None, decoder_inputs) "You should register an encoder-decoder "
"multi-modal processor for encoder-decoder models."
if decoder_inputs is None:
if self.model_config.hf_config.model_type == "whisper":
# For Whisper models, the text prompt should go to the decoder.
# If no explicit encoder/decoder inputs, then copy the prompt
# from the encoder to the decoder. The encoder tokens are later
# overridden by the audio features.
dec_token_ids = encoder_inputs["prompt_token_ids"].copy()
else:
dec_token_ids = self._prepare_decoder_input_ids_for_generation(None)
decoder_inputs = token_inputs(dec_token_ids)
else:
if "multi_modal_data" in decoder_inputs:
raise ValueError(
"Multi-modal decoder inputs of encoder-"
"decoder models are not supported yet"
)
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
decoder_inputs["prompt_token_ids"]
) )
decoder_inputs["prompt_token_ids"] = dec_token_ids
return EncoderDecoderInputs( return inputs # type: ignore[return-value]
encoder=encoder_inputs,
decoder=decoder_inputs,
)
def _split_enc_dec_mm_inputs( def _validate_dec_inputs(
self, self,
inputs: SingletonInputs | MultiModalEncDecInputs, inputs: SingletonInputs,
decoder_inputs_to_override: SingletonInputs | None = None, ) -> TokenInputs | MultiModalInputs:
) -> tuple[SingletonInputs, SingletonInputs]: if inputs["type"] == "embeds":
"""
For encoder/decoder models only:
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
"""
if (
inputs["type"] == "embeds"
or decoder_inputs_to_override
and decoder_inputs_to_override["type"] == "embeds"
):
raise ValueError( raise ValueError(
"Embedding inputs are not supported for encoder-decoder models" "Embedding inputs are not supported for encoder-decoder models"
) )
# Needed for mypy return inputs
inputs = cast(
TokenInputs | MultiModalInputs | MultiModalEncDecInputs,
inputs,
)
decoder_inputs_to_override = cast(
TokenInputs | MultiModalInputs | None,
decoder_inputs_to_override,
)
encoder_inputs: SingletonInputs def _build_enc_dec_inputs(
decoder_inputs: SingletonInputs self,
encoder_inputs: SingletonInputs,
decoder_inputs: SingletonInputs | None = None,
) -> EncoderDecoderInputs:
if decoder_inputs is None:
decoder_inputs = encoder_inputs
if inputs["type"] == "multimodal": # Multimodal data inputs enc_inputs = self._validate_enc_inputs(encoder_inputs)
if "encoder_prompt_token_ids" not in inputs: dec_inputs = self._validate_dec_inputs(decoder_inputs)
raise RuntimeError(
"You should register an encoder-decoder "
"multi-modal processor for encoder-decoder "
"models."
)
inputs = cast(MultiModalEncDecInputs, inputs)
encoder_inputs = token_inputs(inputs["encoder_prompt_token_ids"]) enc_inputs_new: TokenInputs | MultiModalEncDecInputs
dec_inputs_new: TokenInputs | MultiModalInputs
decoder_prompt_inputs = decoder_inputs_to_override or inputs if enc_inputs["type"] == "multimodal":
decoder_inputs = MultiModalInputs( enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"])
dec_inputs_new = MultiModalInputs(
type="multimodal", type="multimodal",
prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"], prompt_token_ids=dec_inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"], mm_kwargs=enc_inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"], mm_hashes=enc_inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"], mm_placeholders=enc_inputs["mm_placeholders"],
) )
if cache_salt := inputs.get("cache_salt"): elif enc_inputs["type"] == "token":
decoder_inputs["cache_salt"] = cache_salt enc_inputs_new = token_inputs(prompt_token_ids=[])
dec_inputs_new = dec_inputs
elif inputs["type"] == "token": # Text-only inputs
encoder_inputs = token_inputs(prompt_token_ids=[])
decoder_inputs = decoder_inputs_to_override or inputs
else: else:
assert_never(inputs) # type: ignore[arg-type] assert_never(enc_inputs)
dec_inputs_new["prompt_token_ids"] = self._prepare_decoder_input_ids(
dec_inputs_new["prompt_token_ids"]
)
if cache_salt := enc_inputs.get("cache_salt"):
dec_inputs_new["cache_salt"] = cache_salt
return encoder_inputs, decoder_inputs return EncoderDecoderInputs(encoder=enc_inputs_new, decoder=dec_inputs_new)
def _process_encoder_decoder_prompt( def _process_encoder_decoder_prompt(
self, self,
...@@ -574,54 +475,23 @@ class InputPreprocessor: ...@@ -574,54 +475,23 @@ class InputPreprocessor:
* [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs] * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance instance
""" """
encoder_inputs: SingletonInputs encoder_prompt, decoder_prompt = split_enc_dec_prompt(prompt)
decoder_inputs: SingletonInputs | None
if is_explicit_encoder_decoder_prompt(prompt): return self._build_enc_dec_inputs(
# `cast` is needed for mypy, but not pyright encoder_inputs=self._prompt_to_llm_inputs(
prompt_ = cast(ExplicitEncoderDecoderPrompt, prompt) encoder_prompt,
encoder_inputs = self._prompt_to_llm_inputs(
prompt_["encoder_prompt"],
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) ),
if (decoder_input := prompt_["decoder_prompt"]) is None: decoder_inputs=(
decoder_inputs = None None
else: if decoder_prompt is None
decoder_inputs = self._prompt_to_llm_inputs( else self._prompt_to_llm_inputs(
decoder_input, tokenization_kwargs=tokenization_kwargs decoder_prompt,
tokenization_kwargs=tokenization_kwargs,
) )
# For multimodal model, override decoder prompt from processor ),
# with explicit decoder prompt. )
if self.model_config.is_multimodal_model:
encoder_inputs, decoder_inputs = self._split_enc_dec_mm_inputs(
encoder_inputs, decoder_inputs
)
else:
# `cast` is needed for mypy, but not pyright
inputs = self._prompt_to_llm_inputs(
cast(SingletonPrompt, prompt),
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = self._split_enc_dec_mm_inputs(inputs)
else:
encoder_inputs = inputs
decoder_inputs = None
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
def _build_decoder_only_llm_inputs(
self,
prompt_inputs: DecoderOnlyInputs,
) -> DecoderOnlyInputs:
if "prompt_token_ids" in prompt_inputs:
prompt_inputs = cast(
TokenInputs | MultiModalInputs, prompt_inputs
) # Needed for mypy
return prompt_inputs
def _process_decoder_only_prompt( def _process_decoder_only_prompt(
self, self,
...@@ -643,15 +513,12 @@ class InputPreprocessor: ...@@ -643,15 +513,12 @@ class InputPreprocessor:
* [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
""" """
return self._prompt_to_llm_inputs(
prompt_comps = self._prompt_to_llm_inputs(
prompt, prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
return self._build_decoder_only_llm_inputs(prompt_comps)
def _preprocess( def _preprocess(
self, self,
prompt: PromptType, prompt: PromptType,
...@@ -673,10 +540,8 @@ class InputPreprocessor: ...@@ -673,10 +540,8 @@ class InputPreprocessor:
"Cannot pass encoder-decoder prompt to decoder-only models" "Cannot pass encoder-decoder prompt to decoder-only models"
) )
# Decoder-only operation
# `cast` is needed for mypy, but not pyright
return self._process_decoder_only_prompt( return self._process_decoder_only_prompt(
cast(SingletonPrompt, prompt), prompt,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
......
...@@ -1083,6 +1083,10 @@ class MultiModalEncDecInputs(MultiModalInputs): ...@@ -1083,6 +1083,10 @@ class MultiModalEncDecInputs(MultiModalInputs):
Represents the outputs of Represents the outputs of
[`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor] [`EncDecMultiModalProcessor`][vllm.multimodal.processing.EncDecMultiModalProcessor]
ready to be passed to vLLM internals. ready to be passed to vLLM internals.
Note: Even text-only encoder-decoder models are currently implemented
as multi-modal models for convenience.
(Example: https://github.com/neuralmagic/bart-plugin)
""" """
encoder_prompt_token_ids: list[int] encoder_prompt_token_ids: list[int]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment