Unverified Commit 73391a1b authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Renderer] Move InputPreprocessor into Renderer (1/2) (#34510)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent b3c14229
...@@ -53,6 +53,7 @@ from vllm.multimodal.processing import ( ...@@ -53,6 +53,7 @@ from vllm.multimodal.processing import (
BaseProcessingInfo, BaseProcessingInfo,
PromptReplacement, PromptReplacement,
) )
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis import OvisProcessor from vllm.transformers_utils.processors.ovis import OvisProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -264,6 +265,9 @@ class OvisProcessingInfo(BaseProcessingInfo): ...@@ -264,6 +265,9 @@ class OvisProcessingInfo(BaseProcessingInfo):
**kwargs, **kwargs,
) )
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_image_segment_len(self) -> int: def get_image_segment_len(self) -> int:
visual_tokenizer_config = self.get_hf_config().visual_tokenizer_config visual_tokenizer_config = self.get_hf_config().visual_tokenizer_config
image_size = visual_tokenizer_config.backbone_config.image_size image_size = visual_tokenizer_config.backbone_config.image_size
......
...@@ -35,6 +35,7 @@ from vllm.multimodal.processing import ( ...@@ -35,6 +35,7 @@ from vllm.multimodal.processing import (
BaseProcessingInfo, BaseProcessingInfo,
PromptReplacement, PromptReplacement,
) )
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -183,6 +184,9 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo): ...@@ -183,6 +184,9 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo):
temporal_patch_size=vit_config.temporal_patch_size, temporal_patch_size=vit_config.temporal_patch_size,
) )
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_image_processor(self) -> BaseImageProcessor: def get_image_processor(self) -> BaseImageProcessor:
return self.get_hf_processor().image_processor # type: ignore return self.get_hf_processor().image_processor # type: ignore
......
...@@ -32,6 +32,7 @@ from vllm.multimodal.processing import ( ...@@ -32,6 +32,7 @@ from vllm.multimodal.processing import (
PromptUpdate, PromptUpdate,
PromptUpdateDetails, PromptUpdateDetails,
) )
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -102,6 +103,9 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo): ...@@ -102,6 +103,9 @@ class PaliGemmaProcessingInfo(BaseProcessingInfo):
def get_vision_encoder_info(self): def get_vision_encoder_info(self):
return get_vision_encoder_info(self.get_hf_config()) return get_vision_encoder_info(self.get_hf_config())
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": 1} return {"image": 1}
......
...@@ -194,14 +194,23 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]): ...@@ -194,14 +194,23 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
*, *,
mm_uuids: MultiModalUUIDDict | None = None, mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if prompt and mm_items:
raise ValueError(
"Siglip accepts text-only or image-only inputs, not both! "
"Image-only inputs means passing an image with an empty text "
"prompt."
)
if mm_items: if mm_items:
if isinstance(prompt, str):
if len(prompt) > 0:
raise ValueError(
"SigLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty text prompt."
)
else:
special_tokens = self.info.get_tokenizer().all_special_ids
if all(tok in special_tokens for tok in prompt):
prompt = []
else:
raise ValueError(
"SigLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty token prompt."
)
# For multi-modal data, the prompt after processing should # For multi-modal data, the prompt after processing should
# only contain the image token # only contain the image token
tokenization_kwargs = { tokenization_kwargs = {
......
...@@ -42,6 +42,7 @@ from vllm.multimodal.processing import ( ...@@ -42,6 +42,7 @@ from vllm.multimodal.processing import (
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
) )
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -133,6 +134,9 @@ class UltravoxProcessingInfo(BaseProcessingInfo): ...@@ -133,6 +134,9 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
assert isinstance(feature_extractor, WhisperFeatureExtractor) assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor return feature_extractor
def get_default_tok_params(self) -> TokenizeParams:
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_data_parser(self): def get_data_parser(self):
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
......
...@@ -17,8 +17,9 @@ from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig ...@@ -17,8 +17,9 @@ from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.engine.protocol import StreamingInput
from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S from vllm.envs import VLLM_ENGINE_ITERATION_TIMEOUT_S
from vllm.inputs.data import PromptType, StreamingInput, TokensPrompt from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime from vllm.model_executor.models.interfaces import MultiModalEmbeddings, SupportsRealtime
from vllm.model_executor.models.voxtral import ( from vllm.model_executor.models.voxtral import (
......
...@@ -55,6 +55,7 @@ from vllm.multimodal.processing import ( ...@@ -55,6 +55,7 @@ from vllm.multimodal.processing import (
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
) )
from vllm.renderers import TokenizeParams
from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.jsontree import json_map_leaves from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -644,6 +645,12 @@ class WhisperProcessingInfo(BaseProcessingInfo): ...@@ -644,6 +645,12 @@ class WhisperProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> WhisperConfig: def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig) return self.ctx.get_hf_config(WhisperConfig)
def get_default_tok_params(self) -> TokenizeParams:
# Special tokens should be provided by the user based on the
# task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_data_parser(self): def get_data_parser(self):
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
......
...@@ -21,6 +21,7 @@ from vllm.multimodal.parse import ( ...@@ -21,6 +21,7 @@ from vllm.multimodal.parse import (
MultiModalDataItems, MultiModalDataItems,
MultiModalDataParser, MultiModalDataParser,
) )
from vllm.renderers import TokenizeParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
...@@ -93,110 +94,6 @@ class MultiModalProcessorTimingStats: ...@@ -93,110 +94,6 @@ class MultiModalProcessorTimingStats:
} }
def get_timing_stats_from_engine_client(
engine_client: Any,
) -> dict[str, dict[str, float]]:
"""
Get all multimodal timing stats from the engine client.
Collects both preprocessing stats (HF processor, hashing, cache lookup,
prompt update) and encoder forward pass timing, merged by request_id.
Args:
engine_client: The engine client (has input_processor and workers).
Returns:
Dictionary mapping request_id to merged stats dict containing
both preprocessing and encoder timing metrics.
Example:
{
'request-123': {
'hf_processor_time': 0.45,
'hashing_time': 0.02,
'cache_lookup_time': 0.01,
'prompt_update_time': 0.03,
'preprocessor_total_time': 0.51,
'encoder_forward_time': 0.23,
'num_encoder_calls': 1
}
}
"""
try:
if not engine_client.vllm_config.observability_config.enable_mm_processor_stats:
return {}
except (AttributeError, RuntimeError):
return {}
preprocessing_stats = {}
try:
input_processor = engine_client.input_processor
input_preprocessor = input_processor.input_preprocessor
if hasattr(input_preprocessor, "_get_mm_processor"):
mm_processor = input_preprocessor._get_mm_processor()
if mm_processor is not None and hasattr(mm_processor, "info"):
ctx = mm_processor.info.ctx
preprocessing_stats = ctx.get_all_timing_stats()
except (AttributeError, RuntimeError):
pass
encoder_stats = {}
try:
if hasattr(engine_client, "collective_rpc"):
encoder_stats_results = engine_client.collective_rpc(
"get_encoder_timing_stats"
)
if encoder_stats_results and len(encoder_stats_results) > 0:
for worker_stats in encoder_stats_results:
if not worker_stats:
continue
for request_id, stats_dict in worker_stats.items():
if request_id not in encoder_stats:
encoder_stats[request_id] = dict(stats_dict)
else:
# Aggregate timing metrics across workers
current_time = encoder_stats[request_id].get(
"encoder_forward_time", 0.0
)
new_time = stats_dict.get("encoder_forward_time", 0.0)
encoder_stats[request_id]["encoder_forward_time"] = max(
current_time, new_time
)
current_calls = encoder_stats[request_id].get(
"num_encoder_calls", 0
)
new_calls = stats_dict.get("num_encoder_calls", 0)
encoder_stats[request_id]["num_encoder_calls"] = max(
current_calls, new_calls
)
except (AttributeError, RuntimeError):
pass
merged_stats = {}
for request_id, prep_dict in preprocessing_stats.items():
merged_stats[request_id] = dict(prep_dict)
for request_id, enc_dict in encoder_stats.items():
if request_id in merged_stats:
merged_stats[request_id].update(enc_dict)
continue
# In V1 engine, the request_id in encoder_stats has a suffix
# appended to the original request_id (which is used in
# preprocessing_stats).
# We try to strip the suffix to find the matching request.
possible_original_id = request_id.rpartition("-")[0]
if possible_original_id and possible_original_id in merged_stats:
merged_stats[possible_original_id].update(enc_dict)
else:
merged_stats[request_id] = dict(enc_dict)
return merged_stats
@contextmanager @contextmanager
def timed_preprocessor_operation(ctx: "InputProcessingContext", stage_name: str): def timed_preprocessor_operation(ctx: "InputProcessingContext", stage_name: str):
""" """
...@@ -576,6 +473,21 @@ class BaseProcessingInfo: ...@@ -576,6 +473,21 @@ class BaseProcessingInfo:
""" """
return self.ctx.get_hf_processor(**kwargs) return self.ctx.get_hf_processor(**kwargs)
def get_default_tok_params(self) -> TokenizeParams:
"""Construct the default parameters for tokenization."""
model_config = self.ctx.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=True,
)
@cached_property
def default_tok_params(self) -> TokenizeParams:
return self.get_default_tok_params()
def _get_expected_hidden_size(self) -> int | None: def _get_expected_hidden_size(self) -> int | None:
""" """
Get expected hidden size for embedding validation if `mm_embeds` are enabled. Get expected hidden size for embedding validation if `mm_embeds` are enabled.
......
...@@ -3,12 +3,17 @@ ...@@ -3,12 +3,17 @@
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, overload from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload
from typing_extensions import TypeVar
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.metrics.stats import MultiModalCacheStats
from .embed_utils import safe_load_prompt_embeds from .embed_utils import safe_load_prompt_embeds
from .inputs import ( from .inputs import (
...@@ -26,11 +31,16 @@ if TYPE_CHECKING: ...@@ -26,11 +31,16 @@ if TYPE_CHECKING:
ChatCompletionMessageParam, ChatCompletionMessageParam,
ConversationMessage, ConversationMessage,
) )
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.processing import BaseMultiModalProcessor
logger = init_logger(__name__) logger = init_logger(__name__)
class BaseRenderer(ABC): _T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)
class BaseRenderer(ABC, Generic[_T]):
@classmethod @classmethod
@abstractmethod @abstractmethod
def from_config( def from_config(
...@@ -40,20 +50,36 @@ class BaseRenderer(ABC): ...@@ -40,20 +50,36 @@ class BaseRenderer(ABC):
) -> "BaseRenderer": ) -> "BaseRenderer":
raise NotImplementedError raise NotImplementedError
def __init__(self, config: "VllmConfig") -> None: def __init__(self, config: "VllmConfig", tokenizer: _T | None) -> None:
super().__init__() super().__init__()
self.config = config
self.model_config = config.model_config self.model_config = config.model_config
self.tokenizer = tokenizer
# Lazy initialization since offline LLM doesn't use async # Lazy initialization since offline LLM doesn't use async
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
@property self.mm_processor: BaseMultiModalProcessor | None = None
@abstractmethod self._mm_cache_stats: MultiModalCacheStats | None = None
def tokenizer(self) -> TokenizerLike | None: if config.model_config.is_multimodal_model:
raise NotImplementedError from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
mm_processor_cache = mm_registry.processor_cache_from_config(config)
with set_default_torch_num_threads():
self.mm_processor = mm_registry.create_processor(
config.model_config,
config.observability_config,
tokenizer=tokenizer,
cache=mm_processor_cache,
)
if mm_processor_cache:
self._mm_cache_stats = MultiModalCacheStats()
def get_tokenizer(self) -> TokenizerLike: def get_tokenizer(self) -> _T:
tokenizer = self.tokenizer tokenizer = self.tokenizer
if tokenizer is None: if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`") raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
...@@ -66,6 +92,49 @@ class BaseRenderer(ABC): ...@@ -66,6 +92,49 @@ class BaseRenderer(ABC):
return self._async_tokenizer return self._async_tokenizer
def get_mm_processor(self) -> "BaseMultiModalProcessor":
if self.mm_processor is None:
raise ValueError("Multi-modal processor not available for text-only models")
return self.mm_processor
@property
def mm_processor_cache(self) -> "BaseMultiModalProcessorCache | None":
if self.mm_processor is None:
return None
return self.mm_processor.cache
def stat_mm_cache(self) -> MultiModalCacheStats | None:
mm_cache_stats = self._mm_cache_stats
if mm_cache_stats is None:
return None
self._mm_cache_stats = MultiModalCacheStats()
return mm_cache_stats
def update_mm_cache_stats(self) -> None:
mm_processor_cache = self.mm_processor_cache
mm_cache_stats = self._mm_cache_stats
if mm_processor_cache and mm_cache_stats:
delta = mm_processor_cache.make_stats(delta=True)
mm_cache_stats.record(delta.total, delta.hits)
def clear_mm_cache(self) -> None:
mm_processor_cache = self.mm_processor_cache
if mm_processor_cache is not None:
mm_processor_cache.clear_cache()
if self._mm_cache_stats is not None:
self._mm_cache_stats.reset = True
def shutdown(self) -> None:
mm_processor_cache = self.mm_processor_cache
if mm_processor_cache is not None:
mm_processor_cache.close()
def get_bos_token_id(self) -> int | None: def get_bos_token_id(self) -> int | None:
if self.tokenizer is None: if self.tokenizer is None:
logger.warning_once( logger.warning_once(
...@@ -84,6 +153,36 @@ class BaseRenderer(ABC): ...@@ -84,6 +153,36 @@ class BaseRenderer(ABC):
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
@cached_property
def default_cmpl_tok_params(self) -> TokenizeParams:
mm_processor = self.mm_processor
if mm_processor is not None:
return mm_processor.info.default_tok_params
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=True,
)
@cached_property
def default_chat_tok_params(self) -> TokenizeParams:
mm_processor = self.mm_processor
if mm_processor is not None:
return mm_processor.info.default_tok_params
model_config = self.model_config
encoder_config = model_config.encoder_config or {}
return TokenizeParams(
max_total_tokens=model_config.max_model_len,
do_lower_case=encoder_config.get("do_lower_case", False),
add_special_tokens=False,
)
# Step 1: Convert raw inputs to prompts # Step 1: Convert raw inputs to prompts
def render_prompt( def render_prompt(
self, self,
...@@ -317,18 +416,14 @@ class BaseRenderer(ABC): ...@@ -317,18 +416,14 @@ class BaseRenderer(ABC):
def render_cmpl( def render_cmpl(
self, self,
prompts: Sequence[DictPrompt | bytes], prompts: Sequence[DictPrompt | bytes],
tok_params: TokenizeParams, tok_params: TokenizeParams | None = None,
*, *,
prompt_extras: dict[str, Any] | None = None, prompt_extras: dict[str, Any] | None = None,
): ):
dict_prompts = self.render_prompts(prompts) if tok_params is None:
tok_params = self.default_cmpl_tok_params
# NOTE: Some MM models have non-default `add_special_tokens`
# so we handle tokenization in multi-modal processor
if self.model_config.is_multimodal_model:
self._apply_prompt_extras(dict_prompts, prompt_extras)
return dict_prompts
dict_prompts = self.render_prompts(prompts)
tok_prompts = self.tokenize_prompts(dict_prompts, tok_params) tok_prompts = self.tokenize_prompts(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras) self._apply_prompt_extras(tok_prompts, prompt_extras)
...@@ -339,14 +434,14 @@ class BaseRenderer(ABC): ...@@ -339,14 +434,14 @@ class BaseRenderer(ABC):
async def render_cmpl_async( async def render_cmpl_async(
self, self,
prompts: Sequence[DictPrompt | bytes], prompts: Sequence[DictPrompt | bytes],
tok_params: TokenizeParams, tok_params: TokenizeParams | None = None,
*, *,
prompt_extras: dict[str, Any] | None = None, prompt_extras: dict[str, Any] | None = None,
): ):
dict_prompts = await self.render_prompts_async(prompts) if tok_params is None:
tok_params = self.default_cmpl_tok_params
# NOTE: MM data cannot be passed to online Completions API dict_prompts = await self.render_prompts_async(prompts)
# so we don't have the special case that is in the offline version
tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params) tok_prompts = await self.tokenize_prompts_async(dict_prompts, tok_params)
self._apply_prompt_extras(tok_prompts, prompt_extras) self._apply_prompt_extras(tok_prompts, prompt_extras)
...@@ -358,10 +453,13 @@ class BaseRenderer(ABC): ...@@ -358,10 +453,13 @@ class BaseRenderer(ABC):
self, self,
conversations: Sequence[list["ChatCompletionMessageParam"]], conversations: Sequence[list["ChatCompletionMessageParam"]],
chat_params: ChatParams, chat_params: ChatParams,
tok_params: TokenizeParams, tok_params: TokenizeParams | None = None,
*, *,
prompt_extras: dict[str, Any] | None = None, prompt_extras: dict[str, Any] | None = None,
): ):
if tok_params is None:
tok_params = self.default_chat_tok_params
rendered = [ rendered = [
self.render_messages(conversation, chat_params) self.render_messages(conversation, chat_params)
for conversation in conversations for conversation in conversations
...@@ -384,10 +482,13 @@ class BaseRenderer(ABC): ...@@ -384,10 +482,13 @@ class BaseRenderer(ABC):
self, self,
conversations: Sequence[list["ChatCompletionMessageParam"]], conversations: Sequence[list["ChatCompletionMessageParam"]],
chat_params: ChatParams, chat_params: ChatParams,
tok_params: TokenizeParams, tok_params: TokenizeParams | None = None,
*, *,
prompt_extras: dict[str, Any] | None = None, prompt_extras: dict[str, Any] | None = None,
): ):
if tok_params is None:
tok_params = self.default_chat_tok_params
rendered = [ rendered = [
self.render_messages_async(conversation, chat_params) self.render_messages_async(conversation, chat_params)
for conversation in conversations for conversation in conversations
......
...@@ -13,7 +13,6 @@ from vllm.logger import init_logger ...@@ -13,7 +13,6 @@ from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from ..tokenizers.hf import HfTokenizer
from .base import BaseRenderer from .base import BaseRenderer
from .inputs import DictPrompt from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt from .inputs.preprocess import parse_dec_only_prompt
...@@ -22,23 +21,14 @@ from .params import ChatParams ...@@ -22,23 +21,14 @@ from .params import ChatParams
logger = init_logger(__name__) logger = init_logger(__name__)
class DeepseekV32Renderer(BaseRenderer): class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
@classmethod @classmethod
def from_config( def from_config( # type: ignore[override]
cls, cls,
config: VllmConfig, config: VllmConfig,
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> "BaseRenderer": ) -> "DeepseekV32Renderer":
return cls(config, tokenizer_kwargs) model_config = config.model_config
def __init__(
self,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> None:
super().__init__(config)
model_config = self.model_config
if model_config.skip_tokenizer_init: if model_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
...@@ -47,18 +37,7 @@ class DeepseekV32Renderer(BaseRenderer): ...@@ -47,18 +37,7 @@ class DeepseekV32Renderer(BaseRenderer):
**tokenizer_kwargs, **tokenizer_kwargs,
) )
self._tokenizer = tokenizer return cls(config, tokenizer)
@property
def tokenizer(self) -> HfTokenizer | None:
return self._tokenizer
def get_tokenizer(self) -> HfTokenizer:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
def render_messages( def render_messages(
self, self,
......
...@@ -21,23 +21,14 @@ from .params import ChatParams ...@@ -21,23 +21,14 @@ from .params import ChatParams
logger = init_logger(__name__) logger = init_logger(__name__)
class Grok2Renderer(BaseRenderer): class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
@classmethod @classmethod
def from_config( def from_config( # type: ignore[override]
cls, cls,
config: VllmConfig, config: VllmConfig,
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> "BaseRenderer": ) -> "Grok2Renderer":
return cls(config, tokenizer_kwargs) model_config = config.model_config
def __init__(
self,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> None:
super().__init__(config)
model_config = self.model_config
if model_config.skip_tokenizer_init: if model_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
...@@ -46,18 +37,7 @@ class Grok2Renderer(BaseRenderer): ...@@ -46,18 +37,7 @@ class Grok2Renderer(BaseRenderer):
**tokenizer_kwargs, **tokenizer_kwargs,
) )
self._tokenizer = tokenizer return cls(config, tokenizer)
@property
def tokenizer(self) -> Grok2Tokenizer | None:
return self._tokenizer
def get_tokenizer(self) -> Grok2Tokenizer:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
def render_messages( def render_messages(
self, self,
......
...@@ -585,27 +585,14 @@ def replace_vision_chunk_video_placeholder( ...@@ -585,27 +585,14 @@ def replace_vision_chunk_video_placeholder(
return prompt_raw return prompt_raw
class HfRenderer(BaseRenderer): class HfRenderer(BaseRenderer[HfTokenizer]):
@classmethod @classmethod
def from_config( def from_config( # type: ignore[override]
cls, cls,
config: VllmConfig, config: VllmConfig,
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> "BaseRenderer": ) -> "HfRenderer":
return cls(config, tokenizer_kwargs) model_config = config.model_config
def __init__(
self,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> None:
super().__init__(config)
model_config = self.model_config
self.use_unified_vision_chunk = getattr(
model_config.hf_config, "use_unified_vision_chunk", False
)
if model_config.skip_tokenizer_init: if model_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
...@@ -617,18 +604,18 @@ class HfRenderer(BaseRenderer): ...@@ -617,18 +604,18 @@ class HfRenderer(BaseRenderer):
), ),
) )
self._tokenizer = tokenizer return cls(config, tokenizer)
@property
def tokenizer(self) -> HfTokenizer | None:
return self._tokenizer
def get_tokenizer(self) -> HfTokenizer: def __init__(
tokenizer = self.tokenizer self,
if tokenizer is None: config: VllmConfig,
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`") tokenizer: HfTokenizer | None,
) -> None:
super().__init__(config, tokenizer)
return tokenizer self.use_unified_vision_chunk = getattr(
config.model_config.hf_config, "use_unified_vision_chunk", False
)
def render_messages( def render_messages(
self, self,
......
...@@ -50,23 +50,14 @@ def safe_apply_chat_template( ...@@ -50,23 +50,14 @@ def safe_apply_chat_template(
raise ValueError(str(e)) from e raise ValueError(str(e)) from e
class MistralRenderer(BaseRenderer): class MistralRenderer(BaseRenderer[MistralTokenizer]):
@classmethod @classmethod
def from_config( def from_config( # type: ignore[override]
cls, cls,
config: VllmConfig, config: VllmConfig,
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> "BaseRenderer": ) -> "MistralRenderer":
return cls(config, tokenizer_kwargs) model_config = config.model_config
def __init__(
self,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> None:
super().__init__(config)
model_config = self.model_config
if model_config.skip_tokenizer_init: if model_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
...@@ -75,24 +66,20 @@ class MistralRenderer(BaseRenderer): ...@@ -75,24 +66,20 @@ class MistralRenderer(BaseRenderer):
**tokenizer_kwargs, **tokenizer_kwargs,
) )
self._tokenizer = tokenizer return cls(config, tokenizer)
def __init__(
self,
config: VllmConfig,
tokenizer: MistralTokenizer | None,
) -> None:
super().__init__(config, tokenizer)
self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1) self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1)
self._apply_chat_template_async = make_async( self._apply_chat_template_async = make_async(
safe_apply_chat_template, executor=self._apply_chat_template_executor safe_apply_chat_template, executor=self._apply_chat_template_executor
) )
@property
def tokenizer(self) -> MistralTokenizer | None:
return self._tokenizer
def get_tokenizer(self) -> MistralTokenizer:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
def render_messages( def render_messages(
self, self,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, TypeVar from typing import TYPE_CHECKING, Any, TypeVar
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -12,9 +11,13 @@ from vllm.utils.import_utils import LazyLoader ...@@ -12,9 +11,13 @@ from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING: if TYPE_CHECKING:
import torch import torch
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
else: else:
torch = LazyLoader("torch", globals(), "torch") torch = LazyLoader("torch", globals(), "torch")
ChatTemplateContentFormatOption = object
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -43,7 +46,7 @@ class ChatParams: ...@@ -43,7 +46,7 @@ class ChatParams:
chat_template: str | None = None chat_template: str | None = None
"""The chat template to apply.""" """The chat template to apply."""
chat_template_content_format: ChatTemplateContentFormatOption = "auto" chat_template_content_format: "ChatTemplateContentFormatOption" = "auto"
"""The format of the chat template.""" """The format of the chat template."""
chat_template_kwargs: dict[str, Any] = field(default_factory=dict) chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
...@@ -163,10 +166,7 @@ class TokenizeParams: ...@@ -163,10 +166,7 @@ class TokenizeParams:
value=truncate_prompt_tokens, value=truncate_prompt_tokens,
) )
def with_kwargs(self, tokenization_kwargs: dict[str, Any] | None): def with_kwargs(self, **tokenization_kwargs: Any):
if tokenization_kwargs is None:
tokenization_kwargs = {}
max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens) max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens)
pad_prompt_tokens = tokenization_kwargs.pop( pad_prompt_tokens = tokenization_kwargs.pop(
"pad_prompt_tokens", self.pad_prompt_tokens "pad_prompt_tokens", self.pad_prompt_tokens
......
...@@ -10,7 +10,6 @@ from vllm.entrypoints.chat_utils import ( ...@@ -10,7 +10,6 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages_async, parse_chat_messages_async,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from .base import BaseRenderer from .base import BaseRenderer
from .inputs import DictPrompt from .inputs import DictPrompt
...@@ -24,24 +23,14 @@ class TerratorchRenderer(BaseRenderer): ...@@ -24,24 +23,14 @@ class TerratorchRenderer(BaseRenderer):
@classmethod @classmethod
def from_config( def from_config(
cls, cls,
config: VllmConfig, config: VllmConfig, # type: ignore[override]
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> "BaseRenderer": ) -> "TerratorchRenderer":
return cls(config) model_config = config.model_config
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
model_config = self.model_config
if not model_config.skip_tokenizer_init: if not model_config.skip_tokenizer_init:
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`") raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")
@property return cls(config, None)
def tokenizer(self) -> TokenizerLike | None:
return None
def get_tokenizer(self) -> TokenizerLike:
raise ValueError("Tokenizer not available for Terratorch renderer")
def render_messages( def render_messages(
self, self,
......
...@@ -19,8 +19,8 @@ from vllm.distributed.weight_transfer.base import ( ...@@ -19,8 +19,8 @@ from vllm.distributed.weight_transfer.base import (
WeightTransferUpdateRequest, WeightTransferUpdateRequest,
) )
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.inputs import PromptType, StreamingInput from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
...@@ -268,12 +268,12 @@ class AsyncLLM(EngineClient): ...@@ -268,12 +268,12 @@ class AsyncLLM(EngineClient):
shutdown_prometheus() shutdown_prometheus()
if renderer := getattr(self, "renderer", None):
renderer.shutdown()
if engine_core := getattr(self, "engine_core", None): if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown() engine_core.shutdown()
if input_processor := getattr(self, "input_processor", None):
input_processor.close()
handler = getattr(self, "output_handler", None) handler = getattr(self, "output_handler", None)
if handler is not None: if handler is not None:
cancel_task_threadsafe(handler) cancel_task_threadsafe(handler)
...@@ -654,7 +654,7 @@ class AsyncLLM(EngineClient): ...@@ -654,7 +654,7 @@ class AsyncLLM(EngineClient):
output_processor = self.output_processor output_processor = self.output_processor
log_stats = self.log_stats log_stats = self.log_stats
logger_manager = self.logger_manager logger_manager = self.logger_manager
input_processor = self.input_processor renderer = self.renderer
chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE chunk_size = envs.VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
async def output_handler(): async def output_handler():
...@@ -702,7 +702,7 @@ class AsyncLLM(EngineClient): ...@@ -702,7 +702,7 @@ class AsyncLLM(EngineClient):
engine_idx=outputs.engine_index, engine_idx=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
mm_cache_stats=input_processor.stat_mm_cache(), mm_cache_stats=renderer.stat_mm_cache(),
) )
except Exception as e: except Exception as e:
logger.exception("AsyncLLM output_handler failed.") logger.exception("AsyncLLM output_handler failed.")
...@@ -881,7 +881,7 @@ class AsyncLLM(EngineClient): ...@@ -881,7 +881,7 @@ class AsyncLLM(EngineClient):
await asyncio.gather(*coros) await asyncio.gather(*coros)
async def reset_mm_cache(self) -> None: async def reset_mm_cache(self) -> None:
self.input_processor.clear_mm_cache() self.renderer.clear_mm_cache()
await self.engine_core.reset_mm_cache_async() await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache( async def reset_prefix_cache(
......
...@@ -33,9 +33,9 @@ from vllm.sampling_params import SamplingParams ...@@ -33,9 +33,9 @@ from vllm.sampling_params import SamplingParams
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.utils.jsontree import json_iter_leaves
from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -60,8 +60,6 @@ class InputProcessor: ...@@ -60,8 +60,6 @@ class InputProcessor:
self.generation_config_fields = model_config.try_get_generation_config() self.generation_config_fields = model_config.try_get_generation_config()
self.renderer = renderer or renderer_from_config(vllm_config) self.renderer = renderer or renderer_from_config(vllm_config)
self.mm_registry = mm_registry
self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
self.supports_mm_inputs = mm_registry.supports_multimodal_inputs(model_config) self.supports_mm_inputs = mm_registry.supports_multimodal_inputs(model_config)
self.mm_encoder_cache_size = 0 self.mm_encoder_cache_size = 0
...@@ -78,7 +76,6 @@ class InputProcessor: ...@@ -78,7 +76,6 @@ class InputProcessor:
vllm_config, vllm_config,
renderer=renderer, renderer=renderer,
mm_registry=mm_registry, mm_registry=mm_registry,
mm_processor_cache=self.mm_processor_cache,
) )
@property @property
...@@ -136,7 +133,7 @@ class InputProcessor: ...@@ -136,7 +133,7 @@ class InputProcessor:
) )
def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems: def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
mm_processor = self.input_preprocessor._get_mm_processor() mm_processor = self.renderer.get_mm_processor()
return mm_processor.info.parse_mm_data(mm_data) return mm_processor.info.parse_mm_data(mm_data)
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None: def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
...@@ -415,6 +412,15 @@ class InputProcessor: ...@@ -415,6 +412,15 @@ class InputProcessor:
decoder_mm_positions = decoder_inputs["mm_placeholders"] decoder_mm_positions = decoder_inputs["mm_placeholders"]
decoder_mm_hashes = decoder_inputs["mm_hashes"] decoder_mm_hashes = decoder_inputs["mm_hashes"]
if not all(
isinstance(leaf, str) for leaf in json_iter_leaves(decoder_mm_hashes)
):
raise ValueError(
f"mm_hashes must contain only strings, got: {decoder_mm_hashes}. "
"This is likely due to an incorrect custom implementation of "
"MultiModalProcessor.apply method."
)
# Merge and flatten multimodal placeholders, hashes and inputs # Merge and flatten multimodal placeholders, hashes and inputs
# from dictionaries to lists, and sort them by each item's position # from dictionaries to lists, and sort them by each item's position
# in the input sequence. # in the input sequence.
...@@ -562,13 +568,3 @@ class InputProcessor: ...@@ -562,13 +568,3 @@ class InputProcessor:
self._validate_model_input(encoder_inputs, prompt_type="encoder") self._validate_model_input(encoder_inputs, prompt_type="encoder")
self._validate_model_input(decoder_inputs, prompt_type="decoder") self._validate_model_input(decoder_inputs, prompt_type="decoder")
def stat_mm_cache(self) -> MultiModalCacheStats | None:
return self.input_preprocessor.stat_mm_cache()
def clear_mm_cache(self) -> None:
self.input_preprocessor.clear_mm_cache()
def close(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.close()
...@@ -320,7 +320,7 @@ class LLMEngine: ...@@ -320,7 +320,7 @@ class LLMEngine:
self.logger_manager.record( self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats, scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats, iteration_stats=iteration_stats,
mm_cache_stats=self.input_processor.stat_mm_cache(), mm_cache_stats=self.renderer.stat_mm_cache(),
) )
self.do_log_stats_with_interval() self.do_log_stats_with_interval()
...@@ -333,7 +333,7 @@ class LLMEngine: ...@@ -333,7 +333,7 @@ class LLMEngine:
self.engine_core.profile(False) self.engine_core.profile(False)
def reset_mm_cache(self): def reset_mm_cache(self):
self.input_processor.clear_mm_cache() self.renderer.clear_mm_cache()
self.engine_core.reset_mm_cache() self.engine_core.reset_mm_cache()
def reset_prefix_cache( def reset_prefix_cache(
......
...@@ -151,6 +151,12 @@ class MultiModalCacheStats(BaseCacheStats): ...@@ -151,6 +151,12 @@ class MultiModalCacheStats(BaseCacheStats):
that were queried. that were queried.
""" """
def record(self, num_queries: int, num_hits: int) -> None:
"""Aggregate request information into the stats."""
self.requests += 1
self.queries += num_queries
self.hits += num_hits
@dataclass @dataclass
class KVCacheEvictionEvent: class KVCacheEvictionEvent:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment