Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c6e7404c
Unverified
Commit
c6e7404c
authored
Jan 29, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 29, 2026
Browse files
[Multimodal] Simplify MM input definitions (#33331)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
17b17c06
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
142 additions
and
164 deletions
+142
-164
tests/distributed/test_shm_storage.py
tests/distributed/test_shm_storage.py
+5
-7
tests/models/multimodal/processing/test_tensor_schema.py
tests/models/multimodal/processing/test_tensor_schema.py
+5
-1
tests/multimodal/test_cache.py
tests/multimodal/test_cache.py
+15
-21
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+1
-1
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+1
-1
tests/v1/core/test_priority_scheduler_random.py
tests/v1/core/test_priority_scheduler_random.py
+1
-1
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+1
-1
tests/v1/core/utils.py
tests/v1/core/utils.py
+1
-1
tests/v1/streaming_input/test_gpu_model_runner_streaming.py
tests/v1/streaming_input/test_gpu_model_runner_streaming.py
+2
-2
tests/v1/streaming_input/test_scheduler_streaming.py
tests/v1/streaming_input/test_scheduler_streaming.py
+2
-2
tests/v1/test_serial_utils.py
tests/v1/test_serial_utils.py
+9
-14
vllm/multimodal/cache.py
vllm/multimodal/cache.py
+14
-17
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+65
-72
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+4
-4
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+5
-7
vllm/v1/worker/gpu/mm/encoder_runner.py
vllm/v1/worker/gpu/mm/encoder_runner.py
+5
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+6
-8
No files found.
tests/distributed/test_shm_storage.py
View file @
c6e7404c
...
...
@@ -23,18 +23,16 @@ from vllm.multimodal.inputs import (
)
def
_dummy_elem
(
modality
:
str
,
key
:
str
,
size
:
int
):
def
_dummy_elem
(
size
:
int
):
return
MultiModalFieldElem
(
modality
=
modality
,
key
=
key
,
data
=
torch
.
empty
((
size
,),
dtype
=
torch
.
int8
),
field
=
MultiModalSharedField
(
batch_size
=
1
),
)
def
_dummy_item
(
modality
:
str
,
size_by_key
:
dict
[
str
,
int
]):
return
MultiModalKwargsItem
.
from_elems
(
[
_dummy_elem
(
modality
,
key
,
size
)
for
key
,
size
in
size_by_key
.
items
()
]
def
_dummy_item
(
size_by_key
:
dict
[
str
,
int
]):
return
MultiModalKwargsItem
(
{
key
:
_dummy_elem
(
size
)
for
key
,
size
in
size_by_key
.
items
()
}
)
...
...
@@ -61,7 +59,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
def
test_minimal_put_get_cycle
(
self
):
"""Test basic put and get operations."""
key
=
"test_key"
value
=
_dummy_item
(
"text"
,
{
"field1"
:
10
,
"field2"
:
20
})
value
=
_dummy_item
({
"field1"
:
10
,
"field2"
:
20
})
# Put operation
address
,
monotonic_id
=
self
.
storage
.
put
(
key
,
value
)
...
...
tests/models/multimodal/processing/test_tensor_schema.py
View file @
c6e7404c
...
...
@@ -119,7 +119,11 @@ def create_batched_mm_kwargs(
)[
"mm_kwargs"
].
require_data
()
return
group_mm_kwargs_by_modality
(
[
item
for
modality
in
supported_mm_limits
for
item
in
mm_kwargs
[
modality
]]
[
(
modality
,
item
)
for
modality
in
supported_mm_limits
for
item
in
mm_kwargs
[
modality
]
]
)
...
...
tests/multimodal/test_cache.py
View file @
c6e7404c
...
...
@@ -36,8 +36,6 @@ pytestmark = pytest.mark.cpu_test
def
_dummy_elem
(
modality
:
str
,
key
:
str
,
size
:
int
,
*
,
rng
:
np
.
random
.
RandomState
|
None
=
None
,
...
...
@@ -48,21 +46,18 @@ def _dummy_elem(
data
=
torch
.
from_numpy
(
rng
.
randint
(
4
,
size
=
(
size
,),
dtype
=
np
.
int8
))
return
MultiModalFieldElem
(
modality
=
modality
,
key
=
key
,
data
=
data
,
field
=
MultiModalSharedField
(
batch_size
=
1
),
)
def
_dummy_item
(
modality
:
str
,
size_by_key
:
dict
[
str
,
int
],
*
,
rng
:
np
.
random
.
RandomState
|
None
=
None
,
):
return
MultiModalKwargsItem
.
from_elems
(
[
_dummy_elem
(
modality
,
key
,
size
,
rng
=
rng
)
for
key
,
size
in
size_by_key
.
items
()
]
return
MultiModalKwargsItem
(
{
key
:
_dummy_elem
(
size
,
rng
=
rng
)
for
key
,
size
in
size_by_key
.
items
()
}
)
...
...
@@ -71,19 +66,19 @@ def _dummy_items(
*
,
rng
:
np
.
random
.
RandomState
|
None
=
None
,
):
return
MultiModalKwargsItems
.
from_seq
(
[
_dummy_item
(
modality
,
size_by_key
,
rng
=
rng
)
return
MultiModalKwargsItems
(
{
modality
:
[
_dummy_item
(
size_by_key
,
rng
=
rng
)
]
for
modality
,
size_by_key
in
size_by_key_modality
.
items
()
]
}
)
@
pytest
.
mark
.
parametrize
(
(
"item"
,
"expected_size"
),
[
(
_dummy_item
(
"a"
,
{
"a1"
:
100
}),
100
),
(
_dummy_item
(
"a"
,
{
"a1"
:
100
,
"a2"
:
110
}),
210
),
(
_dummy_item
({
"a1"
:
100
}),
100
),
(
_dummy_item
({
"a1"
:
100
,
"a2"
:
110
}),
210
),
(
_dummy_items
({
"a"
:
{
"a1"
:
100
,
"a2"
:
110
},
"b"
:
{
"b1"
:
120
,
"b2"
:
130
}}),
460
),
# noqa: E501
],
)
...
...
@@ -143,7 +138,7 @@ def _compare_caches(
rng
=
np
.
random
.
RandomState
(
seed
)
all_items
=
[
_dummy_item
(
"item"
,
{
"key"
:
item_size_gb
},
rng
=
rng
)
_dummy_item
({
"key"
:
item_size_gb
},
rng
=
rng
)
for
_
in
range
(
int
(
item_capacity
/
hit_rate
))
]
all_hashes
=
[
...
...
@@ -245,13 +240,13 @@ def _run_test_cache_eviction_lru(
"image_C"
,
]
request1_items
=
{
h
:
MultiModalKwargsItem
.
dummy
(
h
,
nbytes
=
2
*
base_item_size
)
h
:
MultiModalKwargsItem
.
dummy
(
nbytes
=
2
*
base_item_size
)
for
h
in
request1_hashes
}
request2_hashes
=
[
"image_D"
,
"image_E"
,
"image_A"
,
"image_C"
]
request2_items
=
{
h
:
MultiModalKwargsItem
.
dummy
(
h
,
nbytes
=
1
*
base_item_size
)
h
:
MultiModalKwargsItem
.
dummy
(
nbytes
=
1
*
base_item_size
)
for
h
in
request2_hashes
}
...
...
@@ -356,15 +351,14 @@ def _run_test_cache_eviction_shm(
):
request1_hashes
=
[
"image_A"
,
"image_B"
,
"image_C"
]
request1_items
=
{
h
:
MultiModalKwargsItem
.
dummy
(
h
,
nbytes
=
5
*
base_item_size
)
for
h
in
request1_hashes
h
:
MultiModalKwargsItem
.
dummy
(
5
*
base_item_size
)
for
h
in
request1_hashes
}
request1_items_p0_result
=
[]
request2_hashes
=
[
"image_G"
,
"image_A"
]
request2_items
=
{
h
:
MultiModalKwargsItem
.
dummy
(
h
,
nbytes
=
(
5
if
h
in
request1_hashes
else
2
)
*
base_item_size
(
5
if
h
in
request1_hashes
else
2
)
*
base_item_size
)
for
h
in
request2_hashes
}
...
...
@@ -373,7 +367,7 @@ def _run_test_cache_eviction_shm(
request3_hashes
=
[
"image_G"
,
"image_H"
,
"image_I"
,
"image_B"
]
request3_items
=
{
h
:
MultiModalKwargsItem
.
dummy
(
h
,
nbytes
=
(
5
if
h
in
request1_hashes
else
2
)
*
base_item_size
(
5
if
h
in
request1_hashes
else
2
)
*
base_item_size
)
for
h
in
request3_hashes
}
...
...
@@ -532,7 +526,7 @@ def test_processor_cache_shared_across_loras():
lora_a_identifier
=
f
"12345:
{
base_mm_hash
}
"
lora_b_identifier
=
f
"67890:
{
base_mm_hash
}
"
item_data
=
MultiModalKwargsItem
.
dummy
(
"test_image"
,
nbytes
=
1024
)
item_data
=
MultiModalKwargsItem
.
dummy
(
1024
)
feature_lora_a
=
MultiModalFeatureSpec
(
data
=
item_data
,
...
...
tests/v1/core/test_kv_cache_utils.py
View file @
c6e7404c
...
...
@@ -77,7 +77,7 @@ def make_request(
for
j
,
position
in
enumerate
(
mm_positions
):
identifier
=
mm_hashes
[
j
]
if
mm_hashes
else
f
"hash_
{
j
}
"
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
data
=
MultiModalKwargsItem
.
dummy
(),
mm_position
=
position
,
identifier
=
identifier
,
modality
=
"image"
,
...
...
tests/v1/core/test_prefix_caching.py
View file @
c6e7404c
...
...
@@ -68,7 +68,7 @@ def make_request(
for
j
,
position
in
enumerate
(
mm_positions
):
identifier
=
mm_hashes
[
j
]
if
mm_hashes
else
f
"hash_
{
j
}
"
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
data
=
MultiModalKwargsItem
.
dummy
(),
mm_position
=
position
,
identifier
=
identifier
,
modality
=
"image"
,
...
...
tests/v1/core/test_priority_scheduler_random.py
View file @
c6e7404c
...
...
@@ -56,7 +56,7 @@ def _create_random_request(
for
j
,
position
in
enumerate
(
mm_positions
):
identifier
=
f
"
{
request_id
}
_hash_
{
j
}
"
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
data
=
MultiModalKwargsItem
.
dummy
(),
mm_position
=
position
,
identifier
=
identifier
,
modality
=
"image"
,
...
...
tests/v1/core/test_scheduler.py
View file @
c6e7404c
...
...
@@ -1707,7 +1707,7 @@ def create_requests_with_priority(
# Unique dummy hash for each mm item
identifier
=
f
"hash
{
i
}
_
{
j
}
"
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
data
=
MultiModalKwargsItem
.
dummy
(),
mm_position
=
position
,
identifier
=
identifier
,
modality
=
"image"
,
...
...
tests/v1/core/utils.py
View file @
c6e7404c
...
...
@@ -236,7 +236,7 @@ def create_requests(
# Unique dummy hash for each mm item
identifier
=
f
"hash
{
i
}
_
{
j
}
"
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
data
=
MultiModalKwargsItem
.
dummy
(),
mm_position
=
position
,
identifier
=
identifier
,
modality
=
"image"
,
...
...
tests/v1/streaming_input/test_gpu_model_runner_streaming.py
View file @
c6e7404c
...
...
@@ -131,7 +131,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat
# Step 1: Create initial request state with one multimodal feature
mm_feature_1
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"audio"
),
data
=
MultiModalKwargsItem
.
dummy
(),
modality
=
"audio"
,
identifier
=
"audio_1"
,
mm_position
=
PlaceholderRange
(
offset
=
2
,
length
=
10
),
...
...
@@ -158,7 +158,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat
# The scheduler has already set prompt_token_ids to the full sequence
# (original prompt + intermediate outputs + new prompt with new multimodal feature)
mm_feature_2
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"audio"
),
data
=
MultiModalKwargsItem
.
dummy
(),
modality
=
"audio"
,
identifier
=
"audio_2"
,
mm_position
=
PlaceholderRange
(
offset
=
15
,
length
=
5
),
...
...
tests/v1/streaming_input/test_scheduler_streaming.py
View file @
c6e7404c
...
...
@@ -174,7 +174,7 @@ class TestStreamingScheduler(unittest.TestCase):
scheduler
=
create_scheduler
()
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"audio"
),
data
=
MultiModalKwargsItem
.
dummy
(),
modality
=
"audio"
,
identifier
=
""
,
mm_position
=
PlaceholderRange
(
offset
=
1
,
length
=
1
),
...
...
@@ -187,7 +187,7 @@ class TestStreamingScheduler(unittest.TestCase):
session
.
num_computed_tokens
=
len
(
session
.
prompt_token_ids
)
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"audio"
),
data
=
MultiModalKwargsItem
.
dummy
(),
modality
=
"audio"
,
identifier
=
""
,
mm_position
=
PlaceholderRange
(
offset
=
2
,
length
=
1
),
...
...
tests/v1/test_serial_utils.py
View file @
c6e7404c
...
...
@@ -104,14 +104,10 @@ class MyRequest(msgspec.Struct):
def
test_multimodal_kwargs
():
e1
=
MultiModalFieldElem
(
"audio"
,
"a0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
bfloat16
),
MultiModalBatchedField
(),
)
e2
=
MultiModalFieldElem
(
"video"
,
"v0"
,
[
torch
.
zeros
(
1000
,
dtype
=
torch
.
int8
)
for
_
in
range
(
4
)],
MultiModalFlatField
(
slices
=
[[
slice
(
1
,
2
,
3
),
slice
(
4
,
5
,
6
)],
[
slice
(
None
,
2
)]],
...
...
@@ -119,21 +115,20 @@ def test_multimodal_kwargs():
),
)
e3
=
MultiModalFieldElem
(
"image"
,
"i0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalSharedField
(
batch_size
=
4
),
)
e4
=
MultiModalFieldElem
(
"image"
,
"i1"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalFlatField
(
slices
=
[
slice
(
1
,
2
,
3
),
slice
(
4
,
5
,
6
)],
dim
=
2
),
)
audio
=
MultiModalKwargsItem
.
from_elems
([
e1
])
video
=
MultiModalKwargsItem
.
from_elems
([
e2
])
image
=
MultiModalKwargsItem
.
from_elems
([
e3
,
e4
])
mm
=
MultiModalKwargsItems
.
from_seq
([
audio
,
video
,
image
])
mm
=
MultiModalKwargsItems
(
{
"audio"
:
[
MultiModalKwargsItem
({
"a0"
:
e1
})],
"video"
:
[
MultiModalKwargsItem
({
"v0"
:
e2
})],
"image"
:
[
MultiModalKwargsItem
({
"i0"
:
e3
,
"i1"
:
e4
})],
}
)
# pack mm kwargs into a mock request so that it can be decoded properly
req
=
MyRequest
([
mm
])
...
...
@@ -147,8 +142,8 @@ def test_multimodal_kwargs():
total_len
=
sum
(
memoryview
(
x
).
cast
(
"B"
).
nbytes
for
x
in
encoded
)
# expected total encoding length, should be 1439
5
, +-20 for minor changes
assert
143
75
<=
total_len
<=
14
425
# expected total encoding length, should be 143
1
9, +-20 for minor changes
assert
143
00
<=
total_len
<=
14
340
decoded
=
decoder
.
decode
(
encoded
).
mm
[
0
]
assert
isinstance
(
decoded
,
MultiModalKwargsItems
)
...
...
vllm/multimodal/cache.py
View file @
c6e7404c
...
...
@@ -463,8 +463,8 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
ring_buffer
=
ring_buffer
,
serde_class
=
MsgpackSerde
,
)
# cache
(
prompt_updates
, modality)
for P0 only
self
.
_p0_cache
:
dict
[
str
,
tuple
[
Sequence
[
ResolvedPromptUpdate
]
,
str
]
]
=
{}
# cache prompt_updates for P0 only
self
.
_p0_cache
:
dict
[
str
,
Sequence
[
ResolvedPromptUpdate
]]
=
{}
self
.
_hits
=
0
self
.
_total
=
0
...
...
@@ -495,23 +495,22 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
self
.
_total
+=
1
address
,
monotonic_id
=
self
.
_shm_cache
.
get_cached
(
mm_hash
)
prompt_updates
,
modality
=
self
.
_p0_cache
[
mm_hash
]
return
self
.
address_as_item
(
address
,
monotonic_id
,
modality
),
prompt_updates
prompt_updates
=
self
.
_p0_cache
[
mm_hash
]
return
self
.
address_as_item
(
address
,
monotonic_id
),
prompt_updates
assert
mm_item
is
not
None
,
f
"Expected a cached item for
{
mm_hash
=
}
"
item
,
prompt_updates
=
mm_item
self
.
_total
+=
1
try
:
address
,
monotonic_id
=
self
.
_shm_cache
.
put
(
mm_hash
,
mm_
item
[
0
]
)
address
,
monotonic_id
=
self
.
_shm_cache
.
put
(
mm_hash
,
item
)
# Try to remove dangling items if p0 cache is too large.
if
len
(
self
.
_p0_cache
)
>=
2
*
len
(
self
.
_shm_cache
.
key_index
):
self
.
remove_dangling_items
()
self
.
_p0_cache
[
mm_hash
]
=
mm_item
[
1
],
mm_item
[
0
].
modality
address_item
=
self
.
address_as_item
(
address
,
monotonic_id
,
mm_item
[
0
].
modality
)
return
address_item
,
mm_item
[
1
]
self
.
_p0_cache
[
mm_hash
]
=
prompt_updates
return
self
.
address_as_item
(
address
,
monotonic_id
),
prompt_updates
except
(
ValueError
,
MemoryError
)
as
e
:
# put may fail if the object is too large or
# the cache is full.
...
...
@@ -550,22 +549,20 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
del
self
.
_p0_cache
[
mm_hash
]
def
address_as_item
(
self
,
address
:
int
,
monotonic_id
:
int
,
modality
:
str
self
,
address
:
int
,
monotonic_id
:
int
,
)
->
MultiModalKwargsItem
:
addr_elem
=
MultiModalFieldElem
(
modality
=
modality
,
key
=
"address"
,
data
=
address
,
field
=
MultiModalBatchedField
(),
)
id_elem
=
MultiModalFieldElem
(
modality
=
modality
,
key
=
"monotonic_id"
,
data
=
monotonic_id
,
field
=
MultiModalBatchedField
(),
)
mm_item
=
MultiModalKwargsItem
.
from_elems
([
addr_elem
,
id_elem
])
return
mm_item
return
MultiModalKwargsItem
({
"address"
:
addr_elem
,
"monotonic_id"
:
id_elem
})
class
BaseMultiModalReceiverCache
(
...
...
vllm/multimodal/inputs.py
View file @
c6e7404c
...
...
@@ -23,7 +23,7 @@ import numpy as np
from
PIL.Image
import
Image
from
typing_extensions
import
NotRequired
,
TypeVar
from
vllm.utils.collection_utils
import
full_groupby
,
is_list_of
from
vllm.utils.collection_utils
import
is_list_of
from
vllm.utils.import_utils
import
LazyLoader
from
vllm.utils.jsontree
import
json_map_leaves
...
...
@@ -336,25 +336,33 @@ class MultiModalFeatureSpec:
"""
Represents a single multimodal input with its processed data and metadata.
Used
by the V1 engine
to track multimodal data through processing and
caching.
A request containing multiple multimodal items will have one
MultiModalFeatureSpec per item.
Used to track multimodal data through processing and
caching.
A request containing multiple multimodal items will have one
`
MultiModalFeatureSpec
`
per item.
"""
data
:
Optional
[
"MultiModalKwargsItem"
]
"""Multimodal data for this feature"""
"""
Represents multimodal data for this feature.
Can be `None` if the item is cached, to skip IPC between API server
and engine core processes.
"""
modality
:
str
"""
Based on the input
, e.g., "image", "audio", "video"."""
"""
The input modality
, e.g.,
`
"image"
`
,
`
"audio"
`
,
`
"video"
`
."""
identifier
:
str
"""
mm_hash or uuid
for caching encoder outputs."""
"""
The hash
for caching encoder outputs
(with LoRA prefix if applicable)
."""
mm_position
:
PlaceholderRange
"""e.g., PlaceholderRange(offset=2, length=336)"""
"""
The location of the `modality` tokens corresponding to this item
in the prompt, e.g., `PlaceholderRange(offset=2, length=336)`.
"""
mm_hash
:
str
|
None
=
None
"""
Base mm_
hash for processor
cache
(without LoRA prefix)."""
"""
The
hash for
caching
processor
outputs
(without LoRA prefix)."""
@
staticmethod
def
gather_kwargs
(
features
:
list
[
"MultiModalFeatureSpec"
],
keys
:
set
[
str
]):
...
...
@@ -373,23 +381,10 @@ class MultiModalFeatureSpec:
@
dataclass
class
MultiModalFieldElem
:
"""
Represents a keyword argument
inside
a
Represents a
processed
keyword argument
to pass to a model for
a
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
"""
modality
:
str
"""
The modality of the corresponding multi-modal item.
Each multi-modal item can consist of multiple keyword arguments.
"""
key
:
str
"""
The key of this field in
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
i.e. the name of the keyword argument to be passed to the model.
"""
data
:
NestedTensors
"""
The tensor data of this field in
...
...
@@ -417,11 +412,7 @@ class MultiModalFieldElem:
else
:
data_equal
=
nested_tensors_equal
(
self
.
data
,
other
.
data
)
return
(
(
self
.
modality
,
self
.
key
)
==
(
other
.
modality
,
other
.
key
)
and
data_equal
and
type
(
self
.
field
)
is
type
(
other
.
field
)
)
# noqa: E721
return
data_equal
and
type
(
self
.
field
)
is
type
(
other
.
field
)
# noqa: E721
@
dataclass
(
frozen
=
True
,
kw_only
=
True
)
...
...
@@ -438,13 +429,8 @@ class BaseMultiModalField(ABC):
when `MultiModalKwargsItems.get_data()` is called to batch the data.
"""
def
_field_factory
(
self
,
*
,
modality
:
str
,
key
:
str
):
f
=
partial
(
MultiModalFieldElem
,
modality
=
modality
,
key
=
key
,
field
=
self
,
)
def
_field_factory
(
self
):
f
=
partial
(
MultiModalFieldElem
,
field
=
self
)
# Allow passing data as positional argument
def
factory
(
data
:
NestedTensors
)
->
MultiModalFieldElem
:
...
...
@@ -519,7 +505,7 @@ class MultiModalBatchedField(BaseMultiModalField):
key
:
str
,
data
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
field_factory
=
self
.
_field_factory
(
modality
=
modality
,
key
=
key
)
field_factory
=
self
.
_field_factory
()
return
[
field_factory
(
item
)
for
item
in
data
]
def
_reduce_data
(
...
...
@@ -565,7 +551,7 @@ class MultiModalFlatField(BaseMultiModalField):
key
:
str
,
data
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
field_factory
=
self
.
_field_factory
(
modality
=
modality
,
key
=
key
)
field_factory
=
self
.
_field_factory
()
if
not
is_list_of
(
self
.
slices
,
slice
,
check
=
"all"
):
assert
isinstance
(
data
,
torch
.
Tensor
),
(
"torch.Tensor is required for multiple slices"
...
...
@@ -664,7 +650,7 @@ class MultiModalSharedField(BaseMultiModalField):
key
:
str
,
data
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
field_factory
=
self
.
_field_factory
(
modality
=
modality
,
key
=
key
)
field_factory
=
self
.
_field_factory
()
return
[
field_factory
(
data
)]
*
self
.
batch_size
def
_reduce_data
(
...
...
@@ -899,37 +885,19 @@ class MultiModalFieldConfig:
class
MultiModalKwargsItem
(
UserDict
[
str
,
MultiModalFieldElem
]):
"""
A collection of
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
corresponding to a data item in
A dictionary of processed keyword arguments to pass to the model,
corresponding to a single item in
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
"""
@
staticmethod
def
dummy
(
modality
:
str
,
nbytes
:
int
=
1
):
def
dummy
(
nbytes
:
int
=
1
):
"""Convenience class for testing."""
mm_elem
=
MultiModalFieldElem
(
modality
=
modality
,
key
=
"dummy"
,
data
=
torch
.
empty
(
nbytes
,
dtype
=
torch
.
uint8
),
field
=
MultiModalSharedField
(
batch_size
=
1
),
)
return
MultiModalKwargsItem
.
from_elems
([
mm_elem
])
@
staticmethod
def
from_elems
(
elems
:
Sequence
[
MultiModalFieldElem
]):
return
MultiModalKwargsItem
({
elem
.
key
:
elem
for
elem
in
elems
})
def
__init__
(
self
,
data
:
Mapping
[
str
,
MultiModalFieldElem
]
=
{})
->
None
:
super
().
__init__
(
data
)
modalities
=
{
elem
.
modality
for
elem
in
self
.
values
()}
assert
len
(
modalities
)
==
1
,
f
"Found different modalities=
{
modalities
}
"
self
.
_modality
=
next
(
iter
(
modalities
))
@
property
def
modality
(
self
)
->
str
:
return
self
.
_modality
return
MultiModalKwargsItem
({
"dummy"
:
mm_elem
})
def
get_data
(
self
)
->
dict
[
str
,
NestedTensors
]:
return
{
key
:
elem
.
data
for
key
,
elem
in
self
.
items
()}
...
...
@@ -945,9 +913,38 @@ _I = TypeVar(
class
MultiModalKwargsItems
(
UserDict
[
str
,
Sequence
[
_I
]]):
"""
A dictionary of
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
by modality.
A dictionary of processed multi-modal inputs by modality.
For example, given a processor that processes
images into `pixel_values` and `image_grid_thw`,
and audios into `input_audio_features`,
a prompt with 2 images and 1 audio will be processed
into a `MultiModalKwargsItems` with the following structure:
```python
MultiModalKwargsItems(
{
"image": [
# For the first image
MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
# For the second imgae
MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
],
"audio": [
# For the first audio
MultiModalKwargsItem({"input_audio_features": ...}),
],
}
)
```
Unlike HF processing which returns all items
in a single dictionary with batched keyword arguments,
we split up the items because some of them may already be cached.
Also, items from multiple requests may be batched together to improve throughput,
using the logic defined by the
[`BaseMultiModalField`][vllm.multimodal.inputs.BaseMultiModalField]
for each keyword argument.
"""
@
staticmethod
...
...
@@ -967,7 +964,7 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
elems_by_key
[
key
]
=
elems
keys_by_modality
[
config
.
modality
].
add
(
key
)
items
=
list
[
MultiModalKwargsItem
]()
items
_by_modality
=
dict
[
str
,
list
[
MultiModalKwargsItem
]
]
()
for
modality
,
keys
in
keys_by_modality
.
items
():
elems_in_modality
=
{
k
:
elems_by_key
[
k
]
for
k
in
keys
}
batch_sizes
=
{
k
:
len
(
v
)
for
k
,
v
in
elems_in_modality
.
items
()}
...
...
@@ -979,15 +976,11 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
)
batch_size
=
next
(
iter
(
batch_sizes
.
values
()))
for
item_idx
in
range
(
batch_size
):
elems
=
[
v
[
item_idx
]
for
v
in
elems_in_modality
.
values
()]
items
.
append
(
MultiModalKwargsItem
.
from_elems
(
elems
))
return
MultiModalKwargsItems
.
from_seq
(
items
)
items_by_modality
[
modality
]
=
[
MultiModalKwargsItem
({
k
:
v
[
i
]
for
k
,
v
in
elems_in_modality
.
items
()})
for
i
in
range
(
batch_size
)
]
@
staticmethod
def
from_seq
(
items
:
Sequence
[
MultiModalKwargsItem
]):
items_by_modality
=
full_groupby
(
items
,
key
=
lambda
x
:
x
.
modality
)
return
MultiModalKwargsItems
(
items_by_modality
)
def
__getitem__
(
self
,
modality
:
str
)
->
Sequence
[
_I
]:
...
...
vllm/multimodal/utils.py
View file @
c6e7404c
...
...
@@ -467,7 +467,7 @@ def argsort_mm_positions(
def
group_mm_kwargs_by_modality
(
mm_kwargs
:
list
[
MultiModalKwargsItem
],
mm_kwargs
:
list
[
tuple
[
str
,
MultiModalKwargsItem
]
]
,
*
,
device
:
torch
.
types
.
Device
=
None
,
pin_memory
:
bool
=
False
,
...
...
@@ -485,9 +485,9 @@ def group_mm_kwargs_by_modality(
"""
from
vllm.multimodal.inputs
import
MultiModalKwargsItems
for
modality
,
items
in
groupby
(
mm_kwargs
,
key
=
lambda
item
:
item
.
modality
):
items_lst
=
list
(
items
)
mm_kwargs_items
=
MultiModalKwargsItems
.
from_seq
(
items_lst
)
for
modality
,
group
in
groupby
(
mm_kwargs
,
key
=
lambda
x
:
x
[
0
]
):
items_lst
=
[
item
for
_
,
item
in
group
]
mm_kwargs_items
=
MultiModalKwargsItems
({
modality
:
items_lst
}
)
mm_kwargs_data
=
mm_kwargs_items
.
get_data
(
device
=
device
,
pin_memory
=
pin_memory
,
...
...
vllm/v1/serial_utils.py
View file @
c6e7404c
...
...
@@ -242,13 +242,11 @@ class MsgpackEncoder:
for
modality
,
itemlist
in
items
.
items
()
}
def
_encode_mm_item
(
self
,
item
:
MultiModalKwargsItem
)
->
list
[
dict
[
str
,
Any
]
]
:
return
[
self
.
_encode_mm_field_elem
(
elem
)
for
elem
in
item
.
value
s
()
]
def
_encode_mm_item
(
self
,
item
:
MultiModalKwargsItem
)
->
dict
[
str
,
Any
]:
return
{
key
:
self
.
_encode_mm_field_elem
(
elem
)
for
key
,
elem
in
item
.
item
s
()
}
def
_encode_mm_field_elem
(
self
,
elem
:
MultiModalFieldElem
)
->
dict
[
str
,
Any
]:
return
{
"modality"
:
elem
.
modality
,
"key"
:
elem
.
key
,
"data"
:
(
None
if
elem
.
data
is
None
else
self
.
_encode_nested_tensors
(
elem
.
data
)
),
...
...
@@ -383,9 +381,9 @@ class MsgpackDecoder:
}
)
def
_decode_mm_item
(
self
,
obj
:
list
[
Any
])
->
MultiModalKwargsItem
:
return
MultiModalKwargsItem
.
from_elems
(
[
self
.
_decode_mm_field_elem
(
v
)
for
v
in
obj
]
def
_decode_mm_item
(
self
,
obj
:
dict
[
str
,
Any
])
->
MultiModalKwargsItem
:
return
MultiModalKwargsItem
(
{
key
:
self
.
_decode_mm_field_elem
(
elem
)
for
key
,
elem
in
obj
.
items
()}
)
def
_decode_mm_field_elem
(
self
,
obj
:
dict
[
str
,
Any
])
->
MultiModalFieldElem
:
...
...
vllm/v1/worker/gpu/mm/encoder_runner.py
View file @
c6e7404c
...
...
@@ -43,9 +43,9 @@ class EncoderRunner:
def
prepare_mm_inputs
(
self
,
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]],
)
->
tuple
[
list
[
str
],
list
[
MultiModalKwargsItem
]]:
)
->
tuple
[
list
[
str
],
list
[
tuple
[
str
,
MultiModalKwargsItem
]]
]
:
mm_hashes
:
list
[
str
]
=
[]
mm_kwargs
:
list
[
MultiModalKwargsItem
]
=
[]
mm_kwargs
:
list
[
tuple
[
str
,
MultiModalKwargsItem
]
]
=
[]
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
mm_features
=
self
.
req_id_to_mm_features
[
req_id
]
for
mm_input_id
in
encoder_input_ids
:
...
...
@@ -53,7 +53,8 @@ class EncoderRunner:
if
mm_feature
.
data
is
None
:
continue
mm_hashes
.
append
(
mm_feature
.
identifier
)
mm_kwargs
.
append
(
mm_feature
.
data
)
mm_kwargs
.
append
((
mm_feature
.
modality
,
mm_feature
.
data
))
return
mm_hashes
,
mm_kwargs
@
torch
.
inference_mode
()
...
...
@@ -61,7 +62,7 @@ class EncoderRunner:
self
,
model
:
SupportsMultiModal
,
mm_hashes
:
list
[
str
],
mm_kwargs
:
list
[
MultiModalKwargsItem
],
mm_kwargs
:
list
[
tuple
[
str
,
MultiModalKwargsItem
]
]
,
)
->
list
[
torch
.
Tensor
]:
if
not
mm_hashes
:
return
[]
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
c6e7404c
...
...
@@ -1217,11 +1217,11 @@ class GPUModelRunner(
if
not
scheduler_output
or
not
self
.
is_multimodal_raw_input_only_model
:
return
{}
mm_kwargs
=
list
[
MultiModalKwargsItem
]()
mm_kwargs
=
list
[
tuple
[
str
,
MultiModalKwargsItem
]
]
()
for
req
in
scheduler_output
.
scheduled_new_reqs
:
for
feature
in
req
.
mm_features
:
if
feature
.
data
is
not
None
:
mm_kwargs
.
append
(
feature
.
data
)
mm_kwargs
.
append
(
(
feature
.
modality
,
feature
.
data
)
)
# Input all modalities at once
mm_kwargs_combined
:
BatchedTensorInputs
=
{}
...
...
@@ -2219,7 +2219,7 @@ class GPUModelRunner(
scheduler_output
:
"SchedulerOutput"
,
)
->
tuple
[
list
[
str
],
list
[
MultiModalKwargsItem
],
list
[
tuple
[
str
,
MultiModalKwargsItem
]
]
,
list
[
tuple
[
str
,
PlaceholderRange
]],
]:
"""Batch multimodal inputs from scheduled encoder inputs.
...
...
@@ -2239,7 +2239,7 @@ class GPUModelRunner(
return
[],
[],
[]
mm_hashes
=
list
[
str
]()
mm_kwargs
=
list
[
MultiModalKwargsItem
]()
mm_kwargs
=
list
[
tuple
[
str
,
MultiModalKwargsItem
]
]
()
# Multimodal LoRA reference info to map each multimodal item
# back to its request & position
mm_lora_refs
=
list
[
tuple
[
str
,
PlaceholderRange
]]()
...
...
@@ -2252,7 +2252,7 @@ class GPUModelRunner(
continue
mm_hashes
.
append
(
mm_feature
.
identifier
)
mm_kwargs
.
append
(
mm_feature
.
data
)
mm_kwargs
.
append
(
(
mm_feature
.
modality
,
mm_feature
.
data
)
)
mm_lora_refs
.
append
((
req_id
,
mm_feature
.
mm_position
))
return
mm_hashes
,
mm_kwargs
,
mm_lora_refs
...
...
@@ -4475,12 +4475,10 @@ class GPUModelRunner(
# but not read from the cache
assert
dummy_mm_item
is
not
None
,
"Item should not already be cached"
dummy_mm_items
=
[
dummy_mm_item
]
*
max_items_per_batch
return
next
(
mm_kwargs_group
for
_
,
_
,
mm_kwargs_group
in
group_mm_kwargs_by_modality
(
dummy_mm_item
s
,
[(
modality
,
dummy_mm_item
)]
*
max_items_per_batch
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment