Unverified Commit 9ea07b41 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[1/N] Reorganize multimodal processing code (#32327)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 552b2629
...@@ -35,13 +35,13 @@ from vllm.multimodal.inputs import ( ...@@ -35,13 +35,13 @@ from vllm.multimodal.inputs import (
) )
from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
PromptUpdateDetails, PromptUpdateDetails,
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.configs import Step3VisionEncoderConfig from vllm.transformers_utils.configs import Step3VisionEncoderConfig
......
...@@ -34,13 +34,13 @@ from vllm.multimodal.parse import ( ...@@ -34,13 +34,13 @@ from vllm.multimodal.parse import (
MultiModalDataItems, MultiModalDataItems,
) )
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
InputProcessingContext, InputProcessingContext,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
......
...@@ -56,11 +56,11 @@ from vllm.multimodal.parse import ( ...@@ -56,11 +56,11 @@ from vllm.multimodal.parse import (
MultiModalDataParser, MultiModalDataParser,
) )
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
PromptUpdate, PromptUpdate,
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal
......
...@@ -35,8 +35,11 @@ from vllm.multimodal.inputs import ( ...@@ -35,8 +35,11 @@ from vllm.multimodal.inputs import (
PlaceholderRange, PlaceholderRange,
) )
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import BaseMultiModalProcessor, BaseProcessingInfo from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
......
...@@ -36,12 +36,12 @@ from vllm.multimodal.inputs import ( ...@@ -36,12 +36,12 @@ from vllm.multimodal.inputs import (
) )
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
......
...@@ -47,14 +47,14 @@ from vllm.multimodal.parse import ( ...@@ -47,14 +47,14 @@ from vllm.multimodal.parse import (
MultiModalDataItems, MultiModalDataItems,
MultiModalDataParser, MultiModalDataParser,
) )
from vllm.multimodal.processing import ( from vllm.multimodal.processing import BaseDummyInputsBuilder, ProcessorInputs
from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
MultiModalProcessingInfo, MultiModalProcessingInfo,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
......
...@@ -30,11 +30,11 @@ from vllm.multimodal.inputs import ( ...@@ -30,11 +30,11 @@ from vllm.multimodal.inputs import (
MultiModalKwargsOptionalItems, MultiModalKwargsOptionalItems,
) )
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import ( from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
MultiModalPromptUpdates, MultiModalPromptUpdates,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers import cached_tokenizer_from_config
......
...@@ -49,12 +49,12 @@ from vllm.multimodal.inputs import ( ...@@ -49,12 +49,12 @@ from vllm.multimodal.inputs import (
) )
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseProcessingInfo, BaseProcessingInfo,
EncDecMultiModalProcessor, EncDecMultiModalProcessor,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.jsontree import json_map_leaves from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
......
...@@ -34,7 +34,7 @@ from .inputs import ( ...@@ -34,7 +34,7 @@ from .inputs import (
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from .processing import ResolvedPromptUpdate from .processing.processor import ResolvedPromptUpdate
from .registry import MultiModalRegistry from .registry import MultiModalRegistry
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -33,8 +33,6 @@ if TYPE_CHECKING: ...@@ -33,8 +33,6 @@ if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature from transformers.feature_extraction_utils import BatchFeature
from .base import MediaWithBytes from .base import MediaWithBytes
from .processing import MultiModalHashes
else: else:
torch = LazyLoader("torch", globals(), "torch") torch = LazyLoader("torch", globals(), "torch")
...@@ -979,9 +977,15 @@ MultiModalKwargsOptionalItems: TypeAlias = ( ...@@ -979,9 +977,15 @@ MultiModalKwargsOptionalItems: TypeAlias = (
) )
MultiModalHashes = dict[str, list[str]]
"""
A dictionary containing per-item hashes for each modality.
"""
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]] MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
""" """
A dictionary containing placeholder ranges for each modality. A dictionary containing per-item placeholder ranges for each modality.
""" """
...@@ -1001,10 +1005,10 @@ class MultiModalInputs(TypedDict): ...@@ -1001,10 +1005,10 @@ class MultiModalInputs(TypedDict):
mm_kwargs: MultiModalKwargsOptionalItems mm_kwargs: MultiModalKwargsOptionalItems
"""Keyword arguments to be directly passed to the model after batching.""" """Keyword arguments to be directly passed to the model after batching."""
mm_hashes: "MultiModalHashes" mm_hashes: MultiModalHashes
"""The hashes of the multi-modal data.""" """The hashes of the multi-modal data."""
mm_placeholders: "MultiModalPlaceholderDict" mm_placeholders: MultiModalPlaceholderDict
""" """
For each modality, information about the placeholder tokens in For each modality, information about the placeholder tokens in
`prompt_token_ids`. `prompt_token_ids`.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .context import BaseProcessingInfo, InputProcessingContext
from .dummy_inputs import BaseDummyInputsBuilder, ProcessorInputs
from .processor import (
BaseMultiModalProcessor,
EncDecMultiModalProcessor,
PromptIndexTargets,
PromptInsertion,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
__all__ = [
"BaseProcessingInfo",
"InputProcessingContext",
"BaseDummyInputsBuilder",
"ProcessorInputs",
"BaseMultiModalProcessor",
"EncDecMultiModalProcessor",
"PromptUpdate",
"PromptIndexTargets",
"PromptUpdateDetails",
"PromptInsertion",
"PromptReplacement",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextvars
import threading
import time
from abc import abstractmethod
from collections.abc import Generator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Any,
overload,
)
import torch
from typing_extensions import TypeVar
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
from vllm.utils.jsontree import JSONTree, json_map_leaves
if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from vllm.config import ModelConfig, ObservabilityConfig
else:
PretrainedConfig = object
BatchFeature = object
ProcessorMixin = object
ModelConfig = object
ObservabilityConfig = object
logger = init_logger(__name__)
_request_id_context: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"_request_id_context", default=None
)
def get_current_request_id() -> str | None:
"""Get the current request_id from the context, if available."""
return _request_id_context.get()
@contextmanager
def set_request_id(request_id: str) -> Generator[None, None, None]:
"""Context manager to set the request_id for the current context."""
token = _request_id_context.set(request_id)
try:
yield
finally:
_request_id_context.reset(token)
@dataclass
class MultiModalProcessorTimingStats:
"""Per-request timing statistics for multimodal processor stages."""
hf_processor_time: float = 0.0
"""Time spent in HuggingFace processor calls (seconds)."""
hashing_time: float = 0.0
"""Time spent computing multimodal item hashes (seconds)."""
cache_lookup_time: float = 0.0
"""Time spent in cache lookups and merges (seconds)."""
prompt_update_time: float = 0.0
"""Time spent applying prompt updates and finding placeholders (seconds)."""
total_time: float = 0.0
"""Total processing time (seconds)."""
def to_dict(self) -> dict[str, float]:
"""Convert stats to a dictionary for JSON serialization."""
return {
"hf_processor_time": self.hf_processor_time,
"hashing_time": self.hashing_time,
"cache_lookup_time": self.cache_lookup_time,
"prompt_update_time": self.prompt_update_time,
"total_time": self.total_time,
}
def get_timing_stats_from_engine_client(
engine_client: Any,
) -> dict[str, dict[str, float]]:
"""
Get all timing stats from the context associated with the engine client.
Args:
engine_client: The engine client that has input_processor.
Returns:
A dictionary mapping request_id to stats dict.
"""
try:
if not engine_client.vllm_config.observability_config.enable_mm_processor_stats:
return {}
except (AttributeError, RuntimeError):
return {}
try:
input_processor = engine_client.input_processor
input_preprocessor = input_processor.input_preprocessor
if hasattr(input_preprocessor, "_get_mm_processor"):
mm_processor = input_preprocessor._get_mm_processor()
if mm_processor is not None and hasattr(mm_processor, "info"):
ctx = mm_processor.info.ctx
return ctx.get_all_timing_stats()
except (AttributeError, RuntimeError):
pass
return {}
@contextmanager
def timed_operation(ctx: "InputProcessingContext", stage_name: str):
"""
Context manager to time an operation using the context's timing stats.
The request_id is automatically retrieved from the context variable,
so it doesn't need to be passed as a parameter.
Args:
ctx: The InputProcessingContext containing the timing stats registry.
stage_name: Name of the stage being timed.
"""
request_id = get_current_request_id()
if ctx is None or request_id is None:
yield
return
stats = ctx.get_timing_stats(request_id)
if stats is None:
yield
return
start_time = time.perf_counter()
try:
yield
finally:
elapsed = time.perf_counter() - start_time
if stage_name == "hf_processor":
stats.hf_processor_time += elapsed
elif stage_name == "hashing":
stats.hashing_time += elapsed
elif stage_name == "cache_lookup":
stats.cache_lookup_time += elapsed
elif stage_name == "prompt_update":
stats.prompt_update_time += elapsed
stats.total_time += elapsed
_T = TypeVar("_T")
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
@dataclass(frozen=True)
class InputProcessingContext:
"""
Contains information about the model which may be used to
modify the inputs.
"""
model_config: ModelConfig
"""The configuration of the model."""
tokenizer: TokenizerLike | None
"""The tokenizer used to tokenize the inputs."""
observability_config: "ObservabilityConfig | None" = field(
default=None, compare=False, repr=False
)
"""Configuration for observability features."""
timing_stats_registry: dict[str, MultiModalProcessorTimingStats] = field(
default_factory=dict, compare=False, repr=False
)
"""Registry for storing timing stats keyed by request_id."""
_timing_stats_registry_lock: threading.Lock = field(
default_factory=threading.Lock, compare=False, repr=False
)
"""Lock for thread-safe access to timing_stats_registry."""
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"You cannot pass text prompts when `skip_tokenizer_init=True`"
)
return self.tokenizer
@overload
def get_hf_config(self, /) -> PretrainedConfig: ...
@overload
def get_hf_config(
self,
typ: type[_C] | tuple[type[_C], ...],
/,
) -> _C: ...
def get_hf_config(
self,
typ: type[Any] | tuple[type[Any], ...] | None = None,
/,
) -> Any:
"""
Get the HuggingFace configuration
(`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
TypeError: If the configuration is not of the specified type.
"""
if typ is None:
from transformers.configuration_utils import PretrainedConfig
typ = PretrainedConfig
hf_config = self.model_config.hf_config
if not isinstance(hf_config, typ):
raise TypeError(
"Invalid type of HuggingFace config. "
f"Expected type: {typ}, but "
f"found type: {type(hf_config)}"
)
return hf_config
def get_hf_image_processor_config(self) -> dict[str, Any]:
"""
Get the HuggingFace image processor configuration of the model.
"""
return self.model_config.hf_image_processor_config
def get_mm_config(self):
"""
Get the multimodal config of the model.
Raises:
RuntimeError: If the model is not a multimodal model.
"""
mm_config = self.model_config.multimodal_config
if mm_config is None:
raise RuntimeError("Not a multimodal model")
return mm_config
@overload
def get_hf_processor(self, /, **kwargs: object) -> ProcessorMixin: ...
@overload
def get_hf_processor(
self,
typ: type[_P] | tuple[type[_P], ...],
/,
**kwargs: object,
) -> _P: ...
def get_hf_processor(
self,
typ: type[Any] | tuple[type[Any], ...] | None = None,
/,
**kwargs: object,
) -> Any:
"""
Get the HuggingFace processor
(`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
TypeError: If the processor is not of the specified type.
"""
if typ is None:
from transformers.processing_utils import ProcessorMixin
typ = ProcessorMixin
from vllm.tokenizers.mistral import MistralTokenizer
tokenizer = self.tokenizer
if isinstance(tokenizer, MistralTokenizer):
tokenizer = tokenizer.transformers_tokenizer
return cached_processor_from_config(
self.model_config,
processor_cls=typ,
tokenizer=tokenizer,
**kwargs,
)
def init_processor(
self,
typ: type[_T],
/,
**kwargs: object,
) -> _T:
"""
Initialize a HuggingFace-like processor class, merging the
keyword arguments with those in the model's configuration.
"""
mm_config = self.model_config.get_multimodal_config()
base_kwargs = mm_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return typ(**merged_kwargs)
def _postprocess_output(
self,
output: JSONTree,
) -> JSONTree:
def _postprocess_one(x: object):
if isinstance(x, torch.Tensor): # noqa: SIM102
# This mimics the behavior of transformers.BatchFeature
if x.is_floating_point():
x = x.to(dtype=self.model_config.dtype)
return x
return json_map_leaves(_postprocess_one, output)
def call_hf_processor(
self,
hf_processor: ProcessorMixin,
data: Mapping[str, object],
kwargs: Mapping[str, object] = {},
*,
num_tries: int = 1,
max_tries: int = 5,
) -> BatchFeature | JSONTree:
"""
Call `hf_processor` on the prompt `data`
(text, image, audio...) with configurable options `kwargs`.
"""
assert callable(hf_processor)
mm_config = self.model_config.get_multimodal_config()
merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
allowed_kwargs = get_allowed_kwarg_only_overrides(
hf_processor,
merged_kwargs,
requires_kw_only=False,
allow_var_kwargs=True,
)
try:
output = hf_processor(**data, **allowed_kwargs, return_tensors="pt")
except Exception as exc:
# See https://github.com/huggingface/tokenizers/issues/537
if (
isinstance(exc, RuntimeError)
and exc
and exc.args[0] == "Already borrowed"
and num_tries < max_tries
):
logger.warning(
"Failed to acquire tokenizer in current thread. "
"Retrying (%d/%d)...",
num_tries,
max_tries,
)
time.sleep(0.5)
return self.call_hf_processor(
hf_processor,
data,
kwargs,
num_tries=num_tries + 1,
max_tries=max_tries,
)
msg = (
f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={allowed_kwargs}"
)
raise ValueError(msg) from exc
# this emulates output.to(dtype=self.model_config.dtype)
from transformers.feature_extraction_utils import BatchFeature
if isinstance(output, BatchFeature):
output_ = self._postprocess_output(output.data)
return BatchFeature(output_)
logger.warning_once(
"%s did not return `BatchFeature`. "
"Make sure to match the behaviour of `ProcessorMixin` when "
"implementing custom processors.",
type(hf_processor).__name__,
)
return self._postprocess_output(output)
def get_timing_stats(
self, request_id: str
) -> MultiModalProcessorTimingStats | None:
"""
Get timing stats for a request.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return None
with self._timing_stats_registry_lock:
return self.timing_stats_registry.get(request_id)
def create_timing_stats(self, request_id: str) -> MultiModalProcessorTimingStats:
"""
Create and store timing stats in the registry for a request.
This should be called at the start of processing for a request.
The stats object is created immediately and stored in the registry.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return MultiModalProcessorTimingStats()
with self._timing_stats_registry_lock:
if request_id in self.timing_stats_registry:
raise ValueError(
f"Timing stats already exist for request_id: {request_id}"
)
stats = MultiModalProcessorTimingStats()
self.timing_stats_registry[request_id] = stats
return stats
def clear_timing_stats_registry(self) -> int:
"""
Clear all stats from the registry. Returns the number of stats cleared.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return 0
with self._timing_stats_registry_lock:
count = len(self.timing_stats_registry)
self.timing_stats_registry.clear()
return count
def get_all_timing_stats(self) -> dict[str, dict[str, float]]:
"""
Get all timing stats as a dictionary for API endpoints.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return {}
with self._timing_stats_registry_lock:
return {
rid: stats.to_dict()
for rid, stats in self.timing_stats_registry.items()
}
class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing."""
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__()
self.ctx = ctx
@property
def model_id(self) -> str:
return self.ctx.model_config.model
def get_tokenizer(self) -> TokenizerLike:
return self.ctx.get_tokenizer()
def get_hf_config(self) -> PretrainedConfig:
return self.ctx.get_hf_config()
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
"""
Subclasses can override this method to handle
specific kwargs from model config or user inputs.
"""
return self.ctx.get_hf_processor(**kwargs)
@property
def skip_prompt_length_check(self) -> bool:
return False
@abstractmethod
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise NotImplementedError
def get_allowed_mm_limits(self) -> Mapping[str, int]:
"""Return the maximum allowed number of items for each modality."""
supported_mm_limits = self.get_supported_mm_limits()
mm_config = self.ctx.get_mm_config()
allowed_limits = dict[str, int]()
for modality, supported_limit in supported_mm_limits.items():
user_limit = mm_config.get_limit_per_prompt(modality)
allowed_limits[modality] = (
user_limit
if supported_limit is None
else min(user_limit, supported_limit)
)
return allowed_limits
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int] | None:
"""
Return the maximum number of tokens per item of for each modality.
When `None` (the default) is returned, vLLM will generate dummy inputs
(images/videos) at maximum possible sizes and process them to determine
the maximum token count per modality.
This approach works but can be very slow for certain models (e.g.,
Qwen2.5-VL), leading to very long startup time. For better performance,
each model can override this method to return pre-computed maximum token
counts, avoiding the need for dummy input generation and processing.
Note:
The maximum number of tokens per item of each modality returned
from this function should respect the model's maximum sequence
length and the maximum number of items of each modality allowed,
and agree with dummy inputs (images/videos) at maximum possible
sizes.
"""
return None
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic from typing import Generic, TypeVar
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
...@@ -17,14 +17,10 @@ from vllm.config.multimodal import ( ...@@ -17,14 +17,10 @@ from vllm.config.multimodal import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from .inputs import MultiModalDataDict from ..inputs import MultiModalDataDict
from .context import BaseProcessingInfo
if TYPE_CHECKING: _I = TypeVar("_I", bound=BaseProcessingInfo)
from .processing import _I
else:
from typing import TypeVar
_I = TypeVar("_I")
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -12,11 +12,11 @@ from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config ...@@ -12,11 +12,11 @@ from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from .cache import BaseMultiModalProcessorCache from .cache import BaseMultiModalProcessorCache
from .inputs import MultiModalInputs from .inputs import MultiModalInputs
from .processing import ( from .processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor, BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
InputProcessingContext, InputProcessingContext,
) )
from .profiling import BaseDummyInputsBuilder
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, ObservabilityConfig from vllm.config import ModelConfig, ObservabilityConfig
...@@ -45,7 +45,7 @@ class ProcessingInfoFactory(Protocol[_I_co]): ...@@ -45,7 +45,7 @@ class ProcessingInfoFactory(Protocol[_I_co]):
class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc] class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc]
""" """
Constructs a Constructs a
[`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder] [`BaseDummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
instance from the context. instance from the context.
""" """
......
...@@ -17,7 +17,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry ...@@ -17,7 +17,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import set_request_id from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.utils import argsort_mm_positions from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment