Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4dff91c9
Unverified
Commit
4dff91c9
authored
Aug 16, 2025
by
Cyrus Leung
Committed by
GitHub
Aug 16, 2025
Browse files
[Refactor] Allow optional MultiModalKwargsItem in IPC (#23022)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
de9cb617
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
59 additions
and
108 deletions
+59
-108
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+2
-10
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+2
-10
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+2
-10
tests/v1/core/utils.py
tests/v1/core/utils.py
+2
-10
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+17
-45
vllm/v1/engine/__init__.py
vllm/v1/engine/__init__.py
+2
-1
vllm/v1/engine/mm_input_cache.py
vllm/v1/engine/mm_input_cache.py
+18
-15
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+7
-3
vllm/v1/request.py
vllm/v1/request.py
+5
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-2
No files found.
tests/v1/core/test_kv_cache_utils.py
View file @
4dff91c9
...
...
@@ -7,9 +7,7 @@ import pytest
import
torch
from
vllm.config
import
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.multimodal.inputs
import
(
MultiModalBatchedField
,
MultiModalFieldElem
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
GiB_bytes
,
sha256
,
sha256_cbor_64bit
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
...
...
@@ -42,13 +40,7 @@ def make_request(
if
mm_positions
is
None
:
mm_kwargs
=
None
else
:
mm_elem
=
MultiModalFieldElem
(
modality
=
"dummy_m"
,
key
=
"dummy_k"
,
data
=
None
,
field
=
MultiModalBatchedField
(),
)
mm_item
=
MultiModalKwargsItem
.
from_elems
([
mm_elem
])
mm_item
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
)
mm_kwargs
=
[
mm_item
]
*
len
(
mm_positions
)
return
Request
(
request_id
=
request_id
,
...
...
tests/v1/core/test_prefix_caching.py
View file @
4dff91c9
...
...
@@ -9,9 +9,7 @@ import pytest
import
torch
from
vllm.distributed.kv_events
import
AllBlocksCleared
,
BlockRemoved
from
vllm.multimodal.inputs
import
(
MultiModalBatchedField
,
MultiModalFieldElem
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
sha256
,
sha256_cbor_64bit
from
vllm.v1.core.block_pool
import
BlockPool
...
...
@@ -37,13 +35,7 @@ def make_request(
if
mm_positions
is
None
:
mm_kwargs
=
None
else
:
mm_elem
=
MultiModalFieldElem
(
modality
=
"dummy_m"
,
key
=
"dummy_k"
,
data
=
None
,
field
=
MultiModalBatchedField
(),
)
mm_item
=
MultiModalKwargsItem
.
from_elems
([
mm_elem
])
mm_item
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
)
mm_kwargs
=
[
mm_item
]
*
len
(
mm_positions
)
return
Request
(
request_id
=
request_id
,
...
...
tests/v1/core/test_scheduler.py
View file @
4dff91c9
...
...
@@ -8,9 +8,7 @@ import torch
from
vllm.config
import
(
CacheConfig
,
KVTransferConfig
,
ModelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VllmConfig
)
from
vllm.multimodal.inputs
import
(
MultiModalBatchedField
,
MultiModalFieldElem
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
,
PlaceholderRange
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
from
vllm.v1.core.sched.output
import
CachedRequestData
,
SchedulerOutput
from
vllm.v1.core.sched.scheduler
import
Scheduler
...
...
@@ -1328,13 +1326,7 @@ def create_requests_with_priority(
for
i
in
range
(
num_requests
):
if
mm_positions
is
not
None
:
mm_position
=
mm_positions
[
i
]
mm_elem
=
MultiModalFieldElem
(
modality
=
"dummy_m"
,
key
=
"dummy_k"
,
data
=
None
,
field
=
MultiModalBatchedField
(),
)
mm_item
=
MultiModalKwargsItem
.
from_elems
([
mm_elem
])
mm_item
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
)
mm_kwargs
=
[
mm_item
]
*
len
(
mm_position
)
else
:
mm_position
=
None
...
...
tests/v1/core/utils.py
View file @
4dff91c9
...
...
@@ -6,9 +6,7 @@ import torch
from
vllm.config
import
(
CacheConfig
,
KVTransferConfig
,
ModelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
VllmConfig
)
from
vllm.multimodal.inputs
import
(
MultiModalBatchedField
,
MultiModalFieldElem
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
,
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.kv_cache_utils
import
(
get_request_block_hasher
,
init_none_hash
)
...
...
@@ -143,13 +141,7 @@ def create_requests(
for
i
in
range
(
num_requests
):
if
mm_positions
is
not
None
:
mm_position
=
mm_positions
[
i
]
mm_elem
=
MultiModalFieldElem
(
modality
=
"dummy_m"
,
key
=
"dummy_k"
,
data
=
None
,
field
=
MultiModalBatchedField
(),
)
mm_item
=
MultiModalKwargsItem
.
from_elems
([
mm_elem
])
mm_item
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
)
mm_kwargs
=
[
mm_item
]
*
len
(
mm_position
)
mm_hashes
=
[
"hash"
]
*
len
(
mm_position
)
else
:
...
...
vllm/multimodal/inputs.py
View file @
4dff91c9
...
...
@@ -4,7 +4,7 @@
from
abc
import
ABC
,
abstractmethod
from
collections
import
UserDict
,
defaultdict
from
collections.abc
import
Mapping
,
Sequence
from
dataclasses
import
dataclass
,
replace
from
dataclasses
import
dataclass
from
functools
import
partial
from
itertools
import
accumulate
from
typing
import
(
TYPE_CHECKING
,
Any
,
Literal
,
Optional
,
TypedDict
,
TypeVar
,
...
...
@@ -218,7 +218,7 @@ class MultiModalFieldElem:
i.e. the name of the keyword argument to be passed to the model.
"""
data
:
Optional
[
NestedTensors
]
data
:
NestedTensors
"""
The tensor data of this field in
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
...
...
@@ -315,13 +315,8 @@ class BaseMultiModalField(ABC):
if
len
(
set
(
field_types
))
>
1
:
raise
ValueError
(
f
"Cannot merge different
{
field_types
=
}
"
)
validated_data
=
list
[
NestedTensors
]()
for
i
,
elem
in
enumerate
(
elems
):
assert
elem
.
data
is
not
None
,
(
f
"Cannot merge with empty `elems[
{
i
}
]`"
)
validated_data
.
append
(
elem
.
data
)
return
self
.
_reduce_data
(
validated_data
,
pin_memory
=
pin_memory
)
batch
=
[
elem
.
data
for
elem
in
elems
]
return
self
.
_reduce_data
(
batch
,
pin_memory
=
pin_memory
)
@
dataclass
(
frozen
=
True
)
...
...
@@ -643,6 +638,17 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
"""
@
staticmethod
def
dummy
(
modality
:
str
):
"""Convenience class for testing."""
mm_elem
=
MultiModalFieldElem
(
modality
=
modality
,
key
=
"dummy"
,
data
=
torch
.
empty
(
1
),
field
=
MultiModalSharedField
(
1
),
)
return
MultiModalKwargsItem
.
from_elems
([
mm_elem
])
@
staticmethod
def
from_elems
(
elems
:
Sequence
[
MultiModalFieldElem
]):
return
MultiModalKwargsItem
({
elem
.
key
:
elem
for
elem
in
elems
})
...
...
@@ -654,46 +660,12 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
assert
len
(
modalities
)
==
1
,
f
"Found different modalities=
{
modalities
}
"
self
.
_modality
=
next
(
iter
(
modalities
))
self
.
_is_empty
=
any
(
elem
.
data
is
None
for
elem
in
self
.
values
())
@
property
def
modality
(
self
)
->
str
:
return
self
.
_modality
@
property
def
is_empty
(
self
)
->
bool
:
return
self
.
_is_empty
def
get_data
(
self
)
->
Optional
[
Mapping
[
str
,
NestedTensors
]]:
if
self
.
_is_empty
:
return
None
out_data
=
dict
[
str
,
NestedTensors
]()
for
key
,
elem
in
self
.
items
():
assert
elem
.
data
is
not
None
,
(
f
"Cannot get data of empty `elem[
{
key
!
r
}
]`"
)
out_data
[
key
]
=
elem
.
data
return
out_data
def
require_data
(
self
)
->
Mapping
[
str
,
NestedTensors
]:
if
(
data
:
=
self
.
get_data
())
is
None
:
raise
RuntimeError
(
"Cannot get data of empty item"
)
return
data
# These methods create a new item to avoid mutating cached items in place
def
with_data
(
self
,
data
:
Mapping
[
str
,
NestedTensors
]):
return
MultiModalKwargsItem
({
key
:
replace
(
elem
,
data
=
data
[
key
])
for
key
,
elem
in
self
.
items
()
})
def
without_data
(
self
):
return
MultiModalKwargsItem
({
key
:
replace
(
elem
,
data
=
None
)
for
key
,
elem
in
self
.
items
()
})
def
get_data
(
self
)
->
Mapping
[
str
,
NestedTensors
]:
return
{
key
:
elem
.
data
for
key
,
elem
in
self
.
items
()}
# NOTE: UserDict is for V0 compatibility.
...
...
vllm/v1/engine/__init__.py
View file @
4dff91c9
...
...
@@ -3,6 +3,7 @@
import
enum
import
time
from
collections.abc
import
Sequence
from
typing
import
Any
,
Optional
,
Union
import
msgspec
...
...
@@ -47,7 +48,7 @@ class EngineCoreRequest(
request_id
:
str
prompt_token_ids
:
list
[
int
]
mm_kwargs
:
Optional
[
list
[
MultiModalKwargsItem
]]
mm_kwargs
:
Optional
[
Sequence
[
Optional
[
MultiModalKwargsItem
]]
]
mm_hashes
:
Optional
[
list
[
str
]]
mm_placeholders
:
Optional
[
list
[
PlaceholderRange
]]
sampling_params
:
Optional
[
SamplingParams
]
...
...
vllm/v1/engine/mm_input_cache.py
View file @
4dff91c9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Mapping
from
typing
import
TYPE_CHECKING
from
collections.abc
import
Sequence
from
typing
import
TYPE_CHECKING
,
Optional
from
vllm.multimodal
import
MultiModalRegistry
from
vllm.multimodal.cache
import
MultiModalCache
,
MultiModalCacheItemMetadata
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
,
NestedTensors
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
from
vllm.utils
import
is_list_of
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
...
...
@@ -58,21 +59,21 @@ class MultiModalInputCacheClient:
def
get_and_update
(
self
,
mm_kwargs
:
list
[
MultiModalKwargsItem
],
mm_kwargs
:
Sequence
[
MultiModalKwargsItem
],
mm_hashes
:
list
[
str
],
)
->
list
[
MultiModalKwargsItem
]:
)
->
list
[
Optional
[
MultiModalKwargsItem
]
]
:
if
not
self
.
enabled
:
return
mm_kwargs
return
list
(
mm_kwargs
)
assert
len
(
mm_kwargs
)
==
len
(
mm_hashes
)
out_mm_items
=
list
[
MultiModalKwargsItem
]()
out_mm_items
=
list
[
Optional
[
MultiModalKwargsItem
]
]
()
for
mm_item
,
mm_hash
in
zip
(
mm_kwargs
,
mm_hashes
):
if
self
.
mm_cache
.
get
(
mm_hash
)
is
not
None
:
out_mm_items
.
append
(
mm_item
.
without_data
()
)
out_mm_items
.
append
(
None
)
else
:
self
.
mm_cache
[
mm_hash
]
=
\
MultiModalCacheItemMetadata
.
wraps
(
mm_item
.
require_data
()
)
MultiModalCacheItemMetadata
.
wraps
(
mm_item
)
out_mm_items
.
append
(
mm_item
)
return
out_mm_items
...
...
@@ -91,25 +92,27 @@ class MultiModalInputCacheServer:
self
.
enabled
=
mm_registry
.
enable_mm_input_cache
(
model_config
)
self
.
mm_cache
=
MultiModalCache
.
get_lru_cache
(
model_config
.
get_mm_input_cache_gb
(),
M
apping
[
str
,
NestedTensors
]
,
M
ultiModalKwargsItem
,
)
def
get_and_update
(
self
,
mm_kwargs
:
list
[
MultiModalKwargsItem
],
mm_kwargs
:
Sequence
[
Optional
[
MultiModalKwargsItem
]
]
,
mm_hashes
:
list
[
str
],
)
->
list
[
MultiModalKwargsItem
]:
if
not
self
.
enabled
:
return
mm_kwargs
mm_kwargs_lst
=
list
(
mm_kwargs
)
assert
is_list_of
(
mm_kwargs_lst
,
MultiModalKwargsItem
)
return
mm_kwargs_lst
assert
len
(
mm_kwargs
)
==
len
(
mm_hashes
)
out_mm_items
=
list
[
MultiModalKwargsItem
]()
for
mm_item
,
mm_hash
in
zip
(
mm_kwargs
,
mm_hashes
):
if
(
mm_data
:
=
mm_item
.
get_data
())
is
None
:
out_mm_items
.
append
(
mm_item
.
with_data
(
self
.
mm_cache
[
mm_hash
])
)
if
mm_item
is
None
:
out_mm_items
.
append
(
self
.
mm_cache
[
mm_hash
])
else
:
self
.
mm_cache
[
mm_hash
]
=
mm_
data
self
.
mm_cache
[
mm_hash
]
=
mm_
item
out_mm_items
.
append
(
mm_item
)
return
out_mm_items
...
...
vllm/v1/engine/processor.py
View file @
4dff91c9
...
...
@@ -17,6 +17,7 @@ from vllm.multimodal.utils import argsort_mm_positions
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.utils
import
is_list_of
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.mm_input_cache
import
MultiModalInputCacheClient
from
vllm.v1.structured_output.backend_guidance
import
(
...
...
@@ -295,7 +296,7 @@ class Processor:
pooling_params
=
params
.
clone
()
# Multimodal related.
sorted_mm_inputs
:
Optional
[
list
[
MultiModalKwargsItem
]]
=
None
sorted_mm_inputs
:
Optional
[
list
[
Optional
[
MultiModalKwargsItem
]]
]
=
None
sorted_mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
sorted_mm_hashes
:
Optional
[
list
[
str
]]
=
None
if
decoder_inputs
[
"type"
]
==
"multimodal"
:
...
...
@@ -308,7 +309,7 @@ class Processor:
# in the input sequence.
sorted_mm_idxs
=
argsort_mm_positions
(
decoder_mm_positions
)
sorted_mm_inputs
=
[
orig_
sorted_mm_inputs
=
[
decoder_mm_inputs
.
get_item
(
modality
,
idx
)
for
modality
,
idx
in
sorted_mm_idxs
]
...
...
@@ -323,9 +324,12 @@ class Processor:
if
sorted_mm_hashes
is
not
None
:
sorted_mm_inputs
=
self
.
mm_input_cache_client
.
get_and_update
(
sorted_mm_inputs
,
orig_
sorted_mm_inputs
,
sorted_mm_hashes
,
)
else
:
assert
is_list_of
(
orig_sorted_mm_inputs
,
MultiModalKwargsItem
)
sorted_mm_inputs
=
orig_sorted_mm_inputs
return
decoder_inputs
.
get
(
"prompt"
),
EngineCoreRequest
(
request_id
=
request_id
,
...
...
vllm/v1/request.py
View file @
4dff91c9
...
...
@@ -125,14 +125,17 @@ class Request:
block_hasher
:
Optional
[
Callable
[[
"Request"
],
list
[
"BlockHash"
]]]
)
->
"Request"
:
if
request
.
mm_kwargs
is
not
None
:
assert
is_list_of
(
request
.
mm_kwargs
,
MultiModalKwargsItem
),
(
mm_kwargs_lst
=
list
(
request
.
mm_kwargs
)
assert
is_list_of
(
mm_kwargs_lst
,
MultiModalKwargsItem
),
(
"mm_kwargs was not updated in EngineCore.add_request"
)
else
:
mm_kwargs_lst
=
None
return
cls
(
request_id
=
request
.
request_id
,
client_index
=
request
.
client_index
,
prompt_token_ids
=
request
.
prompt_token_ids
,
multi_modal_kwargs
=
request
.
mm_kwargs
,
multi_modal_kwargs
=
mm_kwargs
_lst
,
multi_modal_hashes
=
request
.
mm_hashes
,
multi_modal_placeholders
=
request
.
mm_placeholders
,
sampling_params
=
request
.
sampling_params
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
4dff91c9
...
...
@@ -500,8 +500,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
second_per_grid_ts
=
[]
audio_feature_lengths
=
[]
use_audio_in_video
=
False
for
item
in
self
.
requests
[
req_id
].
mm_kwargs
:
mm_input
=
item
.
require
_data
()
for
mm_
item
in
self
.
requests
[
req_id
].
mm_kwargs
:
mm_input
=
mm_
item
.
get
_data
()
if
mm_input
.
get
(
"image_grid_thw"
)
is
not
None
:
image_grid_thw
.
append
(
mm_input
[
"image_grid_thw"
].
tolist
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment