"examples/vscode:/vscode.git/clone" did not exist on "51d7c6a2b23e100cd9e7d85b8e7c0eea656b331e"
Unverified Commit 24718153 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Replace `is_encoder_decoder_inputs` with `split_enc_dec_inputs` (#15620)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 07bf813f
...@@ -29,7 +29,7 @@ def test_processor_override( ...@@ -29,7 +29,7 @@ def test_processor_override(
num_imgs: int, num_imgs: int,
kwargs_on_init: bool, kwargs_on_init: bool,
): ):
"""Ensure input_processor_for_idefics3 handles num_crops properly.""" """Ensure Idefics3MultiModalProcessor handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs # Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by # in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor. # the partial when calling the custom input processor.
......
...@@ -30,7 +30,7 @@ def test_processor_override( ...@@ -30,7 +30,7 @@ def test_processor_override(
num_imgs: int, num_imgs: int,
kwargs_on_init: bool, kwargs_on_init: bool,
): ):
"""Ensure input_processor_for_phi3v handles num_crops properly.""" """Ensure Phi3VMultiModalProcessor handles num_crops properly."""
# Avoid initializing CUDA early # Avoid initializing CUDA early
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
......
...@@ -665,7 +665,7 @@ class EngineArgs: ...@@ -665,7 +665,7 @@ class EngineArgs:
type=nullable_kvs, type=nullable_kvs,
default=EngineArgs.limit_mm_per_prompt, default=EngineArgs.limit_mm_per_prompt,
# The default value is given in # The default value is given in
# MultiModalRegistry.init_mm_limits_per_prompt # MultiModalConfig.get_limit_per_prompt
help=('For each multimodal plugin, limit how many ' help=('For each multimodal plugin, limit how many '
'input instances to allow for each prompt. ' 'input instances to allow for each prompt. '
'Expects a comma-separated list of items, ' 'Expects a comma-separated list of items, '
......
...@@ -30,8 +30,8 @@ from vllm.entrypoints.openai.logits_processors import ( ...@@ -30,8 +30,8 @@ from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors) get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter) PromptType)
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors from vllm.logits_process import get_bad_words_logits_processors
...@@ -609,12 +609,7 @@ class LLMEngine: ...@@ -609,12 +609,7 @@ class LLMEngine:
seq_id = next(self.seq_counter) seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
decoder_inputs = processed_inputs["decoder"]
encoder_inputs = processed_inputs["encoder"]
else:
decoder_inputs = processed_inputs
encoder_inputs = None
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request) lora_request, prompt_adapter_request)
...@@ -2031,15 +2026,16 @@ class LLMEngine: ...@@ -2031,15 +2026,16 @@ class LLMEngine:
def _validate_model_inputs(self, inputs: ProcessorInputs, def _validate_model_inputs(self, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest]): lora_request: Optional[LoRARequest]):
if is_encoder_decoder_inputs(inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
# For encoder-decoder multimodal models, the max_prompt_len # For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length # restricts the decoder prompt length
prompt_inputs = inputs["decoder" if self.model_config. if self.model_config.is_multimodal_model:
is_multimodal_model else "encoder"] prompt_inputs = decoder_inputs
else: else:
prompt_inputs = inputs prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0: if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty") raise ValueError("Prompt cannot be empty")
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal, TypedDict, Union, cast, overload from typing import Literal, Optional, TypedDict, Union, cast, overload
from typing_extensions import TypeIs from typing_extensions import TypeIs
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
ProcessorInputs, PromptType, SingletonPrompt, TextPrompt, SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
TokensPrompt)
class ParsedText(TypedDict): class ParsedText(TypedDict):
...@@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt( ...@@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt(
return isinstance(prompt, dict) and "encoder_prompt" in prompt return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_encoder_decoder_inputs( def split_enc_dec_inputs(
inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]: inputs: ProcessorInputs,
return "encoder" in inputs and "decoder" in inputs ) -> tuple[Optional[SingletonInputs], SingletonInputs]:
if "encoder" in inputs and "decoder" in inputs:
# NOTE: This passes pyright but not mypy
return (
inputs["encoder"], # type: ignore[typeddict-item]
inputs["decoder"], # type: ignore[typeddict-item]
)
return None, inputs
...@@ -19,7 +19,7 @@ from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, ...@@ -19,7 +19,7 @@ from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs) resolve_mm_processor_kwargs)
from .data import ProcessorInputs, SingletonInputs from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs from .parse import split_enc_dec_inputs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -462,13 +462,11 @@ class InputRegistry: ...@@ -462,13 +462,11 @@ class InputRegistry:
**mm_processor_kwargs, **mm_processor_kwargs,
) )
if is_encoder_decoder_inputs(processed_inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
self._ensure_mm_kwargs(processed_inputs["encoder"], if encoder_inputs is not None:
mm_processor_kwargs) self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
self._ensure_mm_kwargs(processed_inputs["decoder"], if decoder_inputs is not None:
mm_processor_kwargs) self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
else:
self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)
return processed_inputs return processed_inputs
......
...@@ -232,7 +232,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo] ...@@ -232,7 +232,7 @@ class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
) )
class Idefics3MultimodalProcessor( class Idefics3MultiModalProcessor(
BaseMultiModalProcessor[Idefics3ProcessingInfo]): BaseMultiModalProcessor[Idefics3ProcessingInfo]):
def _call_hf_processor( def _call_hf_processor(
...@@ -575,7 +575,7 @@ class Idefics3Model(nn.Module): ...@@ -575,7 +575,7 @@ class Idefics3Model(nn.Module):
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
Idefics3MultimodalProcessor, Idefics3MultiModalProcessor,
info=Idefics3ProcessingInfo, info=Idefics3ProcessingInfo,
dummy_inputs=Idefics3DummyInputsBuilder) dummy_inputs=Idefics3DummyInputsBuilder)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
......
...@@ -7,7 +7,7 @@ from typing import Optional, Union ...@@ -7,7 +7,7 @@ from typing import Optional, Union
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter) PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
...@@ -209,14 +209,8 @@ class Processor: ...@@ -209,14 +209,8 @@ class Processor:
self._validate_model_inputs(processed_inputs, lora_request) self._validate_model_inputs(processed_inputs, lora_request)
if is_encoder_decoder_inputs(processed_inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
decoder_inputs = SingletonInputsAdapter( decoder_inputs = SingletonInputsAdapter(decoder_inputs)
processed_inputs["decoder"])
encoder_inputs = SingletonInputsAdapter(
processed_inputs["encoder"])
else:
decoder_inputs = SingletonInputsAdapter(processed_inputs)
encoder_inputs = None
# TODO: Impl encoder-decoder # TODO: Impl encoder-decoder
if encoder_inputs is not None: if encoder_inputs is not None:
...@@ -301,15 +295,16 @@ class Processor: ...@@ -301,15 +295,16 @@ class Processor:
def _validate_model_inputs(self, def _validate_model_inputs(self,
inputs: ProcessorInputs, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None): lora_request: Optional[LoRARequest] = None):
if is_encoder_decoder_inputs(inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
# For encoder-decoder multimodal models, the max_prompt_len # For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length # restricts the decoder prompt length
prompt_inputs = inputs["decoder" if self.model_config. if self.model_config.is_multimodal_model:
is_multimodal_model else "encoder"] prompt_inputs = decoder_inputs
else: else:
prompt_inputs = inputs prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0: if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty") raise ValueError("Prompt cannot be empty")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment