Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fdcf64d3
"docs/vscode:/vscode.git/clone" did not exist on "8bddb735123204872788a8ffe117321de7550e6c"
Unverified
Commit
fdcf64d3
authored
Feb 13, 2025
by
Roger Wang
Committed by
GitHub
Feb 13, 2025
Browse files
[V1] Clarify input processing and multimodal feature caching logic (#13211)
parent
578087e5
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
45 additions
and
27 deletions
+45
-27
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+8
-8
vllm/v1/engine/mm_input_cache.py
vllm/v1/engine/mm_input_cache.py
+19
-10
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+14
-6
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+4
-3
No files found.
vllm/v1/engine/core.py
View file @
fdcf64d3
...
@@ -20,7 +20,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
...
@@ -20,7 +20,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.core.scheduler
import
Scheduler
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
from
vllm.v1.engine
import
(
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequestType
)
EngineCoreRequestType
)
from
vllm.v1.engine.mm_input_
mapper
import
MMInput
Mapper
Server
from
vllm.v1.engine.mm_input_
cache
import
MMInput
Cache
Server
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.executor.abstract
import
Executor
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
...
@@ -65,7 +65,7 @@ class EngineCore:
...
@@ -65,7 +65,7 @@ class EngineCore:
log_stats
=
self
.
log_stats
,
log_stats
=
self
.
log_stats
,
)
)
self
.
mm_input_
mapper
_server
=
MMInput
Mapper
Server
(
self
.
mm_input_
cache
_server
=
MMInput
Cache
Server
(
vllm_config
.
model_config
)
vllm_config
.
model_config
)
def
_initialize_kv_caches
(
self
,
def
_initialize_kv_caches
(
self
,
...
@@ -102,13 +102,13 @@ class EngineCore:
...
@@ -102,13 +102,13 @@ class EngineCore:
"""Add request to the scheduler."""
"""Add request to the scheduler."""
if
request
.
mm_hashes
is
not
None
:
if
request
.
mm_hashes
is
not
None
:
# Here, if hash exists for a
n image
, then it will be
fetched
# Here, if hash exists for a
multimodal input
, then it will be
# from the cache, else it will be added to the cache.
#
fetched
from the cache, else it will be added to the cache.
# Note that the cache here is mirrored with the client
side of the
# Note that the cache here is mirrored with the client
cache, so
#
MM mapper, so
anything that has a hash must have a HIT cache
# anything that has a hash must have a HIT cache
entry here
#
entry here
as well.
# as well.
assert
request
.
mm_inputs
is
not
None
assert
request
.
mm_inputs
is
not
None
request
.
mm_inputs
=
self
.
mm_input_
mapper
_server
.
process_inputs
(
request
.
mm_inputs
=
self
.
mm_input_
cache
_server
.
get_and_update
(
request
.
mm_inputs
,
request
.
mm_hashes
)
request
.
mm_inputs
,
request
.
mm_hashes
)
req
=
Request
.
from_engine_core_request
(
request
)
req
=
Request
.
from_engine_core_request
(
request
)
...
...
vllm/v1/engine/mm_input_
mapper
.py
→
vllm/v1/engine/mm_input_
cache
.py
View file @
fdcf64d3
...
@@ -10,12 +10,18 @@ from vllm.utils import LRUCache
...
@@ -10,12 +10,18 @@ from vllm.utils import LRUCache
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# The idea of
MM
preprocess
or
caching is based on having a client and
a server,
# The idea of
multimodal
preprocess
ing
caching is based on having a client and
# where the client executes in the frontend process (=P0) and the
server in the
#
a server,
where the client executes in the frontend process (=P0) and the
# core process (=P1).
#
server in the
core process (=P1).
#
#
# -- Client: Executes the MM mapper and performs caching of the results.
# -- Client:
# -- Server: Performs caching of the results
# - Apply legacy input_mapper (if one exists) to generate MultiModalKwargs.
# - Perform caching of the generated MultiModalKwargs.
# - This client can be deprecated once all mutimodal models migrate to use
# merged preprocessor with built-in caching functionality.
#
# -- Server:
# - Perform caching of the received MultiModalKwargs.
#
#
# The caching for both client and server is mirrored/similar, and this allows us
# The caching for both client and server is mirrored/similar, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between
# to avoid the serialization of "mm_inputs" (like pixel values) between
...
@@ -27,7 +33,9 @@ logger = init_logger(__name__)
...
@@ -27,7 +33,9 @@ logger = init_logger(__name__)
MM_CACHE_SIZE
=
256
MM_CACHE_SIZE
=
256
class
MMInputMapperClient
:
# TODO(ywang96): Deprecate this class once all multimodal models migrate to use
# merged preprocessor with built-in caching functionality.
class
MMInputCacheClient
:
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -54,7 +62,8 @@ class MMInputMapperClient:
...
@@ -54,7 +62,8 @@ class MMInputMapperClient:
logger
.
debug
(
"MMInputMapper: cache_hit_ratio = %.2f "
,
logger
.
debug
(
"MMInputMapper: cache_hit_ratio = %.2f "
,
self
.
mm_cache_hits
/
self
.
mm_cache_total
)
self
.
mm_cache_hits
/
self
.
mm_cache_total
)
# TODO: Support modalities beyond image.
# NOTE: process_inputs only supports image inputs since all multimodal
# models with other modalities have migrated to use merged preprocessor.
def
process_inputs
(
def
process_inputs
(
self
,
self
,
mm_data
:
MultiModalDataDict
,
mm_data
:
MultiModalDataDict
,
...
@@ -95,7 +104,7 @@ class MMInputMapperClient:
...
@@ -95,7 +104,7 @@ class MMInputMapperClient:
# Reuse precomputed input (for merged preprocessor)
# Reuse precomputed input (for merged preprocessor)
mm_input
=
precomputed_mm_inputs
[
input_id
]
mm_input
=
precomputed_mm_inputs
[
input_id
]
else
:
else
:
# Apply
MM
mapper
# Apply
legacy input_
mapper
mm_input
=
self
.
multi_modal_input_mapper
(
mm_input
=
self
.
multi_modal_input_mapper
(
{
"image"
:
[
image_inputs
[
input_id
]]},
{
"image"
:
[
image_inputs
[
input_id
]]},
mm_processor_kwargs
=
mm_processor_kwargs
,
mm_processor_kwargs
=
mm_processor_kwargs
,
...
@@ -114,13 +123,13 @@ class MMInputMapperClient:
...
@@ -114,13 +123,13 @@ class MMInputMapperClient:
return
ret_inputs
return
ret_inputs
class
MMInput
Mapper
Server
:
class
MMInput
Cache
Server
:
def
__init__
(
self
,
model_config
):
def
__init__
(
self
,
model_config
):
self
.
use_cache
=
not
model_config
.
disable_mm_preprocessor_cache
self
.
use_cache
=
not
model_config
.
disable_mm_preprocessor_cache
self
.
mm_cache
=
LRUCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
self
.
mm_cache
=
LRUCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
def
process_inputs
(
def
get_and_update
(
self
,
self
,
mm_inputs
:
List
[
Optional
[
MultiModalKwargs
]],
mm_inputs
:
List
[
Optional
[
MultiModalKwargs
]],
mm_hashes
:
List
[
str
],
mm_hashes
:
List
[
str
],
...
...
vllm/v1/engine/processor.py
View file @
fdcf64d3
...
@@ -17,7 +17,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
...
@@ -17,7 +17,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.mm_input_
mapper
import
MMInput
Mapper
Client
from
vllm.v1.engine.mm_input_
cache
import
MMInput
Cache
Client
class
Processor
:
class
Processor
:
...
@@ -46,7 +46,7 @@ class Processor:
...
@@ -46,7 +46,7 @@ class Processor:
model_config
)
model_config
)
# Multi-modal (huggingface) input mapper
# Multi-modal (huggingface) input mapper
self
.
mm_input_
mapper
_client
=
MMInput
Mapper
Client
(
model_config
)
self
.
mm_input_
cache
_client
=
MMInput
Cache
Client
(
model_config
)
# Multi-modal hasher (for images)
# Multi-modal hasher (for images)
self
.
use_hash
=
(
not
model_config
.
disable_mm_preprocessor_cache
)
or
\
self
.
use_hash
=
(
not
model_config
.
disable_mm_preprocessor_cache
)
or
\
...
@@ -106,16 +106,24 @@ class Processor:
...
@@ -106,16 +106,24 @@ class Processor:
assert
priority
==
0
,
"vLLM V1 does not support priority at the moment."
assert
priority
==
0
,
"vLLM V1 does not support priority at the moment."
assert
trace_headers
is
None
,
"vLLM V1 does not support tracing yet."
assert
trace_headers
is
None
,
"vLLM V1 does not support tracing yet."
# Process inputs.
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
# 3. Apply prompt adapter to prompt token ids if one exists.
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
prompt
,
prompt
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
lora_request
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
)
eos_token_id
=
self
.
input_preprocessor
.
get_eos_token_id
(
lora_request
)
# Process prompt and prompt token ids.
# Only applicable to multimodal models with legacy input processor.
processed_inputs
=
self
.
input_processor
(
preprocessed_inputs
)
processed_inputs
=
self
.
input_processor
(
preprocessed_inputs
)
self
.
_validate_model_inputs
(
processed_inputs
)
self
.
_validate_model_inputs
(
processed_inputs
)
eos_token_id
=
self
.
input_preprocessor
.
get_eos_token_id
(
lora_request
)
if
is_encoder_decoder_inputs
(
processed_inputs
):
if
is_encoder_decoder_inputs
(
processed_inputs
):
decoder_inputs
=
SingletonInputsAdapter
(
decoder_inputs
=
SingletonInputsAdapter
(
...
@@ -200,8 +208,8 @@ class Processor:
...
@@ -200,8 +208,8 @@ class Processor:
key
=
lambda
mm_input
:
modality_order_dict
[
list
(
key
=
lambda
mm_input
:
modality_order_dict
[
list
(
mm_input
.
modalities
)[
0
]])
mm_input
.
modalities
)[
0
]])
# Apply mm input cache update
(
and input mapper if ne
cessary)
.
# Apply mm input cache update and
legacy
input mapper if
o
ne
exists
.
sorted_mm_inputs
=
self
.
mm_input_
mapper
_client
.
process_inputs
(
sorted_mm_inputs
=
self
.
mm_input_
cache
_client
.
process_inputs
(
mm_data
=
decoder_mm_data
,
mm_data
=
decoder_mm_data
,
mm_hashes
=
sorted_mm_hashes
,
mm_hashes
=
sorted_mm_hashes
,
mm_processor_kwargs
=
decoder_inputs
.
mm_processor_kwargs
,
mm_processor_kwargs
=
decoder_inputs
.
mm_processor_kwargs
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
fdcf64d3
...
@@ -27,7 +27,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
...
@@ -27,7 +27,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionBackend
,
from
vllm.v1.attention.backends.flash_attn
import
(
FlashAttentionBackend
,
FlashAttentionMetadata
)
FlashAttentionMetadata
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.engine.mm_input_
mapper
import
MMInput
Mapper
Client
from
vllm.v1.engine.mm_input_
cache
import
MMInput
Cache
Client
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
KVCacheSpec
)
from
vllm.v1.outputs
import
LogprobsTensors
,
ModelRunnerOutput
from
vllm.v1.outputs
import
LogprobsTensors
,
ModelRunnerOutput
...
@@ -95,9 +95,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -95,9 +95,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
mm_registry
=
MULTIMODAL_REGISTRY
self
.
uses_mrope
=
model_config
.
uses_mrope
self
.
uses_mrope
=
model_config
.
uses_mrope
# NOTE: Initialized
input mapper
is only used for processing dummy
# NOTE: Initialized
client
is only used for processing dummy
# multimodal data into multimodal kwargs for GPU memory profiling.
# multimodal data into multimodal kwargs for GPU memory profiling.
self
.
mm_input_mapper_profiling
=
MMInputMapperClient
(
self
.
model_config
)
# Only applicable to multimodal models with legacy input mapper.
self
.
mm_input_mapper_profiling
=
MMInputCacheClient
(
self
.
model_config
)
self
.
mm_input_mapper_profiling
.
use_cache
=
False
self
.
mm_input_mapper_profiling
.
use_cache
=
False
encoder_compute_budget
,
encoder_cache_size
=
compute_encoder_budget
(
encoder_compute_budget
,
encoder_cache_size
=
compute_encoder_budget
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment