Unverified Commit 39264545 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Decouple TimingContext from InputProcessingContext (#35083)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 1e8438a8
...@@ -41,15 +41,16 @@ from vllm.multimodal.parse import ( ...@@ -41,15 +41,16 @@ from vllm.multimodal.parse import (
ImageProcessorItems, ImageProcessorItems,
ImageSize, ImageSize,
MultiModalDataItems, MultiModalDataItems,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
ProcessorInputs,
PromptIndexTargets, PromptIndexTargets,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
TimingContext,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -204,23 +205,20 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]): ...@@ -204,23 +205,20 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
def apply( def apply(
self, self,
prompt: str | list[int], inputs: ProcessorInputs,
mm_items: MultiModalDataItems, timing_ctx: TimingContext,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if mm_items: if inputs.mm_data_items:
if isinstance(prompt, str): if isinstance(inputs.prompt, str):
if len(prompt) > 0: if len(inputs.prompt) > 0:
raise ValueError( raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! " "CLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty text prompt." "You must pass an image with an empty text prompt."
) )
else: else:
special_tokens = self.info.get_tokenizer().all_special_ids special_tokens = self.info.get_tokenizer().all_special_ids
if all(tok in special_tokens for tok in prompt): if all(tok in special_tokens for tok in inputs.prompt):
prompt = [] inputs.prompt = []
else: else:
raise ValueError( raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! " "CLIP accepts text-only or image-only inputs, not both! "
...@@ -229,18 +227,12 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]): ...@@ -229,18 +227,12 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
# For multi-modal data, the prompt after processing should # For multi-modal data, the prompt after processing should
# only contain the dummy image tokens # only contain the dummy image tokens
tokenization_kwargs = { inputs.tokenization_kwargs = {
**(tokenization_kwargs or {}), **inputs.tokenization_kwargs,
"add_special_tokens": False, "add_special_tokens": False,
} }
return super().apply( return super().apply(inputs, timing_ctx)
prompt=prompt,
mm_items=mm_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
def _hf_processor_applies_updates( def _hf_processor_applies_updates(
self, self,
......
...@@ -30,15 +30,16 @@ from vllm.multimodal.parse import ( ...@@ -30,15 +30,16 @@ from vllm.multimodal.parse import (
ImageProcessorItems, ImageProcessorItems,
ImageSize, ImageSize,
MultiModalDataItems, MultiModalDataItems,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing import BaseDummyInputsBuilder from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import ( from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
MultiModalProcessingInfo, MultiModalProcessingInfo,
ProcessorInputs,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
TimingContext,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers import cached_tokenizer_from_config
...@@ -310,32 +311,17 @@ class DeepseekVL2MultiModalProcessor( ...@@ -310,32 +311,17 @@ class DeepseekVL2MultiModalProcessor(
def _cached_apply_hf_processor( def _cached_apply_hf_processor(
self, self,
prompt: str | list[int], inputs: ProcessorInputs,
mm_data_items: MultiModalDataItems, timing_ctx: TimingContext,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 2 vs > 2 # The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is # Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only # invariant of how many images are passed per prompt, we only
# perform caching for the most common case # perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 2: if inputs.mm_data_items.get_count("image", strict=False) > 2:
return self._apply_hf_processor( return self._apply_hf_processor(inputs, timing_ctx)
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return super()._cached_apply_hf_processor( return super()._cached_apply_hf_processor(inputs, timing_ctx)
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
......
...@@ -21,13 +21,14 @@ from vllm.multimodal.parse import ( ...@@ -21,13 +21,14 @@ from vllm.multimodal.parse import (
ImageEmbeddingItems, ImageEmbeddingItems,
ImageProcessorItems, ImageProcessorItems,
MultiModalDataItems, MultiModalDataItems,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing.processor import ( from vllm.multimodal.processing.processor import (
MultiModalProcessingInfo, MultiModalProcessingInfo,
ProcessorInputs,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
PromptUpdateDetails, PromptUpdateDetails,
TimingContext,
) )
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -490,32 +491,17 @@ class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingIn ...@@ -490,32 +491,17 @@ class H2OVLMultiModalProcessor(BaseInternVLMultiModalProcessor[H2OVLProcessingIn
def _cached_apply_hf_processor( def _cached_apply_hf_processor(
self, self,
prompt: str | list[int], inputs: ProcessorInputs,
mm_data_items: MultiModalDataItems, timing_ctx: TimingContext,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 1 vs > 1 # The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is # Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only # invariant of how many images are passed per prompt, we only
# perform caching for the most common case # perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 1: if inputs.mm_data_items.get_count("image", strict=False) > 1:
return self._apply_hf_processor( return self._apply_hf_processor(inputs, timing_ctx)
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return super()._cached_apply_hf_processor( return super()._cached_apply_hf_processor(inputs, timing_ctx)
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
......
...@@ -37,16 +37,17 @@ from vllm.multimodal.parse import ( ...@@ -37,16 +37,17 @@ from vllm.multimodal.parse import (
ImageProcessorItems, ImageProcessorItems,
ImageSize, ImageSize,
MultiModalDataItems, MultiModalDataItems,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
InputProcessingContext, InputProcessingContext,
ProcessorInputs,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
PromptUpdateDetails, PromptUpdateDetails,
TimingContext,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -770,11 +771,8 @@ class MantisProcessingInfo(LlavaProcessingInfo): ...@@ -770,11 +771,8 @@ class MantisProcessingInfo(LlavaProcessingInfo):
class MantisMultiModalProcessor(LlavaMultiModalProcessor): class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def apply( def apply(
self, self,
prompt: str | list[int], inputs: ProcessorInputs,
mm_items: MultiModalDataItems, timing_ctx: TimingContext,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
...@@ -785,15 +783,9 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -785,15 +783,9 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
image_height=-1, image_height=-1,
) )
result = super().apply( result = super().apply(inputs, timing_ctx)
prompt,
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
mm_item_counts = mm_items.get_all_counts() mm_item_counts = inputs.mm_data_items.get_all_counts()
mm_kwargs = result["mm_kwargs"] mm_kwargs = result["mm_kwargs"]
mm_hashes = result["mm_hashes"] mm_hashes = result["mm_hashes"]
...@@ -825,8 +817,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -825,8 +817,8 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
) )
orig_repls = self._get_mm_prompt_updates( orig_repls = self._get_mm_prompt_updates(
mm_items, inputs.mm_data_items,
hf_processor_mm_kwargs, inputs.hf_processor_mm_kwargs,
mm_kwargs, mm_kwargs,
) )
mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls) mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
......
...@@ -21,16 +21,17 @@ from vllm.multimodal.parse import ( ...@@ -21,16 +21,17 @@ from vllm.multimodal.parse import (
ImageEmbeddingItems, ImageEmbeddingItems,
ImageProcessorItems, ImageProcessorItems,
MultiModalDataItems, MultiModalDataItems,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
ProcessorInputs,
PromptIndexTargets, PromptIndexTargets,
PromptInsertion, PromptInsertion,
PromptUpdate, PromptUpdate,
PromptUpdateDetails, PromptUpdateDetails,
TimingContext,
) )
from vllm.renderers import TokenizeParams from vllm.renderers import TokenizeParams
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -228,19 +229,10 @@ class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingIn ...@@ -228,19 +229,10 @@ class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingIn
def apply( def apply(
self, self,
prompt: str | list[int], inputs: ProcessorInputs,
mm_items: MultiModalDataItems, timing_ctx: TimingContext,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
mm_inputs = super().apply( mm_inputs = super().apply(inputs, timing_ctx)
prompt,
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
......
...@@ -50,16 +50,17 @@ from vllm.multimodal.parse import ( ...@@ -50,16 +50,17 @@ from vllm.multimodal.parse import (
ImageProcessorItems, ImageProcessorItems,
ImageSize, ImageSize,
MultiModalDataItems, MultiModalDataItems,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import ( from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
MultiModalProcessingInfo, MultiModalProcessingInfo,
ProcessorInputs,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
PromptUpdateDetails, PromptUpdateDetails,
TimingContext,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -277,7 +278,6 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): ...@@ -277,7 +278,6 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
dummy_text = self.get_dummy_text(mm_counts) dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_images = dummy_mm_data.get("image", []) dummy_images = dummy_mm_data.get("image", [])
tokenization_kwargs = {"truncation": False}
request = ChatCompletionRequest( request = ChatCompletionRequest(
messages=[ messages=[
...@@ -294,11 +294,7 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): ...@@ -294,11 +294,7 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data) dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
return ProcessorInputs( return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
prompt=dummy_tokens,
mm_items=dummy_mm_items,
tokenization_kwargs=tokenization_kwargs,
)
class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]): class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]):
...@@ -344,19 +340,10 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]) ...@@ -344,19 +340,10 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
def _cached_apply_hf_processor( def _cached_apply_hf_processor(
self, self,
prompt: str | list[int], inputs: ProcessorInputs,
mm_data_items: MultiModalDataItems, timing_ctx: TimingContext,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(inputs, timing_ctx)
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# NOTE: The tokens are already inserted by the chat template # NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_info, True return prompt_ids, mm_info, True
......
...@@ -47,15 +47,16 @@ from vllm.multimodal.parse import ( ...@@ -47,15 +47,16 @@ from vllm.multimodal.parse import (
ImageProcessorItems, ImageProcessorItems,
ImageSize, ImageSize,
MultiModalDataItems, MultiModalDataItems,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
ProcessorInputs,
PromptIndexTargets, PromptIndexTargets,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
TimingContext,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -190,23 +191,20 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]): ...@@ -190,23 +191,20 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
def apply( def apply(
self, self,
prompt: str | list[int], inputs: ProcessorInputs,
mm_items: MultiModalDataItems, timing_ctx: TimingContext,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if mm_items: if inputs.mm_data_items:
if isinstance(prompt, str): if isinstance(inputs.prompt, str):
if len(prompt) > 0: if len(inputs.prompt) > 0:
raise ValueError( raise ValueError(
"SigLIP accepts text-only or image-only inputs, not both! " "SigLIP accepts text-only or image-only inputs, not both! "
"You must pass an image with an empty text prompt." "You must pass an image with an empty text prompt."
) )
else: else:
special_tokens = self.info.get_tokenizer().all_special_ids special_tokens = self.info.get_tokenizer().all_special_ids
if all(tok in special_tokens for tok in prompt): if all(tok in special_tokens for tok in inputs.prompt):
prompt = [] inputs.prompt = []
else: else:
raise ValueError( raise ValueError(
"SigLIP accepts text-only or image-only inputs, not both! " "SigLIP accepts text-only or image-only inputs, not both! "
...@@ -214,19 +212,13 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]): ...@@ -214,19 +212,13 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
) )
# For multi-modal data, the prompt after processing should # For multi-modal data, the prompt after processing should
# only contain the image token # only contain the dummy image tokens
tokenization_kwargs = { inputs.tokenization_kwargs = {
**(tokenization_kwargs or {}), **inputs.tokenization_kwargs,
"add_special_tokens": False, "add_special_tokens": False,
} }
return super().apply( return super().apply(inputs, timing_ctx)
prompt=prompt,
mm_items=mm_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
def _hf_processor_applies_updates( def _hf_processor_applies_updates(
self, self,
......
...@@ -54,13 +54,14 @@ from vllm.multimodal.parse import ( ...@@ -54,13 +54,14 @@ from vllm.multimodal.parse import (
ModalityDataItems, ModalityDataItems,
MultiModalDataItems, MultiModalDataItems,
MultiModalDataParser, MultiModalDataParser,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
ProcessorInputs,
PromptUpdate, PromptUpdate,
TimingContext,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -193,29 +194,21 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing ...@@ -193,29 +194,21 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
def apply( def apply(
self, self,
prompt: str | list[int], inputs: ProcessorInputs,
mm_items: MultiModalDataItems, timing_ctx: TimingContext,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if hf_processor_mm_kwargs is None: mm_items = inputs.mm_data_items
hf_processor_mm_kwargs = {} hf_processor_mm_kwargs = inputs.hf_processor_mm_kwargs
if tokenization_kwargs is None:
tokenization_kwargs = {} with timing_ctx.record("apply_hf_processor"):
_, passthrough_data = self._get_hf_mm_data(mm_items)
mm_hashes = self._hash_mm_items( mm_processed_data = BatchFeature(
mm_items, {
mm_uuid_items, k: torch.as_tensor(v).unsqueeze(0)
hf_processor_mm_kwargs=hf_processor_mm_kwargs, for k, v in passthrough_data.items()
) },
tensor_type="pt",
_, passthrough_data = self._get_hf_mm_data(mm_items) )
mm_processed_data = BatchFeature(
{k: torch.as_tensor(v).unsqueeze(0) for k, v in passthrough_data.items()},
tensor_type="pt",
)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_processed_data, mm_processed_data,
...@@ -226,6 +219,11 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing ...@@ -226,6 +219,11 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
), ),
) )
with timing_ctx.record("get_mm_hashes"):
mm_hashes = inputs.get_mm_hashes(self.info.model_id)
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
return mm_inputs( return mm_inputs(
prompt_token_ids=[1], prompt_token_ids=[1],
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
......
...@@ -37,12 +37,13 @@ from vllm.multimodal.inputs import ( ...@@ -37,12 +37,13 @@ from vllm.multimodal.inputs import (
from vllm.multimodal.parse import ( from vllm.multimodal.parse import (
ImageProcessorItems, ImageProcessorItems,
MultiModalDataItems, MultiModalDataItems,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
ProcessorInputs,
TimingContext,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -177,11 +178,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -177,11 +178,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
def apply( def apply(
self, self,
prompt: str | list[int], inputs: ProcessorInputs,
mm_items: MultiModalDataItems, timing_ctx: TimingContext,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
...@@ -189,29 +187,30 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -189,29 +187,30 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
Apply HF Processor on prompt text and multi-modal data together, Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors. outputting token IDs and processed tensors.
""" """
if hf_processor_mm_kwargs is None: prompt = inputs.prompt
hf_processor_mm_kwargs = {} mm_items = inputs.mm_data_items
if tokenization_kwargs is None: hf_processor_mm_kwargs = inputs.hf_processor_mm_kwargs
tokenization_kwargs = {} tokenization_kwargs = inputs.tokenization_kwargs
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) with timing_ctx.record("apply_hf_processor"):
if not isinstance(prompt, str): hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
# the prompt is the tokenized ids which is not supported if not isinstance(prompt, str):
# by the hf_processor, which is why we would need to decode the ids # the prompt is the tokenized ids which is not supported
# into string # by the hf_processor, which is why we would need to decode the ids
prompt = hf_processor.decode(prompt) # into string
prompt = hf_processor.decode(prompt)
# Bypass cached processor and always apply to the full set of mm inputs
# NOTE: we can't just set caching=False because base class method # Bypass cached processor and always apply to the full set of mm inputs
# transforms outputs to `MultiModalKwargs` which is not going to # NOTE: we can't just set caching=False because base class method
# work for Transformers. We have a lot of logic tied to # transforms outputs to `MultiModalKwargs` which is not going to
# `mm_tokens_per_modality` below # work for Transformers. We have a lot of logic tied to
prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm( # `mm_tokens_per_modality` below
prompt_text=prompt, prompt_ids, processed_data, _ = self._apply_hf_processor_text_mm(
mm_items=mm_items, prompt_text=prompt,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, mm_items=mm_items,
tokenization_kwargs=tokenization_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
) tokenization_kwargs=tokenization_kwargs,
)
# For gemma3 we check `token_type_ids` as the key # For gemma3 we check `token_type_ids` as the key
token_type_key = ( token_type_key = (
...@@ -225,15 +224,14 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -225,15 +224,14 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
# it for each input `mm_data`. # it for each input `mm_data`.
mm_positions = torch.where(mm_token_type_ids == 1)[1] mm_positions = torch.where(mm_token_type_ids == 1)[1]
images = mm_items.get_items("image", ImageProcessorItems) images = mm_items.get_items("image", ImageProcessorItems)
multimodal_config = self.info.ctx.model_config.multimodal_config
mm_processor_kwargs = multimodal_config.mm_processor_kwargs or {}
image_sizes = [] image_sizes = []
for item_idx in range(len(images)): for item_idx in range(len(images)):
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
image_sizes.append((image_size.height, image_size.width)) image_sizes.append((image_size.height, image_size.width))
mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens( mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
image_sizes=image_sizes, **mm_processor_kwargs image_sizes=image_sizes,
**self.info.ctx.get_merged_mm_kwargs({}),
) )
mm_placeholders = {} mm_placeholders = {}
...@@ -261,11 +259,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): ...@@ -261,11 +259,8 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
) )
# Use overrides if provided; fallback to data-dependent hashing. # Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = self._hash_mm_items( with timing_ctx.record("get_mm_hashes"):
mm_items, mm_hashes = inputs.get_mm_hashes(self.info.model_id)
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
return mm_inputs( return mm_inputs(
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
......
...@@ -47,16 +47,17 @@ from vllm.multimodal.parse import ( ...@@ -47,16 +47,17 @@ from vllm.multimodal.parse import (
AudioProcessorItems, AudioProcessorItems,
MultiModalDataItems, MultiModalDataItems,
MultiModalDataParser, MultiModalDataParser,
MultiModalUUIDItems,
) )
from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import ( from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
MultiModalProcessingInfo, MultiModalProcessingInfo,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
ProcessorInputs,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
TimingContext,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers import cached_tokenizer_from_config
...@@ -265,13 +266,13 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): ...@@ -265,13 +266,13 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
res = tokenizer.mistral.encode_chat_completion(request) res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens dummy_tokens = res.tokens
dummy_mm_inputs = self.info.parse_mm_data( dummy_mm_items = self.info.parse_mm_data(
# whixtral tokenizer adds padding to the audio # whixtral tokenizer adds padding to the audio
# so we need to update the audio arrays # so we need to update the audio arrays
{**dummy_mm_data, "audio": [a.audio_array for a in res.audios]}, {**dummy_mm_data, "audio": [a.audio_array for a in res.audios]},
) )
return ProcessorInputs(prompt=dummy_tokens, mm_items=dummy_mm_inputs) return ProcessorInputs(prompt=dummy_tokens, mm_data_items=dummy_mm_items)
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]): class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]):
...@@ -361,19 +362,10 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]) ...@@ -361,19 +362,10 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
def _cached_apply_hf_processor( def _cached_apply_hf_processor(
self, self,
prompt: str | list[int], inputs: ProcessorInputs,
mm_data_items: MultiModalDataItems, timing_ctx: TimingContext,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor( prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(inputs, timing_ctx)
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# NOTE: The tokens are already inserted by the chat template # NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_info, True return prompt_ids, mm_info, True
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .context import BaseProcessingInfo, InputProcessingContext from .context import BaseProcessingInfo, InputProcessingContext, TimingContext
from .dummy_inputs import BaseDummyInputsBuilder, ProcessorInputs from .dummy_inputs import BaseDummyInputsBuilder
from .inputs import ProcessorInputs
from .processor import ( from .processor import (
BaseMultiModalProcessor, BaseMultiModalProcessor,
EncDecMultiModalProcessor, EncDecMultiModalProcessor,
...@@ -15,6 +16,7 @@ from .processor import ( ...@@ -15,6 +16,7 @@ from .processor import (
__all__ = [ __all__ = [
"BaseProcessingInfo", "BaseProcessingInfo",
"InputProcessingContext", "InputProcessingContext",
"TimingContext",
"BaseDummyInputsBuilder", "BaseDummyInputsBuilder",
"ProcessorInputs", "ProcessorInputs",
"BaseMultiModalProcessor", "BaseMultiModalProcessor",
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextvars
import threading
import time import time
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Generator, Mapping from collections.abc import Mapping
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import cached_property from functools import cached_property
...@@ -33,104 +31,53 @@ if TYPE_CHECKING: ...@@ -33,104 +31,53 @@ if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
from vllm.config import ModelConfig, ObservabilityConfig from vllm.config import ModelConfig
else: else:
PretrainedConfig = object PretrainedConfig = object
BatchFeature = object BatchFeature = object
ProcessorMixin = object ProcessorMixin = object
ModelConfig = object ModelConfig = object
ObservabilityConfig = object
logger = init_logger(__name__) logger = init_logger(__name__)
_request_id_context: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"_request_id_context", default=None
)
def get_current_request_id() -> str | None:
"""Get the current request_id from the context, if available."""
return _request_id_context.get()
@contextmanager
def set_request_id(request_id: str) -> Generator[None, None, None]:
"""Context manager to set the request_id for the current context."""
token = _request_id_context.set(request_id)
try:
yield
finally:
_request_id_context.reset(token)
@dataclass @dataclass
class MultiModalProcessorTimingStats: class TimingContext:
"""Per-request timing statistics for multimodal processor stages.""" """Helper class to record execution times during multi-modal processing."""
hf_processor_time: float = 0.0
"""Time spent in HuggingFace processor calls (seconds)."""
hashing_time: float = 0.0 enabled: bool = True
"""Time spent computing multimodal item hashes (seconds).""" """If disabled, `TimingContext.record` becomes a no-op."""
cache_lookup_time: float = 0.0 stage_secs: dict[str, float] = field(default_factory=dict)
"""Time spent in cache lookups and merges (seconds).""" """The execution time (in seconds) for each processing stage."""
prompt_update_time: float = 0.0 @property
"""Time spent applying prompt updates and finding placeholders (seconds).""" def total_secs(self) -> float:
return sum(self.stage_secs.values())
preprocessor_total_time: float = 0.0 @contextmanager
"""Total preprocessing time (seconds).""" def record(self, stage: str):
"""Record the execution time for a processing stage."""
if not self.enabled:
yield
return
def to_dict(self) -> dict[str, float]: start_time = time.perf_counter()
"""Convert stats to a dictionary for JSON serialization.""" try:
return { yield
"hf_processor_time": self.hf_processor_time, finally:
"hashing_time": self.hashing_time, elapsed = time.perf_counter() - start_time
"cache_lookup_time": self.cache_lookup_time, self.stage_secs.setdefault(stage, 0.0)
"prompt_update_time": self.prompt_update_time, self.stage_secs[stage] += elapsed
"preprocessor_total_time": self.preprocessor_total_time,
def get_stats_dict(self):
stats_dict = {
f"{stage}_secs": time_s for stage, time_s in self.stage_secs.items()
} }
stats_dict["preprocessor_total_secs"] = self.total_secs
return stats_dict
@contextmanager
def timed_preprocessor_operation(ctx: "InputProcessingContext", stage_name: str):
"""
Context manager to time an operation using the context's timing stats.
The request_id is automatically retrieved from the context variable,
so it doesn't need to be passed as a parameter.
Args:
ctx: The InputProcessingContext containing the timing stats registry.
stage_name: Name of the stage being timed.
"""
request_id = get_current_request_id()
if ctx is None or request_id is None:
yield
return
stats = ctx.get_timing_stats(request_id)
if stats is None:
yield
return
start_time = time.perf_counter()
try:
yield
finally:
elapsed = time.perf_counter() - start_time
if stage_name == "hf_processor":
stats.hf_processor_time += elapsed
elif stage_name == "hashing":
stats.hashing_time += elapsed
elif stage_name == "cache_lookup":
stats.cache_lookup_time += elapsed
elif stage_name == "prompt_update":
stats.prompt_update_time += elapsed
stats.preprocessor_total_time += elapsed
_T = TypeVar("_T") _T = TypeVar("_T")
...@@ -151,21 +98,6 @@ class InputProcessingContext: ...@@ -151,21 +98,6 @@ class InputProcessingContext:
tokenizer: TokenizerLike | None tokenizer: TokenizerLike | None
"""The tokenizer used to tokenize the inputs.""" """The tokenizer used to tokenize the inputs."""
observability_config: "ObservabilityConfig | None" = field(
default=None, compare=False, repr=False
)
"""Configuration for observability features."""
timing_stats_registry: dict[str, MultiModalProcessorTimingStats] = field(
default_factory=dict, compare=False, repr=False
)
"""Registry for storing timing stats keyed by request_id."""
_timing_stats_registry_lock: threading.Lock = field(
default_factory=threading.Lock, compare=False, repr=False
)
"""Lock for thread-safe access to timing_stats_registry."""
def get_tokenizer(self) -> TokenizerLike: def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError( raise ValueError(
...@@ -379,71 +311,6 @@ class InputProcessingContext: ...@@ -379,71 +311,6 @@ class InputProcessingContext:
return self._postprocess_output(output) return self._postprocess_output(output)
def get_timing_stats(
self, request_id: str
) -> MultiModalProcessorTimingStats | None:
"""
Get timing stats for a request.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return None
with self._timing_stats_registry_lock:
return self.timing_stats_registry.get(request_id)
def create_timing_stats(self, request_id: str) -> MultiModalProcessorTimingStats:
"""
Create and store timing stats in the registry for a request.
This should be called at the start of processing for a request.
The stats object is created immediately and stored in the registry.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return MultiModalProcessorTimingStats()
with self._timing_stats_registry_lock:
if request_id in self.timing_stats_registry:
raise ValueError(
f"Timing stats already exist for request_id: {request_id}"
)
stats = MultiModalProcessorTimingStats()
self.timing_stats_registry[request_id] = stats
return stats
def clear_timing_stats_registry(self) -> int:
"""
Clear all stats from the registry. Returns the number of stats cleared.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return 0
with self._timing_stats_registry_lock:
count = len(self.timing_stats_registry)
self.timing_stats_registry.clear()
return count
def get_all_timing_stats(self) -> dict[str, dict[str, float]]:
"""
Get all timing stats as a dictionary for API endpoints.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return {}
with self._timing_stats_registry_lock:
return {
rid: stats.to_dict()
for rid, stats in self.timing_stats_registry.items()
}
class BaseProcessingInfo: class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing.""" """Base class to provide the information necessary for data processing."""
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Generic, TypeVar from typing import Generic, TypeVar
import numpy as np import numpy as np
...@@ -18,27 +17,14 @@ from vllm.config.multimodal import ( ...@@ -18,27 +17,14 @@ from vllm.config.multimodal import (
from vllm.logger import init_logger from vllm.logger import init_logger
from ..inputs import MultiModalDataDict from ..inputs import MultiModalDataDict
from ..parse import MultiModalDataItems
from .context import BaseProcessingInfo from .context import BaseProcessingInfo
from .inputs import ProcessorInputs
_I = TypeVar("_I", bound=BaseProcessingInfo) _I = TypeVar("_I", bound=BaseProcessingInfo)
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class ProcessorInputs:
"""
Represents the keyword arguments to
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
"""
prompt: str | list[int]
mm_items: MultiModalDataItems
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
class BaseDummyInputsBuilder(ABC, Generic[_I]): class BaseDummyInputsBuilder(ABC, Generic[_I]):
""" """
Abstract base class that constructs the dummy data to profile Abstract base class that constructs the dummy data to profile
...@@ -101,7 +87,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -101,7 +87,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
return ProcessorInputs( return ProcessorInputs(
prompt=dummy_text, prompt=dummy_text,
mm_items=dummy_mm_items, mm_data_items=dummy_mm_items,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from dataclasses import dataclass, field
from ..hasher import MultiModalHasher
from ..inputs import MultiModalHashes
from ..parse import MultiModalDataItems, MultiModalUUIDItems
@dataclass
class ProcessorInputs:
"""
Represents the keyword arguments to
[`vllm.multimodal.processing.BaseMultiModalProcessor.apply`][].
"""
prompt: str | list[int]
mm_data_items: MultiModalDataItems
mm_uuid_items: MultiModalUUIDItems | None = None
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
def get_mm_hashes(self, model_id: str) -> MultiModalHashes:
mm_data_items = self.mm_data_items
mm_uuid_items = self.mm_uuid_items or {}
hf_processor_mm_kwargs = self.hf_processor_mm_kwargs
mm_hashes: MultiModalHashes = {}
hasher = MultiModalHasher
for modality, data_items in mm_data_items.items():
if modality in mm_uuid_items:
uuid_items = mm_uuid_items[modality]
# For None entries, compute a hash; otherwise, use provided ID.
hashes: list[str] = []
for i, item in enumerate(data_items.get_all_items_for_hash()):
uuid_item = uuid_items[i]
# NOTE: Even if a uuid_item is provided, we still compute a hash
# if `hf_processor_mm_kwargs` is provided.
# This is because the processed multimodal inputs can be different
# depending on the processor kwargs.
if uuid_item is None or hf_processor_mm_kwargs:
# NOTE: use provided hash string to hash with kwargs
# if available for better performance.
item = uuid_item if uuid_item is not None else item
hashes.append(
hasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
)
)
else:
hashes.append(uuid_item)
mm_hashes[modality] = hashes
else:
mm_hashes[modality] = [
hasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
)
for item in data_items
]
return mm_hashes
...@@ -23,7 +23,6 @@ from vllm.logger import init_logger ...@@ -23,7 +23,6 @@ from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from ..hasher import MultiModalHasher
from ..inputs import ( from ..inputs import (
MultiModalEncDecInputs, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalFieldConfig,
...@@ -42,12 +41,9 @@ from ..parse import ( ...@@ -42,12 +41,9 @@ from ..parse import (
MultiModalDataItems, MultiModalDataItems,
MultiModalUUIDItems, MultiModalUUIDItems,
) )
from .context import ( from .context import BaseProcessingInfo, TimingContext
BaseProcessingInfo,
get_current_request_id,
timed_preprocessor_operation,
)
from .dummy_inputs import BaseDummyInputsBuilder from .dummy_inputs import BaseDummyInputsBuilder
from .inputs import ProcessorInputs
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature from transformers.feature_extraction_utils import BatchFeature
...@@ -1017,13 +1013,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1017,13 +1013,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_uuid_items: MultiModalUUIDItems | None = None, mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None, hf_processor_mm_kwargs: Mapping[str, object] | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
return self.apply( processor_inputs = ProcessorInputs(
prompt, prompt,
mm_items, mm_items,
mm_uuid_items, mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs or {},
) )
return self.apply(processor_inputs, TimingContext(enabled=False))
@abstractmethod @abstractmethod
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
...@@ -1139,12 +1137,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1139,12 +1137,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Call the HF processor on the prompt text and Call the HF processor on the prompt text and
associated multi-modal data. associated multi-modal data.
""" """
with timed_preprocessor_operation(self.info.ctx, "hf_processor"): return self.info.ctx.call_hf_processor(
return self.info.ctx.call_hf_processor( self.info.get_hf_processor(**mm_kwargs),
self.info.get_hf_processor(**mm_kwargs), dict(text=prompt, **mm_data),
dict(text=prompt, **mm_data), dict(**mm_kwargs, **tok_kwargs),
dict(**mm_kwargs, **tok_kwargs), )
)
def _hf_processor_applies_updates( def _hf_processor_applies_updates(
self, self,
...@@ -1306,60 +1303,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1306,60 +1303,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return prompt_ids, mm_processed_data, False return prompt_ids, mm_processed_data, False
def _hash_mm_items(
self,
mm_data_items: MultiModalDataItems,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalHashes:
model_id = self.info.model_id
if mm_uuid_items is None:
mm_uuid_items = {}
mm_hashes: MultiModalHashes = {}
hasher = MultiModalHasher
for modality, data_items in mm_data_items.items():
if modality in mm_uuid_items:
uuid_items = mm_uuid_items[modality]
# For None entries, compute a hash; otherwise, use provided ID.
hashes: list[str] = []
for i, item in enumerate(data_items.get_all_items_for_hash()):
uuid_item = uuid_items[i]
# NOTE: Even if a uuid_item is provided, we still compute a hash
# if `hf_processor_mm_kwargs` is provided.
# This is because the processed multimodal inputs can be different
# depending on the processor kwargs.
if uuid_item is None or hf_processor_mm_kwargs:
# NOTE: use provided hash string to hash with kwargs
# if available for better performance.
item = uuid_item if uuid_item is not None else item
hashes.append(
hasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
)
)
else:
hashes.append(uuid_item)
mm_hashes[modality] = hashes
else:
mm_hashes[modality] = [
hasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
)
for item in data_items
]
return mm_hashes
def _get_cache_missing_items( def _get_cache_missing_items(
self, self,
cache: BaseMultiModalProcessorCache, cache: BaseMultiModalProcessorCache,
...@@ -1461,40 +1404,36 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1461,40 +1404,36 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _apply_hf_processor( def _apply_hf_processor(
self, self,
prompt: str | list[int], inputs: ProcessorInputs,
mm_data_items: MultiModalDataItems, timing_ctx: TimingContext,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
( with timing_ctx.record("apply_hf_processor"):
prompt_ids, (
mm_processed_data, prompt_ids,
is_update_applied, mm_processed_data,
) = self._apply_hf_processor_main( is_update_applied,
prompt=prompt, ) = self._apply_hf_processor_main(
mm_items=mm_data_items, prompt=inputs.prompt,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, mm_items=inputs.mm_data_items,
tokenization_kwargs=tokenization_kwargs, hf_processor_mm_kwargs=inputs.hf_processor_mm_kwargs,
enable_hf_prompt_update=True, tokenization_kwargs=inputs.tokenization_kwargs,
) enable_hf_prompt_update=True,
)
mm_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_processed_data, mm_processed_data,
self._get_mm_fields_config(mm_processed_data, hf_processor_mm_kwargs), self._get_mm_fields_config(
mm_processed_data, inputs.hf_processor_mm_kwargs
),
) )
# Use overrides if provided; fallback to data-dependent hashing. # Use overrides if provided; fallback to data-dependent hashing.
with timed_preprocessor_operation(self.info.ctx, "hashing"): with timing_ctx.record("get_mm_hashes"):
mm_hashes = self._hash_mm_items( mm_hashes = inputs.get_mm_hashes(self.info.model_id)
mm_data_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
mm_prompt_updates = self._get_mm_prompt_updates( mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items, inputs.mm_data_items,
hf_processor_mm_kwargs, inputs.hf_processor_mm_kwargs,
mm_kwargs, mm_kwargs,
) )
...@@ -1508,11 +1447,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1508,11 +1447,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _cached_apply_hf_processor( def _cached_apply_hf_processor(
self, self,
prompt: str | list[int], inputs: ProcessorInputs,
mm_data_items: MultiModalDataItems, timing_ctx: TimingContext,
mm_uuid_items: MultiModalUUIDItems | None,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalProcessingInfo, bool]: ) -> tuple[list[int], MultiModalProcessingInfo, bool]:
""" """
Apply the HF processor on the full prompt text, Apply the HF processor on the full prompt text,
...@@ -1520,59 +1456,50 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1520,59 +1456,50 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
""" """
cache = self.cache cache = self.cache
_, passthrough_data = self._get_hf_mm_data(mm_data_items) _, passthrough_data = self._get_hf_mm_data(inputs.mm_data_items)
if cache is None or passthrough_data: if cache is None or passthrough_data:
return self._apply_hf_processor( return self._apply_hf_processor(inputs, timing_ctx)
prompt=prompt,
mm_data_items=mm_data_items,
mm_uuid_items=mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
with timed_preprocessor_operation(self.info.ctx, "hashing"): with timing_ctx.record("get_mm_hashes"):
mm_hashes = self._hash_mm_items( mm_hashes = inputs.get_mm_hashes(self.info.model_id)
mm_data_items,
mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
with timed_preprocessor_operation(self.info.ctx, "cache_lookup"): with timing_ctx.record("get_cache_missing_items"):
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items( mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
cache=cache, cache=cache,
mm_data_items=mm_data_items, mm_data_items=inputs.mm_data_items,
mm_hashes=mm_hashes, mm_hashes=mm_hashes,
) )
# NOTE: `prompt` does not correspond to `mm_missing_data_items`, # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt updates until the new multimodal # so we can't apply prompt updates until the new multimodal
# items are combined with the cached multimodal items # items are combined with the cached multimodal items
( with timing_ctx.record("apply_hf_processor"):
prompt_ids, (
mm_missing_processed_data, prompt_ids,
is_update_applied, mm_missing_processed_data,
) = self._apply_hf_processor_main( is_update_applied,
prompt=prompt, ) = self._apply_hf_processor_main(
mm_items=mm_missing_data_items, prompt=inputs.prompt,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, mm_items=mm_missing_data_items,
tokenization_kwargs=tokenization_kwargs, hf_processor_mm_kwargs=inputs.hf_processor_mm_kwargs,
enable_hf_prompt_update=False, tokenization_kwargs=inputs.tokenization_kwargs,
) enable_hf_prompt_update=False,
)
mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs( mm_missing_kwargs = MultiModalKwargsItems.from_hf_inputs(
mm_missing_processed_data, mm_missing_processed_data,
self._get_mm_fields_config( self._get_mm_fields_config(
mm_missing_processed_data, hf_processor_mm_kwargs mm_missing_processed_data, inputs.hf_processor_mm_kwargs
), ),
) )
mm_missing_prompt_updates = self._get_mm_prompt_updates( mm_missing_prompt_updates = self._get_mm_prompt_updates(
mm_missing_data_items, mm_missing_data_items,
hf_processor_mm_kwargs, inputs.hf_processor_mm_kwargs,
mm_missing_kwargs, mm_missing_kwargs,
) )
with timed_preprocessor_operation(self.info.ctx, "cache_lookup"): with timing_ctx.record("merge_mm_kwargs"):
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs( mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
cache, cache,
mm_hashes=mm_hashes, mm_hashes=mm_hashes,
...@@ -1742,11 +1669,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1742,11 +1669,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def apply( def apply(
self, self,
prompt: str | list[int], inputs: ProcessorInputs,
mm_items: MultiModalDataItems, timing_ctx: TimingContext,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
...@@ -1761,31 +1685,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1761,31 +1685,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
3. Extract information about the placeholder tokens from the 3. Extract information about the placeholder tokens from the
processed token IDs. processed token IDs.
""" """
request_id = get_current_request_id()
if request_id is not None:
self.info.ctx.create_timing_stats(request_id)
if hf_processor_mm_kwargs is None:
hf_processor_mm_kwargs = {}
if tokenization_kwargs is None:
tokenization_kwargs = {}
( (
prompt_ids, prompt_ids,
mm_info, mm_info,
is_update_applied, is_update_applied,
) = self._cached_apply_hf_processor( ) = self._cached_apply_hf_processor(inputs, timing_ctx)
prompt,
mm_items,
mm_uuid_items,
hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# NOTE: tokenization_kwargs are not required to init processor # NOTE: tokenization_kwargs are not required to init processor
with timed_preprocessor_operation(self.info.ctx, "prompt_update"): with timing_ctx.record("apply_prompt_updates"):
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates( prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items, mm_items=inputs.mm_data_items,
prompt_ids=prompt_ids, prompt_ids=prompt_ids,
mm_kwargs=mm_info.kwargs, mm_kwargs=mm_info.kwargs,
mm_prompt_updates=mm_info.prompt_updates, mm_prompt_updates=mm_info.prompt_updates,
...@@ -1851,11 +1760,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1851,11 +1760,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
def apply( def apply(
self, self,
prompt: str | list[int], inputs: ProcessorInputs,
mm_items: MultiModalDataItems, timing_ctx: TimingContext,
mm_uuid_items: MultiModalUUIDItems | None = None,
hf_processor_mm_kwargs: Mapping[str, object] | None = None,
tokenization_kwargs: Mapping[str, object] | None = None,
) -> MultiModalEncDecInputs: ) -> MultiModalEncDecInputs:
""" """
Process multi-modal inputs to be used in vLLM. Process multi-modal inputs to be used in vLLM.
...@@ -1864,17 +1770,22 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1864,17 +1770,22 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
2. Apply the HF processor on encoder prompt. 2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs. 3. Copy the input prompt text as decoder prompt inputs.
""" """
encoder_prompt = self.create_encoder_prompt(prompt, mm_items) encoder_prompt = self.create_encoder_prompt(
encoder_inputs = super().apply( inputs.prompt,
inputs.mm_data_items,
)
encoder_processor_inputs = ProcessorInputs(
encoder_prompt, encoder_prompt,
mm_items, inputs.mm_data_items,
mm_uuid_items, inputs.mm_uuid_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=inputs.hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=inputs.tokenization_kwargs,
) )
encoder_inputs = super().apply(encoder_processor_inputs, timing_ctx)
return self._get_enc_dec_inputs( return self._get_enc_dec_inputs(
prompt=prompt, prompt=inputs.prompt,
mm_items=mm_items, mm_items=inputs.mm_data_items,
encoder_inputs=encoder_inputs, encoder_inputs=encoder_inputs,
) )
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from collections import defaultdict
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing.synchronize import Lock as LockType from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
from vllm.config.observability import ObservabilityConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
...@@ -24,6 +25,7 @@ from .processing import ( ...@@ -24,6 +25,7 @@ from .processing import (
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
InputProcessingContext, InputProcessingContext,
TimingContext,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -174,32 +176,26 @@ class MultiModalRegistry: ...@@ -174,32 +176,26 @@ class MultiModalRegistry:
def _create_processing_ctx( def _create_processing_ctx(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
tokenizer: TokenizerLike | None = None, tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext: ) -> InputProcessingContext:
if tokenizer is None: if tokenizer is None:
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext( return InputProcessingContext(model_config, tokenizer)
model_config, tokenizer, observability_config=observability_config
)
def _create_processing_info( def _create_processing_info(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
*,
tokenizer: TokenizerLike | None = None, tokenizer: TokenizerLike | None = None,
) -> BaseProcessingInfo: ) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config) model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, observability_config, tokenizer) ctx = self._create_processing_ctx(model_config, tokenizer)
return factories.info(ctx) return factories.info(ctx)
def create_processor( def create_processor(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
*, *,
tokenizer: TokenizerLike | None = None, tokenizer: TokenizerLike | None = None,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
...@@ -213,7 +209,7 @@ class MultiModalRegistry: ...@@ -213,7 +209,7 @@ class MultiModalRegistry:
model_cls = self._get_model_cls(model_config) model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, observability_config, tokenizer) ctx = self._create_processing_ctx(model_config, tokenizer)
return factories.build_processor(ctx, cache=cache) return factories.build_processor(ctx, cache=cache)
...@@ -242,10 +238,8 @@ class MultiModalRegistry: ...@@ -242,10 +238,8 @@ class MultiModalRegistry:
mm_options=mm_config.limit_per_prompt, mm_options=mm_config.limit_per_prompt,
) )
mm_inputs = processor.apply( mm_inputs = processor.apply(
prompt=processor_inputs.prompt, processor_inputs,
mm_items=processor_inputs.mm_items, timing_ctx=TimingContext(enabled=False),
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
) )
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
...@@ -335,3 +329,34 @@ class MultiModalRegistry: ...@@ -335,3 +329,34 @@ class MultiModalRegistry:
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock) return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
else: else:
raise ValueError(f"Unknown cache type: {cache_type!r}") raise ValueError(f"Unknown cache type: {cache_type!r}")
class MultiModalTimingRegistry:
def __init__(self, observability_config: "ObservabilityConfig | None") -> None:
super().__init__()
if observability_config and observability_config.enable_mm_processor_stats:
self._lock = threading.Lock()
self._ctx_by_request_id = defaultdict[str, TimingContext](TimingContext)
self._enabled = True
else:
self._enabled = False
def get(self, request_id: str) -> TimingContext:
if not self._enabled:
return TimingContext(enabled=False)
with self._lock:
return self._ctx_by_request_id[request_id]
def stat(self) -> dict[str, dict[str, float]]:
if not self._enabled:
return {}
with self._lock:
stats = {
req_id: ctx.get_stats_dict()
for req_id, ctx in self._ctx_by_request_id.items()
}
self._ctx_by_request_id.clear()
return stats
...@@ -85,13 +85,13 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -85,13 +85,13 @@ class BaseRenderer(ABC, Generic[_T]):
self._mm_cache_stats: MultiModalCacheStats | None = None self._mm_cache_stats: MultiModalCacheStats | None = None
if config.model_config.is_multimodal_model: if config.model_config.is_multimodal_model:
from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry from vllm.multimodal import MULTIMODAL_REGISTRY as mm_registry
from vllm.multimodal.registry import MultiModalTimingRegistry
mm_processor_cache = mm_registry.processor_cache_from_config(config) mm_processor_cache = mm_registry.processor_cache_from_config(config)
with set_default_torch_num_threads(): with set_default_torch_num_threads():
self.mm_processor = mm_registry.create_processor( self.mm_processor = mm_registry.create_processor(
config.model_config, config.model_config,
config.observability_config,
tokenizer=tokenizer, tokenizer=tokenizer,
cache=mm_processor_cache, cache=mm_processor_cache,
) )
...@@ -102,6 +102,9 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -102,6 +102,9 @@ class BaseRenderer(ABC, Generic[_T]):
# This is used to generate internal request ID for MM processing # This is used to generate internal request ID for MM processing
# It has no relation to the request ID for engine core # It has no relation to the request ID for engine core
self._mm_req_counter = AtomicCounter() self._mm_req_counter = AtomicCounter()
self._mm_timing_registry = MultiModalTimingRegistry(
config.observability_config
)
def get_tokenizer(self) -> _T: def get_tokenizer(self) -> _T:
tokenizer = self.tokenizer tokenizer = self.tokenizer
...@@ -534,7 +537,7 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -534,7 +537,7 @@ class BaseRenderer(ABC, Generic[_T]):
tokenization_kwargs: dict[str, Any] | None, tokenization_kwargs: dict[str, Any] | None,
) -> "MultiModalInputs": ) -> "MultiModalInputs":
from vllm.multimodal.parse import parse_mm_uuids from vllm.multimodal.parse import parse_mm_uuids
from vllm.multimodal.processing.context import set_request_id from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}" mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}"
...@@ -543,18 +546,21 @@ class BaseRenderer(ABC, Generic[_T]): ...@@ -543,18 +546,21 @@ class BaseRenderer(ABC, Generic[_T]):
mm_data_items = mm_processor.info.parse_mm_data(mm_data) mm_data_items = mm_processor.info.parse_mm_data(mm_data)
mm_uuid_items = parse_mm_uuids(mm_uuids) mm_uuid_items = parse_mm_uuids(mm_uuids)
mm_uuids = self._process_mm_uuids( mm_uuid_items = self._process_mm_uuids(
mm_data, mm_data_items, mm_uuid_items, mm_req_id mm_data, mm_data_items, mm_uuid_items, mm_req_id
) )
with set_request_id(mm_req_id), set_default_torch_num_threads(): mm_processor_inputs = MMProcessorInputs(
mm_inputs = mm_processor.apply( prompt,
prompt, mm_data_items,
mm_data_items, mm_uuid_items,
mm_uuid_items, hf_processor_mm_kwargs=mm_processor_kwargs or {},
hf_processor_mm_kwargs=mm_processor_kwargs, tokenization_kwargs=tokenization_kwargs or {},
tokenization_kwargs=tokenization_kwargs, )
) mm_timing_ctx = self._mm_timing_registry.get(mm_req_id)
with set_default_torch_num_threads():
mm_inputs = mm_processor.apply(mm_processor_inputs, mm_timing_ctx)
self.update_mm_cache_stats() self.update_mm_cache_stats()
......
...@@ -6272,7 +6272,7 @@ class GPUModelRunner( ...@@ -6272,7 +6272,7 @@ class GPUModelRunner(
self.encoder_timing_registry[req_id] = EncoderTimingStats() self.encoder_timing_registry[req_id] = EncoderTimingStats()
stats = self.encoder_timing_registry[req_id] stats = self.encoder_timing_registry[req_id]
stats.encoder_forward_time += per_request_time stats.encoder_forward_secs += per_request_time
stats.num_encoder_calls += 1 stats.num_encoder_calls += 1
...@@ -6280,7 +6280,7 @@ class GPUModelRunner( ...@@ -6280,7 +6280,7 @@ class GPUModelRunner(
class EncoderTimingStats: class EncoderTimingStats:
"""Per-request timing statistics for encoder forward pass.""" """Per-request timing statistics for encoder forward pass."""
encoder_forward_time: float = 0.0 encoder_forward_secs: float = 0.0
"""Time spent in vision encoder forward pass (seconds).""" """Time spent in vision encoder forward pass (seconds)."""
num_encoder_calls: int = 0 num_encoder_calls: int = 0
...@@ -6288,6 +6288,6 @@ class EncoderTimingStats: ...@@ -6288,6 +6288,6 @@ class EncoderTimingStats:
def to_dict(self) -> dict[str, float | int]: def to_dict(self) -> dict[str, float | int]:
return { return {
"encoder_forward_time": self.encoder_forward_time, "encoder_forward_secs": self.encoder_forward_secs,
"num_encoder_calls": self.num_encoder_calls, "num_encoder_calls": self.num_encoder_calls,
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment