Unverified Commit 0b8bb86b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[1/N] Initial prototype for multi-modal processor (#10044)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent bb7991aa
...@@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i ...@@ -66,7 +66,7 @@ A default mapper is available for each modality in the core vLLM library. This i
3. Register maximum number of multi-modal tokens 3. Register maximum number of multi-modal tokens
------------------------------------------------ ------------------------------------------------
For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data instance For each modality type that the model accepts as input, calculate the maximum possible number of tokens per data item
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`. and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.
.. code-block:: diff .. code-block:: diff
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from PIL.Image import Image from PIL.Image import Image
from vllm.inputs import InputContext, token_inputs from vllm.inputs import InputContext, token_inputs
from vllm.multimodal.base import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from .....conftest import IMAGE_ASSETS from .....conftest import IMAGE_ASSETS
......
import torch import torch
from vllm.multimodal.base import MultiModalKwargs, NestedTensors from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
def assert_nested_tensors_equal(expected: NestedTensors, def assert_nested_tensors_equal(expected: NestedTensors,
......
from array import array from array import array
from typing import Mapping from typing import Callable, Dict, Mapping, Optional
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext, from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
InputRegistry, token_inputs) InputRegistry, ProcessorInputs, token_inputs)
from vllm.multimodal import MultiModalRegistry from vllm.multimodal import MultiModalRegistry
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
...@@ -34,10 +34,9 @@ def use_processor_mock(): ...@@ -34,10 +34,9 @@ def use_processor_mock():
inputs: DecoderOnlyInputs, inputs: DecoderOnlyInputs,
*, *,
num_crops=DEFAULT_NUM_CROPS): num_crops=DEFAULT_NUM_CROPS):
# For testing purposes, we don't worry about the llm inputs / return # For testing purposes, we don't worry about the prompt
# type validation, and just return the value of the kwarg that we return token_inputs(prompt_token_ids=[],
# clobber. mm_processor_kwargs={"num_crops": num_crops})
return num_crops
with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor",
return_value=custom_processor): return_value=custom_processor):
...@@ -109,6 +108,21 @@ def _get_num_crops_info(init_num_crops: int, inference_num_crops: int): ...@@ -109,6 +108,21 @@ def _get_num_crops_info(init_num_crops: int, inference_num_crops: int):
return init_kwargs, inference_kwargs, expected_seq_count return init_kwargs, inference_kwargs, expected_seq_count
def _get_processed_num_crops(
processor: Callable[[ProcessorInputs], ProcessorInputs],
inference_kwargs: Optional[Dict[str, int]],
) -> int:
processed_inputs = processor(
token_inputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=inference_kwargs))
assert "type" in processed_inputs
assert processed_inputs["type"] == "token"
assert "mm_processor_kwargs" in processed_inputs
return processed_inputs["mm_processor_kwargs"]["num_crops"]
@pytest.mark.parametrize("init_num_crops,inference_num_crops", [ @pytest.mark.parametrize("init_num_crops,inference_num_crops", [
(None, None), (None, None),
(NUM_CROPS_OVERRIDE, None), (NUM_CROPS_OVERRIDE, None),
...@@ -124,10 +138,8 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops, ...@@ -124,10 +138,8 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops,
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs) ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
processor = dummy_registry.create_input_processor(ctx.model_config) processor = dummy_registry.create_input_processor(ctx.model_config)
num_crops_val = processor( num_crops_val = _get_processed_num_crops(processor, inference_kwargs)
token_inputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=inference_kwargs))
assert num_crops_val == expected_seq_count assert num_crops_val == expected_seq_count
...@@ -153,10 +165,7 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, ...@@ -153,10 +165,7 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
processor = dummy_registry.create_input_processor(ctx.model_config) processor = dummy_registry.create_input_processor(ctx.model_config)
# Should filter out the inference time kwargs # Should filter out the inference time kwargs
num_crops_val = processor( num_crops_val = _get_processed_num_crops(processor, mm_processor_kwargs)
token_inputs(prompt_token_ids=[],
prompt="",
mm_processor_kwargs=mm_processor_kwargs))
assert num_crops_val == DEFAULT_NUM_CROPS assert num_crops_val == DEFAULT_NUM_CROPS
......
"""Compare the with and without prefix caching.""" """Compare the with and without prefix caching."""
from vllm.inputs import DecoderOnlyInputs from vllm.inputs import token_inputs
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import hash_block_tokens from vllm.v1.core.kv_cache_utils import hash_block_tokens
...@@ -8,7 +8,7 @@ from vllm.v1.core.kv_cache_utils import hash_block_tokens ...@@ -8,7 +8,7 @@ from vllm.v1.core.kv_cache_utils import hash_block_tokens
def make_request(request_id, prompt_token_ids): def make_request(request_id, prompt_token_ids):
return Request( return Request(
request_id=request_id, request_id=request_id,
inputs=DecoderOnlyInputs(prompt_token_ids=prompt_token_ids), inputs=token_inputs(prompt_token_ids=prompt_token_ids),
sampling_params=SamplingParams(max_tokens=17), sampling_params=SamplingParams(max_tokens=17),
eos_token_id=100, eos_token_id=100,
arrival_time=0, arrival_time=0,
......
...@@ -107,7 +107,7 @@ class ModelConfig: ...@@ -107,7 +107,7 @@ class ModelConfig:
matches the model name exposed via the APIs. If multiple model matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified, names provided, the first name will be used. If not specified,
the model name will be the same as `model`. the model name will be the same as `model`.
limit_mm_per_prompt: Maximum number of data instances per modality limit_mm_per_prompt: Maximum number of data items per modality
per prompt. Only applicable for multimodal models. per prompt. Only applicable for multimodal models.
override_neuron_config: Initialize non default neuron config or override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices, override default neuron config that are specific to Neuron devices,
......
...@@ -19,6 +19,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase ...@@ -19,6 +19,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
...@@ -729,6 +730,9 @@ class AsyncLLMEngine(EngineClient): ...@@ -729,6 +730,9 @@ class AsyncLLMEngine(EngineClient):
self.set_errored(exc) self.set_errored(exc)
self._request_tracker.propagate_exception(exc) self._request_tracker.propagate_exception(exc)
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.engine.input_preprocessor
async def get_tokenizer( async def get_tokenizer(
self, self,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
......
...@@ -30,7 +30,7 @@ from vllm.executor.executor_base import ExecutorBase ...@@ -30,7 +30,7 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType) PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -39,6 +39,7 @@ from vllm.lora.request import LoRARequest ...@@ -39,6 +39,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor) get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -226,6 +227,7 @@ class LLMEngine: ...@@ -226,6 +227,7 @@ class LLMEngine:
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY, input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
) -> None: ) -> None:
...@@ -335,7 +337,8 @@ class LLMEngine: ...@@ -335,7 +337,8 @@ class LLMEngine:
model_config) model_config)
self.input_preprocessor = InputPreprocessor(model_config, self.input_preprocessor = InputPreprocessor(model_config,
self.tokenizer) self.tokenizer,
mm_registry)
self.input_registry = input_registry self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor( self.input_processor = input_registry.create_input_processor(
...@@ -851,13 +854,6 @@ class LLMEngine: ...@@ -851,13 +854,6 @@ class LLMEngine:
) )
processed_inputs = self.input_processor(preprocessed_inputs) processed_inputs = self.input_processor(preprocessed_inputs)
# This is a bit of a hack - copy the mm_processor_kwargs that were
# used in the input processor to the processed output, since these
# kwargs are presumed to be immutable and the values should be aligned
# between the input processor (here) and the input mapper.
processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get(
"mm_processor_kwargs")
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
processed_inputs=processed_inputs, processed_inputs=processed_inputs,
...@@ -2019,7 +2015,7 @@ class LLMEngine: ...@@ -2019,7 +2015,7 @@ class LLMEngine:
else: else:
prompt_inputs = inputs prompt_inputs = inputs
prompt_ids = prompt_inputs.get("prompt_token_ids") prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
if prompt_ids is None or len(prompt_ids) == 0: if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty") raise ValueError("Prompt cannot be empty")
......
...@@ -31,6 +31,7 @@ from vllm.engine.protocol import EngineClient ...@@ -31,6 +31,7 @@ from vllm.engine.protocol import EngineClient
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
...@@ -94,6 +95,8 @@ class MQLLMEngineClient(EngineClient): ...@@ -94,6 +95,8 @@ class MQLLMEngineClient(EngineClient):
parallel_config=engine_config.parallel_config, parallel_config=engine_config.parallel_config,
enable_lora=bool(engine_config.lora_config), enable_lora=bool(engine_config.lora_config),
) )
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer)
# Send RPCGenerateRequest to the MQLLMEngine. # Send RPCGenerateRequest to the MQLLMEngine.
self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) self.input_socket: Socket = self.context.socket(zmq.constants.PUSH)
...@@ -345,6 +348,9 @@ class MQLLMEngineClient(EngineClient): ...@@ -345,6 +348,9 @@ class MQLLMEngineClient(EngineClient):
or response != VLLM_RPC_SUCCESS_STR): or response != VLLM_RPC_SUCCESS_STR):
raise ValueError(error_message) raise ValueError(error_message)
async def get_input_preprocessor(self) -> InputPreprocessor:
return self.input_preprocessor
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
return await self.tokenizer.get_lora_tokenizer_async(lora_request) return await self.tokenizer.get_lora_tokenizer_async(lora_request)
......
...@@ -62,7 +62,6 @@ class EngineClient(ABC): ...@@ -62,7 +62,6 @@ class EngineClient(ABC):
async def beam_search( async def beam_search(
self, self,
prompt: PromptType, prompt: PromptType,
model_config: ModelConfig,
request_id: str, request_id: str,
params: BeamSearchParams, params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
...@@ -74,13 +73,14 @@ class EngineClient(ABC): ...@@ -74,13 +73,14 @@ class EngineClient(ABC):
length_penalty = params.length_penalty length_penalty = params.length_penalty
include_stop_str_in_output = params.include_stop_str_in_output include_stop_str_in_output = params.include_stop_str_in_output
tokenizer = await self.get_tokenizer() preprocessor = await self.get_input_preprocessor()
input_preprocessor = InputPreprocessor(model_config, tokenizer) tokenizer_group = preprocessor.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async()
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError raise NotImplementedError
else: else:
processed_inputs = input_preprocessor._prompt_to_llm_inputs( processed_inputs = preprocessor._prompt_to_llm_inputs(
prompt, prompt,
request_id=request_id, request_id=request_id,
) )
...@@ -220,6 +220,7 @@ class EngineClient(ABC): ...@@ -220,6 +220,7 @@ class EngineClient(ABC):
Args: Args:
request_id: The unique id of the request. request_id: The unique id of the request.
""" """
...
@abstractmethod @abstractmethod
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
...@@ -228,8 +229,13 @@ class EngineClient(ABC): ...@@ -228,8 +229,13 @@ class EngineClient(ABC):
@abstractmethod @abstractmethod
async def get_decoding_config(self) -> DecodingConfig: async def get_decoding_config(self) -> DecodingConfig:
...
"""Get the decoding configuration of the vLLM engine.""" """Get the decoding configuration of the vLLM engine."""
...
@abstractmethod
async def get_input_preprocessor(self) -> InputPreprocessor:
"""Get the input processor of the vLLM engine."""
...
@abstractmethod @abstractmethod
async def get_tokenizer( async def get_tokenizer(
......
...@@ -190,7 +190,6 @@ class OpenAIServingChat(OpenAIServing): ...@@ -190,7 +190,6 @@ class OpenAIServingChat(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search( generator = self.engine_client.beam_search(
prompt=engine_prompt, prompt=engine_prompt,
model_config=self.model_config,
request_id=request_id, request_id=request_id,
params=sampling_params, params=sampling_params,
) )
......
...@@ -140,7 +140,6 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -140,7 +140,6 @@ class OpenAIServingCompletion(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search( generator = self.engine_client.beam_search(
prompt=engine_prompt, prompt=engine_prompt,
model_config=self.model_config,
request_id=request_id, request_id=request_id,
params=sampling_params, params=sampling_params,
) )
......
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, SingletonInputs, SingletonInputsAdapter, SingletonPrompt,
TokensPrompt, build_explicit_enc_dec_prompt, TextPrompt, TokenInputs, TokensPrompt,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
from .registry import DummyData, InputContext, InputRegistry token_inputs, zip_enc_dec_prompts)
from .registry import (DummyData, InputContext, InputProcessingContext,
InputRegistry)
INPUT_REGISTRY = InputRegistry() INPUT_REGISTRY = InputRegistry()
""" """
...@@ -26,12 +28,14 @@ __all__ = [ ...@@ -26,12 +28,14 @@ __all__ = [
"EncoderDecoderInputs", "EncoderDecoderInputs",
"ProcessorInputs", "ProcessorInputs",
"SingletonInputs", "SingletonInputs",
"SingletonInputsAdapter",
"build_explicit_enc_dec_prompt", "build_explicit_enc_dec_prompt",
"to_enc_dec_tuple_list", "to_enc_dec_tuple_list",
"zip_enc_dec_prompts", "zip_enc_dec_prompts",
"INPUT_REGISTRY", "INPUT_REGISTRY",
"DummyData", "DummyData",
"InputContext", "InputContext",
"InputProcessingContext",
"InputRegistry", "InputRegistry",
] ]
......
from dataclasses import dataclass
from functools import cached_property
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal, from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, Literal,
Optional, Tuple, Union, cast) Optional, Tuple, Union, cast)
from typing_extensions import NotRequired, TypedDict, TypeVar import torch
from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.multimodal.inputs import MultiModalInputsV2
class TextPrompt(TypedDict): class TextPrompt(TypedDict):
...@@ -36,13 +40,13 @@ class TokensPrompt(TypedDict): ...@@ -36,13 +40,13 @@ class TokensPrompt(TypedDict):
multi_modal_data: NotRequired["MultiModalDataDict"] multi_modal_data: NotRequired["MultiModalDataDict"]
""" """
Optional multi-modal data to pass to the model, DEPRECATED: Optional multi-modal data to pass to the model,
if the model supports it. if the model supports it.
""" """
mm_processor_kwargs: NotRequired[Dict[str, Any]] mm_processor_kwargs: NotRequired[Dict[str, Any]]
""" """
Optional multi-modal processor kwargs to be forwarded to the DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them. to pass the mm_processor_kwargs to each of them.
...@@ -176,7 +180,7 @@ def token_inputs( ...@@ -176,7 +180,7 @@ def token_inputs(
return inputs return inputs
DecoderOnlyInputs = TokenInputs DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputsV2"]
""" """
The inputs in :class:`~vllm.LLMEngine` before they are The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor. passed to the model executor.
...@@ -191,19 +195,91 @@ class EncoderDecoderInputs(TypedDict): ...@@ -191,19 +195,91 @@ class EncoderDecoderInputs(TypedDict):
This specifies the required data for encoder-decoder models. This specifies the required data for encoder-decoder models.
""" """
encoder: TokenInputs encoder: Union[TokenInputs, "MultiModalInputsV2"]
"""The inputs for the encoder portion.""" """The inputs for the encoder portion."""
decoder: TokenInputs decoder: Union[TokenInputs, "MultiModalInputsV2"]
"""The inputs for the decoder portion.""" """The inputs for the decoder portion."""
SingletonInputs = TokenInputs SingletonInputs = Union[TokenInputs, "MultiModalInputsV2"]
""" """
A processed :class:`SingletonPrompt` which can be passed to A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`. :class:`vllm.sequence.Sequence`.
""" """
@dataclass
class SingletonInputsAdapter:
"""
Unified interface to access the components of :class:`SingletonInputs`.
"""
inputs: SingletonInputs
@cached_property
def prompt(self) -> Optional[str]:
inputs = self.inputs
if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("prompt")
assert_never(inputs)
@cached_property
def prompt_token_ids(self) -> List[int]:
inputs = self.inputs
if inputs["type"] == "token" or inputs["type"] == "multimodal":
return inputs.get("prompt_token_ids", [])
assert_never(inputs)
@cached_property
def prompt_embeds(self) -> Optional[torch.Tensor]:
inputs = self.inputs
if inputs["type"] == "token" or inputs["type"] == "multimodal":
return None
assert_never(inputs)
@cached_property
def multi_modal_data(self) -> "MultiModalDataDict":
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("multi_modal_data", {})
if inputs["type"] == "multimodal":
return inputs.get("mm_kwargs", {})
assert_never(inputs)
@cached_property
def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict":
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("multi_modal_placeholders", {})
if inputs["type"] == "multimodal":
return inputs.get("mm_placeholders", {})
assert_never(inputs)
@cached_property
def mm_processor_kwargs(self) -> Dict[str, Any]:
inputs = self.inputs
if inputs["type"] == "token":
return inputs.get("mm_processor_kwargs", {})
if inputs["type"] == "multimodal":
return {}
assert_never(inputs)
ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs] ProcessorInputs = Union[DecoderOnlyInputs, EncoderDecoderInputs]
""" """
The inputs to :data:`vllm.inputs.InputProcessor`. The inputs to :data:`vllm.inputs.InputProcessor`.
...@@ -234,10 +310,11 @@ def zip_enc_dec_prompts( ...@@ -234,10 +310,11 @@ def zip_enc_dec_prompts(
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]: ) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
""" """
Zip encoder and decoder prompts together into a list of Zip encoder and decoder prompts together into a list of
:class:`ExplicitEncoderDecoderPrompt` instances. mm_processor_kwargs :class:`ExplicitEncoderDecoderPrompt` instances.
may also be provided; if a dict is passed, the same dictionary will be
used for every encoder/decoder prompt. If an iterable is provided, it will ``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
be zipped with the encoder/decoder prompts. dictionary will be used for every encoder/decoder prompt. If an iterable is
provided, it will be zipped with the encoder/decoder prompts.
""" """
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = cast(Dict[str, Any], {}) mm_processor_kwargs = cast(Dict[str, Any], {})
......
import asyncio import asyncio
from typing import List, Optional from typing import List, Mapping, Optional, Union
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.processing import MultiModalDataDict, MultiModalInputsV2
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
...@@ -23,11 +25,13 @@ class InputPreprocessor: ...@@ -23,11 +25,13 @@ class InputPreprocessor:
self, self,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: Optional[BaseTokenizerGroup], tokenizer: Optional[BaseTokenizerGroup],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model_config = model_config self.model_config = model_config
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.mm_registry = mm_registry
def get_tokenizer_group(self) -> BaseTokenizerGroup: def get_tokenizer_group(self) -> BaseTokenizerGroup:
if self.tokenizer is None: if self.tokenizer is None:
...@@ -198,14 +202,79 @@ class InputPreprocessor: ...@@ -198,14 +202,79 @@ class InputPreprocessor:
prompt=prompt, prompt=prompt,
lora_request=lora_request) lora_request=lora_request)
def _can_process_multimodal(self) -> bool:
model_config = self.model_config
if not model_config.is_multimodal_model:
raise ValueError("Your model does not support multi-modal inputs")
# Interim measure so we can handle models that have yet to be
# updated to use the new multi-modal processor
can_process_multimodal = self.mm_registry.has_processor(model_config)
if not can_process_multimodal:
logger.info(
"Your model uses the legacy input pipeline instead of the new "
"multi-modal processor. Please note that the legacy pipeline "
"will be removed in a future release. For more details, see: "
"https://github.com/vllm-project/vllm/issues/10114")
return can_process_multimodal
def _process_multimodal(
self,
prompt: Union[str, List[int]],
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]],
lora_request: Optional[LoRARequest],
) -> MultiModalInputsV2:
"""
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
"""
tokenizer_group = self.get_tokenizer_group()
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor(
self.model_config, tokenizer)
if isinstance(prompt, list):
prompt = tokenizer.decode(prompt)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs)
async def _process_multimodal_async(
self,
prompt: Union[str, List[int]],
mm_data: MultiModalDataDict,
mm_processor_kwargs: Optional[Mapping[str, object]],
lora_request: Optional[LoRARequest],
) -> MultiModalInputsV2:
"""Async version of :meth:`_process_multimodal`."""
tokenizer_group = self.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request
)
mm_processor = self.mm_registry.create_processor(
self.model_config, tokenizer)
if isinstance(prompt, list):
logger.warning("Passing `multi_modal_data` in TokensPrompt is"
"deprecated and will be removed in a future update")
prompt = tokenizer.decode(prompt)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs)
def _prompt_to_llm_inputs( def _prompt_to_llm_inputs(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> SingletonInputs: ) -> SingletonInputs:
''' """
Extract the components of any single encoder or decoder input prompt. Extract the singleton inputs from a prompt.
Arguments: Arguments:
...@@ -215,12 +284,8 @@ class InputPreprocessor: ...@@ -215,12 +284,8 @@ class InputPreprocessor:
Returns: Returns:
* prompt * :class:`SingletonInputs` instance
* prompt_token_ids """
* multi_modal_data
* mm_processor_kwargs (request-level input processor/mapper overrides)
'''
parsed = parse_singleton_prompt(prompt) parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "str": if parsed["type"] == "str":
...@@ -243,6 +308,14 @@ class InputPreprocessor: ...@@ -243,6 +308,14 @@ class InputPreprocessor:
multi_modal_data = tokens_content.get("multi_modal_data") multi_modal_data = tokens_content.get("multi_modal_data")
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
if multi_modal_data is not None and self._can_process_multimodal():
return self._process_multimodal(
prompt_token_ids,
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
)
return token_inputs( return token_inputs(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
...@@ -253,13 +326,22 @@ class InputPreprocessor: ...@@ -253,13 +326,22 @@ class InputPreprocessor:
text_content = parsed["content"] text_content = parsed["content"]
prompt_text = text_content["prompt"] prompt_text = text_content["prompt"]
multi_modal_data = text_content.get("multi_modal_data")
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
if multi_modal_data is not None and self._can_process_multimodal():
return self._process_multimodal(
prompt_text,
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
)
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt_text, prompt_text,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = text_content.get("multi_modal_data")
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
return token_inputs( return token_inputs(
prompt=prompt_text, prompt=prompt_text,
...@@ -299,6 +381,14 @@ class InputPreprocessor: ...@@ -299,6 +381,14 @@ class InputPreprocessor:
multi_modal_data = tokens_content.get("multi_modal_data") multi_modal_data = tokens_content.get("multi_modal_data")
mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs")
if multi_modal_data is not None and self._can_process_multimodal():
return await self._process_multimodal_async(
prompt_token_ids,
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
)
return token_inputs( return token_inputs(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
...@@ -309,13 +399,22 @@ class InputPreprocessor: ...@@ -309,13 +399,22 @@ class InputPreprocessor:
text_content = parsed["content"] text_content = parsed["content"]
prompt_text = text_content["prompt"] prompt_text = text_content["prompt"]
multi_modal_data = text_content.get("multi_modal_data")
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
if multi_modal_data is not None and self._can_process_multimodal():
return await self._process_multimodal_async(
prompt_text,
multi_modal_data,
mm_processor_kwargs,
lora_request=lora_request,
)
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt_text, prompt_text,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = text_content.get("multi_modal_data")
mm_processor_kwargs = text_content.get("mm_processor_kwargs")
return token_inputs( return token_inputs(
prompt=prompt_text, prompt=prompt_text,
...@@ -331,7 +430,8 @@ class InputPreprocessor: ...@@ -331,7 +430,8 @@ class InputPreprocessor:
encoder_inputs: SingletonInputs, encoder_inputs: SingletonInputs,
decoder_inputs: Optional[SingletonInputs], decoder_inputs: Optional[SingletonInputs],
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
if encoder_inputs["type"] == "token": if (encoder_inputs["type"] == "token"
or encoder_inputs["type"] == "multimodal"):
pass pass
else: else:
assert_never(encoder_inputs) assert_never(encoder_inputs)
...@@ -340,7 +440,8 @@ class InputPreprocessor: ...@@ -340,7 +440,8 @@ class InputPreprocessor:
dec_token_ids = self._prepare_decoder_input_ids_for_generation( dec_token_ids = self._prepare_decoder_input_ids_for_generation(
None) None)
decoder_inputs = token_inputs(dec_token_ids) decoder_inputs = token_inputs(dec_token_ids)
elif decoder_inputs["type"] == "token": elif (decoder_inputs["type"] == "token"
or decoder_inputs["type"] == "multimodal"):
dec_token_ids = self._prepare_decoder_input_ids_for_generation( dec_token_ids = self._prepare_decoder_input_ids_for_generation(
decoder_inputs["prompt_token_ids"]) decoder_inputs["prompt_token_ids"])
decoder_inputs["prompt_token_ids"] = dec_token_ids decoder_inputs["prompt_token_ids"] = dec_token_ids
...@@ -361,7 +462,7 @@ class InputPreprocessor: ...@@ -361,7 +462,7 @@ class InputPreprocessor:
prompt: PromptType, prompt: PromptType,
request_id: str, request_id: str,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
''' """
For encoder/decoder models only: For encoder/decoder models only:
Process an input prompt into an :class:`EncoderDecoderInputs` instance. Process an input prompt into an :class:`EncoderDecoderInputs` instance.
...@@ -391,8 +492,7 @@ class InputPreprocessor: ...@@ -391,8 +492,7 @@ class InputPreprocessor:
Returns: Returns:
* :class:`EncoderDecoderInputs` instance * :class:`EncoderDecoderInputs` instance
''' """
encoder_inputs: SingletonInputs encoder_inputs: SingletonInputs
decoder_inputs: Optional[SingletonInputs] decoder_inputs: Optional[SingletonInputs]
...@@ -460,7 +560,8 @@ class InputPreprocessor: ...@@ -460,7 +560,8 @@ class InputPreprocessor:
prompt_inputs: DecoderOnlyInputs, prompt_inputs: DecoderOnlyInputs,
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
if prompt_inputs["type"] == "token": if (prompt_inputs["type"] == "token"
or prompt_inputs["type"] == "multimodal"):
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
prompt_inputs["prompt_token_ids"], prompt_inputs["prompt_token_ids"],
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
...@@ -477,7 +578,7 @@ class InputPreprocessor: ...@@ -477,7 +578,7 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
''' """
For decoder-only models: For decoder-only models:
Process an input prompt into an :class:`DecoderOnlyInputs` instance. Process an input prompt into an :class:`DecoderOnlyInputs` instance.
...@@ -491,7 +592,7 @@ class InputPreprocessor: ...@@ -491,7 +592,7 @@ class InputPreprocessor:
Returns: Returns:
* :class:`DecoderOnlyInputs` instance * :class:`DecoderOnlyInputs` instance
''' """
prompt_comps = self._prompt_to_llm_inputs( prompt_comps = self._prompt_to_llm_inputs(
prompt, prompt,
......
...@@ -5,14 +5,17 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple, ...@@ -5,14 +5,17 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
Optional, Protocol, Type, cast) Optional, Protocol, Type, cast)
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
resolve_mm_processor_kwargs) resolve_mm_processor_kwargs)
from .data import ProcessorInputs from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -61,6 +64,19 @@ class InputContext: ...@@ -61,6 +64,19 @@ class InputContext:
return self.model_config.hf_image_processor_config return self.model_config.hf_image_processor_config
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs."""
def get_hf_processor(self) -> ProcessorMixin:
return cached_get_processor(
self.model_config.tokenizer,
tokenizer=self.tokenizer, # Override the tokenizer with ours
trust_remote_code=self.model_config.trust_remote_code,
)
N = TypeVar("N", bound=Type[nn.Module]) N = TypeVar("N", bound=Type[nn.Module])
...@@ -94,7 +110,7 @@ class DummyDataFactory(Protocol): ...@@ -94,7 +110,7 @@ class DummyDataFactory(Protocol):
... ...
class _MultiModalCounts(UserDict): class _MultiModalCounts(UserDict[str, int]):
""" """
Wraps `mm_counts` for a more informative error message Wraps `mm_counts` for a more informative error message
when attempting to access a plugin that does not exist. when attempting to access a plugin that does not exist.
...@@ -287,6 +303,21 @@ class InputRegistry: ...@@ -287,6 +303,21 @@ class InputRegistry:
return self._input_processors_by_model_type \ return self._input_processors_by_model_type \
.get(model_cls, self._default_input_processor) .get(model_cls, self._default_input_processor)
def _ensure_mm_kwargs(
self,
inputs: SingletonInputs,
mm_processor_kwargs: Dict[str, Any],
):
if inputs["type"] == "token":
# In case the input processor for that model fails to set it
if "mm_processor_kwargs" not in inputs:
inputs["mm_processor_kwargs"] = mm_processor_kwargs
elif inputs["type"] == "multimodal":
# Be more strict in V2
assert "mm_kwargs" in inputs
else:
assert_never(inputs["type"])
def process_input(self, model_config: "ModelConfig", def process_input(self, model_config: "ModelConfig",
inputs: ProcessorInputs) -> ProcessorInputs: inputs: ProcessorInputs) -> ProcessorInputs:
""" """
...@@ -312,8 +343,21 @@ class InputRegistry: ...@@ -312,8 +343,21 @@ class InputRegistry:
processor, processor,
) )
return processor(InputContext(model_config), inputs, processed_inputs = processor(
**mm_processor_kwargs) InputContext(model_config),
inputs,
**mm_processor_kwargs,
)
if is_encoder_decoder_inputs(processed_inputs):
self._ensure_mm_kwargs(processed_inputs["encoder"],
mm_processor_kwargs)
self._ensure_mm_kwargs(processed_inputs["decoder"],
mm_processor_kwargs)
else:
self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)
return processed_inputs
def create_input_processor(self, model_config: "ModelConfig"): def create_input_processor(self, model_config: "ModelConfig"):
""" """
......
...@@ -30,8 +30,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -30,8 +30,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalData from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
......
...@@ -32,8 +32,7 @@ from vllm.model_executor.layers.linear import ColumnParallelLinear ...@@ -32,8 +32,7 @@ from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.base import MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges) consecutive_placeholder_ranges)
......
...@@ -15,8 +15,7 @@ from transformers import PretrainedConfig ...@@ -15,8 +15,7 @@ from transformers import PretrainedConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs) token_inputs)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.base import MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.utils import is_list_of from vllm.utils import is_list_of
......
...@@ -25,8 +25,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler ...@@ -25,8 +25,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.intern_vit import (InternVisionModel, from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel) InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.base import MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment