Unverified Commit 39264545 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Decouple TimingContext from InputProcessingContext (#35083)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 1e8438a8
......@@ -41,15 +41,16 @@ from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
ProcessorInputs,
PromptIndexTargets,
PromptReplacement,
PromptUpdate,
TimingContext,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......@@ -204,23 +205,20 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
if mm_items:
if isinstance(prompt, str):
if len(prompt) > 0:
if inputs.mm_data_items:
if isinstance(inputs.prompt, str):
if len(inputs.prompt) > 0:
raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty text prompt."
)
else:
special_tokens = self.info.get_tokenizer().all_special_ids
if all(tok in special_tokens for tok in prompt):
prompt = []
if all(tok in special_tokens for tok in inputs.prompt):
inputs.prompt = []
else:
raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! "
......@@ -229,18 +227,12 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
# For multi-modal data, the prompt after processing should
# only contain the dummy image tokens
tokenization_kwargs = {
**(tokenization_kwargs or {}),
inputs.tokenization_kwargs = {
**inputs.tokenization_kwargs,
"add_special_tokens": False,
}
return super().apply(
prompt=prompt,
mm_items=mm_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return super().apply(inputs, timing_ctx)
def _hf_processor_applies_updates(
self,
......
......@@ -30,15 +30,16 @@ from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalProcessingInfo,
ProcessorInputs,
PromptReplacement,
PromptUpdate,
TimingContext,
)
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
......@@ -310,32 +311,17 @@ class DeepseekVL2MultiModalProcessor(
def _cached_apply_hf_processor(
self,
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 2:
return self._apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
if inputs.mm_data_items.get_count("image", strict=False) > 2:
return self._apply_hf_processor(inputs, timing_ctx)
return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return super()._cached_apply_hf_processor(inputs, timing_ctx)
@MULTIMODAL_REGISTRY.register_processor(
......
......@@ -21,13 +21,14 @@ from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing.processor import (
MultiModalProcessingInfo,
ProcessorInputs,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
TimingContext,
)
from vllm.tokenizers import TokenizerLike
......@@ -490,32 +491,17 @@ class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingIn
def _cached_apply_hf_processor(
self,
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 1:
return self._apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
if inputs.mm_data_items.get_count("image", strict=False) > 1:
return self._apply_hf_processor(inputs, timing_ctx)
return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return super()._cached_apply_hf_processor(inputs, timing_ctx)
@MULTIMODAL_REGISTRY.register_processor(
......
......@@ -37,16 +37,17 @@ from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
ProcessorInputs,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
TimingContext,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......@@ -770,11 +771,8 @@ class MantisProcessingInfo(LlavaProcessingInfo):
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
......@@ -785,15 +783,9 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
image_height=-1,
)
result = super().apply(
prompt,
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
result = super().apply(inputs, timing_ctx)
mm_item_counts = mm_items.get_all_counts()
mm_item_counts = inputs.mm_data_items.get_all_counts()
mm_kwargs = result["mm_kwargs"]
mm_hashes = result["mm_hashes"]
......@@ -825,8 +817,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
)
orig_repls = self._get_mm_prompt_updates(
mm_items,
hf_processor_mm_kwargs,
inputs.mm_data_items,
inputs.hf_processor_mm_kwargs,
mm_kwargs,
)
mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
......
......@@ -21,16 +21,17 @@ from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
ProcessorInputs,
PromptIndexTargets,
PromptInsertion,
PromptUpdate,
PromptUpdateDetails,
TimingContext,
)
from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors
......@@ -228,19 +229,10 @@ class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingIn
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
mm_inputs = super().apply(
prompt,
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
mm_inputs = super().apply(inputs, timing_ctx)
prompt_token_ids = mm_inputs["prompt_token_ids"]
tokenizer = self.info.get_tokenizer()
......
......@@ -50,16 +50,17 @@ from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalProcessingInfo,
ProcessorInputs,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
TimingContext,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
......@@ -277,7 +278,6 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_images = dummy_mm_data.get("image", [])
tokenization_kwargs = {"truncation": False}
request = ChatCompletionRequest(
messages=[
......@@ -294,11 +294,7 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
return ProcessorInputs(
prompt=dummy_tokens,
mm_items=dummy_mm_items,
tokenization_kwargs=tokenization_kwargs,
)
return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]):
......@@ -344,19 +340,10 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
def _cached_apply_hf_processor(
self,
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(inputs, timing_ctx)
# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_info, True
......
......@@ -47,15 +47,16 @@ from vllm.multimodal.parse import (
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
ProcessorInputs,
PromptIndexTargets,
PromptReplacement,
PromptUpdate,
TimingContext,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......@@ -190,23 +191,20 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
if mm_items:
if isinstance(prompt, str):
if len(prompt) > 0:
if inputs.mm_data_items:
if isinstance(inputs.prompt, str):
if len(inputs.prompt) > 0:
raise ValueError(
"SigLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty text prompt."
)
else:
special_tokens = self.info.get_tokenizer().all_special_ids
if all(tok in special_tokens for tok in prompt):
prompt = []
if all(tok in special_tokens for tok in inputs.prompt):
inputs.prompt = []
else:
raise ValueError(
"SigLIP accepts text-only or image-only inputs, not both! "
......@@ -214,19 +212,13 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
)
# For multi-modal data, the prompt after processing should
# only contain the image token
tokenization_kwargs = {
**(tokenization_kwargs or {}),
# only contain the dummy image tokens
inputs.tokenization_kwargs = {
**inputs.tokenization_kwargs,
"add_special_tokens": False,
}
return super().apply(
prompt=prompt,
mm_items=mm_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return super().apply(inputs, timing_ctx)
def _hf_processor_applies_updates(
self,
......
......@@ -54,13 +54,14 @@ from vllm.multimodal.parse import (
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
ProcessorInputs,
PromptUpdate,
TimingContext,
)
from vllm.sequence import IntermediateTensors
......@@ -193,29 +194,21 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
if hf_processor_mm_kwargs is None:
hf_processor_mm_kwargs = {}
if tokenization_kwargs is None:
tokenization_kwargs = {}
mm_hashes = self._hash_mm_items(
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
_, passthrough_data = self._get_hf_mm_data(mm_items)
mm_processed_data = BatchFeature(
{k: torch.as_tensor(v).unsqueeze(0) for k, v in passthrough_data.items()},
tensor_type="pt",
)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_items = inputs.mm_data_items
hf_processor_mm_kwargs = inputs.hf_processor_mm_kwargs
with timing_ctx.record("apply_hf_processor"):
_, passthrough_data = self._get_hf_mm_data(mm_items)
mm_processed_data = BatchFeature(
{
k: torch.as_tensor(v).unsqueeze(0)
for k, v in passthrough_data.items()
},
tensor_type="pt",
)
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_processed_data,
......@@ -226,6 +219,11 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
),
)
with timing_ctx.record("get_mm_hashes"):
mm_hashes = inputs.get_mm_hashes(self.info.model_id)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
return mm_inputs(
prompt_token_ids=[1],
mm_kwargs=mm_kwargs,
......
......@@ -37,12 +37,13 @@ from vllm.multimodal.inputs import (
from vllm.multimodal.parse import (
ImageProcessorItems,
MultiModalDataItems,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
ProcessorInputs,
TimingContext,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
......@@ -177,11 +178,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
......@@ -189,29 +187,30 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
"""
if hf_processor_mm_kwargs is None:
hf_processor_mm_kwargs = {}
if tokenization_kwargs is None:
tokenization_kwargs = {}
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if not isinstance(prompt, str):
# the prompt is the tokenized ids which is not supported
# by the hf_processor, which is why we would need to decode the ids
# into string
prompt = hf_processor.decode(prompt)
# Bypass cached processor and always apply to the full set of mm inputs
# NOTE: we can't just set caching=False because base class method
# transforms outputs to `MultiModalKwargs` which is not going to
# work for Transformers. We have a lot of logic tied to
# `mm_tokens_per_modality` below
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
prompt = inputs.prompt
mm_items = inputs.mm_data_items
hf_processor_mm_kwargs = inputs.hf_processor_mm_kwargs
tokenization_kwargs = inputs.tokenization_kwargs
with timing_ctx.record("apply_hf_processor"):
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if not isinstance(prompt, str):
# the prompt is the tokenized ids which is not supported
# by the hf_processor, which is why we would need to decode the ids
# into string
prompt = hf_processor.decode(prompt)
# Bypass cached processor and always apply to the full set of mm inputs
# NOTE: we can't just set caching=False because base class method
# transforms outputs to `MultiModalKwargs` which is not going to
# work for Transformers. We have a lot of logic tied to
# `mm_tokens_per_modality` below
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# For gemma3 we check `token_type_ids` as the key
token_type_key = (
......@@ -225,15 +224,14 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
# it for each input `mm_data`.
mm_positions = torch.where(mm_token_type_ids == 1)[1]
images = mm_items.get_items("image", ImageProcessorItems)
multimodal_config = self.info.ctx.model_config.multimodal_config
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
image_sizes = []
for item_idx in range(len(images)):
image_size = images.get_image_size(item_idx)
image_sizes.append((image_size.height, image_size.width))
mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
image_sizes=image_sizes, **mm_processor_kwargs
image_sizes=image_sizes,
**self.info.ctx.get_merged_mm_kwargs({}),
)
mm_placeholders = {}
......@@ -261,11 +259,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = self._hash_mm_items(
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
with timing_ctx.record("get_mm_hashes"):
mm_hashes = inputs.get_mm_hashes(self.info.model_id)
return mm_inputs(
prompt_token_ids=prompt_ids,
......
......@@ -47,16 +47,17 @@ from vllm.multimodal.parse import (
AudioProcessorItems,
MultiModalDataItems,
MultiModalDataParser,
MultiModalUUIDItems,
)
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalProcessingInfo,
PlaceholderFeaturesInfo,
ProcessorInputs,
PromptReplacement,
PromptUpdate,
TimingContext,
)
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
......@@ -265,13 +266,13 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens
dummy_mm_inputs = self.info.parse_mm_data(
dummy_mm_items = self.info.parse_mm_data(
# whixtral tokenizer adds padding to the audio
# so we need to update the audio arrays
{**dummy_mm_data, "audio": [a.audio_array for a in res.audios]},
)
return ProcessorInputs(prompt=dummy_tokens, mm_items=dummy_mm_inputs)
return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]):
......@@ -361,19 +362,10 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
def _cached_apply_hf_processor(
self,
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(inputs, timing_ctx)
# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_info, True
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .context import BaseProcessingInfo, InputProcessingContext
from .dummy_inputs import BaseDummyInputsBuilder, ProcessorInputs
from .context import BaseProcessingInfo, InputProcessingContext, TimingContext
from .dummy_inputs import BaseDummyInputsBuilder
from .inputs import ProcessorInputs
from .processor import (
BaseMultiModalProcessor,
EncDecMultiModalProcessor,
......@@ -15,6 +16,7 @@ from .processor import (
__all__ = [
"BaseProcessingInfo",
"InputProcessingContext",
"TimingContext",
"BaseDummyInputsBuilder",
"ProcessorInputs",
"BaseMultiModalProcessor",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextvars
import threading
import time
from abc import abstractmethod
from collections.abc import Generator, Mapping
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import cached_property
......@@ -33,104 +31,53 @@ if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from vllm.config import ModelConfig, ObservabilityConfig
from vllm.config import ModelConfig
else:
PretrainedConfig = object
BatchFeature = object
ProcessorMixin = object
ModelConfig = object
ObservabilityConfig = object
logger = init_logger(__name__)
_request_id_context: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"_request_id_context", default=None
)
def get_current_request_id() -> str | None:
"""Get the current request_id from the context, if available."""
return _request_id_context.get()
@contextmanager
def set_request_id(request_id: str) -> Generator[None, None, None]:
"""Context manager to set the request_id for the current context."""
token = _request_id_context.set(request_id)
try:
yield
finally:
_request_id_context.reset(token)
@dataclass
class MultiModalProcessorTimingStats:
"""Per-request timing statistics for multimodal processor stages."""
hf_processor_time: float = 0.0
"""Time spent in HuggingFace processor calls (seconds)."""
class TimingContext:
"""Helper class to record execution times during multi-modal processing."""
hashing_time: float = 0.0
"""Time spent computing multimodal item hashes (seconds)."""
enabled: bool = True
"""If disabled, `TimingContext.record` becomes a no-op."""
cache_lookup_time: float = 0.0
"""Time spent in cache lookups and merges (seconds)."""
stage_secs: dict[str, float] = field(default_factory=dict)
"""The execution time (in seconds) for each processing stage."""
prompt_update_time: float = 0.0
"""Time spent applying prompt updates and finding placeholders (seconds)."""
@property
def total_secs(self) -> float:
return sum(self.stage_secs.values())
preprocessor_total_time: float = 0.0
"""Total preprocessing time (seconds)."""
@contextmanager
def record(self, stage: str):
"""Record the execution time for a processing stage."""
if not self.enabled:
yield
return
def to_dict(self) -> dict[str, float]:
"""Convert stats to a dictionary for JSON serialization."""
return {
"hf_processor_time": self.hf_processor_time,
"hashing_time": self.hashing_time,
"cache_lookup_time": self.cache_lookup_time,
"prompt_update_time": self.prompt_update_time,
"preprocessor_total_time": self.preprocessor_total_time,
start_time = time.perf_counter()
try:
yield
finally:
elapsed = time.perf_counter() - start_time
self.stage_secs.setdefault(stage, 0.0)
self.stage_secs[stage] += elapsed
def get_stats_dict(self):
stats_dict = {
f"{stage}_secs": time_s for stage, time_s in self.stage_secs.items()
}
stats_dict["preprocessor_total_secs"] = self.total_secs
@contextmanager
def timed_preprocessor_operation(ctx: "InputProcessingContext", stage_name: str):
"""
Context manager to time an operation using the context's timing stats.
The request_id is automatically retrieved from the context variable,
so it doesn't need to be passed as a parameter.
Args:
ctx: The InputProcessingContext containing the timing stats registry.
stage_name: Name of the stage being timed.
"""
request_id = get_current_request_id()
if ctx is None or request_id is None:
yield
return
stats = ctx.get_timing_stats(request_id)
if stats is None:
yield
return
start_time = time.perf_counter()
try:
yield
finally:
elapsed = time.perf_counter() - start_time
if stage_name == "hf_processor":
stats.hf_processor_time += elapsed
elif stage_name == "hashing":
stats.hashing_time += elapsed
elif stage_name == "cache_lookup":
stats.cache_lookup_time += elapsed
elif stage_name == "prompt_update":
stats.prompt_update_time += elapsed
stats.preprocessor_total_time += elapsed
return stats_dict
_T = TypeVar("_T")
......@@ -151,21 +98,6 @@ class InputProcessingContext:
tokenizer: TokenizerLike | None
"""The tokenizer used to tokenize the inputs."""
observability_config: "ObservabilityConfig | None" = field(
default=None, compare=False, repr=False
)
"""Configuration for observability features."""
timing_stats_registry: dict[str, MultiModalProcessorTimingStats] = field(
default_factory=dict, compare=False, repr=False
)
"""Registry for storing timing stats keyed by request_id."""
_timing_stats_registry_lock: threading.Lock = field(
default_factory=threading.Lock, compare=False, repr=False
)
"""Lock for thread-safe access to timing_stats_registry."""
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
......@@ -379,71 +311,6 @@ class InputProcessingContext:
return self._postprocess_output(output)
def get_timing_stats(
self, request_id: str
) -> MultiModalProcessorTimingStats | None:
"""
Get timing stats for a request.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return None
with self._timing_stats_registry_lock:
return self.timing_stats_registry.get(request_id)
def create_timing_stats(self, request_id: str) -> MultiModalProcessorTimingStats:
"""
Create and store timing stats in the registry for a request.
This should be called at the start of processing for a request.
The stats object is created immediately and stored in the registry.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return MultiModalProcessorTimingStats()
with self._timing_stats_registry_lock:
if request_id in self.timing_stats_registry:
raise ValueError(
f"Timing stats already exist for request_id: {request_id}"
)
stats = MultiModalProcessorTimingStats()
self.timing_stats_registry[request_id] = stats
return stats
def clear_timing_stats_registry(self) -> int:
"""
Clear all stats from the registry. Returns the number of stats cleared.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return 0
with self._timing_stats_registry_lock:
count = len(self.timing_stats_registry)
self.timing_stats_registry.clear()
return count
def get_all_timing_stats(self) -> dict[str, dict[str, float]]:
"""
Get all timing stats as a dictionary for API endpoints.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return {}
with self._timing_stats_registry_lock:
return {
rid: stats.to_dict()
for rid, stats in self.timing_stats_registry.items()
}
class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing."""
......
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Generic, TypeVar
import numpy as np
......@@ -18,27 +17,14 @@ from vllm.config.multimodal import (
from vllm.logger import init_logger
from ..inputs import MultiModalDataDict
from ..parse import MultiModalDataItems
from .context import BaseProcessingInfo
from .inputs import ProcessorInputs
_I = TypeVar("_I", bound=BaseProcessingInfo)
logger = init_logger(__name__)
@dataclass
class ProcessorInputs:
"""
Represents the keyword arguments to
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
"""
prompt: str | list[int]
mm_items: MultiModalDataItems
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
class BaseDummyInputsBuilder(ABC, Generic[_I]):
"""
Abstract base class that constructs the dummy data to profile
......@@ -101,7 +87,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
return ProcessorInputs(
prompt=dummy_text,
mm_items=dummy_mm_items,
mm_data_items=dummy_mm_items,
tokenization_kwargs=tokenization_kwargs,
)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from dataclasses import dataclass, field
from ..hasher import MultiModalHasher
from ..inputs import MultiModalHashes
from ..parse import MultiModalDataItems, MultiModalUUIDItems
@dataclass
class ProcessorInputs:
"""
Represents the keyword arguments to
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
"""
prompt: str | list[int]
mm_data_items: MultiModalDataItems
mm_uuid_items: MultiModalUUIDItems | None = None
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
def get_mm_hashes(self, model_id: str) -> MultiModalHashes:
mm_data_items = self.mm_data_items
mm_uuid_items = self.mm_uuid_items or {}
hf_processor_mm_kwargs = self.hf_processor_mm_kwargs
mm_hashes: MultiModalHashes = {}
hasher = MultiModalHasher
for modality, data_items in mm_data_items.items():
if modality in mm_uuid_items:
uuid_items = mm_uuid_items[modality]
# For None entries, compute a hash; otherwise, use provided ID.
hashes: list[str] = []
for i, item in enumerate(data_items.get_all_items_for_hash()):
uuid_item = uuid_items[i]
# NOTE: Even if a uuid_item is provided, we still compute a hash
# if `hf_processor_mm_kwargs` is provided.
# This is because the processed multimodal inputs can be different
# depending on the processor kwargs.
if uuid_item is None or hf_processor_mm_kwargs:
# NOTE: use provided hash string to hash with kwargs
# if available for better performance.
item = uuid_item if uuid_item is not None else item
hashes.append(
hasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
)
)
else:
hashes.append(uuid_item)
mm_hashes[modality] = hashes
else:
mm_hashes[modality] = [
hasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
)
for item in data_items
]
return mm_hashes
......@@ -23,7 +23,6 @@ from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from ..hasher import MultiModalHasher
from ..inputs import (
MultiModalEncDecInputs,
MultiModalFieldConfig,
......@@ -42,12 +41,9 @@ from ..parse import (
MultiModalDataItems,
MultiModalUUIDItems,
)
from .context import (
BaseProcessingInfo,
get_current_request_id,
timed_preprocessor_operation,
)
from .context import BaseProcessingInfo, TimingContext
from .dummy_inputs import BaseDummyInputsBuilder
from .inputs import ProcessorInputs
if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature
......@@ -1017,13 +1013,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
) -> MultiModalInputs:
return self.apply(
processor_inputs = ProcessorInputs(
prompt,
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
hf_processor_mm_kwargs=hf_processor_mm_kwargs or {},
)
return self.apply(processor_inputs, TimingContext(enabled=False))
@abstractmethod
def _get_mm_fields_config(
self,
......@@ -1139,12 +1137,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Call the HF processor on the prompt text and
associated multi-modal data.
"""
with timed_preprocessor_operation(self.info.ctx, "hf_processor"):
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
def _hf_processor_applies_updates(
self,
......@@ -1306,60 +1303,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return prompt_ids, mm_processed_data, False
def _hash_mm_items(
self,
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalHashes:
model_id = self.info.model_id
if mm_uuid_items is None:
mm_uuid_items = {}
mm_hashes: MultiModalHashes = {}
hasher = MultiModalHasher
for modality, data_items in mm_data_items.items():
if modality in mm_uuid_items:
uuid_items = mm_uuid_items[modality]
# For None entries, compute a hash; otherwise, use provided ID.
hashes: list[str] = []
for i, item in enumerate(data_items.get_all_items_for_hash()):
uuid_item = uuid_items[i]
# NOTE: Even if a uuid_item is provided, we still compute a hash
# if `hf_processor_mm_kwargs` is provided.
# This is because the processed multimodal inputs can be different
# depending on the processor kwargs.
if uuid_item is None or hf_processor_mm_kwargs:
# NOTE: use provided hash string to hash with kwargs
# if available for better performance.
item = uuid_item if uuid_item is not None else item
hashes.append(
hasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
)
)
else:
hashes.append(uuid_item)
mm_hashes[modality] = hashes
else:
mm_hashes[modality] = [
hasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
)
for item in data_items
]
return mm_hashes
def _get_cache_missing_items(
self,
cache: BaseMultiModalProcessorCache,
......@@ -1461,40 +1404,36 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _apply_hf_processor(
self,
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
(
prompt_ids,
mm_processed_data,
is_update_applied,
) = self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
enable_hf_prompt_update=True,
)
with timing_ctx.record("apply_hf_processor"):
(
prompt_ids,
mm_processed_data,
is_update_applied,
) = self._apply_hf_processor_main(
prompt=inputs.prompt,
mm_items=inputs.mm_data_items,
hf_processor_mm_kwargs=inputs.hf_processor_mm_kwargs,
tokenization_kwargs=inputs.tokenization_kwargs,
enable_hf_prompt_update=True,
)
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_processed_data,
self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs),
self._get_mm_fields_config(
mm_processed_data, inputs.hf_processor_mm_kwargs
),
)
# Use overrides if provided; fallback to data-dependent hashing.
with timed_preprocessor_operation(self.info.ctx, "hashing"):
mm_hashes = self._hash_mm_items(
mm_data_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
with timing_ctx.record("get_mm_hashes"):
mm_hashes = inputs.get_mm_hashes(self.info.model_id)
mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items,
hf_processor_mm_kwargs,
inputs.mm_data_items,
inputs.hf_processor_mm_kwargs,
mm_kwargs,
)
......@@ -1508,11 +1447,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _cached_apply_hf_processor(
self,
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
"""
Apply the HF processor on the full prompt text,
......@@ -1520,59 +1456,50 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
cache = self.cache
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
_, passthrough_data = self._get_hf_mm_data(inputs.mm_data_items)
if cache is None or passthrough_data:
return self._apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return self._apply_hf_processor(inputs, timing_ctx)
with timed_preprocessor_operation(self.info.ctx, "hashing"):
mm_hashes = self._hash_mm_items(
mm_data_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
with timing_ctx.record("get_mm_hashes"):
mm_hashes = inputs.get_mm_hashes(self.info.model_id)
with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
with timing_ctx.record("get_cache_missing_items"):
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
cache=cache,
mm_data_items=mm_data_items,
mm_data_items=inputs.mm_data_items,
mm_hashes=mm_hashes,
)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt updates until the new multimodal
# items are combined with the cached multimodal items
(
prompt_ids,
mm_missing_processed_data,
is_update_applied,
) = self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
enable_hf_prompt_update=False,
)
with timing_ctx.record("apply_hf_processor"):
(
prompt_ids,
mm_missing_processed_data,
is_update_applied,
) = self._apply_hf_processor_main(
prompt=inputs.prompt,
mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=inputs.hf_processor_mm_kwargs,
tokenization_kwargs=inputs.tokenization_kwargs,
enable_hf_prompt_update=False,
)
mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_missing_processed_data,
self._get_mm_fields_config(
mm_missing_processed_data, hf_processor_mm_kwargs
mm_missing_processed_data, inputs.hf_processor_mm_kwargs
),
)
mm_missing_prompt_updates = self._get_mm_prompt_updates(
mm_missing_data_items,
hf_processor_mm_kwargs,
inputs.hf_processor_mm_kwargs,
mm_missing_kwargs,
)
with timed_preprocessor_operation(self.info.ctx, "cache_lookup"):
with timing_ctx.record("merge_mm_kwargs"):
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
cache,
mm_hashes=mm_hashes,
......@@ -1742,11 +1669,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
......@@ -1761,31 +1685,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
3. Extract information about the placeholder tokens from the
processed token IDs.
"""
request_id = get_current_request_id()
if request_id is not None:
self.info.ctx.create_timing_stats(request_id)
if hf_processor_mm_kwargs is None:
hf_processor_mm_kwargs = {}
if tokenization_kwargs is None:
tokenization_kwargs = {}
(
prompt_ids,
mm_info,
is_update_applied,
) = self._cached_apply_hf_processor(
prompt,
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
) = self._cached_apply_hf_processor(inputs, timing_ctx)
# NOTE: tokenization_kwargs are not required to init processor
with timed_preprocessor_operation(self.info.ctx, "prompt_update"):
with timing_ctx.record("apply_prompt_updates"):
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items,
mm_items=inputs.mm_data_items,
prompt_ids=prompt_ids,
mm_kwargs=mm_info.kwargs,
mm_prompt_updates=mm_info.prompt_updates,
......@@ -1851,11 +1760,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
def apply(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
) -> MultiModalEncDecInputs:
"""
Process multi-modal inputs to be used in vLLM.
......@@ -1864,17 +1770,22 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs.
"""
encoder_prompt = self.create_encoder_prompt(prompt, mm_items)
encoder_inputs = super().apply(
encoder_prompt = self.create_encoder_prompt(
inputs.prompt,
inputs.mm_data_items,
)
encoder_processor_inputs = ProcessorInputs(
encoder_prompt,
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
inputs.mm_data_items,
inputs.mm_uuid_items,
hf_processor_mm_kwargs=inputs.hf_processor_mm_kwargs,
tokenization_kwargs=inputs.tokenization_kwargs,
)
encoder_inputs = super().apply(encoder_processor_inputs, timing_ctx)
return self._get_enc_dec_inputs(
prompt=prompt,
mm_items=mm_items,
prompt=inputs.prompt,
mm_items=inputs.mm_data_items,
encoder_inputs=encoder_inputs,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from collections import defaultdict
from collections.abc import Mapping
from dataclasses import dataclass
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
from vllm.config.observability import ObservabilityConfig
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
......@@ -24,6 +25,7 @@ from .processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
TimingContext,
)
if TYPE_CHECKING:
......@@ -174,32 +176,26 @@ class MultiModalRegistry:
def _create_processing_ctx(
self,
model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext:
if tokenizer is None:
tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(
model_config, tokenizer, observability_config=observability_config
)
return InputProcessingContext(model_config, tokenizer)
def _create_processing_info(
self,
model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
*,
tokenizer: TokenizerLike | None = None,
) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
ctx = self._create_processing_ctx(model_config, tokenizer)
return factories.info(ctx)
def create_processor(
self,
model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
*,
tokenizer: TokenizerLike | None = None,
cache: BaseMultiModalProcessorCache | None = None,
......@@ -213,7 +209,7 @@ class MultiModalRegistry:
model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
ctx = self._create_processing_ctx(model_config, tokenizer)
return factories.build_processor(ctx, cache=cache)
......@@ -242,10 +238,8 @@ class MultiModalRegistry:
mm_options=mm_config.limit_per_prompt,
)
mm_inputs = processor.apply(
prompt=processor_inputs.prompt,
mm_items=processor_inputs.mm_items,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
processor_inputs,
timing_ctx=TimingContext(enabled=False),
)
prompt_token_ids = mm_inputs["prompt_token_ids"]
......@@ -335,3 +329,34 @@ class MultiModalRegistry:
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
else:
raise ValueError(f"Unknown cache type: {cache_type!r}")
class MultiModalTimingRegistry:
def __init__(self, observability_config: "ObservabilityConfig | None") -> None:
super().__init__()
if observability_config and observability_config.enable_mm_processor_stats:
self._lock = threading.Lock()
self._ctx_by_request_id = defaultdict[str, TimingContext](TimingContext)
self._enabled = True
else:
self._enabled = False
def get(self, request_id: str) -> TimingContext:
if not self._enabled:
return TimingContext(enabled=False)
with self._lock:
return self._ctx_by_request_id[request_id]
def stat(self) -> dict[str, dict[str, float]]:
if not self._enabled:
return {}
with self._lock:
stats = {
req_id: ctx.get_stats_dict()
for req_id, ctx in self._ctx_by_request_id.items()
}
self._ctx_by_request_id.clear()
return stats
......@@ -85,13 +85,13 @@ class BaseRenderer(ABC, Generic[_T]):
self._mm_cache_stats: MultiModalCacheStats | None = None
if config.model_config.is_multimodal_model:
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
from vllm.multimodal.registry import MultiModalTimingRegistry
mm_processor_cache = mm_registry.processor_cache_from_config(config)
with set_default_torch_num_threads():
self.mm_processor = mm_registry.create_processor(
config.model_config,
config.observability_config,
tokenizer=tokenizer,
cache=mm_processor_cache,
)
......@@ -102,6 +102,9 @@ class BaseRenderer(ABC, Generic[_T]):
# This is used to generate internal request ID for MM processing
# It has no relation to the request ID for engine core
self._mm_req_counter = AtomicCounter()
self._mm_timing_registry = MultiModalTimingRegistry(
config.observability_config
)
def get_tokenizer(self) -> _T:
tokenizer = self.tokenizer
......@@ -534,7 +537,7 @@ class BaseRenderer(ABC, Generic[_T]):
tokenization_kwargs: dict[str, Any] | None,
) -> "MultiModalInputs":
from vllm.multimodal.parse import parse_mm_uuids
from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}"
......@@ -543,18 +546,21 @@ class BaseRenderer(ABC, Generic[_T]):
mm_data_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuid_items = parse_mm_uuids(mm_uuids)
mm_uuids = self._process_mm_uuids(
mm_uuid_items = self._process_mm_uuids(
mm_data, mm_data_items, mm_uuid_items, mm_req_id
)
with set_request_id(mm_req_id), set_default_torch_num_threads():
mm_inputs = mm_processor.apply(
prompt,
mm_data_items,
mm_uuid_items,
hf_processor_mm_kwargs=mm_processor_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
mm_processor_inputs = MMProcessorInputs(
prompt,
mm_data_items,
mm_uuid_items,
hf_processor_mm_kwargs=mm_processor_kwargs or {},
tokenization_kwargs=tokenization_kwargs or {},
)
mm_timing_ctx = self._mm_timing_registry.get(mm_req_id)
with set_default_torch_num_threads():
mm_inputs = mm_processor.apply(mm_processor_inputs, mm_timing_ctx)
self.update_mm_cache_stats()
......
......@@ -6272,7 +6272,7 @@ class GPUModelRunner(
self.encoder_timing_registry[req_id] = EncoderTimingStats()
stats = self.encoder_timing_registry[req_id]
stats.encoder_forward_time += per_request_time
stats.encoder_forward_secs += per_request_time
stats.num_encoder_calls += 1
......@@ -6280,7 +6280,7 @@ class GPUModelRunner(
class EncoderTimingStats:
"""Per-request timing statistics for encoder forward pass."""
encoder_forward_time: float = 0.0
encoder_forward_secs: float = 0.0
"""Time spent in vision encoder forward pass (seconds)."""
num_encoder_calls: int = 0
......@@ -6288,6 +6288,6 @@ class EncoderTimingStats:
def to_dict(self) -> dict[str, float | int]:
return {
"encoder_forward_time": self.encoder_forward_time,
"encoder_forward_secs": self.encoder_forward_secs,
"num_encoder_calls": self.num_encoder_calls,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment