Unverified Commit 73391a1b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Move InputPreprocessor into Renderer (1/2) (#34510)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent b3c14229
...@@ -54,6 +54,7 @@ class MockModelConfig: ...@@ -54,6 +54,7 @@ class MockModelConfig:
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False skip_tokenizer_init = False
is_encoder_decoder: bool = False is_encoder_decoder: bool = False
is_multimodal_model: bool = False
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
...@@ -67,7 +68,7 @@ class MockVllmConfig: ...@@ -67,7 +68,7 @@ class MockVllmConfig:
def _build_renderer(model_config: MockModelConfig): def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config) _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer( return HfRenderer.from_config(
MockVllmConfig(model_config), MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name}, tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
) )
......
...@@ -53,6 +53,7 @@ class MockModelConfig: ...@@ -53,6 +53,7 @@ class MockModelConfig:
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False skip_tokenizer_init = False
is_encoder_decoder: bool = False is_encoder_decoder: bool = False
is_multimodal_model: bool = False
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
...@@ -78,7 +79,7 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion: ...@@ -78,7 +79,7 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
def _build_renderer(model_config: MockModelConfig): def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config) _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer( return HfRenderer.from_config(
MockVllmConfig(model_config), MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name}, tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
) )
......
...@@ -52,6 +52,7 @@ class MockModelConfig: ...@@ -52,6 +52,7 @@ class MockModelConfig:
generation_config: str = "auto" generation_config: str = "auto"
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False is_encoder_decoder: bool = False
is_multimodal_model: bool = False
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
...@@ -95,7 +96,7 @@ def register_mock_resolver(): ...@@ -95,7 +96,7 @@ def register_mock_resolver():
def _build_renderer(model_config: MockModelConfig): def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config) _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer( return HfRenderer.from_config(
MockVllmConfig(model_config), MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name}, tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
) )
......
...@@ -529,6 +529,7 @@ class MockModelConfig: ...@@ -529,6 +529,7 @@ class MockModelConfig:
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False is_encoder_decoder: bool = False
is_multimodal_model: bool = False
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}
...@@ -542,7 +543,7 @@ class MockVllmConfig: ...@@ -542,7 +543,7 @@ class MockVllmConfig:
def _build_renderer(model_config: MockModelConfig): def _build_renderer(model_config: MockModelConfig):
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config) _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
return HfRenderer( return HfRenderer.from_config(
MockVllmConfig(model_config), MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name}, tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
) )
...@@ -756,9 +757,8 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated(): ...@@ -756,9 +757,8 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated():
mock_tokenizer = MagicMock(spec=MistralTokenizer) mock_tokenizer = MagicMock(spec=MistralTokenizer)
mock_renderer = MistralRenderer( mock_renderer = MistralRenderer(
MockVllmConfig(mock_engine.model_config), MockVllmConfig(mock_engine.model_config),
tokenizer_kwargs={}, tokenizer=mock_tokenizer,
) )
mock_renderer._tokenizer = mock_tokenizer
# Force the Mistral chat template renderer to return token IDs. # Force the Mistral chat template renderer to return token IDs.
# Choose a prompt length that is < max_model_len, but large enough that # Choose a prompt length that is < max_model_len, but large enough that
# adding max_tokens should exceed the model context window. # adding max_tokens should exceed the model context window.
...@@ -798,9 +798,8 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected(): ...@@ -798,9 +798,8 @@ async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
mock_tokenizer = MagicMock(spec=MistralTokenizer) mock_tokenizer = MagicMock(spec=MistralTokenizer)
mock_renderer = MistralRenderer( mock_renderer = MistralRenderer(
MockVllmConfig(mock_engine.model_config), MockVllmConfig(mock_engine.model_config),
tokenizer_kwargs={}, tokenizer=mock_tokenizer,
) )
mock_renderer._tokenizer = mock_tokenizer
# prompt_token_ids length == max_model_len should be rejected for # prompt_token_ids length == max_model_len should be rejected for
# completion-like requests (ChatCompletionRequest). # completion-like requests (ChatCompletionRequest).
mock_renderer.render_messages_async = AsyncMock( mock_renderer.render_messages_async = AsyncMock(
......
...@@ -38,6 +38,7 @@ class MockModelConfig: ...@@ -38,6 +38,7 @@ class MockModelConfig:
enable_prompt_embeds: bool = True enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False is_encoder_decoder: bool = False
is_multimodal_model: bool = False
@dataclass @dataclass
...@@ -78,15 +79,16 @@ def _build_renderer( ...@@ -78,15 +79,16 @@ def _build_renderer(
renderer = HfRenderer( renderer = HfRenderer(
MockVllmConfig(model_config), MockVllmConfig(model_config),
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name}, tokenizer=(
None
if model_config.skip_tokenizer_init
else DummyTokenizer(
truncation_side=truncation_side,
max_chars_per_token=max_chars_per_token,
)
),
) )
if not model_config.skip_tokenizer_init:
renderer._tokenizer = DummyTokenizer(
truncation_side=truncation_side,
max_chars_per_token=max_chars_per_token,
)
return renderer return renderer
...@@ -277,7 +279,7 @@ class TestRenderPrompt: ...@@ -277,7 +279,7 @@ class TestRenderPrompt:
) )
# Should not even attempt tokenization # Should not even attempt tokenization
assert renderer._tokenizer._captured_encode_kwargs == {} assert renderer.tokenizer._captured_encode_kwargs == {}
def test_text_max_length_exceeded_nonobvious(self): def test_text_max_length_exceeded_nonobvious(self):
renderer = _build_renderer(MockModelConfig(), max_chars_per_token=2) renderer = _build_renderer(MockModelConfig(), max_chars_per_token=2)
...@@ -298,8 +300,8 @@ class TestRenderPrompt: ...@@ -298,8 +300,8 @@ class TestRenderPrompt:
) )
# Should only tokenize the first max_total_tokens + 1 tokens # Should only tokenize the first max_total_tokens + 1 tokens
assert renderer._tokenizer._captured_encode_kwargs["truncation"] is True assert renderer.tokenizer._captured_encode_kwargs["truncation"] is True
assert renderer._tokenizer._captured_encode_kwargs["max_length"] == 101 assert renderer.tokenizer._captured_encode_kwargs["max_length"] == 101
def test_token_max_length_exceeded(self): def test_token_max_length_exceeded(self):
renderer = _build_renderer(MockModelConfig()) renderer = _build_renderer(MockModelConfig())
......
...@@ -36,6 +36,7 @@ class MockModelConfig: ...@@ -36,6 +36,7 @@ class MockModelConfig:
enable_prompt_embeds: bool = True enable_prompt_embeds: bool = True
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False is_encoder_decoder: bool = False
is_multimodal_model: bool = False
@dataclass @dataclass
...@@ -57,9 +58,8 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop(): ...@@ -57,9 +58,8 @@ async def test_async_mistral_tokenizer_does_not_block_event_loop():
mock_tokenizer.apply_chat_template = mocked_apply_chat_template mock_tokenizer.apply_chat_template = mocked_apply_chat_template
mock_renderer = MistralRenderer( mock_renderer = MistralRenderer(
MockVllmConfig(mock_model_config), MockVllmConfig(mock_model_config),
tokenizer_kwargs={}, tokenizer=mock_tokenizer,
) )
mock_renderer._tokenizer = mock_tokenizer
task = mock_renderer.render_messages_async([], ChatParams()) task = mock_renderer.render_messages_async([], ChatParams())
......
...@@ -19,7 +19,7 @@ import pytest ...@@ -19,7 +19,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from vllm import SamplingParams from vllm import SamplingParams
from vllm.inputs import StreamingInput from vllm.engine.protocol import StreamingInput
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
......
...@@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, MagicMock ...@@ -7,7 +7,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from vllm.inputs import StreamingInput from vllm.engine.protocol import StreamingInput
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
......
...@@ -18,7 +18,7 @@ import dataclasses ...@@ -18,7 +18,7 @@ import dataclasses
import json import json
import time import time
from datetime import datetime from datetime import datetime
from typing import Any from typing import TYPE_CHECKING, Any
import numpy as np import numpy as np
...@@ -28,9 +28,6 @@ from vllm.benchmarks.datasets import ( ...@@ -28,9 +28,6 @@ from vllm.benchmarks.datasets import (
) )
from vllm.benchmarks.throughput import get_requests from vllm.benchmarks.throughput import get_requests
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.multimodal.processing.context import (
get_timing_stats_from_engine_client,
)
from vllm.utils.gc_utils import freeze_gc_heap from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.import_utils import PlaceholderModule from vllm.utils.import_utils import PlaceholderModule
...@@ -39,16 +36,103 @@ try: ...@@ -39,16 +36,103 @@ try:
except ImportError: except ImportError:
pd = PlaceholderModule("pandas") pd = PlaceholderModule("pandas")
if TYPE_CHECKING: # Avoid having to mock during docs build
from vllm.v1.engine.llm_engine import LLMEngine
else:
LLMEngine = object
def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, float]]:
"""
Get all multimodal timing stats from the LLM engine.
Collects both preprocessing stats (HF processor, hashing, cache lookup,
prompt update) and encoder forward pass timing, merged by request_id.
Args:
llm_engine: The LLM engine (has input_processor and workers).
Returns:
Dictionary mapping request_id to merged stats dict containing
both preprocessing and encoder timing metrics.
Example:
{
'request-123': {
'hf_processor_time': 0.45,
'hashing_time': 0.02,
'cache_lookup_time': 0.01,
'prompt_update_time': 0.03,
'preprocessor_total_time': 0.51,
'encoder_forward_time': 0.23,
'num_encoder_calls': 1
}
}
"""
observability_config = llm_engine.vllm_config.observability_config
if not observability_config or not observability_config.enable_mm_processor_stats:
return {}
renderer = llm_engine.renderer
mm_processor = renderer.get_mm_processor()
preprocessing_stats = mm_processor.info.ctx.get_all_timing_stats()
encoder_stats = dict[str, dict[str, float]]()
for worker_stats in llm_engine.collective_rpc("get_encoder_timing_stats"):
if not worker_stats:
continue
for request_id, stats_dict in worker_stats.items():
if request_id not in encoder_stats:
encoder_stats[request_id] = dict(stats_dict)
else:
# Aggregate timing metrics across workers
current_time = encoder_stats[request_id].get(
"encoder_forward_time", 0.0
)
new_time = stats_dict.get("encoder_forward_time", 0.0)
encoder_stats[request_id]["encoder_forward_time"] = max(
current_time, new_time
)
current_calls = encoder_stats[request_id].get("num_encoder_calls", 0)
new_calls = stats_dict.get("num_encoder_calls", 0)
encoder_stats[request_id]["num_encoder_calls"] = max(
current_calls, new_calls
)
merged_stats = dict[str, dict[str, float]]()
for request_id, prep_dict in preprocessing_stats.items():
merged_stats[request_id] = dict(prep_dict)
for request_id, enc_dict in encoder_stats.items():
if request_id in merged_stats:
merged_stats[request_id].update(enc_dict)
continue
# In V1 engine, the request_id in encoder_stats has a suffix
# appended to the original request_id (which is used in
# preprocessing_stats).
# We try to strip the suffix to find the matching request.
possible_original_id = request_id.rpartition("-")[0]
if possible_original_id and possible_original_id in merged_stats:
merged_stats[possible_original_id].update(enc_dict)
else:
merged_stats[request_id] = dict(enc_dict)
return merged_stats
def collect_mm_processor_stats( def collect_mm_processor_stats(
llm_engine: Any, llm_engine: LLMEngine,
num_warmup_reqs: int = 0, num_warmup_reqs: int = 0,
) -> dict[str, list[float]]: ) -> dict[str, list[float]]:
""" """
Collect multimodal processor timing stats. Collect multimodal processor timing stats.
Returns a dictionary mapping stage names to lists of timing values (in seconds). Returns a dictionary mapping stage names to lists of timing values (in seconds).
""" """
all_stats = get_timing_stats_from_engine_client(llm_engine) all_stats = get_timing_stats_from_engine(llm_engine)
stat_keys = [ stat_keys = [
"hf_processor_time", "hf_processor_time",
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Iterable, Mapping from collections.abc import AsyncGenerator, Iterable, Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
...@@ -10,7 +11,7 @@ from vllm.distributed.weight_transfer.base import ( ...@@ -10,7 +11,7 @@ from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest, WeightTransferInitRequest,
WeightTransferUpdateRequest, WeightTransferUpdateRequest,
) )
from vllm.inputs.data import PromptType, StreamingInput from vllm.inputs.data import PromptType
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor from vllm.plugins.io_processors import IOProcessor
...@@ -26,6 +27,18 @@ if TYPE_CHECKING: ...@@ -26,6 +27,18 @@ if TYPE_CHECKING:
from vllm.v1.engine import PauseMode from vllm.v1.engine import PauseMode
@dataclass
class StreamingInput:
"""Input data for a streaming generation request.
This is used with generate() to support multi-turn streaming sessions
where inputs are provided via an async generator.
"""
prompt: PromptType
sampling_params: SamplingParams | None = None
class EngineClient(ABC): class EngineClient(ABC):
"""Protocol class for Clients to Engine""" """Protocol class for Clients to Engine"""
......
...@@ -72,7 +72,7 @@ from vllm.outputs import ( ...@@ -72,7 +72,7 @@ from vllm.outputs import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs from vllm.renderers import ChatParams, merge_kwargs
from vllm.renderers.inputs import DictPrompt, TokPrompt from vllm.renderers.inputs import DictPrompt, TokPrompt
from vllm.renderers.inputs.preprocess import ( from vllm.renderers.inputs.preprocess import (
conversation_to_seq, conversation_to_seq,
...@@ -384,7 +384,7 @@ class LLM: ...@@ -384,7 +384,7 @@ class LLM:
return parallel_config.world_size return parallel_config.world_size
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
self.input_processor.clear_mm_cache() self.renderer.clear_mm_cache()
self.llm_engine.reset_mm_cache() self.llm_engine.reset_mm_cache()
def get_default_sampling_params(self) -> SamplingParams: def get_default_sampling_params(self) -> SamplingParams:
...@@ -876,19 +876,6 @@ class LLM: ...@@ -876,19 +876,6 @@ class LLM:
return outputs return outputs
def _get_cmpl_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
add_special_tokens=not model_config.is_encoder_decoder,
).with_kwargs(tokenization_kwargs)
def _preprocess_cmpl( def _preprocess_cmpl(
self, self,
prompts: Sequence[PromptType], prompts: Sequence[PromptType],
...@@ -910,20 +897,12 @@ class LLM: ...@@ -910,20 +897,12 @@ class LLM:
parsed_prompts = [ parsed_prompts = [
parse_model_prompt(model_config, prompt) for prompt in prompts parse_model_prompt(model_config, prompt) for prompt in prompts
] ]
tok_params = self._get_cmpl_tok_params(tokenization_kwargs) tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
return renderer.render_cmpl(parsed_prompts, tok_params) return renderer.render_cmpl(parsed_prompts, tok_params)
def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=False,
).with_kwargs(tokenization_kwargs)
def _preprocess_chat( def _preprocess_chat(
self, self,
conversations: Sequence[list[ChatCompletionMessageParam]], conversations: Sequence[list[ChatCompletionMessageParam]],
...@@ -961,7 +940,9 @@ class LLM: ...@@ -961,7 +940,9 @@ class LLM:
), ),
), ),
) )
tok_params = self._get_chat_tok_params(tokenization_kwargs) tok_params = renderer.default_chat_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
_, engine_prompts = renderer.render_chat( _, engine_prompts = renderer.render_chat(
conversations, conversations,
...@@ -1653,7 +1634,10 @@ class LLM: ...@@ -1653,7 +1634,10 @@ class LLM:
architecture=architecture, architecture=architecture,
) )
tok_params = self._get_cmpl_tok_params(tokenization_kwargs) renderer = self.renderer
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
encode_kwargs = tok_params.get_encode_kwargs() encode_kwargs = tok_params.get_encode_kwargs()
if model_config.is_cross_encoder: if model_config.is_cross_encoder:
...@@ -1970,7 +1954,10 @@ class LLM: ...@@ -1970,7 +1954,10 @@ class LLM:
dict(truncate_prompt_tokens=params.truncate_prompt_tokens), dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
) )
tok_params = self._get_cmpl_tok_params(tokenization_kwargs) renderer = self.renderer
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
tokenization_kwargs = tok_params.get_encode_kwargs() tokenization_kwargs = tok_params.get_encode_kwargs()
engine_request = self.input_processor.process_inputs( engine_request = self.input_processor.process_inputs(
......
...@@ -8,11 +8,11 @@ from typing import Literal, cast ...@@ -8,11 +8,11 @@ from typing import Literal, cast
import numpy as np import numpy as np
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.inputs.data import PromptType, StreamingInput from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsRealtime from vllm.model_executor.models.interfaces import SupportsRealtime
......
...@@ -12,7 +12,6 @@ from .data import ( ...@@ -12,7 +12,6 @@ from .data import (
PromptType, PromptType,
SingletonInputs, SingletonInputs,
SingletonPrompt, SingletonPrompt,
StreamingInput,
TextPrompt, TextPrompt,
TokenInputs, TokenInputs,
TokensPrompt, TokensPrompt,
...@@ -36,5 +35,4 @@ __all__ = [ ...@@ -36,5 +35,4 @@ __all__ = [
"EncoderDecoderInputs", "EncoderDecoderInputs",
"ProcessorInputs", "ProcessorInputs",
"SingletonInputs", "SingletonInputs",
"StreamingInput",
] ]
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, TypeAlias from typing import TYPE_CHECKING, Any, Literal, TypeAlias
import torch import torch
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
from vllm.sampling_params import SamplingParams
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
...@@ -299,15 +296,3 @@ which can be passed to ...@@ -299,15 +296,3 @@ which can be passed to
SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs SingletonInputs: TypeAlias = DecoderOnlyInputs | MultiModalEncDecInputs
"""The inputs for a single encoder/decoder prompt.""" """The inputs for a single encoder/decoder prompt."""
@dataclass
class StreamingInput:
"""Input data for a streaming generation request.
This is used with generate() to support multi-turn streaming sessions
where inputs are provided via an async generator.
"""
prompt: PromptType
sampling_params: SamplingParams | None = None
...@@ -9,13 +9,11 @@ from typing_extensions import assert_never ...@@ -9,13 +9,11 @@ from typing_extensions import assert_never
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
MultiModalInputs, MultiModalInputs,
MultiModalUUIDDict, MultiModalUUIDDict,
) )
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.renderers import BaseRenderer, renderer_from_config from vllm.renderers import BaseRenderer, renderer_from_config
from vllm.renderers.inputs import ( from vllm.renderers.inputs import (
DecoderDictPrompt, DecoderDictPrompt,
...@@ -28,8 +26,6 @@ from vllm.renderers.inputs import ( ...@@ -28,8 +26,6 @@ from vllm.renderers.inputs import (
) )
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats
from .data import ( from .data import (
DecoderInputs, DecoderInputs,
...@@ -57,17 +53,12 @@ class InputPreprocessor: ...@@ -57,17 +53,12 @@ class InputPreprocessor:
vllm_config: VllmConfig, vllm_config: VllmConfig,
renderer: BaseRenderer | None = None, renderer: BaseRenderer | None = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: BaseMultiModalProcessorCache | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.observability_config = vllm_config.observability_config
self.renderer = renderer or renderer_from_config(vllm_config) self.renderer = renderer or renderer_from_config(vllm_config)
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
@property @property
def tokenizer(self) -> TokenizerLike | None: def tokenizer(self) -> TokenizerLike | None:
...@@ -124,23 +115,6 @@ class InputPreprocessor: ...@@ -124,23 +115,6 @@ class InputPreprocessor:
return decoder_input_ids return decoder_input_ids
def _get_tokenization_kw(
self,
overrides: dict[str, Any] | None = None,
) -> dict[str, Any]:
kwargs = dict[str, Any]()
if self.model_config.is_encoder_decoder:
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
kwargs["add_special_tokens"] = False
if overrides:
kwargs.update(overrides)
return kwargs
def _tokenize_prompt( def _tokenize_prompt(
self, self,
prompt: str, prompt: str,
...@@ -150,26 +124,18 @@ class InputPreprocessor: ...@@ -150,26 +124,18 @@ class InputPreprocessor:
Apply the model's tokenizer to a text prompt, returning the Apply the model's tokenizer to a text prompt, returning the
corresponding token IDs. corresponding token IDs.
""" """
tokenizer = self.get_tokenizer() renderer = self.renderer
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
encoder_config = self.model_config.encoder_config
if encoder_config and encoder_config.get("do_lower_case", False): tok_params = renderer.default_cmpl_tok_params.with_kwargs(
prompt = prompt.lower() **(tokenization_kwargs or {})
)
return tokenizer.encode(prompt, **tokenization_kwargs)
def _get_mm_processor(self) -> BaseMultiModalProcessor: tok_prompt = renderer.tokenize_prompt(
if not hasattr(self, "_mm_processor"): TextPrompt(prompt=prompt),
self._mm_processor = self.mm_registry.create_processor( tok_params,
self.model_config, )
self.observability_config,
tokenizer=self.tokenizer,
cache=self.mm_processor_cache,
)
return self._mm_processor return tok_prompt["prompt_token_ids"]
def _process_multimodal( def _process_multimodal(
self, self,
...@@ -184,33 +150,20 @@ class InputPreprocessor: ...@@ -184,33 +150,20 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt, Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata. returning the corresponding token IDs and metadata.
""" """
mm_processor = self._get_mm_processor() mm_processor = self.renderer.get_mm_processor()
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
mm_items = mm_processor.info.parse_mm_data(mm_data) mm_items = mm_processor.info.parse_mm_data(mm_data)
mm_input = mm_processor.apply(
return mm_processor.apply(
prompt, prompt,
mm_items, mm_items,
hf_processor_mm_kwargs=mm_processor_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
mm_hashes = mm_input["mm_hashes"]
# Validate that all mm items have a string as their hash
contains_only_strings = all(
isinstance(leaf, str) for leaf in json_iter_leaves(mm_hashes)
)
if not contains_only_strings:
raise ValueError(
f"mm_hashes must contain only strings, got: {mm_hashes}. "
"This is likely due to an incorrect custom implementation of "
"MultiModalProcessor.apply method."
)
return mm_input
def _process_embeds( def _process_embeds(
self, self,
...@@ -245,19 +198,18 @@ class InputPreprocessor: ...@@ -245,19 +198,18 @@ class InputPreprocessor:
def _truncate_inputs( def _truncate_inputs(
self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
) -> list[int]: ) -> list[int]:
if ( renderer = self.renderer
not tokenization_kwargs
or "truncation" not in tokenization_kwargs
or self.tokenizer is None
):
return inputs
max_length = tokenization_kwargs["max_length"] tok_params = renderer.default_cmpl_tok_params.with_kwargs(
**(tokenization_kwargs or {})
)
if self.tokenizer.truncation_side == "left": tok_prompt = renderer.tokenize_prompt(
return inputs[-max_length:] TokensPrompt(prompt_token_ids=inputs),
else: tok_params,
return inputs[:max_length] )
return tok_prompt["prompt_token_ids"]
def _process_tokens( def _process_tokens(
self, self,
...@@ -539,26 +491,6 @@ class InputPreprocessor: ...@@ -539,26 +491,6 @@ class InputPreprocessor:
"""Preprocess the input prompt.""" """Preprocess the input prompt."""
res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids) res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids)
if self.mm_processor_cache and self.mm_cache_stats is not None: self.renderer.update_mm_cache_stats()
delta = self.mm_processor_cache.make_stats(delta=True)
self.mm_cache_stats.requests += 1
self.mm_cache_stats.queries += delta.total
self.mm_cache_stats.hits += delta.hits
return res return res
def stat_mm_cache(self) -> MultiModalCacheStats | None:
mm_cache_stats = self.mm_cache_stats
if mm_cache_stats is None:
return None
self.mm_cache_stats = MultiModalCacheStats()
return mm_cache_stats
def clear_mm_cache(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache()
if self.mm_cache_stats is not None:
self.mm_cache_stats.reset = True
...@@ -208,14 +208,23 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]): ...@@ -208,14 +208,23 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if prompt and mm_items:
raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! "
"Image-only inputs means passing an image with an empty text "
"prompt."
)
if mm_items: if mm_items:
if isinstance(prompt, str):
if len(prompt) > 0:
raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty text prompt."
)
else:
special_tokens = self.info.get_tokenizer().all_special_ids
if all(tok in special_tokens for tok in prompt):
prompt = []
else:
raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty token prompt."
)
# For multi-modal data, the prompt after processing should # For multi-modal data, the prompt after processing should
# only contain the dummy image tokens # only contain the dummy image tokens
tokenization_kwargs = { tokenization_kwargs = {
......
...@@ -42,6 +42,7 @@ from vllm.multimodal.processing import ( ...@@ -42,6 +42,7 @@ from vllm.multimodal.processing import (
PromptReplacement, PromptReplacement,
PromptUpdateDetails, PromptUpdateDetails,
) )
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -90,6 +91,9 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo): ...@@ -90,6 +91,9 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
def get_image_processor(self, **kwargs: object) -> Lfm2VlImageProcessorFast: def get_image_processor(self, **kwargs: object) -> Lfm2VlImageProcessorFast:
return self.get_hf_processor(**kwargs).image_processor return self.get_hf_processor(**kwargs).image_processor
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None} return {"image": None}
......
...@@ -66,6 +66,7 @@ from vllm.multimodal.processing import ( ...@@ -66,6 +66,7 @@ from vllm.multimodal.processing import (
PromptUpdate, PromptUpdate,
PromptUpdateDetails, PromptUpdateDetails,
) )
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -554,6 +555,9 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): ...@@ -554,6 +555,9 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs Llama4Processor, use_fast=kwargs.pop("use_fast", True), **kwargs
) )
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
# Although vLLM can support more images from an infra capability # Although vLLM can support more images from an infra capability
# perspective, we do not recommend using >10 images in practice. # perspective, we do not recommend using >10 images in practice.
......
...@@ -76,6 +76,7 @@ from vllm.multimodal.processing.processor import ( ...@@ -76,6 +76,7 @@ from vllm.multimodal.processing.processor import (
PromptUpdateDetails, PromptUpdateDetails,
_seq2tokens, _seq2tokens,
) )
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.transformers_utils.configs.radio import RadioConfig from vllm.transformers_utils.configs.radio import RadioConfig
...@@ -1093,6 +1094,9 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo): ...@@ -1093,6 +1094,9 @@ class BaseNanoNemotronVLProcessingInfo(BaseProcessingInfo):
) -> BaseNanoNemotronVLProcessor: ) -> BaseNanoNemotronVLProcessor:
raise NotImplementedError raise NotImplementedError
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None} return {"image": None}
......
...@@ -58,6 +58,7 @@ from vllm.multimodal.processing import ( ...@@ -58,6 +58,7 @@ from vllm.multimodal.processing import (
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
) )
from vllm.renderers import TokenizeParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.configs.radio import RadioConfig from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -608,6 +609,9 @@ class NemotronParseProcessingInfo(BaseProcessingInfo): ...@@ -608,6 +609,9 @@ class NemotronParseProcessingInfo(BaseProcessingInfo):
**kwargs, **kwargs,
) )
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
@property @property
def skip_prompt_length_check(self) -> bool: def skip_prompt_length_check(self) -> bool:
return True # Because the encoder prompt is padded return True # Because the encoder prompt is padded
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment