Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
69244e67
Unverified
Commit
69244e67
authored
Aug 27, 2025
by
Cyrus Leung
Committed by
GitHub
Aug 27, 2025
Browse files
[Core] Use key-only cache for `BaseMultiModalProcessor` (#23018)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
8dbf6ed7
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
75 additions
and
203 deletions
+75
-203
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+36
-54
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+1
-2
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+10
-7
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+1
-2
vllm/v1/engine/mm_input_cache.py
vllm/v1/engine/mm_input_cache.py
+0
-121
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+14
-15
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-0
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+3
-0
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+7
-2
No files found.
vllm/multimodal/registry.py
View file @
69244e67
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
typing
import
TYPE_CHECKING
,
Generic
,
Optional
,
Protocol
,
TypeVar
from
typing
import
TYPE_CHECKING
,
Generic
,
Optional
,
Protocol
,
TypeVar
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -13,8 +12,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
...
@@ -13,8 +12,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
cached_tokenizer_from_config
)
cached_tokenizer_from_config
)
from
vllm.utils
import
ClassRegistry
from
vllm.utils
import
ClassRegistry
from
.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
from
.cache
import
(
BaseMultiModalProcessorCache
,
ProcessingCache
)
processor_only_cache_from_config
)
from
.processing
import
BaseMultiModalProcessor
,
BaseProcessingInfo
from
.profiling
import
(
BaseDummyInputsBuilder
,
DummyDecoderData
,
from
.profiling
import
(
BaseDummyInputsBuilder
,
DummyDecoderData
,
DummyEncoderData
,
MultiModalProfiler
)
DummyEncoderData
,
MultiModalProfiler
)
...
@@ -65,7 +65,7 @@ class MultiModalProcessorFactory(Protocol[_I]):
...
@@ -65,7 +65,7 @@ class MultiModalProcessorFactory(Protocol[_I]):
info
:
_I
,
info
:
_I
,
dummy_inputs
:
BaseDummyInputsBuilder
[
_I
],
dummy_inputs
:
BaseDummyInputsBuilder
[
_I
],
*
,
*
,
cache
:
Optional
[
Process
ing
Cache
]
=
None
,
cache
:
Optional
[
BaseMultiModal
Process
or
Cache
]
=
None
,
)
->
BaseMultiModalProcessor
[
_I
]:
)
->
BaseMultiModalProcessor
[
_I
]:
...
...
...
@@ -80,20 +80,13 @@ class _ProcessorFactories(Generic[_I]):
...
@@ -80,20 +80,13 @@ class _ProcessorFactories(Generic[_I]):
self
,
self
,
ctx
:
InputProcessingContext
,
ctx
:
InputProcessingContext
,
*
,
*
,
cache
:
Optional
[
Process
ing
Cache
]
=
None
,
cache
:
Optional
[
BaseMultiModal
Process
or
Cache
]
=
None
,
):
):
info
=
self
.
info
(
ctx
)
info
=
self
.
info
(
ctx
)
dummy_inputs_builder
=
self
.
dummy_inputs
(
info
)
dummy_inputs_builder
=
self
.
dummy_inputs
(
info
)
return
self
.
processor
(
info
,
dummy_inputs_builder
,
cache
=
cache
)
return
self
.
processor
(
info
,
dummy_inputs_builder
,
cache
=
cache
)
# Make sure a different cache is used for each model config
# NOTE: ModelConfig is not hashable so it cannot be passed directly
@
lru_cache
(
maxsize
=
1
)
def
_get_processor_cache
(
model_id
:
str
,
capacity_gb
:
int
):
return
ProcessingCache
(
capacity_gb
)
if
capacity_gb
>
0
else
None
class
MultiModalRegistry
:
class
MultiModalRegistry
:
"""
"""
A registry that dispatches data processing according to the model.
A registry that dispatches data processing according to the model.
...
@@ -103,31 +96,6 @@ class MultiModalRegistry:
...
@@ -103,31 +96,6 @@ class MultiModalRegistry:
self
.
_processor_factories
=
ClassRegistry
[
nn
.
Module
,
self
.
_processor_factories
=
ClassRegistry
[
nn
.
Module
,
_ProcessorFactories
]()
_ProcessorFactories
]()
def
_get_processor_cache
(
self
,
model_config
:
"ModelConfig"
):
model_id
=
model_config
.
model
capacity_gb
=
model_config
.
mm_processor_cache_gb
return
_get_processor_cache
(
model_id
,
capacity_gb
)
def
reset_processor_cache
(
self
,
model_config
:
"ModelConfig"
)
->
bool
:
"""Reset the multi-modal processing cache."""
if
processor_cache
:
=
self
.
_get_processor_cache
(
model_config
):
processor_cache
.
reset
()
return
True
# Success
def
enable_mm_input_cache
(
self
,
model_config
:
"ModelConfig"
)
->
bool
:
"""Whether the multi-modal input cache should be enabled.
NOTE: This is put under MultiModalRegistry on purpose to respect
text-only mode for multimodal models.
"""
if
not
self
.
supports_multimodal_inputs
(
model_config
):
return
False
mm_config
=
model_config
.
get_multimodal_config
()
return
mm_config
.
mm_processor_cache_gb
>
0
def
supports_multimodal_inputs
(
self
,
model_config
:
"ModelConfig"
)
->
bool
:
def
supports_multimodal_inputs
(
self
,
model_config
:
"ModelConfig"
)
->
bool
:
"""
"""
Checks if the model supports multimodal inputs.
Checks if the model supports multimodal inputs.
...
@@ -157,6 +125,8 @@ class MultiModalRegistry:
...
@@ -157,6 +125,8 @@ class MultiModalRegistry:
def
get_max_tokens_per_item_by_modality
(
def
get_max_tokens_per_item_by_modality
(
self
,
self
,
model_config
:
"ModelConfig"
,
model_config
:
"ModelConfig"
,
*
,
cache
:
Optional
[
BaseMultiModalProcessorCache
]
=
None
,
)
->
Mapping
[
str
,
int
]:
)
->
Mapping
[
str
,
int
]:
"""
"""
Get the maximum number of tokens per data item from each modality based
Get the maximum number of tokens per data item from each modality based
...
@@ -165,11 +135,11 @@ class MultiModalRegistry:
...
@@ -165,11 +135,11 @@ class MultiModalRegistry:
if
not
model_config
.
is_multimodal_model
:
if
not
model_config
.
is_multimodal_model
:
return
{}
return
{}
processor
=
self
.
create_processor
(
model_config
,
disable_
cache
=
Fals
e
)
processor
=
self
.
create_processor
(
model_config
,
cache
=
cach
e
)
profiler
=
MultiModalProfiler
(
processor
)
profiler
=
MultiModalProfiler
(
processor
)
seq_len
=
model_config
.
max_model_len
seq_len
=
model_config
.
max_model_len
mm_limits
=
self
.
get_mm_limits_per_prompt
(
model_config
)
mm_limits
=
self
.
get_mm_limits_per_prompt
(
model_config
,
cache
=
cache
)
return
profiler
.
get_mm_max_contiguous_tokens
(
return
profiler
.
get_mm_max_contiguous_tokens
(
seq_len
,
seq_len
,
...
@@ -182,6 +152,8 @@ class MultiModalRegistry:
...
@@ -182,6 +152,8 @@ class MultiModalRegistry:
def
get_max_tokens_per_item_by_nonzero_modality
(
def
get_max_tokens_per_item_by_nonzero_modality
(
self
,
self
,
model_config
:
"ModelConfig"
,
model_config
:
"ModelConfig"
,
*
,
cache
:
Optional
[
BaseMultiModalProcessorCache
]
=
None
,
)
->
Mapping
[
str
,
int
]:
)
->
Mapping
[
str
,
int
]:
"""
"""
Get the maximum number of tokens per data item from each modality based
Get the maximum number of tokens per data item from each modality based
...
@@ -192,15 +164,19 @@ class MultiModalRegistry:
...
@@ -192,15 +164,19 @@ class MultiModalRegistry:
This is currently directly used only in V1 for profiling the memory
This is currently directly used only in V1 for profiling the memory
usage of a model.
usage of a model.
"""
"""
mm_limits
=
self
.
get_mm_limits_per_prompt
(
model_config
)
mm_limits
=
self
.
get_mm_limits_per_prompt
(
model_config
,
cache
=
cache
)
max_tokens_per_item
=
self
.
get_max_tokens_per_item_by_modality
(
model_config
,
cache
=
cache
,
)
return
{
return
{
key
:
max_tokens_per_mm_item
key
:
max_tokens_per_mm_item
for
key
,
max_tokens_per_mm_item
in
for
key
,
max_tokens_per_mm_item
in
max_tokens_per_item
.
items
()
self
.
get_max_tokens_per_item_by_modality
(
model_config
).
items
()
if
mm_limits
[
key
]
>
0
if
mm_limits
[
key
]
>
0
}
}
# TODO: Remove once V0 is gone
def
get_max_tokens_by_modality
(
def
get_max_tokens_by_modality
(
self
,
self
,
model_config
:
"ModelConfig"
,
model_config
:
"ModelConfig"
,
...
@@ -209,14 +185,19 @@ class MultiModalRegistry:
...
@@ -209,14 +185,19 @@ class MultiModalRegistry:
Get the maximum number of tokens from each modality
Get the maximum number of tokens from each modality
for profiling the memory usage of a model.
for profiling the memory usage of a model.
"""
"""
mm_limits
=
self
.
get_mm_limits_per_prompt
(
model_config
)
cache
=
processor_only_cache_from_config
(
model_config
,
self
)
mm_limits
=
self
.
get_mm_limits_per_prompt
(
model_config
,
cache
=
cache
)
max_tokens_per_item
=
self
.
get_max_tokens_per_item_by_modality
(
model_config
,
cache
=
cache
,
)
return
{
return
{
key
:
mm_limits
[
key
]
*
max_tokens_per_mm_item
key
:
mm_limits
[
key
]
*
max_tokens_per_mm_item
for
key
,
max_tokens_per_mm_item
in
for
key
,
max_tokens_per_mm_item
in
max_tokens_per_item
.
items
()
self
.
get_max_tokens_per_item_by_modality
(
model_config
).
items
()
}
}
# TODO: Remove once V0 is gone
def
get_max_multimodal_tokens
(
self
,
model_config
:
"ModelConfig"
)
->
int
:
def
get_max_multimodal_tokens
(
self
,
model_config
:
"ModelConfig"
)
->
int
:
"""
"""
Get the maximum number of multi-modal tokens
Get the maximum number of multi-modal tokens
...
@@ -227,6 +208,8 @@ class MultiModalRegistry:
...
@@ -227,6 +208,8 @@ class MultiModalRegistry:
def
get_mm_limits_per_prompt
(
def
get_mm_limits_per_prompt
(
self
,
self
,
model_config
:
"ModelConfig"
,
model_config
:
"ModelConfig"
,
*
,
cache
:
Optional
[
BaseMultiModalProcessorCache
]
=
None
,
)
->
Mapping
[
str
,
int
]:
)
->
Mapping
[
str
,
int
]:
"""
"""
Get the maximum number of multi-modal input instances for each modality
Get the maximum number of multi-modal input instances for each modality
...
@@ -235,7 +218,7 @@ class MultiModalRegistry:
...
@@ -235,7 +218,7 @@ class MultiModalRegistry:
if
not
model_config
.
is_multimodal_model
:
if
not
model_config
.
is_multimodal_model
:
return
{}
return
{}
processor
=
self
.
create_processor
(
model_config
,
disable_
cache
=
Fals
e
)
processor
=
self
.
create_processor
(
model_config
,
cache
=
cach
e
)
profiler
=
MultiModalProfiler
(
processor
)
profiler
=
MultiModalProfiler
(
processor
)
return
profiler
.
get_mm_limits
()
return
profiler
.
get_mm_limits
()
...
@@ -303,7 +286,7 @@ class MultiModalRegistry:
...
@@ -303,7 +286,7 @@ class MultiModalRegistry:
model_config
:
"ModelConfig"
,
model_config
:
"ModelConfig"
,
*
,
*
,
tokenizer
:
Optional
[
AnyTokenizer
]
=
None
,
tokenizer
:
Optional
[
AnyTokenizer
]
=
None
,
disable_
cache
:
Optional
[
bool
]
=
None
,
cache
:
Optional
[
BaseMultiModalProcessorCache
]
=
None
,
)
->
BaseMultiModalProcessor
[
BaseProcessingInfo
]:
)
->
BaseMultiModalProcessor
[
BaseProcessingInfo
]:
"""
"""
Create a multi-modal processor for a specific model and tokenizer.
Create a multi-modal processor for a specific model and tokenizer.
...
@@ -311,15 +294,10 @@ class MultiModalRegistry:
...
@@ -311,15 +294,10 @@ class MultiModalRegistry:
if
not
model_config
.
is_multimodal_model
:
if
not
model_config
.
is_multimodal_model
:
raise
ValueError
(
f
"
{
model_config
.
model
}
is not a multimodal model"
)
raise
ValueError
(
f
"
{
model_config
.
model
}
is not a multimodal model"
)
if
disable_cache
is
None
:
disable_cache
=
not
model_config
.
enable_mm_processor_cache
model_cls
=
self
.
_get_model_cls
(
model_config
)
model_cls
=
self
.
_get_model_cls
(
model_config
)
factories
=
self
.
_processor_factories
[
model_cls
]
factories
=
self
.
_processor_factories
[
model_cls
]
ctx
=
self
.
_create_processing_ctx
(
model_config
,
tokenizer
)
ctx
=
self
.
_create_processing_ctx
(
model_config
,
tokenizer
)
cache
=
None
if
disable_cache
else
self
.
_get_processor_cache
(
model_config
)
return
factories
.
build_processor
(
ctx
,
cache
=
cache
)
return
factories
.
build_processor
(
ctx
,
cache
=
cache
)
...
@@ -328,13 +306,15 @@ class MultiModalRegistry:
...
@@ -328,13 +306,15 @@ class MultiModalRegistry:
model_config
:
"ModelConfig"
,
model_config
:
"ModelConfig"
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
mm_counts
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
*
,
cache
:
Optional
[
BaseMultiModalProcessorCache
]
=
None
,
)
->
DummyDecoderData
:
)
->
DummyDecoderData
:
"""
"""
Create dummy data for profiling the memory usage of a model.
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
The model is identified by ``model_config``.
"""
"""
processor
=
self
.
create_processor
(
model_config
,
disable_
cache
=
Fals
e
)
processor
=
self
.
create_processor
(
model_config
,
cache
=
cach
e
)
profiler
=
MultiModalProfiler
(
processor
)
profiler
=
MultiModalProfiler
(
processor
)
dummy_data
=
profiler
.
get_decoder_dummy_data
(
seq_len
,
mm_counts
)
dummy_data
=
profiler
.
get_decoder_dummy_data
(
seq_len
,
mm_counts
)
...
@@ -352,13 +332,15 @@ class MultiModalRegistry:
...
@@ -352,13 +332,15 @@ class MultiModalRegistry:
model_config
:
"ModelConfig"
,
model_config
:
"ModelConfig"
,
seq_len
:
int
,
seq_len
:
int
,
mm_counts
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
mm_counts
:
Optional
[
Mapping
[
str
,
int
]]
=
None
,
*
,
cache
:
Optional
[
BaseMultiModalProcessorCache
]
=
None
,
)
->
DummyEncoderData
:
)
->
DummyEncoderData
:
"""
"""
Create dummy data for profiling the memory usage of a model.
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
The model is identified by ``model_config``.
"""
"""
processor
=
self
.
create_processor
(
model_config
,
disable_
cache
=
Fals
e
)
processor
=
self
.
create_processor
(
model_config
,
cache
=
cach
e
)
profiler
=
MultiModalProfiler
(
processor
)
profiler
=
MultiModalProfiler
(
processor
)
dummy_data
=
profiler
.
get_encoder_dummy_data
(
seq_len
,
mm_counts
)
dummy_data
=
profiler
.
get_encoder_dummy_data
(
seq_len
,
mm_counts
)
...
...
vllm/v1/engine/async_llm.py
View file @
69244e67
...
@@ -597,8 +597,7 @@ class AsyncLLM(EngineClient):
...
@@ -597,8 +597,7 @@ class AsyncLLM(EngineClient):
await
asyncio
.
gather
(
*
coros
)
await
asyncio
.
gather
(
*
coros
)
async
def
reset_mm_cache
(
self
)
->
None
:
async
def
reset_mm_cache
(
self
)
->
None
:
self
.
processor
.
mm_registry
.
reset_processor_cache
(
self
.
model_config
)
self
.
processor
.
clear_cache
()
self
.
processor
.
mm_input_cache_client
.
reset
()
await
self
.
engine_core
.
reset_mm_cache_async
()
await
self
.
engine_core
.
reset_mm_cache_async
()
async
def
reset_prefix_cache
(
self
,
async
def
reset_prefix_cache
(
self
,
...
...
vllm/v1/engine/core.py
View file @
69244e67
...
@@ -22,6 +22,7 @@ from vllm.logger import init_logger
...
@@ -22,6 +22,7 @@ from vllm.logger import init_logger
from
vllm.logging_utils.dump_input
import
dump_engine_exception
from
vllm.logging_utils.dump_input
import
dump_engine_exception
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.cache
import
receiver_cache_from_config
from
vllm.tasks
import
POOLING_TASKS
,
SupportedTask
from
vllm.tasks
import
POOLING_TASKS
,
SupportedTask
from
vllm.transformers_utils.config
import
(
from
vllm.transformers_utils.config
import
(
maybe_register_config_serialize_by_value
)
maybe_register_config_serialize_by_value
)
...
@@ -38,7 +39,6 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
...
@@ -38,7 +39,6 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType
,
EngineCoreRequestType
,
ReconfigureDistributedRequest
,
ReconfigureRankType
,
ReconfigureDistributedRequest
,
ReconfigureRankType
,
UtilityOutput
,
UtilityResult
)
UtilityOutput
,
UtilityResult
)
from
vllm.v1.engine.mm_input_cache
import
MultiModalInputCacheServer
from
vllm.v1.engine.utils
import
EngineHandshakeMetadata
,
EngineZmqAddresses
from
vllm.v1.engine.utils
import
EngineHandshakeMetadata
,
EngineZmqAddresses
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
...
@@ -128,8 +128,9 @@ class EngineCore:
...
@@ -128,8 +128,9 @@ class EngineCore:
)
)
self
.
use_spec_decode
=
vllm_config
.
speculative_config
is
not
None
self
.
use_spec_decode
=
vllm_config
.
speculative_config
is
not
None
self
.
mm_input_cache_server
=
MultiModalInputCacheServer
(
self
.
mm_registry
=
mm_registry
=
MULTIMODAL_REGISTRY
vllm_config
.
model_config
,
MULTIMODAL_REGISTRY
)
self
.
mm_receiver_cache
=
receiver_cache_from_config
(
vllm_config
,
mm_registry
)
# Setup batch queue for pipeline parallelism.
# Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously
# Batch queue for scheduled batches. This enables us to asynchronously
...
@@ -370,7 +371,8 @@ class EngineCore:
...
@@ -370,7 +371,8 @@ class EngineCore:
logger
.
warning
(
"Resetting the multi-modal cache when requests are "
logger
.
warning
(
"Resetting the multi-modal cache when requests are "
"in progress may lead to desynced internal caches."
)
"in progress may lead to desynced internal caches."
)
self
.
mm_input_cache_server
.
reset
()
if
self
.
mm_receiver_cache
is
not
None
:
self
.
mm_receiver_cache
.
clear_cache
()
def
reset_prefix_cache
(
self
):
def
reset_prefix_cache
(
self
):
self
.
scheduler
.
reset_prefix_cache
()
self
.
scheduler
.
reset_prefix_cache
()
...
@@ -435,10 +437,11 @@ class EngineCore:
...
@@ -435,10 +437,11 @@ class EngineCore:
assert
request
.
mm_kwargs
is
not
None
assert
request
.
mm_kwargs
is
not
None
# Note on thread safety: no race condition.
# Note on thread safety: no race condition.
# `mm_
input_cache_server
` is reset at the end of LLMEngine init,
# `mm_
receiver_cache
` is reset at the end of LLMEngine init,
# and will only accessed in the input processing thread afterwards.
# and will only accessed in the input processing thread afterwards.
request
.
mm_kwargs
=
self
.
mm_input_cache_server
.
get_and_update
(
if
self
.
mm_receiver_cache
is
not
None
:
request
.
mm_kwargs
,
request
.
mm_hashes
)
request
.
mm_kwargs
=
self
.
mm_receiver_cache
.
get_and_update
(
request
.
mm_kwargs
,
request
.
mm_hashes
)
req
=
Request
.
from_engine_core_request
(
request
,
req
=
Request
.
from_engine_core_request
(
request
,
self
.
request_block_hasher
)
self
.
request_block_hasher
)
...
...
vllm/v1/engine/llm_engine.py
View file @
69244e67
...
@@ -271,8 +271,7 @@ class LLMEngine:
...
@@ -271,8 +271,7 @@ class LLMEngine:
self
.
engine_core
.
profile
(
False
)
self
.
engine_core
.
profile
(
False
)
def
reset_mm_cache
(
self
):
def
reset_mm_cache
(
self
):
self
.
processor
.
mm_registry
.
reset_processor_cache
(
self
.
model_config
)
self
.
processor
.
clear_cache
()
self
.
processor
.
mm_input_cache_client
.
reset
()
self
.
engine_core
.
reset_mm_cache
()
self
.
engine_core
.
reset_mm_cache
()
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
):
def
reset_prefix_cache
(
self
,
device
:
Optional
[
Device
]
=
None
):
...
...
vllm/v1/engine/mm_input_cache.py
deleted
100644 → 0
View file @
8dbf6ed7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
from
typing
import
TYPE_CHECKING
,
Optional
from
vllm.multimodal
import
MultiModalRegistry
from
vllm.multimodal.cache
import
MultiModalCache
,
MultiModalCacheItemMetadata
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
from
vllm.utils
import
is_list_of
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
# The idea of multimodal input caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
#
# -- P0:
# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
# each input multi-modal item (e.g. image),
# - BaseMultiModalProcessor processes the input items into `mm_kwargs`,
# which are MultiModalKwargsItem instances that each correspond to an
# input multi-modal item.
# - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding
# `mm_hash` for each item. It stores the `mm_hash` as keys and the size
# of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking
# up additional memory in P0.
# - The `mm_hash` is always sent to P1.
# - The corresponding `mm_kwargs` are only sent to P1 if they are not cached
# in MultiModalInputCacheServer.
#
# -- P1:
# - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0),
# MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`.
# - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0),
# MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`.
# - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to
# the engine for model execution.
#
# Both Client and Server must perform cache update and eviction based on the
# same item size. This ensures that the keys of MultiModalInputCacheClient
# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0
# whether a key is cached in MultiModalInputCacheServer by querying
# MultiModalInputCacheClient without having to communicate with P1.
class
MultiModalInputCacheClient
:
"""Used by P0 to check whether multi-modal kwargs are cached in P1."""
def
__init__
(
self
,
model_config
:
"ModelConfig"
,
mm_registry
:
MultiModalRegistry
)
->
None
:
super
().
__init__
()
self
.
enabled
=
mm_registry
.
enable_mm_input_cache
(
model_config
)
self
.
mm_cache
=
MultiModalCache
.
get_lru_cache
(
model_config
.
get_mm_input_cache_gb
(),
MultiModalCacheItemMetadata
,
)
def
get_and_update
(
self
,
mm_kwargs
:
Sequence
[
MultiModalKwargsItem
],
mm_hashes
:
list
[
str
],
)
->
list
[
Optional
[
MultiModalKwargsItem
]]:
if
not
self
.
enabled
:
return
list
(
mm_kwargs
)
assert
len
(
mm_kwargs
)
==
len
(
mm_hashes
)
out_mm_items
=
list
[
Optional
[
MultiModalKwargsItem
]]()
for
mm_item
,
mm_hash
in
zip
(
mm_kwargs
,
mm_hashes
):
if
self
.
mm_cache
.
get
(
mm_hash
)
is
not
None
:
out_mm_items
.
append
(
None
)
else
:
self
.
mm_cache
[
mm_hash
]
=
\
MultiModalCacheItemMetadata
.
wraps
(
mm_item
)
out_mm_items
.
append
(
mm_item
)
return
out_mm_items
def
reset
(
self
)
->
None
:
self
.
mm_cache
.
clear
()
class
MultiModalInputCacheServer
:
"""Used by P1 to avoid requiring past multi-modal kwargs from P0."""
def
__init__
(
self
,
model_config
:
"ModelConfig"
,
mm_registry
:
MultiModalRegistry
)
->
None
:
super
().
__init__
()
self
.
enabled
=
mm_registry
.
enable_mm_input_cache
(
model_config
)
self
.
mm_cache
=
MultiModalCache
.
get_lru_cache
(
model_config
.
get_mm_input_cache_gb
(),
MultiModalKwargsItem
,
)
def
get_and_update
(
self
,
mm_kwargs
:
Sequence
[
Optional
[
MultiModalKwargsItem
]],
mm_hashes
:
list
[
str
],
)
->
list
[
MultiModalKwargsItem
]:
if
not
self
.
enabled
:
mm_kwargs_lst
=
list
(
mm_kwargs
)
assert
is_list_of
(
mm_kwargs_lst
,
MultiModalKwargsItem
)
return
mm_kwargs_lst
assert
len
(
mm_kwargs
)
==
len
(
mm_hashes
)
out_mm_items
=
list
[
MultiModalKwargsItem
]()
for
mm_item
,
mm_hash
in
zip
(
mm_kwargs
,
mm_hashes
):
if
mm_item
is
None
:
out_mm_items
.
append
(
self
.
mm_cache
[
mm_hash
])
else
:
self
.
mm_cache
[
mm_hash
]
=
mm_item
out_mm_items
.
append
(
mm_item
)
return
out_mm_items
def
reset
(
self
)
->
None
:
self
.
mm_cache
.
clear
()
vllm/v1/engine/processor.py
View file @
69244e67
...
@@ -11,6 +11,7 @@ from vllm.inputs.parse import split_enc_dec_inputs
...
@@ -11,6 +11,7 @@ from vllm.inputs.parse import split_enc_dec_inputs
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal.cache
import
processor_cache_from_config
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
,
PlaceholderRange
from
vllm.multimodal.processing
import
EncDecMultiModalProcessor
from
vllm.multimodal.processing
import
EncDecMultiModalProcessor
from
vllm.multimodal.utils
import
argsort_mm_positions
from
vllm.multimodal.utils
import
argsort_mm_positions
...
@@ -18,7 +19,6 @@ from vllm.pooling_params import PoolingParams
...
@@ -18,7 +19,6 @@ from vllm.pooling_params import PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.mm_input_cache
import
MultiModalInputCacheClient
from
vllm.v1.structured_output.backend_guidance
import
(
from
vllm.v1.structured_output.backend_guidance
import
(
validate_guidance_grammar
)
validate_guidance_grammar
)
from
vllm.v1.structured_output.backend_lm_format_enforcer
import
(
from
vllm.v1.structured_output.backend_lm_format_enforcer
import
(
...
@@ -47,16 +47,17 @@ class Processor:
...
@@ -47,16 +47,17 @@ class Processor:
self
.
generation_config_fields
=
(
self
.
generation_config_fields
=
(
self
.
model_config
.
try_get_generation_config
())
self
.
model_config
.
try_get_generation_config
())
self
.
input_preprocessor
=
InputPreprocessor
(
self
.
model_config
,
self
.
tokenizer
,
mm_registry
)
self
.
mm_input_cache_client
=
MultiModalInputCacheClient
(
self
.
mm_registry
=
mm_registry
self
.
model_config
,
mm_registry
)
self
.
mm_processor_cache
=
processor_cache_from_config
(
vllm_config
,
mm_registry
)
@
property
self
.
input_preprocessor
=
InputPreprocessor
(
def
mm_registry
(
self
):
self
.
model_config
,
return
self
.
input_preprocessor
.
mm_registry
self
.
tokenizer
,
mm_registry
,
mm_processor_cache
=
self
.
mm_processor_cache
,
)
def
_validate_logprobs
(
def
_validate_logprobs
(
self
,
self
,
...
@@ -310,7 +311,7 @@ class Processor:
...
@@ -310,7 +311,7 @@ class Processor:
# in the input sequence.
# in the input sequence.
sorted_mm_idxs
=
argsort_mm_positions
(
decoder_mm_positions
)
sorted_mm_idxs
=
argsort_mm_positions
(
decoder_mm_positions
)
orig_
sorted_mm_inputs
=
[
sorted_mm_inputs
=
[
decoder_mm_inputs
[
modality
][
idx
]
decoder_mm_inputs
[
modality
][
idx
]
for
modality
,
idx
in
sorted_mm_idxs
for
modality
,
idx
in
sorted_mm_idxs
]
]
...
@@ -323,11 +324,6 @@ class Processor:
...
@@ -323,11 +324,6 @@ class Processor:
for
modality
,
idx
in
sorted_mm_idxs
for
modality
,
idx
in
sorted_mm_idxs
]
]
sorted_mm_inputs
=
self
.
mm_input_cache_client
.
get_and_update
(
orig_sorted_mm_inputs
,
sorted_mm_hashes
,
)
return
decoder_inputs
.
get
(
"prompt"
),
EngineCoreRequest
(
return
decoder_inputs
.
get
(
"prompt"
),
EngineCoreRequest
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt_token_ids
=
decoder_inputs
[
"prompt_token_ids"
],
prompt_token_ids
=
decoder_inputs
[
"prompt_token_ids"
],
...
@@ -415,3 +411,6 @@ class Processor:
...
@@ -415,3 +411,6 @@ class Processor:
# TODO: Find out how many placeholder tokens are there so we can
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def
clear_cache
(
self
)
->
None
:
self
.
input_preprocessor
.
clear_cache
()
vllm/v1/worker/gpu_model_runner.py
View file @
69244e67
...
@@ -2186,10 +2186,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2186,10 +2186,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_items_per_batch
:
int
,
max_items_per_batch
:
int
,
)
->
BatchedTensorInputs
:
)
->
BatchedTensorInputs
:
"""Dummy data for profiling and precompiling multimodal models."""
"""Dummy data for profiling and precompiling multimodal models."""
assert
self
.
mm_budget
is
not
None
dummy_decoder_data
=
self
.
mm_registry
.
get_decoder_dummy_data
(
dummy_decoder_data
=
self
.
mm_registry
.
get_decoder_dummy_data
(
model_config
=
self
.
model_config
,
model_config
=
self
.
model_config
,
seq_len
=
self
.
max_num_tokens
,
seq_len
=
self
.
max_num_tokens
,
mm_counts
=
{
modality
:
1
},
mm_counts
=
{
modality
:
1
},
cache
=
self
.
mm_budget
.
cache
,
)
)
dummy_mm_data
=
dummy_decoder_data
.
multi_modal_data
dummy_mm_data
=
dummy_decoder_data
.
multi_modal_data
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
69244e67
...
@@ -1813,10 +1813,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -1813,10 +1813,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_items_per_batch
:
int
,
max_items_per_batch
:
int
,
)
->
BatchedTensorInputs
:
)
->
BatchedTensorInputs
:
"""Dummy data for profiling and precompiling multimodal models."""
"""Dummy data for profiling and precompiling multimodal models."""
assert
self
.
mm_budget
is
not
None
dummy_decoder_data
=
self
.
mm_registry
.
get_decoder_dummy_data
(
dummy_decoder_data
=
self
.
mm_registry
.
get_decoder_dummy_data
(
model_config
=
self
.
model_config
,
model_config
=
self
.
model_config
,
seq_len
=
self
.
max_num_tokens
,
seq_len
=
self
.
max_num_tokens
,
mm_counts
=
{
modality
:
1
},
mm_counts
=
{
modality
:
1
},
cache
=
self
.
mm_budget
.
cache
,
)
)
dummy_mm_data
=
dummy_decoder_data
.
multi_modal_data
dummy_mm_data
=
dummy_decoder_data
.
multi_modal_data
...
...
vllm/v1/worker/utils.py
View file @
69244e67
...
@@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionBackend
...
@@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionBackend
from
vllm.config
import
ModelConfig
,
SchedulerConfig
from
vllm.config
import
ModelConfig
,
SchedulerConfig
from
vllm.model_executor.models.interfaces
import
MultiModalEmbeddings
from
vllm.model_executor.models.interfaces
import
MultiModalEmbeddings
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.model_executor.models.utils
import
extract_layer_index
from
vllm.multimodal.cache
import
processor_only_cache_from_config
from
vllm.multimodal.registry
import
MultiModalRegistry
from
vllm.multimodal.registry
import
MultiModalRegistry
from
vllm.v1.attention.backends.utils
import
AttentionMetadataBuilder
from
vllm.v1.attention.backends.utils
import
AttentionMetadataBuilder
from
vllm.v1.core.encoder_cache_manager
import
compute_mm_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_mm_encoder_budget
...
@@ -33,14 +34,18 @@ class MultiModalBudget:
...
@@ -33,14 +34,18 @@ class MultiModalBudget:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
mm_registry
=
mm_registry
self
.
mm_registry
=
mm_registry
self
.
cache
=
cache
=
processor_only_cache_from_config
(
model_config
,
mm_registry
)
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_model_len
=
model_config
.
max_model_len
self
.
max_num_reqs
=
scheduler_config
.
max_num_seqs
self
.
max_num_reqs
=
scheduler_config
.
max_num_seqs
self
.
mm_limits
=
mm_registry
.
get_mm_limits_per_prompt
(
model_config
)
self
.
mm_limits
=
mm_registry
.
get_mm_limits_per_prompt
(
model_config
,
cache
=
cache
)
max_tokens_by_modality
=
mm_registry
\
max_tokens_by_modality
=
mm_registry
\
.
get_max_tokens_per_item_by_nonzero_modality
(
model_config
)
.
get_max_tokens_per_item_by_nonzero_modality
(
model_config
,
cache
=
cache
)
encoder_compute_budget
,
encoder_cache_size
=
compute_mm_encoder_budget
(
encoder_compute_budget
,
encoder_cache_size
=
compute_mm_encoder_budget
(
scheduler_config
,
scheduler_config
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment