Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9ea07b41
Unverified
Commit
9ea07b41
authored
Jan 14, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 14, 2026
Browse files
[1/N] Reorganize multimodal processing code (#32327)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
552b2629
Changes
76
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
632 additions
and
584 deletions
+632
-584
vllm/model_executor/models/step3_vl.py
vllm/model_executor/models/step3_vl.py
+1
-1
vllm/model_executor/models/tarsier.py
vllm/model_executor/models/tarsier.py
+1
-1
vllm/model_executor/models/terratorch.py
vllm/model_executor/models/terratorch.py
+1
-1
vllm/model_executor/models/transformers/multimodal.py
vllm/model_executor/models/transformers/multimodal.py
+5
-2
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+1
-1
vllm/model_executor/models/voxtral.py
vllm/model_executor/models/voxtral.py
+2
-2
vllm/model_executor/models/voxtral_streaming.py
vllm/model_executor/models/voxtral_streaming.py
+2
-2
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+1
-1
vllm/multimodal/cache.py
vllm/multimodal/cache.py
+1
-1
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+9
-5
vllm/multimodal/processing/__init__.py
vllm/multimodal/processing/__init__.py
+27
-0
vllm/multimodal/processing/context.py
vllm/multimodal/processing/context.py
+558
-0
vllm/multimodal/processing/dummy_inputs.py
vllm/multimodal/processing/dummy_inputs.py
+4
-8
vllm/multimodal/processing/processor.py
vllm/multimodal/processing/processor.py
+16
-556
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+2
-2
vllm/v1/engine/input_processor.py
vllm/v1/engine/input_processor.py
+1
-1
No files found.
vllm/model_executor/models/step3_vl.py
View file @
9ea07b41
...
@@ -35,13 +35,13 @@ from vllm.multimodal.inputs import (
...
@@ -35,13 +35,13 @@ from vllm.multimodal.inputs import (
)
)
from
vllm.multimodal.parse
import
ImageSize
,
MultiModalDataItems
from
vllm.multimodal.parse
import
ImageSize
,
MultiModalDataItems
from
vllm.multimodal.processing
import
(
from
vllm.multimodal.processing
import
(
BaseDummyInputsBuilder
,
BaseMultiModalProcessor
,
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BaseProcessingInfo
,
PromptReplacement
,
PromptReplacement
,
PromptUpdate
,
PromptUpdate
,
PromptUpdateDetails
,
PromptUpdateDetails
,
)
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.transformers_utils.configs
import
Step3VisionEncoderConfig
from
vllm.transformers_utils.configs
import
Step3VisionEncoderConfig
...
...
vllm/model_executor/models/tarsier.py
View file @
9ea07b41
...
@@ -34,13 +34,13 @@ from vllm.multimodal.parse import (
...
@@ -34,13 +34,13 @@ from vllm.multimodal.parse import (
MultiModalDataItems
,
MultiModalDataItems
,
)
)
from
vllm.multimodal.processing
import
(
from
vllm.multimodal.processing
import
(
BaseDummyInputsBuilder
,
BaseMultiModalProcessor
,
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BaseProcessingInfo
,
InputProcessingContext
,
InputProcessingContext
,
PromptReplacement
,
PromptReplacement
,
PromptUpdate
,
PromptUpdate
,
)
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
...
...
vllm/model_executor/models/terratorch.py
View file @
9ea07b41
...
@@ -56,11 +56,11 @@ from vllm.multimodal.parse import (
...
@@ -56,11 +56,11 @@ from vllm.multimodal.parse import (
MultiModalDataParser
,
MultiModalDataParser
,
)
)
from
vllm.multimodal.processing
import
(
from
vllm.multimodal.processing
import
(
BaseDummyInputsBuilder
,
BaseMultiModalProcessor
,
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BaseProcessingInfo
,
PromptUpdate
,
PromptUpdate
,
)
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
IsAttentionFree
,
MultiModalEmbeddings
,
SupportsMultiModal
from
.interfaces
import
IsAttentionFree
,
MultiModalEmbeddings
,
SupportsMultiModal
...
...
vllm/model_executor/models/transformers/multimodal.py
View file @
9ea07b41
...
@@ -35,8 +35,11 @@ from vllm.multimodal.inputs import (
...
@@ -35,8 +35,11 @@ from vllm.multimodal.inputs import (
PlaceholderRange
,
PlaceholderRange
,
)
)
from
vllm.multimodal.parse
import
ImageProcessorItems
,
MultiModalDataItems
from
vllm.multimodal.parse
import
ImageProcessorItems
,
MultiModalDataItems
from
vllm.multimodal.processing
import
BaseMultiModalProcessor
,
BaseProcessingInfo
from
vllm.multimodal.processing
import
(
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
BaseDummyInputsBuilder
,
BaseMultiModalProcessor
,
BaseProcessingInfo
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
...
vllm/model_executor/models/ultravox.py
View file @
9ea07b41
...
@@ -36,12 +36,12 @@ from vllm.multimodal.inputs import (
...
@@ -36,12 +36,12 @@ from vllm.multimodal.inputs import (
)
)
from
vllm.multimodal.parse
import
MultiModalDataItems
,
MultiModalDataParser
from
vllm.multimodal.parse
import
MultiModalDataItems
,
MultiModalDataParser
from
vllm.multimodal.processing
import
(
from
vllm.multimodal.processing
import
(
BaseDummyInputsBuilder
,
BaseMultiModalProcessor
,
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BaseProcessingInfo
,
PromptReplacement
,
PromptReplacement
,
PromptUpdate
,
PromptUpdate
,
)
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
...
...
vllm/model_executor/models/voxtral.py
View file @
9ea07b41
...
@@ -47,14 +47,14 @@ from vllm.multimodal.parse import (
...
@@ -47,14 +47,14 @@ from vllm.multimodal.parse import (
MultiModalDataItems
,
MultiModalDataItems
,
MultiModalDataParser
,
MultiModalDataParser
,
)
)
from
vllm.multimodal.processing
import
(
from
vllm.multimodal.processing
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.processing.processor
import
(
BaseMultiModalProcessor
,
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BaseProcessingInfo
,
MultiModalProcessingInfo
,
MultiModalProcessingInfo
,
PromptReplacement
,
PromptReplacement
,
PromptUpdate
,
PromptUpdate
,
)
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.tokenizers
import
cached_tokenizer_from_config
from
vllm.tokenizers
import
cached_tokenizer_from_config
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.tokenizers.mistral
import
MistralTokenizer
...
...
vllm/model_executor/models/voxtral_streaming.py
View file @
9ea07b41
...
@@ -30,11 +30,11 @@ from vllm.multimodal.inputs import (
...
@@ -30,11 +30,11 @@ from vllm.multimodal.inputs import (
MultiModalKwargsOptionalItems
,
MultiModalKwargsOptionalItems
,
)
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
from
vllm.multimodal.processing
import
BaseDummyInputsBuilder
from
vllm.multimodal.processing.processor
import
(
MultiModalPromptUpdates
,
MultiModalPromptUpdates
,
PlaceholderFeaturesInfo
,
PlaceholderFeaturesInfo
,
)
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.tokenizers
import
cached_tokenizer_from_config
from
vllm.tokenizers
import
cached_tokenizer_from_config
...
...
vllm/model_executor/models/whisper.py
View file @
9ea07b41
...
@@ -49,12 +49,12 @@ from vllm.multimodal.inputs import (
...
@@ -49,12 +49,12 @@ from vllm.multimodal.inputs import (
)
)
from
vllm.multimodal.parse
import
MultiModalDataItems
,
MultiModalDataParser
from
vllm.multimodal.parse
import
MultiModalDataItems
,
MultiModalDataParser
from
vllm.multimodal.processing
import
(
from
vllm.multimodal.processing
import
(
BaseDummyInputsBuilder
,
BaseProcessingInfo
,
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
EncDecMultiModalProcessor
,
PromptReplacement
,
PromptReplacement
,
PromptUpdate
,
PromptUpdate
,
)
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.transformers_utils.processor
import
cached_processor_from_config
from
vllm.transformers_utils.processor
import
cached_processor_from_config
from
vllm.utils.jsontree
import
json_map_leaves
from
vllm.utils.jsontree
import
json_map_leaves
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
...
...
vllm/multimodal/cache.py
View file @
9ea07b41
...
@@ -34,7 +34,7 @@ from .inputs import (
...
@@ -34,7 +34,7 @@ from .inputs import (
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
.processing
import
ResolvedPromptUpdate
from
.processing
.processor
import
ResolvedPromptUpdate
from
.registry
import
MultiModalRegistry
from
.registry
import
MultiModalRegistry
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/multimodal/inputs.py
View file @
9ea07b41
...
@@ -33,8 +33,6 @@ if TYPE_CHECKING:
...
@@ -33,8 +33,6 @@ if TYPE_CHECKING:
from
transformers.feature_extraction_utils
import
BatchFeature
from
transformers.feature_extraction_utils
import
BatchFeature
from
.base
import
MediaWithBytes
from
.base
import
MediaWithBytes
from
.processing
import
MultiModalHashes
else
:
else
:
torch
=
LazyLoader
(
"torch"
,
globals
(),
"torch"
)
torch
=
LazyLoader
(
"torch"
,
globals
(),
"torch"
)
...
@@ -979,9 +977,15 @@ MultiModalKwargsOptionalItems: TypeAlias = (
...
@@ -979,9 +977,15 @@ MultiModalKwargsOptionalItems: TypeAlias = (
)
)
MultiModalHashes
=
dict
[
str
,
list
[
str
]]
"""
A dictionary containing per-item hashes for each modality.
"""
MultiModalPlaceholderDict
:
TypeAlias
=
Mapping
[
str
,
Sequence
[
PlaceholderRange
]]
MultiModalPlaceholderDict
:
TypeAlias
=
Mapping
[
str
,
Sequence
[
PlaceholderRange
]]
"""
"""
A dictionary containing placeholder ranges for each modality.
A dictionary containing
per-item
placeholder ranges for each modality.
"""
"""
...
@@ -1001,10 +1005,10 @@ class MultiModalInputs(TypedDict):
...
@@ -1001,10 +1005,10 @@ class MultiModalInputs(TypedDict):
mm_kwargs
:
MultiModalKwargsOptionalItems
mm_kwargs
:
MultiModalKwargsOptionalItems
"""Keyword arguments to be directly passed to the model after batching."""
"""Keyword arguments to be directly passed to the model after batching."""
mm_hashes
:
"
MultiModalHashes
"
mm_hashes
:
MultiModalHashes
"""The hashes of the multi-modal data."""
"""The hashes of the multi-modal data."""
mm_placeholders
:
"
MultiModalPlaceholderDict
"
mm_placeholders
:
MultiModalPlaceholderDict
"""
"""
For each modality, information about the placeholder tokens in
For each modality, information about the placeholder tokens in
`prompt_token_ids`.
`prompt_token_ids`.
...
...
vllm/multimodal/processing/__init__.py
0 → 100644
View file @
9ea07b41
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
.context
import
BaseProcessingInfo
,
InputProcessingContext
from
.dummy_inputs
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
.processor
import
(
BaseMultiModalProcessor
,
EncDecMultiModalProcessor
,
PromptIndexTargets
,
PromptInsertion
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
,
)
__all__
=
[
"BaseProcessingInfo"
,
"InputProcessingContext"
,
"BaseDummyInputsBuilder"
,
"ProcessorInputs"
,
"BaseMultiModalProcessor"
,
"EncDecMultiModalProcessor"
,
"PromptUpdate"
,
"PromptIndexTargets"
,
"PromptUpdateDetails"
,
"PromptInsertion"
,
"PromptReplacement"
,
]
vllm/multimodal/processing/context.py
0 → 100644
View file @
9ea07b41
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextvars
import
threading
import
time
from
abc
import
abstractmethod
from
collections.abc
import
Generator
,
Mapping
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
from
typing
import
(
TYPE_CHECKING
,
Any
,
overload
,
)
import
torch
from
typing_extensions
import
TypeVar
from
vllm.logger
import
init_logger
from
vllm.tokenizers
import
TokenizerLike
from
vllm.transformers_utils.processor
import
cached_processor_from_config
from
vllm.utils.func_utils
import
get_allowed_kwarg_only_overrides
from
vllm.utils.jsontree
import
JSONTree
,
json_map_leaves
if
TYPE_CHECKING
:
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.feature_extraction_utils
import
BatchFeature
from
transformers.processing_utils
import
ProcessorMixin
from
vllm.config
import
ModelConfig
,
ObservabilityConfig
else
:
PretrainedConfig
=
object
BatchFeature
=
object
ProcessorMixin
=
object
ModelConfig
=
object
ObservabilityConfig
=
object
logger
=
init_logger
(
__name__
)
_request_id_context
:
contextvars
.
ContextVar
[
str
|
None
]
=
contextvars
.
ContextVar
(
"_request_id_context"
,
default
=
None
)
def
get_current_request_id
()
->
str
|
None
:
"""Get the current request_id from the context, if available."""
return
_request_id_context
.
get
()
@
contextmanager
def
set_request_id
(
request_id
:
str
)
->
Generator
[
None
,
None
,
None
]:
"""Context manager to set the request_id for the current context."""
token
=
_request_id_context
.
set
(
request_id
)
try
:
yield
finally
:
_request_id_context
.
reset
(
token
)
@
dataclass
class
MultiModalProcessorTimingStats
:
"""Per-request timing statistics for multimodal processor stages."""
hf_processor_time
:
float
=
0.0
"""Time spent in HuggingFace processor calls (seconds)."""
hashing_time
:
float
=
0.0
"""Time spent computing multimodal item hashes (seconds)."""
cache_lookup_time
:
float
=
0.0
"""Time spent in cache lookups and merges (seconds)."""
prompt_update_time
:
float
=
0.0
"""Time spent applying prompt updates and finding placeholders (seconds)."""
total_time
:
float
=
0.0
"""Total processing time (seconds)."""
def
to_dict
(
self
)
->
dict
[
str
,
float
]:
"""Convert stats to a dictionary for JSON serialization."""
return
{
"hf_processor_time"
:
self
.
hf_processor_time
,
"hashing_time"
:
self
.
hashing_time
,
"cache_lookup_time"
:
self
.
cache_lookup_time
,
"prompt_update_time"
:
self
.
prompt_update_time
,
"total_time"
:
self
.
total_time
,
}
def
get_timing_stats_from_engine_client
(
engine_client
:
Any
,
)
->
dict
[
str
,
dict
[
str
,
float
]]:
"""
Get all timing stats from the context associated with the engine client.
Args:
engine_client: The engine client that has input_processor.
Returns:
A dictionary mapping request_id to stats dict.
"""
try
:
if
not
engine_client
.
vllm_config
.
observability_config
.
enable_mm_processor_stats
:
return
{}
except
(
AttributeError
,
RuntimeError
):
return
{}
try
:
input_processor
=
engine_client
.
input_processor
input_preprocessor
=
input_processor
.
input_preprocessor
if
hasattr
(
input_preprocessor
,
"_get_mm_processor"
):
mm_processor
=
input_preprocessor
.
_get_mm_processor
()
if
mm_processor
is
not
None
and
hasattr
(
mm_processor
,
"info"
):
ctx
=
mm_processor
.
info
.
ctx
return
ctx
.
get_all_timing_stats
()
except
(
AttributeError
,
RuntimeError
):
pass
return
{}
@
contextmanager
def
timed_operation
(
ctx
:
"InputProcessingContext"
,
stage_name
:
str
):
"""
Context manager to time an operation using the context's timing stats.
The request_id is automatically retrieved from the context variable,
so it doesn't need to be passed as a parameter.
Args:
ctx: The InputProcessingContext containing the timing stats registry.
stage_name: Name of the stage being timed.
"""
request_id
=
get_current_request_id
()
if
ctx
is
None
or
request_id
is
None
:
yield
return
stats
=
ctx
.
get_timing_stats
(
request_id
)
if
stats
is
None
:
yield
return
start_time
=
time
.
perf_counter
()
try
:
yield
finally
:
elapsed
=
time
.
perf_counter
()
-
start_time
if
stage_name
==
"hf_processor"
:
stats
.
hf_processor_time
+=
elapsed
elif
stage_name
==
"hashing"
:
stats
.
hashing_time
+=
elapsed
elif
stage_name
==
"cache_lookup"
:
stats
.
cache_lookup_time
+=
elapsed
elif
stage_name
==
"prompt_update"
:
stats
.
prompt_update_time
+=
elapsed
stats
.
total_time
+=
elapsed
_T
=
TypeVar
(
"_T"
)
_C
=
TypeVar
(
"_C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
_P
=
TypeVar
(
"_P"
,
bound
=
ProcessorMixin
,
default
=
ProcessorMixin
)
@
dataclass
(
frozen
=
True
)
class
InputProcessingContext
:
"""
Contains information about the model which may be used to
modify the inputs.
"""
model_config
:
ModelConfig
"""The configuration of the model."""
tokenizer
:
TokenizerLike
|
None
"""The tokenizer used to tokenize the inputs."""
observability_config
:
"ObservabilityConfig | None"
=
field
(
default
=
None
,
compare
=
False
,
repr
=
False
)
"""Configuration for observability features."""
timing_stats_registry
:
dict
[
str
,
MultiModalProcessorTimingStats
]
=
field
(
default_factory
=
dict
,
compare
=
False
,
repr
=
False
)
"""Registry for storing timing stats keyed by request_id."""
_timing_stats_registry_lock
:
threading
.
Lock
=
field
(
default_factory
=
threading
.
Lock
,
compare
=
False
,
repr
=
False
)
"""Lock for thread-safe access to timing_stats_registry."""
def
get_tokenizer
(
self
)
->
TokenizerLike
:
if
self
.
tokenizer
is
None
:
raise
ValueError
(
"You cannot pass text prompts when `skip_tokenizer_init=True`"
)
return
self
.
tokenizer
@
overload
def
get_hf_config
(
self
,
/
)
->
PretrainedConfig
:
...
@
overload
def
get_hf_config
(
self
,
typ
:
type
[
_C
]
|
tuple
[
type
[
_C
],
...],
/
,
)
->
_C
:
...
def
get_hf_config
(
self
,
typ
:
type
[
Any
]
|
tuple
[
type
[
Any
],
...]
|
None
=
None
,
/
,
)
->
Any
:
"""
Get the HuggingFace configuration
(`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
TypeError: If the configuration is not of the specified type.
"""
if
typ
is
None
:
from
transformers.configuration_utils
import
PretrainedConfig
typ
=
PretrainedConfig
hf_config
=
self
.
model_config
.
hf_config
if
not
isinstance
(
hf_config
,
typ
):
raise
TypeError
(
"Invalid type of HuggingFace config. "
f
"Expected type:
{
typ
}
, but "
f
"found type:
{
type
(
hf_config
)
}
"
)
return
hf_config
def
get_hf_image_processor_config
(
self
)
->
dict
[
str
,
Any
]:
"""
Get the HuggingFace image processor configuration of the model.
"""
return
self
.
model_config
.
hf_image_processor_config
def
get_mm_config
(
self
):
"""
Get the multimodal config of the model.
Raises:
RuntimeError: If the model is not a multimodal model.
"""
mm_config
=
self
.
model_config
.
multimodal_config
if
mm_config
is
None
:
raise
RuntimeError
(
"Not a multimodal model"
)
return
mm_config
@
overload
def
get_hf_processor
(
self
,
/
,
**
kwargs
:
object
)
->
ProcessorMixin
:
...
@
overload
def
get_hf_processor
(
self
,
typ
:
type
[
_P
]
|
tuple
[
type
[
_P
],
...],
/
,
**
kwargs
:
object
,
)
->
_P
:
...
def
get_hf_processor
(
self
,
typ
:
type
[
Any
]
|
tuple
[
type
[
Any
],
...]
|
None
=
None
,
/
,
**
kwargs
:
object
,
)
->
Any
:
"""
Get the HuggingFace processor
(`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
TypeError: If the processor is not of the specified type.
"""
if
typ
is
None
:
from
transformers.processing_utils
import
ProcessorMixin
typ
=
ProcessorMixin
from
vllm.tokenizers.mistral
import
MistralTokenizer
tokenizer
=
self
.
tokenizer
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tokenizer
=
tokenizer
.
transformers_tokenizer
return
cached_processor_from_config
(
self
.
model_config
,
processor_cls
=
typ
,
tokenizer
=
tokenizer
,
**
kwargs
,
)
def
init_processor
(
self
,
typ
:
type
[
_T
],
/
,
**
kwargs
:
object
,
)
->
_T
:
"""
Initialize a HuggingFace-like processor class, merging the
keyword arguments with those in the model's configuration.
"""
mm_config
=
self
.
model_config
.
get_multimodal_config
()
base_kwargs
=
mm_config
.
mm_processor_kwargs
if
base_kwargs
is
None
:
base_kwargs
=
{}
merged_kwargs
=
{
**
base_kwargs
,
**
kwargs
}
return
typ
(
**
merged_kwargs
)
def
_postprocess_output
(
self
,
output
:
JSONTree
,
)
->
JSONTree
:
def
_postprocess_one
(
x
:
object
):
if
isinstance
(
x
,
torch
.
Tensor
):
# noqa: SIM102
# This mimics the behavior of transformers.BatchFeature
if
x
.
is_floating_point
():
x
=
x
.
to
(
dtype
=
self
.
model_config
.
dtype
)
return
x
return
json_map_leaves
(
_postprocess_one
,
output
)
def
call_hf_processor
(
self
,
hf_processor
:
ProcessorMixin
,
data
:
Mapping
[
str
,
object
],
kwargs
:
Mapping
[
str
,
object
]
=
{},
*
,
num_tries
:
int
=
1
,
max_tries
:
int
=
5
,
)
->
BatchFeature
|
JSONTree
:
"""
Call `hf_processor` on the prompt `data`
(text, image, audio...) with configurable options `kwargs`.
"""
assert
callable
(
hf_processor
)
mm_config
=
self
.
model_config
.
get_multimodal_config
()
merged_kwargs
=
mm_config
.
merge_mm_processor_kwargs
(
kwargs
)
allowed_kwargs
=
get_allowed_kwarg_only_overrides
(
hf_processor
,
merged_kwargs
,
requires_kw_only
=
False
,
allow_var_kwargs
=
True
,
)
try
:
output
=
hf_processor
(
**
data
,
**
allowed_kwargs
,
return_tensors
=
"pt"
)
except
Exception
as
exc
:
# See https://github.com/huggingface/tokenizers/issues/537
if
(
isinstance
(
exc
,
RuntimeError
)
and
exc
and
exc
.
args
[
0
]
==
"Already borrowed"
and
num_tries
<
max_tries
):
logger
.
warning
(
"Failed to acquire tokenizer in current thread. "
"Retrying (%d/%d)..."
,
num_tries
,
max_tries
,
)
time
.
sleep
(
0.5
)
return
self
.
call_hf_processor
(
hf_processor
,
data
,
kwargs
,
num_tries
=
num_tries
+
1
,
max_tries
=
max_tries
,
)
msg
=
(
f
"Failed to apply
{
type
(
hf_processor
).
__name__
}
"
f
"on data=
{
data
}
with kwargs=
{
allowed_kwargs
}
"
)
raise
ValueError
(
msg
)
from
exc
# this emulates output.to(dtype=self.model_config.dtype)
from
transformers.feature_extraction_utils
import
BatchFeature
if
isinstance
(
output
,
BatchFeature
):
output_
=
self
.
_postprocess_output
(
output
.
data
)
return
BatchFeature
(
output_
)
logger
.
warning_once
(
"%s did not return `BatchFeature`. "
"Make sure to match the behaviour of `ProcessorMixin` when "
"implementing custom processors."
,
type
(
hf_processor
).
__name__
,
)
return
self
.
_postprocess_output
(
output
)
def
get_timing_stats
(
self
,
request_id
:
str
)
->
MultiModalProcessorTimingStats
|
None
:
"""
Get timing stats for a request.
"""
if
(
self
.
observability_config
is
None
or
not
self
.
observability_config
.
enable_mm_processor_stats
):
return
None
with
self
.
_timing_stats_registry_lock
:
return
self
.
timing_stats_registry
.
get
(
request_id
)
def
create_timing_stats
(
self
,
request_id
:
str
)
->
MultiModalProcessorTimingStats
:
"""
Create and store timing stats in the registry for a request.
This should be called at the start of processing for a request.
The stats object is created immediately and stored in the registry.
"""
if
(
self
.
observability_config
is
None
or
not
self
.
observability_config
.
enable_mm_processor_stats
):
return
MultiModalProcessorTimingStats
()
with
self
.
_timing_stats_registry_lock
:
if
request_id
in
self
.
timing_stats_registry
:
raise
ValueError
(
f
"Timing stats already exist for request_id:
{
request_id
}
"
)
stats
=
MultiModalProcessorTimingStats
()
self
.
timing_stats_registry
[
request_id
]
=
stats
return
stats
def
clear_timing_stats_registry
(
self
)
->
int
:
"""
Clear all stats from the registry. Returns the number of stats cleared.
"""
if
(
self
.
observability_config
is
None
or
not
self
.
observability_config
.
enable_mm_processor_stats
):
return
0
with
self
.
_timing_stats_registry_lock
:
count
=
len
(
self
.
timing_stats_registry
)
self
.
timing_stats_registry
.
clear
()
return
count
def
get_all_timing_stats
(
self
)
->
dict
[
str
,
dict
[
str
,
float
]]:
"""
Get all timing stats as a dictionary for API endpoints.
"""
if
(
self
.
observability_config
is
None
or
not
self
.
observability_config
.
enable_mm_processor_stats
):
return
{}
with
self
.
_timing_stats_registry_lock
:
return
{
rid
:
stats
.
to_dict
()
for
rid
,
stats
in
self
.
timing_stats_registry
.
items
()
}
class
BaseProcessingInfo
:
"""Base class to provide the information necessary for data processing."""
def
__init__
(
self
,
ctx
:
InputProcessingContext
)
->
None
:
super
().
__init__
()
self
.
ctx
=
ctx
@
property
def
model_id
(
self
)
->
str
:
return
self
.
ctx
.
model_config
.
model
def
get_tokenizer
(
self
)
->
TokenizerLike
:
return
self
.
ctx
.
get_tokenizer
()
def
get_hf_config
(
self
)
->
PretrainedConfig
:
return
self
.
ctx
.
get_hf_config
()
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
ProcessorMixin
:
"""
Subclasses can override this method to handle
specific kwargs from model config or user inputs.
"""
return
self
.
ctx
.
get_hf_processor
(
**
kwargs
)
@
property
def
skip_prompt_length_check
(
self
)
->
bool
:
return
False
@
abstractmethod
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise
NotImplementedError
def
get_allowed_mm_limits
(
self
)
->
Mapping
[
str
,
int
]:
"""Return the maximum allowed number of items for each modality."""
supported_mm_limits
=
self
.
get_supported_mm_limits
()
mm_config
=
self
.
ctx
.
get_mm_config
()
allowed_limits
=
dict
[
str
,
int
]()
for
modality
,
supported_limit
in
supported_mm_limits
.
items
():
user_limit
=
mm_config
.
get_limit_per_prompt
(
modality
)
allowed_limits
[
modality
]
=
(
user_limit
if
supported_limit
is
None
else
min
(
user_limit
,
supported_limit
)
)
return
allowed_limits
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]
|
None
:
"""
Return the maximum number of tokens per item of for each modality.
When `None` (the default) is returned, vLLM will generate dummy inputs
(images/videos) at maximum possible sizes and process them to determine
the maximum token count per modality.
This approach works but can be very slow for certain models (e.g.,
Qwen2.5-VL), leading to very long startup time. For better performance,
each model can override this method to return pre-computed maximum token
counts, avoiding the need for dummy input generation and processing.
Note:
The maximum number of tokens per item of each modality returned
from this function should respect the model's maximum sequence
length and the maximum number of items of each modality allowed,
and agree with dummy inputs (images/videos) at maximum possible
sizes.
"""
return
None
vllm/multimodal/pro
filing
.py
→
vllm/multimodal/pro
cessing/dummy_inputs
.py
View file @
9ea07b41
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Generic
from
typing
import
Generic
,
TypeVar
import
numpy
as
np
import
numpy
as
np
import
numpy.typing
as
npt
import
numpy.typing
as
npt
...
@@ -17,14 +17,10 @@ from vllm.config.multimodal import (
...
@@ -17,14 +17,10 @@ from vllm.config.multimodal import (
)
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
.inputs
import
MultiModalDataDict
from
..inputs
import
MultiModalDataDict
from
.context
import
BaseProcessingInfo
if
TYPE_CHECKING
:
_I
=
TypeVar
(
"_I"
,
bound
=
BaseProcessingInfo
)
from
.processing
import
_I
else
:
from
typing
import
TypeVar
_I
=
TypeVar
(
"_I"
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/multimodal/processing.py
→
vllm/multimodal/processing
/processor
.py
View file @
9ea07b41
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextvars
import
threading
import
time
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Callable
,
Generator
,
ItemsView
,
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Callable
,
Generator
,
ItemsView
,
Iterable
,
Mapping
,
Sequence
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
,
replace
from
dataclasses
import
dataclass
,
field
,
replace
from
enum
import
Enum
from
enum
import
Enum
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
(
from
typing
import
(
TYPE_CHECKING
,
TYPE_CHECKING
,
Any
,
Generic
,
Generic
,
NamedTuple
,
NamedTuple
,
Protocol
,
Protocol
,
TypeAlias
,
TypeAlias
,
cast
,
cast
,
overload
,
)
)
import
regex
as
re
import
regex
as
re
...
@@ -27,16 +21,14 @@ from typing_extensions import TypeVar, assert_never
...
@@ -27,16 +21,14 @@ from typing_extensions import TypeVar, assert_never
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.transformers_utils.processor
import
cached_processor_from_config
from
vllm.utils.collection_utils
import
flatten_2d_lists
,
full_groupby
from
vllm.utils.collection_utils
import
flatten_2d_lists
,
full_groupby
from
vllm.utils.func_utils
import
get_allowed_kwarg_only_overrides
from
vllm.utils.jsontree
import
JSONTree
,
json_map_leaves
from
.hasher
import
MultiModalHasher
from
.
.hasher
import
MultiModalHasher
from
.inputs
import
(
from
.
.inputs
import
(
MultiModalDataDict
,
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalEncDecInputs
,
MultiModalFieldConfig
,
MultiModalFieldConfig
,
MultiModalHashes
,
MultiModalInputs
,
MultiModalInputs
,
MultiModalKwargsItem
,
MultiModalKwargsItem
,
MultiModalKwargsItems
,
MultiModalKwargsItems
,
...
@@ -44,29 +36,21 @@ from .inputs import (
...
@@ -44,29 +36,21 @@ from .inputs import (
MultiModalUUIDDict
,
MultiModalUUIDDict
,
PlaceholderRange
,
PlaceholderRange
,
)
)
from
.parse
import
(
from
.
.parse
import
(
DictEmbeddingItems
,
DictEmbeddingItems
,
EmbeddingItems
,
EmbeddingItems
,
MultiModalDataItems
,
MultiModalDataItems
,
MultiModalDataParser
,
MultiModalDataParser
,
)
)
from
.profiling
import
BaseDummyInputsBuilder
from
.context
import
BaseProcessingInfo
,
get_current_request_id
,
timed_operation
from
.dummy_inputs
import
BaseDummyInputsBuilder
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.feature_extraction_utils
import
BatchFeature
from
transformers.feature_extraction_utils
import
BatchFeature
from
transformers.processing_utils
import
ProcessorMixin
from
vllm.config
import
ModelConfig
,
ObservabilityConfig
from
..cache
import
BaseMultiModalProcessorCache
from
.cache
import
BaseMultiModalProcessorCache
else
:
else
:
PretrainedConfig
=
object
BatchFeature
=
object
BatchFeature
=
object
ProcessorMixin
=
object
ModelConfig
=
object
ObservabilityConfig
=
object
BaseMultiModalProcessorCache
=
object
BaseMultiModalProcessorCache
=
object
...
@@ -74,126 +58,6 @@ logger = init_logger(__name__)
...
@@ -74,126 +58,6 @@ logger = init_logger(__name__)
_S
=
TypeVar
(
"_S"
,
str
,
list
[
int
])
_S
=
TypeVar
(
"_S"
,
str
,
list
[
int
])
_request_id_context
:
contextvars
.
ContextVar
[
str
|
None
]
=
contextvars
.
ContextVar
(
"_request_id_context"
,
default
=
None
)
def
get_current_request_id
()
->
str
|
None
:
"""Get the current request_id from the context, if available."""
return
_request_id_context
.
get
()
@
contextmanager
def
set_request_id
(
request_id
:
str
)
->
Generator
[
None
,
None
,
None
]:
"""Context manager to set the request_id for the current context."""
token
=
_request_id_context
.
set
(
request_id
)
try
:
yield
finally
:
_request_id_context
.
reset
(
token
)
@
dataclass
class
MultiModalProcessorTimingStats
:
"""Per-request timing statistics for multimodal processor stages."""
hf_processor_time
:
float
=
0.0
"""Time spent in HuggingFace processor calls (seconds)."""
hashing_time
:
float
=
0.0
"""Time spent computing multimodal item hashes (seconds)."""
cache_lookup_time
:
float
=
0.0
"""Time spent in cache lookups and merges (seconds)."""
prompt_update_time
:
float
=
0.0
"""Time spent applying prompt updates and finding placeholders (seconds)."""
total_time
:
float
=
0.0
"""Total processing time (seconds)."""
def
to_dict
(
self
)
->
dict
[
str
,
float
]:
"""Convert stats to a dictionary for JSON serialization."""
return
{
"hf_processor_time"
:
self
.
hf_processor_time
,
"hashing_time"
:
self
.
hashing_time
,
"cache_lookup_time"
:
self
.
cache_lookup_time
,
"prompt_update_time"
:
self
.
prompt_update_time
,
"total_time"
:
self
.
total_time
,
}
def
get_timing_stats_from_engine_client
(
engine_client
:
Any
,
)
->
dict
[
str
,
dict
[
str
,
float
]]:
"""
Get all timing stats from the context associated with the engine client.
Args:
engine_client: The engine client that has input_processor.
Returns:
A dictionary mapping request_id to stats dict.
"""
try
:
if
not
engine_client
.
vllm_config
.
observability_config
.
enable_mm_processor_stats
:
return
{}
except
(
AttributeError
,
RuntimeError
):
return
{}
try
:
input_processor
=
engine_client
.
input_processor
input_preprocessor
=
input_processor
.
input_preprocessor
if
hasattr
(
input_preprocessor
,
"_get_mm_processor"
):
mm_processor
=
input_preprocessor
.
_get_mm_processor
()
if
mm_processor
is
not
None
and
hasattr
(
mm_processor
,
"info"
):
ctx
=
mm_processor
.
info
.
ctx
return
ctx
.
get_all_timing_stats
()
except
(
AttributeError
,
RuntimeError
):
pass
return
{}
@
contextmanager
def
_timed_operation
(
ctx
:
"InputProcessingContext"
,
stage_name
:
str
):
"""
Context manager to time an operation using the context's timing stats.
The request_id is automatically retrieved from the context variable,
so it doesn't need to be passed as a parameter.
Args:
ctx: The InputProcessingContext containing the timing stats registry.
stage_name: Name of the stage being timed.
"""
request_id
=
get_current_request_id
()
if
ctx
is
None
or
request_id
is
None
:
yield
return
stats
=
ctx
.
get_timing_stats
(
request_id
)
if
stats
is
None
:
yield
return
start_time
=
time
.
perf_counter
()
try
:
yield
finally
:
elapsed
=
time
.
perf_counter
()
-
start_time
if
stage_name
==
"hf_processor"
:
stats
.
hf_processor_time
+=
elapsed
elif
stage_name
==
"hashing"
:
stats
.
hashing_time
+=
elapsed
elif
stage_name
==
"cache_lookup"
:
stats
.
cache_lookup_time
+=
elapsed
elif
stage_name
==
"prompt_update"
:
stats
.
prompt_update_time
+=
elapsed
stats
.
total_time
+=
elapsed
PromptSeq
:
TypeAlias
=
str
|
list
[
int
]
PromptSeq
:
TypeAlias
=
str
|
list
[
int
]
"""A token sequence (list of token IDs) or text."""
"""A token sequence (list of token IDs) or text."""
...
@@ -1073,412 +937,6 @@ def find_mm_placeholders(
...
@@ -1073,412 +937,6 @@ def find_mm_placeholders(
return
dict
(
full_groupby_modality
(
it
))
return
dict
(
full_groupby_modality
(
it
))
_T
=
TypeVar
(
"_T"
)
_C
=
TypeVar
(
"_C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
_P
=
TypeVar
(
"_P"
,
bound
=
ProcessorMixin
,
default
=
ProcessorMixin
)
@
dataclass
(
frozen
=
True
)
class
InputProcessingContext
:
"""
Contains information about the model which may be used to
modify the inputs.
"""
model_config
:
ModelConfig
"""The configuration of the model."""
tokenizer
:
TokenizerLike
|
None
"""The tokenizer used to tokenize the inputs."""
observability_config
:
"ObservabilityConfig | None"
=
field
(
default
=
None
,
compare
=
False
,
repr
=
False
)
"""Configuration for observability features."""
timing_stats_registry
:
dict
[
str
,
MultiModalProcessorTimingStats
]
=
field
(
default_factory
=
dict
,
compare
=
False
,
repr
=
False
)
"""Registry for storing timing stats keyed by request_id."""
_timing_stats_registry_lock
:
threading
.
Lock
=
field
(
default_factory
=
threading
.
Lock
,
compare
=
False
,
repr
=
False
)
"""Lock for thread-safe access to timing_stats_registry."""
def
get_tokenizer
(
self
)
->
TokenizerLike
:
if
self
.
tokenizer
is
None
:
raise
ValueError
(
"You cannot pass text prompts when `skip_tokenizer_init=True`"
)
return
self
.
tokenizer
@
overload
def
get_hf_config
(
self
,
/
)
->
PretrainedConfig
:
...
@
overload
def
get_hf_config
(
self
,
typ
:
type
[
_C
]
|
tuple
[
type
[
_C
],
...],
/
,
)
->
_C
:
...
def
get_hf_config
(
self
,
typ
:
type
[
Any
]
|
tuple
[
type
[
Any
],
...]
|
None
=
None
,
/
,
)
->
Any
:
"""
Get the HuggingFace configuration
(`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
TypeError: If the configuration is not of the specified type.
"""
if
typ
is
None
:
from
transformers.configuration_utils
import
PretrainedConfig
typ
=
PretrainedConfig
hf_config
=
self
.
model_config
.
hf_config
if
not
isinstance
(
hf_config
,
typ
):
raise
TypeError
(
"Invalid type of HuggingFace config. "
f
"Expected type:
{
typ
}
, but "
f
"found type:
{
type
(
hf_config
)
}
"
)
return
hf_config
def
get_hf_image_processor_config
(
self
)
->
dict
[
str
,
Any
]:
"""
Get the HuggingFace image processor configuration of the model.
"""
return
self
.
model_config
.
hf_image_processor_config
def
get_mm_config
(
self
):
"""
Get the multimodal config of the model.
Raises:
RuntimeError: If the model is not a multimodal model.
"""
mm_config
=
self
.
model_config
.
multimodal_config
if
mm_config
is
None
:
raise
RuntimeError
(
"Not a multimodal model"
)
return
mm_config
@
overload
def
get_hf_processor
(
self
,
/
,
**
kwargs
:
object
)
->
ProcessorMixin
:
...
@
overload
def
get_hf_processor
(
self
,
typ
:
type
[
_P
]
|
tuple
[
type
[
_P
],
...],
/
,
**
kwargs
:
object
,
)
->
_P
:
...
def
get_hf_processor
(
self
,
typ
:
type
[
Any
]
|
tuple
[
type
[
Any
],
...]
|
None
=
None
,
/
,
**
kwargs
:
object
,
)
->
Any
:
"""
Get the HuggingFace processor
(`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
TypeError: If the processor is not of the specified type.
"""
if
typ
is
None
:
from
transformers.processing_utils
import
ProcessorMixin
typ
=
ProcessorMixin
from
vllm.tokenizers.mistral
import
MistralTokenizer
tokenizer
=
self
.
tokenizer
if
isinstance
(
tokenizer
,
MistralTokenizer
):
tokenizer
=
tokenizer
.
transformers_tokenizer
return
cached_processor_from_config
(
self
.
model_config
,
processor_cls
=
typ
,
tokenizer
=
tokenizer
,
**
kwargs
,
)
def
init_processor
(
self
,
typ
:
type
[
_T
],
/
,
**
kwargs
:
object
,
)
->
_T
:
"""
Initialize a HuggingFace-like processor class, merging the
keyword arguments with those in the model's configuration.
"""
mm_config
=
self
.
model_config
.
get_multimodal_config
()
base_kwargs
=
mm_config
.
mm_processor_kwargs
if
base_kwargs
is
None
:
base_kwargs
=
{}
merged_kwargs
=
{
**
base_kwargs
,
**
kwargs
}
return
typ
(
**
merged_kwargs
)
def
_postprocess_output
(
self
,
output
:
JSONTree
,
)
->
JSONTree
:
def
_postprocess_one
(
x
:
object
):
if
isinstance
(
x
,
torch
.
Tensor
):
# noqa: SIM102
# This mimics the behavior of transformers.BatchFeature
if
x
.
is_floating_point
():
x
=
x
.
to
(
dtype
=
self
.
model_config
.
dtype
)
return
x
return
json_map_leaves
(
_postprocess_one
,
output
)
def
call_hf_processor
(
self
,
hf_processor
:
ProcessorMixin
,
data
:
Mapping
[
str
,
object
],
kwargs
:
Mapping
[
str
,
object
]
=
{},
*
,
num_tries
:
int
=
1
,
max_tries
:
int
=
5
,
)
->
BatchFeature
|
JSONTree
:
"""
Call `hf_processor` on the prompt `data`
(text, image, audio...) with configurable options `kwargs`.
"""
assert
callable
(
hf_processor
)
mm_config
=
self
.
model_config
.
get_multimodal_config
()
merged_kwargs
=
mm_config
.
merge_mm_processor_kwargs
(
kwargs
)
allowed_kwargs
=
get_allowed_kwarg_only_overrides
(
hf_processor
,
merged_kwargs
,
requires_kw_only
=
False
,
allow_var_kwargs
=
True
,
)
try
:
output
=
hf_processor
(
**
data
,
**
allowed_kwargs
,
return_tensors
=
"pt"
)
except
Exception
as
exc
:
# See https://github.com/huggingface/tokenizers/issues/537
if
(
isinstance
(
exc
,
RuntimeError
)
and
exc
and
exc
.
args
[
0
]
==
"Already borrowed"
and
num_tries
<
max_tries
):
logger
.
warning
(
"Failed to acquire tokenizer in current thread. "
"Retrying (%d/%d)..."
,
num_tries
,
max_tries
,
)
time
.
sleep
(
0.5
)
return
self
.
call_hf_processor
(
hf_processor
,
data
,
kwargs
,
num_tries
=
num_tries
+
1
,
max_tries
=
max_tries
,
)
msg
=
(
f
"Failed to apply
{
type
(
hf_processor
).
__name__
}
"
f
"on data=
{
data
}
with kwargs=
{
allowed_kwargs
}
"
)
raise
ValueError
(
msg
)
from
exc
# this emulates output.to(dtype=self.model_config.dtype)
from
transformers.feature_extraction_utils
import
BatchFeature
if
isinstance
(
output
,
BatchFeature
):
output_
=
self
.
_postprocess_output
(
output
.
data
)
return
BatchFeature
(
output_
)
logger
.
warning_once
(
"%s did not return `BatchFeature`. "
"Make sure to match the behaviour of `ProcessorMixin` when "
"implementing custom processors."
,
type
(
hf_processor
).
__name__
,
)
return
self
.
_postprocess_output
(
output
)
def
get_timing_stats
(
self
,
request_id
:
str
)
->
MultiModalProcessorTimingStats
|
None
:
"""
Get timing stats for a request.
"""
if
(
self
.
observability_config
is
None
or
not
self
.
observability_config
.
enable_mm_processor_stats
):
return
None
with
self
.
_timing_stats_registry_lock
:
return
self
.
timing_stats_registry
.
get
(
request_id
)
def
create_timing_stats
(
self
,
request_id
:
str
)
->
MultiModalProcessorTimingStats
:
"""
Create and store timing stats in the registry for a request.
This should be called at the start of processing for a request.
The stats object is created immediately and stored in the registry.
"""
if
(
self
.
observability_config
is
None
or
not
self
.
observability_config
.
enable_mm_processor_stats
):
return
MultiModalProcessorTimingStats
()
with
self
.
_timing_stats_registry_lock
:
if
request_id
in
self
.
timing_stats_registry
:
raise
ValueError
(
f
"Timing stats already exist for request_id:
{
request_id
}
"
)
stats
=
MultiModalProcessorTimingStats
()
self
.
timing_stats_registry
[
request_id
]
=
stats
return
stats
def
clear_timing_stats_registry
(
self
)
->
int
:
"""
Clear all stats from the registry. Returns the number of stats cleared.
"""
if
(
self
.
observability_config
is
None
or
not
self
.
observability_config
.
enable_mm_processor_stats
):
return
0
with
self
.
_timing_stats_registry_lock
:
count
=
len
(
self
.
timing_stats_registry
)
self
.
timing_stats_registry
.
clear
()
return
count
def
get_all_timing_stats
(
self
)
->
dict
[
str
,
dict
[
str
,
float
]]:
"""
Get all timing stats as a dictionary for API endpoints.
"""
if
(
self
.
observability_config
is
None
or
not
self
.
observability_config
.
enable_mm_processor_stats
):
return
{}
with
self
.
_timing_stats_registry_lock
:
return
{
rid
:
stats
.
to_dict
()
for
rid
,
stats
in
self
.
timing_stats_registry
.
items
()
}
class
BaseProcessingInfo
:
"""Base class to provide the information necessary for data processing."""
def
__init__
(
self
,
ctx
:
InputProcessingContext
)
->
None
:
super
().
__init__
()
self
.
ctx
=
ctx
@
property
def
model_id
(
self
)
->
str
:
return
self
.
ctx
.
model_config
.
model
def
get_tokenizer
(
self
)
->
TokenizerLike
:
return
self
.
ctx
.
get_tokenizer
()
def
get_hf_config
(
self
)
->
PretrainedConfig
:
return
self
.
ctx
.
get_hf_config
()
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
ProcessorMixin
:
"""
Subclasses can override this method to handle
specific kwargs from model config or user inputs.
"""
return
self
.
ctx
.
get_hf_processor
(
**
kwargs
)
@
property
def
skip_prompt_length_check
(
self
)
->
bool
:
return
False
@
abstractmethod
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
int
|
None
]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise
NotImplementedError
def
get_allowed_mm_limits
(
self
)
->
Mapping
[
str
,
int
]:
"""Return the maximum allowed number of items for each modality."""
supported_mm_limits
=
self
.
get_supported_mm_limits
()
mm_config
=
self
.
ctx
.
get_mm_config
()
allowed_limits
=
dict
[
str
,
int
]()
for
modality
,
supported_limit
in
supported_mm_limits
.
items
():
user_limit
=
mm_config
.
get_limit_per_prompt
(
modality
)
allowed_limits
[
modality
]
=
(
user_limit
if
supported_limit
is
None
else
min
(
user_limit
,
supported_limit
)
)
return
allowed_limits
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]
|
None
:
"""
Return the maximum number of tokens per item of for each modality.
When `None` (the default) is returned, vLLM will generate dummy inputs
(images/videos) at maximum possible sizes and process them to determine
the maximum token count per modality.
This approach works but can be very slow for certain models (e.g.,
Qwen2.5-VL), leading to very long startup time. For better performance,
each model can override this method to return pre-computed maximum token
counts, avoiding the need for dummy input generation and processing.
Note:
The maximum number of tokens per item of each modality returned
from this function should respect the model's maximum sequence
length and the maximum number of items of each modality allowed,
and agree with dummy inputs (images/videos) at maximum possible
sizes.
"""
return
None
_I
=
TypeVar
(
"_I"
,
bound
=
BaseProcessingInfo
)
MultiModalHashes
=
dict
[
str
,
list
[
str
]]
"""
A collection of the multi-modal hash for each item, with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""
MultiModalIsCached
=
dict
[
str
,
list
[
bool
]]
MultiModalIsCached
=
dict
[
str
,
list
[
bool
]]
"""
"""
A collection of the `is_cached` flag for each item, with a similar structure as
A collection of the `is_cached` flag for each item, with a similar structure as
...
@@ -1499,6 +957,8 @@ For an item `MultiModalPromptUpdates[k][i]`,
...
@@ -1499,6 +957,8 @@ For an item `MultiModalPromptUpdates[k][i]`,
`ResolvedPromptUpdate` instances have been applied.
`ResolvedPromptUpdate` instances have been applied.
"""
"""
_I
=
TypeVar
(
"_I"
,
bound
=
BaseProcessingInfo
)
class
MultiModalProcessingInfo
(
NamedTuple
):
class
MultiModalProcessingInfo
(
NamedTuple
):
kwargs
:
MultiModalKwargsOptionalItems
kwargs
:
MultiModalKwargsOptionalItems
...
@@ -1732,7 +1192,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1732,7 +1192,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Call the HF processor on the prompt text and
Call the HF processor on the prompt text and
associated multi-modal data.
associated multi-modal data.
"""
"""
with
_
timed_operation
(
self
.
info
.
ctx
,
"hf_processor"
):
with
timed_operation
(
self
.
info
.
ctx
,
"hf_processor"
):
return
self
.
info
.
ctx
.
call_hf_processor
(
return
self
.
info
.
ctx
.
call_hf_processor
(
self
.
info
.
get_hf_processor
(
**
mm_kwargs
),
self
.
info
.
get_hf_processor
(
**
mm_kwargs
),
dict
(
text
=
prompt
,
**
mm_data
),
dict
(
text
=
prompt
,
**
mm_data
),
...
@@ -1841,7 +1301,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1841,7 +1301,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Since HF processor requires that text and multi-modal items
Since HF processor requires that text and multi-modal items
correspond to each other, we generate dummy text using
correspond to each other, we generate dummy text using
[`DummyInputsBuilder`][vllm.multimodal.pro
fil
ing.BaseDummyInputsBuilder]
[`DummyInputsBuilder`][vllm.multimodal.pro
cess
ing.BaseDummyInputsBuilder]
to go along with the multi-modal data.
to go along with the multi-modal data.
"""
"""
mm_counts
=
mm_items
.
get_all_counts
()
mm_counts
=
mm_items
.
get_all_counts
()
...
@@ -2085,7 +1545,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -2085,7 +1545,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
)
)
# Use overrides if provided; fallback to data-dependent hashing.
# Use overrides if provided; fallback to data-dependent hashing.
with
_
timed_operation
(
self
.
info
.
ctx
,
"hashing"
):
with
timed_operation
(
self
.
info
.
ctx
,
"hashing"
):
mm_hashes
=
self
.
_hash_mm_items
(
mm_hashes
=
self
.
_hash_mm_items
(
mm_data_items
,
mm_data_items
,
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
,
...
@@ -2132,7 +1592,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -2132,7 +1592,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
with
_
timed_operation
(
self
.
info
.
ctx
,
"hashing"
):
with
timed_operation
(
self
.
info
.
ctx
,
"hashing"
):
mm_hashes
=
self
.
_hash_mm_items
(
mm_hashes
=
self
.
_hash_mm_items
(
mm_data_items
,
mm_data_items
,
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
,
...
@@ -2140,7 +1600,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -2140,7 +1600,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_uuids
=
mm_uuids
,
mm_uuids
=
mm_uuids
,
)
)
with
_
timed_operation
(
self
.
info
.
ctx
,
"cache_lookup"
):
with
timed_operation
(
self
.
info
.
ctx
,
"cache_lookup"
):
mm_is_cached
,
mm_missing_data_items
=
self
.
_get_cache_missing_items
(
mm_is_cached
,
mm_missing_data_items
=
self
.
_get_cache_missing_items
(
cache
=
cache
,
cache
=
cache
,
mm_data_items
=
mm_data_items
,
mm_data_items
=
mm_data_items
,
...
@@ -2175,7 +1635,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -2175,7 +1635,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_kwargs
,
mm_missing_kwargs
,
)
)
with
_
timed_operation
(
self
.
info
.
ctx
,
"cache_lookup"
):
with
timed_operation
(
self
.
info
.
ctx
,
"cache_lookup"
):
mm_kwargs
,
mm_prompt_updates
=
self
.
_merge_mm_kwargs
(
mm_kwargs
,
mm_prompt_updates
=
self
.
_merge_mm_kwargs
(
cache
,
cache
,
mm_hashes
=
mm_hashes
,
mm_hashes
=
mm_hashes
,
...
@@ -2386,7 +1846,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -2386,7 +1846,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
)
)
# NOTE: tokenization_kwargs are not required to init processor
# NOTE: tokenization_kwargs are not required to init processor
with
_
timed_operation
(
self
.
info
.
ctx
,
"prompt_update"
):
with
timed_operation
(
self
.
info
.
ctx
,
"prompt_update"
):
prompt_ids
,
mm_placeholders
=
self
.
_maybe_apply_prompt_updates
(
prompt_ids
,
mm_placeholders
=
self
.
_maybe_apply_prompt_updates
(
mm_items
=
mm_items
,
mm_items
=
mm_items
,
prompt_ids
=
prompt_ids
,
prompt_ids
=
prompt_ids
,
...
...
vllm/multimodal/registry.py
View file @
9ea07b41
...
@@ -12,11 +12,11 @@ from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
...
@@ -12,11 +12,11 @@ from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from
.cache
import
BaseMultiModalProcessorCache
from
.cache
import
BaseMultiModalProcessorCache
from
.inputs
import
MultiModalInputs
from
.inputs
import
MultiModalInputs
from
.processing
import
(
from
.processing
import
(
BaseDummyInputsBuilder
,
BaseMultiModalProcessor
,
BaseMultiModalProcessor
,
BaseProcessingInfo
,
BaseProcessingInfo
,
InputProcessingContext
,
InputProcessingContext
,
)
)
from
.profiling
import
BaseDummyInputsBuilder
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
,
ObservabilityConfig
from
vllm.config
import
ModelConfig
,
ObservabilityConfig
...
@@ -45,7 +45,7 @@ class ProcessingInfoFactory(Protocol[_I_co]):
...
@@ -45,7 +45,7 @@ class ProcessingInfoFactory(Protocol[_I_co]):
class
DummyInputsBuilderFactory
(
Protocol
[
_I
]):
# type: ignore[misc]
class
DummyInputsBuilderFactory
(
Protocol
[
_I
]):
# type: ignore[misc]
"""
"""
Constructs a
Constructs a
[`BaseDummyInputsBuilder`][vllm.multimodal.pro
fil
ing.BaseDummyInputsBuilder]
[`BaseDummyInputsBuilder`][vllm.multimodal.pro
cess
ing.BaseDummyInputsBuilder]
instance from the context.
instance from the context.
"""
"""
...
...
vllm/v1/engine/input_processor.py
View file @
9ea07b41
...
@@ -17,7 +17,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
...
@@ -17,7 +17,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from
vllm.multimodal.cache
import
processor_cache_from_config
from
vllm.multimodal.cache
import
processor_cache_from_config
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
MultiModalUUIDDict
from
vllm.multimodal.inputs
import
MultiModalFeatureSpec
,
MultiModalUUIDDict
from
vllm.multimodal.parse
import
MultiModalDataParser
from
vllm.multimodal.parse
import
MultiModalDataParser
from
vllm.multimodal.processing
import
set_request_id
from
vllm.multimodal.processing
.context
import
set_request_id
from
vllm.multimodal.utils
import
argsort_mm_positions
from
vllm.multimodal.utils
import
argsort_mm_positions
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
_SAMPLING_EPS
,
SamplingParams
from
vllm.sampling_params
import
_SAMPLING_EPS
,
SamplingParams
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment