Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c6e7404c
Unverified
Commit
c6e7404c
authored
Jan 29, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 29, 2026
Browse files
[Multimodal] Simplify MM input definitions (#33331)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
17b17c06
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
142 additions
and
164 deletions
+142
-164
tests/distributed/test_shm_storage.py
tests/distributed/test_shm_storage.py
+5
-7
tests/models/multimodal/processing/test_tensor_schema.py
tests/models/multimodal/processing/test_tensor_schema.py
+5
-1
tests/multimodal/test_cache.py
tests/multimodal/test_cache.py
+15
-21
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+1
-1
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+1
-1
tests/v1/core/test_priority_scheduler_random.py
tests/v1/core/test_priority_scheduler_random.py
+1
-1
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+1
-1
tests/v1/core/utils.py
tests/v1/core/utils.py
+1
-1
tests/v1/streaming_input/test_gpu_model_runner_streaming.py
tests/v1/streaming_input/test_gpu_model_runner_streaming.py
+2
-2
tests/v1/streaming_input/test_scheduler_streaming.py
tests/v1/streaming_input/test_scheduler_streaming.py
+2
-2
tests/v1/test_serial_utils.py
tests/v1/test_serial_utils.py
+9
-14
vllm/multimodal/cache.py
vllm/multimodal/cache.py
+14
-17
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+65
-72
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+4
-4
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+5
-7
vllm/v1/worker/gpu/mm/encoder_runner.py
vllm/v1/worker/gpu/mm/encoder_runner.py
+5
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+6
-8
No files found.
tests/distributed/test_shm_storage.py
View file @
c6e7404c
...
@@ -23,18 +23,16 @@ from vllm.multimodal.inputs import (
...
@@ -23,18 +23,16 @@ from vllm.multimodal.inputs import (
)
)
def
_dummy_elem
(
modality
:
str
,
key
:
str
,
size
:
int
):
def
_dummy_elem
(
size
:
int
):
return
MultiModalFieldElem
(
return
MultiModalFieldElem
(
modality
=
modality
,
key
=
key
,
data
=
torch
.
empty
((
size
,),
dtype
=
torch
.
int8
),
data
=
torch
.
empty
((
size
,),
dtype
=
torch
.
int8
),
field
=
MultiModalSharedField
(
batch_size
=
1
),
field
=
MultiModalSharedField
(
batch_size
=
1
),
)
)
def
_dummy_item
(
modality
:
str
,
size_by_key
:
dict
[
str
,
int
]):
def
_dummy_item
(
size_by_key
:
dict
[
str
,
int
]):
return
MultiModalKwargsItem
.
from_elems
(
return
MultiModalKwargsItem
(
[
_dummy_elem
(
modality
,
key
,
size
)
for
key
,
size
in
size_by_key
.
items
()
]
{
key
:
_dummy_elem
(
size
)
for
key
,
size
in
size_by_key
.
items
()
}
)
)
...
@@ -61,7 +59,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
...
@@ -61,7 +59,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase):
def
test_minimal_put_get_cycle
(
self
):
def
test_minimal_put_get_cycle
(
self
):
"""Test basic put and get operations."""
"""Test basic put and get operations."""
key
=
"test_key"
key
=
"test_key"
value
=
_dummy_item
(
"text"
,
{
"field1"
:
10
,
"field2"
:
20
})
value
=
_dummy_item
({
"field1"
:
10
,
"field2"
:
20
})
# Put operation
# Put operation
address
,
monotonic_id
=
self
.
storage
.
put
(
key
,
value
)
address
,
monotonic_id
=
self
.
storage
.
put
(
key
,
value
)
...
...
tests/models/multimodal/processing/test_tensor_schema.py
View file @
c6e7404c
...
@@ -119,7 +119,11 @@ def create_batched_mm_kwargs(
...
@@ -119,7 +119,11 @@ def create_batched_mm_kwargs(
)[
"mm_kwargs"
].
require_data
()
)[
"mm_kwargs"
].
require_data
()
return
group_mm_kwargs_by_modality
(
return
group_mm_kwargs_by_modality
(
[
item
for
modality
in
supported_mm_limits
for
item
in
mm_kwargs
[
modality
]]
[
(
modality
,
item
)
for
modality
in
supported_mm_limits
for
item
in
mm_kwargs
[
modality
]
]
)
)
...
...
tests/multimodal/test_cache.py
View file @
c6e7404c
...
@@ -36,8 +36,6 @@ pytestmark = pytest.mark.cpu_test
...
@@ -36,8 +36,6 @@ pytestmark = pytest.mark.cpu_test
def
_dummy_elem
(
def
_dummy_elem
(
modality
:
str
,
key
:
str
,
size
:
int
,
size
:
int
,
*
,
*
,
rng
:
np
.
random
.
RandomState
|
None
=
None
,
rng
:
np
.
random
.
RandomState
|
None
=
None
,
...
@@ -48,21 +46,18 @@ def _dummy_elem(
...
@@ -48,21 +46,18 @@ def _dummy_elem(
data
=
torch
.
from_numpy
(
rng
.
randint
(
4
,
size
=
(
size
,),
dtype
=
np
.
int8
))
data
=
torch
.
from_numpy
(
rng
.
randint
(
4
,
size
=
(
size
,),
dtype
=
np
.
int8
))
return
MultiModalFieldElem
(
return
MultiModalFieldElem
(
modality
=
modality
,
key
=
key
,
data
=
data
,
data
=
data
,
field
=
MultiModalSharedField
(
batch_size
=
1
),
field
=
MultiModalSharedField
(
batch_size
=
1
),
)
)
def
_dummy_item
(
def
_dummy_item
(
modality
:
str
,
size_by_key
:
dict
[
str
,
int
],
size_by_key
:
dict
[
str
,
int
],
*
,
*
,
rng
:
np
.
random
.
RandomState
|
None
=
None
,
rng
:
np
.
random
.
RandomState
|
None
=
None
,
):
):
return
MultiModalKwargsItem
.
from_elems
(
return
MultiModalKwargsItem
(
[
_dummy_elem
(
modality
,
key
,
size
,
rng
=
rng
)
for
key
,
size
in
size_by_key
.
items
()
]
{
key
:
_dummy_elem
(
size
,
rng
=
rng
)
for
key
,
size
in
size_by_key
.
items
()
}
)
)
...
@@ -71,19 +66,19 @@ def _dummy_items(
...
@@ -71,19 +66,19 @@ def _dummy_items(
*
,
*
,
rng
:
np
.
random
.
RandomState
|
None
=
None
,
rng
:
np
.
random
.
RandomState
|
None
=
None
,
):
):
return
MultiModalKwargsItems
.
from_seq
(
return
MultiModalKwargsItems
(
[
{
_dummy_item
(
modality
,
size_by_key
,
rng
=
rng
)
modality
:
[
_dummy_item
(
size_by_key
,
rng
=
rng
)
]
for
modality
,
size_by_key
in
size_by_key_modality
.
items
()
for
modality
,
size_by_key
in
size_by_key_modality
.
items
()
]
}
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"item"
,
"expected_size"
),
(
"item"
,
"expected_size"
),
[
[
(
_dummy_item
(
"a"
,
{
"a1"
:
100
}),
100
),
(
_dummy_item
({
"a1"
:
100
}),
100
),
(
_dummy_item
(
"a"
,
{
"a1"
:
100
,
"a2"
:
110
}),
210
),
(
_dummy_item
({
"a1"
:
100
,
"a2"
:
110
}),
210
),
(
_dummy_items
({
"a"
:
{
"a1"
:
100
,
"a2"
:
110
},
"b"
:
{
"b1"
:
120
,
"b2"
:
130
}}),
460
),
# noqa: E501
(
_dummy_items
({
"a"
:
{
"a1"
:
100
,
"a2"
:
110
},
"b"
:
{
"b1"
:
120
,
"b2"
:
130
}}),
460
),
# noqa: E501
],
],
)
)
...
@@ -143,7 +138,7 @@ def _compare_caches(
...
@@ -143,7 +138,7 @@ def _compare_caches(
rng
=
np
.
random
.
RandomState
(
seed
)
rng
=
np
.
random
.
RandomState
(
seed
)
all_items
=
[
all_items
=
[
_dummy_item
(
"item"
,
{
"key"
:
item_size_gb
},
rng
=
rng
)
_dummy_item
({
"key"
:
item_size_gb
},
rng
=
rng
)
for
_
in
range
(
int
(
item_capacity
/
hit_rate
))
for
_
in
range
(
int
(
item_capacity
/
hit_rate
))
]
]
all_hashes
=
[
all_hashes
=
[
...
@@ -245,13 +240,13 @@ def _run_test_cache_eviction_lru(
...
@@ -245,13 +240,13 @@ def _run_test_cache_eviction_lru(
"image_C"
,
"image_C"
,
]
]
request1_items
=
{
request1_items
=
{
h
:
MultiModalKwargsItem
.
dummy
(
h
,
nbytes
=
2
*
base_item_size
)
h
:
MultiModalKwargsItem
.
dummy
(
nbytes
=
2
*
base_item_size
)
for
h
in
request1_hashes
for
h
in
request1_hashes
}
}
request2_hashes
=
[
"image_D"
,
"image_E"
,
"image_A"
,
"image_C"
]
request2_hashes
=
[
"image_D"
,
"image_E"
,
"image_A"
,
"image_C"
]
request2_items
=
{
request2_items
=
{
h
:
MultiModalKwargsItem
.
dummy
(
h
,
nbytes
=
1
*
base_item_size
)
h
:
MultiModalKwargsItem
.
dummy
(
nbytes
=
1
*
base_item_size
)
for
h
in
request2_hashes
for
h
in
request2_hashes
}
}
...
@@ -356,15 +351,14 @@ def _run_test_cache_eviction_shm(
...
@@ -356,15 +351,14 @@ def _run_test_cache_eviction_shm(
):
):
request1_hashes
=
[
"image_A"
,
"image_B"
,
"image_C"
]
request1_hashes
=
[
"image_A"
,
"image_B"
,
"image_C"
]
request1_items
=
{
request1_items
=
{
h
:
MultiModalKwargsItem
.
dummy
(
h
,
nbytes
=
5
*
base_item_size
)
h
:
MultiModalKwargsItem
.
dummy
(
5
*
base_item_size
)
for
h
in
request1_hashes
for
h
in
request1_hashes
}
}
request1_items_p0_result
=
[]
request1_items_p0_result
=
[]
request2_hashes
=
[
"image_G"
,
"image_A"
]
request2_hashes
=
[
"image_G"
,
"image_A"
]
request2_items
=
{
request2_items
=
{
h
:
MultiModalKwargsItem
.
dummy
(
h
:
MultiModalKwargsItem
.
dummy
(
h
,
nbytes
=
(
5
if
h
in
request1_hashes
else
2
)
*
base_item_size
(
5
if
h
in
request1_hashes
else
2
)
*
base_item_size
)
)
for
h
in
request2_hashes
for
h
in
request2_hashes
}
}
...
@@ -373,7 +367,7 @@ def _run_test_cache_eviction_shm(
...
@@ -373,7 +367,7 @@ def _run_test_cache_eviction_shm(
request3_hashes
=
[
"image_G"
,
"image_H"
,
"image_I"
,
"image_B"
]
request3_hashes
=
[
"image_G"
,
"image_H"
,
"image_I"
,
"image_B"
]
request3_items
=
{
request3_items
=
{
h
:
MultiModalKwargsItem
.
dummy
(
h
:
MultiModalKwargsItem
.
dummy
(
h
,
nbytes
=
(
5
if
h
in
request1_hashes
else
2
)
*
base_item_size
(
5
if
h
in
request1_hashes
else
2
)
*
base_item_size
)
)
for
h
in
request3_hashes
for
h
in
request3_hashes
}
}
...
@@ -532,7 +526,7 @@ def test_processor_cache_shared_across_loras():
...
@@ -532,7 +526,7 @@ def test_processor_cache_shared_across_loras():
lora_a_identifier
=
f
"12345:
{
base_mm_hash
}
"
lora_a_identifier
=
f
"12345:
{
base_mm_hash
}
"
lora_b_identifier
=
f
"67890:
{
base_mm_hash
}
"
lora_b_identifier
=
f
"67890:
{
base_mm_hash
}
"
item_data
=
MultiModalKwargsItem
.
dummy
(
"test_image"
,
nbytes
=
1024
)
item_data
=
MultiModalKwargsItem
.
dummy
(
1024
)
feature_lora_a
=
MultiModalFeatureSpec
(
feature_lora_a
=
MultiModalFeatureSpec
(
data
=
item_data
,
data
=
item_data
,
...
...
tests/v1/core/test_kv_cache_utils.py
View file @
c6e7404c
...
@@ -77,7 +77,7 @@ def make_request(
...
@@ -77,7 +77,7 @@ def make_request(
for
j
,
position
in
enumerate
(
mm_positions
):
for
j
,
position
in
enumerate
(
mm_positions
):
identifier
=
mm_hashes
[
j
]
if
mm_hashes
else
f
"hash_
{
j
}
"
identifier
=
mm_hashes
[
j
]
if
mm_hashes
else
f
"hash_
{
j
}
"
mm_feature
=
MultiModalFeatureSpec
(
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
data
=
MultiModalKwargsItem
.
dummy
(),
mm_position
=
position
,
mm_position
=
position
,
identifier
=
identifier
,
identifier
=
identifier
,
modality
=
"image"
,
modality
=
"image"
,
...
...
tests/v1/core/test_prefix_caching.py
View file @
c6e7404c
...
@@ -68,7 +68,7 @@ def make_request(
...
@@ -68,7 +68,7 @@ def make_request(
for
j
,
position
in
enumerate
(
mm_positions
):
for
j
,
position
in
enumerate
(
mm_positions
):
identifier
=
mm_hashes
[
j
]
if
mm_hashes
else
f
"hash_
{
j
}
"
identifier
=
mm_hashes
[
j
]
if
mm_hashes
else
f
"hash_
{
j
}
"
mm_feature
=
MultiModalFeatureSpec
(
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
data
=
MultiModalKwargsItem
.
dummy
(),
mm_position
=
position
,
mm_position
=
position
,
identifier
=
identifier
,
identifier
=
identifier
,
modality
=
"image"
,
modality
=
"image"
,
...
...
tests/v1/core/test_priority_scheduler_random.py
View file @
c6e7404c
...
@@ -56,7 +56,7 @@ def _create_random_request(
...
@@ -56,7 +56,7 @@ def _create_random_request(
for
j
,
position
in
enumerate
(
mm_positions
):
for
j
,
position
in
enumerate
(
mm_positions
):
identifier
=
f
"
{
request_id
}
_hash_
{
j
}
"
identifier
=
f
"
{
request_id
}
_hash_
{
j
}
"
mm_feature
=
MultiModalFeatureSpec
(
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
data
=
MultiModalKwargsItem
.
dummy
(),
mm_position
=
position
,
mm_position
=
position
,
identifier
=
identifier
,
identifier
=
identifier
,
modality
=
"image"
,
modality
=
"image"
,
...
...
tests/v1/core/test_scheduler.py
View file @
c6e7404c
...
@@ -1707,7 +1707,7 @@ def create_requests_with_priority(
...
@@ -1707,7 +1707,7 @@ def create_requests_with_priority(
# Unique dummy hash for each mm item
# Unique dummy hash for each mm item
identifier
=
f
"hash
{
i
}
_
{
j
}
"
identifier
=
f
"hash
{
i
}
_
{
j
}
"
mm_feature
=
MultiModalFeatureSpec
(
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
data
=
MultiModalKwargsItem
.
dummy
(),
mm_position
=
position
,
mm_position
=
position
,
identifier
=
identifier
,
identifier
=
identifier
,
modality
=
"image"
,
modality
=
"image"
,
...
...
tests/v1/core/utils.py
View file @
c6e7404c
...
@@ -236,7 +236,7 @@ def create_requests(
...
@@ -236,7 +236,7 @@ def create_requests(
# Unique dummy hash for each mm item
# Unique dummy hash for each mm item
identifier
=
f
"hash
{
i
}
_
{
j
}
"
identifier
=
f
"hash
{
i
}
_
{
j
}
"
mm_feature
=
MultiModalFeatureSpec
(
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"dummy_m"
),
data
=
MultiModalKwargsItem
.
dummy
(),
mm_position
=
position
,
mm_position
=
position
,
identifier
=
identifier
,
identifier
=
identifier
,
modality
=
"image"
,
modality
=
"image"
,
...
...
tests/v1/streaming_input/test_gpu_model_runner_streaming.py
View file @
c6e7404c
...
@@ -131,7 +131,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat
...
@@ -131,7 +131,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat
# Step 1: Create initial request state with one multimodal feature
# Step 1: Create initial request state with one multimodal feature
mm_feature_1
=
MultiModalFeatureSpec
(
mm_feature_1
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"audio"
),
data
=
MultiModalKwargsItem
.
dummy
(),
modality
=
"audio"
,
modality
=
"audio"
,
identifier
=
"audio_1"
,
identifier
=
"audio_1"
,
mm_position
=
PlaceholderRange
(
offset
=
2
,
length
=
10
),
mm_position
=
PlaceholderRange
(
offset
=
2
,
length
=
10
),
...
@@ -158,7 +158,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat
...
@@ -158,7 +158,7 @@ def test_e2e_streaming_with_multimodal_features(mock_model_runner_with_input_bat
# The scheduler has already set prompt_token_ids to the full sequence
# The scheduler has already set prompt_token_ids to the full sequence
# (original prompt + intermediate outputs + new prompt with new multimodal feature)
# (original prompt + intermediate outputs + new prompt with new multimodal feature)
mm_feature_2
=
MultiModalFeatureSpec
(
mm_feature_2
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"audio"
),
data
=
MultiModalKwargsItem
.
dummy
(),
modality
=
"audio"
,
modality
=
"audio"
,
identifier
=
"audio_2"
,
identifier
=
"audio_2"
,
mm_position
=
PlaceholderRange
(
offset
=
15
,
length
=
5
),
mm_position
=
PlaceholderRange
(
offset
=
15
,
length
=
5
),
...
...
tests/v1/streaming_input/test_scheduler_streaming.py
View file @
c6e7404c
...
@@ -174,7 +174,7 @@ class TestStreamingScheduler(unittest.TestCase):
...
@@ -174,7 +174,7 @@ class TestStreamingScheduler(unittest.TestCase):
scheduler
=
create_scheduler
()
scheduler
=
create_scheduler
()
mm_feature
=
MultiModalFeatureSpec
(
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"audio"
),
data
=
MultiModalKwargsItem
.
dummy
(),
modality
=
"audio"
,
modality
=
"audio"
,
identifier
=
""
,
identifier
=
""
,
mm_position
=
PlaceholderRange
(
offset
=
1
,
length
=
1
),
mm_position
=
PlaceholderRange
(
offset
=
1
,
length
=
1
),
...
@@ -187,7 +187,7 @@ class TestStreamingScheduler(unittest.TestCase):
...
@@ -187,7 +187,7 @@ class TestStreamingScheduler(unittest.TestCase):
session
.
num_computed_tokens
=
len
(
session
.
prompt_token_ids
)
session
.
num_computed_tokens
=
len
(
session
.
prompt_token_ids
)
mm_feature
=
MultiModalFeatureSpec
(
mm_feature
=
MultiModalFeatureSpec
(
data
=
MultiModalKwargsItem
.
dummy
(
"audio"
),
data
=
MultiModalKwargsItem
.
dummy
(),
modality
=
"audio"
,
modality
=
"audio"
,
identifier
=
""
,
identifier
=
""
,
mm_position
=
PlaceholderRange
(
offset
=
2
,
length
=
1
),
mm_position
=
PlaceholderRange
(
offset
=
2
,
length
=
1
),
...
...
tests/v1/test_serial_utils.py
View file @
c6e7404c
...
@@ -104,14 +104,10 @@ class MyRequest(msgspec.Struct):
...
@@ -104,14 +104,10 @@ class MyRequest(msgspec.Struct):
def
test_multimodal_kwargs
():
def
test_multimodal_kwargs
():
e1
=
MultiModalFieldElem
(
e1
=
MultiModalFieldElem
(
"audio"
,
"a0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
bfloat16
),
torch
.
zeros
(
1000
,
dtype
=
torch
.
bfloat16
),
MultiModalBatchedField
(),
MultiModalBatchedField
(),
)
)
e2
=
MultiModalFieldElem
(
e2
=
MultiModalFieldElem
(
"video"
,
"v0"
,
[
torch
.
zeros
(
1000
,
dtype
=
torch
.
int8
)
for
_
in
range
(
4
)],
[
torch
.
zeros
(
1000
,
dtype
=
torch
.
int8
)
for
_
in
range
(
4
)],
MultiModalFlatField
(
MultiModalFlatField
(
slices
=
[[
slice
(
1
,
2
,
3
),
slice
(
4
,
5
,
6
)],
[
slice
(
None
,
2
)]],
slices
=
[[
slice
(
1
,
2
,
3
),
slice
(
4
,
5
,
6
)],
[
slice
(
None
,
2
)]],
...
@@ -119,21 +115,20 @@ def test_multimodal_kwargs():
...
@@ -119,21 +115,20 @@ def test_multimodal_kwargs():
),
),
)
)
e3
=
MultiModalFieldElem
(
e3
=
MultiModalFieldElem
(
"image"
,
"i0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalSharedField
(
batch_size
=
4
),
MultiModalSharedField
(
batch_size
=
4
),
)
)
e4
=
MultiModalFieldElem
(
e4
=
MultiModalFieldElem
(
"image"
,
"i1"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalFlatField
(
slices
=
[
slice
(
1
,
2
,
3
),
slice
(
4
,
5
,
6
)],
dim
=
2
),
MultiModalFlatField
(
slices
=
[
slice
(
1
,
2
,
3
),
slice
(
4
,
5
,
6
)],
dim
=
2
),
)
)
audio
=
MultiModalKwargsItem
.
from_elems
([
e1
])
mm
=
MultiModalKwargsItems
(
video
=
MultiModalKwargsItem
.
from_elems
([
e2
])
{
image
=
MultiModalKwargsItem
.
from_elems
([
e3
,
e4
])
"audio"
:
[
MultiModalKwargsItem
({
"a0"
:
e1
})],
mm
=
MultiModalKwargsItems
.
from_seq
([
audio
,
video
,
image
])
"video"
:
[
MultiModalKwargsItem
({
"v0"
:
e2
})],
"image"
:
[
MultiModalKwargsItem
({
"i0"
:
e3
,
"i1"
:
e4
})],
}
)
# pack mm kwargs into a mock request so that it can be decoded properly
# pack mm kwargs into a mock request so that it can be decoded properly
req
=
MyRequest
([
mm
])
req
=
MyRequest
([
mm
])
...
@@ -147,8 +142,8 @@ def test_multimodal_kwargs():
...
@@ -147,8 +142,8 @@ def test_multimodal_kwargs():
total_len
=
sum
(
memoryview
(
x
).
cast
(
"B"
).
nbytes
for
x
in
encoded
)
total_len
=
sum
(
memoryview
(
x
).
cast
(
"B"
).
nbytes
for
x
in
encoded
)
# expected total encoding length, should be 1439
5
, +-20 for minor changes
# expected total encoding length, should be 143
1
9, +-20 for minor changes
assert
143
75
<=
total_len
<=
14
425
assert
143
00
<=
total_len
<=
14
340
decoded
=
decoder
.
decode
(
encoded
).
mm
[
0
]
decoded
=
decoder
.
decode
(
encoded
).
mm
[
0
]
assert
isinstance
(
decoded
,
MultiModalKwargsItems
)
assert
isinstance
(
decoded
,
MultiModalKwargsItems
)
...
...
vllm/multimodal/cache.py
View file @
c6e7404c
...
@@ -463,8 +463,8 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
...
@@ -463,8 +463,8 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
ring_buffer
=
ring_buffer
,
ring_buffer
=
ring_buffer
,
serde_class
=
MsgpackSerde
,
serde_class
=
MsgpackSerde
,
)
)
# cache
(
prompt_updates
, modality)
for P0 only
# cache prompt_updates for P0 only
self
.
_p0_cache
:
dict
[
str
,
tuple
[
Sequence
[
ResolvedPromptUpdate
]
,
str
]
]
=
{}
self
.
_p0_cache
:
dict
[
str
,
Sequence
[
ResolvedPromptUpdate
]]
=
{}
self
.
_hits
=
0
self
.
_hits
=
0
self
.
_total
=
0
self
.
_total
=
0
...
@@ -495,23 +495,22 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
...
@@ -495,23 +495,22 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
self
.
_total
+=
1
self
.
_total
+=
1
address
,
monotonic_id
=
self
.
_shm_cache
.
get_cached
(
mm_hash
)
address
,
monotonic_id
=
self
.
_shm_cache
.
get_cached
(
mm_hash
)
prompt_updates
,
modality
=
self
.
_p0_cache
[
mm_hash
]
prompt_updates
=
self
.
_p0_cache
[
mm_hash
]
return
self
.
address_as_item
(
address
,
monotonic_id
,
modality
),
prompt_updates
return
self
.
address_as_item
(
address
,
monotonic_id
),
prompt_updates
assert
mm_item
is
not
None
,
f
"Expected a cached item for
{
mm_hash
=
}
"
assert
mm_item
is
not
None
,
f
"Expected a cached item for
{
mm_hash
=
}
"
item
,
prompt_updates
=
mm_item
self
.
_total
+=
1
self
.
_total
+=
1
try
:
try
:
address
,
monotonic_id
=
self
.
_shm_cache
.
put
(
mm_hash
,
mm_
item
[
0
]
)
address
,
monotonic_id
=
self
.
_shm_cache
.
put
(
mm_hash
,
item
)
# Try to remove dangling items if p0 cache is too large.
# Try to remove dangling items if p0 cache is too large.
if
len
(
self
.
_p0_cache
)
>=
2
*
len
(
self
.
_shm_cache
.
key_index
):
if
len
(
self
.
_p0_cache
)
>=
2
*
len
(
self
.
_shm_cache
.
key_index
):
self
.
remove_dangling_items
()
self
.
remove_dangling_items
()
self
.
_p0_cache
[
mm_hash
]
=
mm_item
[
1
],
mm_item
[
0
].
modality
address_item
=
self
.
address_as_item
(
self
.
_p0_cache
[
mm_hash
]
=
prompt_updates
address
,
monotonic_id
,
mm_item
[
0
].
modality
return
self
.
address_as_item
(
address
,
monotonic_id
),
prompt_updates
)
return
address_item
,
mm_item
[
1
]
except
(
ValueError
,
MemoryError
)
as
e
:
except
(
ValueError
,
MemoryError
)
as
e
:
# put may fail if the object is too large or
# put may fail if the object is too large or
# the cache is full.
# the cache is full.
...
@@ -550,22 +549,20 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
...
@@ -550,22 +549,20 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
del
self
.
_p0_cache
[
mm_hash
]
del
self
.
_p0_cache
[
mm_hash
]
def
address_as_item
(
def
address_as_item
(
self
,
address
:
int
,
monotonic_id
:
int
,
modality
:
str
self
,
address
:
int
,
monotonic_id
:
int
,
)
->
MultiModalKwargsItem
:
)
->
MultiModalKwargsItem
:
addr_elem
=
MultiModalFieldElem
(
addr_elem
=
MultiModalFieldElem
(
modality
=
modality
,
key
=
"address"
,
data
=
address
,
data
=
address
,
field
=
MultiModalBatchedField
(),
field
=
MultiModalBatchedField
(),
)
)
id_elem
=
MultiModalFieldElem
(
id_elem
=
MultiModalFieldElem
(
modality
=
modality
,
key
=
"monotonic_id"
,
data
=
monotonic_id
,
data
=
monotonic_id
,
field
=
MultiModalBatchedField
(),
field
=
MultiModalBatchedField
(),
)
)
mm_item
=
MultiModalKwargsItem
.
from_elems
([
addr_elem
,
id_elem
])
return
mm_item
return
MultiModalKwargsItem
({
"address"
:
addr_elem
,
"monotonic_id"
:
id_elem
})
class
BaseMultiModalReceiverCache
(
class
BaseMultiModalReceiverCache
(
...
...
vllm/multimodal/inputs.py
View file @
c6e7404c
...
@@ -23,7 +23,7 @@ import numpy as np
...
@@ -23,7 +23,7 @@ import numpy as np
from
PIL.Image
import
Image
from
PIL.Image
import
Image
from
typing_extensions
import
NotRequired
,
TypeVar
from
typing_extensions
import
NotRequired
,
TypeVar
from
vllm.utils.collection_utils
import
full_groupby
,
is_list_of
from
vllm.utils.collection_utils
import
is_list_of
from
vllm.utils.import_utils
import
LazyLoader
from
vllm.utils.import_utils
import
LazyLoader
from
vllm.utils.jsontree
import
json_map_leaves
from
vllm.utils.jsontree
import
json_map_leaves
...
@@ -336,25 +336,33 @@ class MultiModalFeatureSpec:
...
@@ -336,25 +336,33 @@ class MultiModalFeatureSpec:
"""
"""
Represents a single multimodal input with its processed data and metadata.
Represents a single multimodal input with its processed data and metadata.
Used
by the V1 engine
to track multimodal data through processing and
Used to track multimodal data through processing and
caching.
caching.
A request containing multiple multimodal items will have one
A request containing multiple multimodal items will have one
MultiModalFeatureSpec per item.
`
MultiModalFeatureSpec
`
per item.
"""
"""
data
:
Optional
[
"MultiModalKwargsItem"
]
data
:
Optional
[
"MultiModalKwargsItem"
]
"""Multimodal data for this feature"""
"""
Represents multimodal data for this feature.
Can be `None` if the item is cached, to skip IPC between API server
and engine core processes.
"""
modality
:
str
modality
:
str
"""
Based on the input
, e.g., "image", "audio", "video"."""
"""
The input modality
, e.g.,
`
"image"
`
,
`
"audio"
`
,
`
"video"
`
."""
identifier
:
str
identifier
:
str
"""
mm_hash or uuid
for caching encoder outputs."""
"""
The hash
for caching encoder outputs
(with LoRA prefix if applicable)
."""
mm_position
:
PlaceholderRange
mm_position
:
PlaceholderRange
"""e.g., PlaceholderRange(offset=2, length=336)"""
"""
The location of the `modality` tokens corresponding to this item
in the prompt, e.g., `PlaceholderRange(offset=2, length=336)`.
"""
mm_hash
:
str
|
None
=
None
mm_hash
:
str
|
None
=
None
"""
Base mm_
hash for processor
cache
(without LoRA prefix)."""
"""
The
hash for
caching
processor
outputs
(without LoRA prefix)."""
@
staticmethod
@
staticmethod
def
gather_kwargs
(
features
:
list
[
"MultiModalFeatureSpec"
],
keys
:
set
[
str
]):
def
gather_kwargs
(
features
:
list
[
"MultiModalFeatureSpec"
],
keys
:
set
[
str
]):
...
@@ -373,23 +381,10 @@ class MultiModalFeatureSpec:
...
@@ -373,23 +381,10 @@ class MultiModalFeatureSpec:
@
dataclass
@
dataclass
class
MultiModalFieldElem
:
class
MultiModalFieldElem
:
"""
"""
Represents a keyword argument
inside
a
Represents a
processed
keyword argument
to pass to a model for
a
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem].
"""
"""
modality
:
str
"""
The modality of the corresponding multi-modal item.
Each multi-modal item can consist of multiple keyword arguments.
"""
key
:
str
"""
The key of this field in
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem],
i.e. the name of the keyword argument to be passed to the model.
"""
data
:
NestedTensors
data
:
NestedTensors
"""
"""
The tensor data of this field in
The tensor data of this field in
...
@@ -417,11 +412,7 @@ class MultiModalFieldElem:
...
@@ -417,11 +412,7 @@ class MultiModalFieldElem:
else
:
else
:
data_equal
=
nested_tensors_equal
(
self
.
data
,
other
.
data
)
data_equal
=
nested_tensors_equal
(
self
.
data
,
other
.
data
)
return
(
return
data_equal
and
type
(
self
.
field
)
is
type
(
other
.
field
)
# noqa: E721
(
self
.
modality
,
self
.
key
)
==
(
other
.
modality
,
other
.
key
)
and
data_equal
and
type
(
self
.
field
)
is
type
(
other
.
field
)
)
# noqa: E721
@
dataclass
(
frozen
=
True
,
kw_only
=
True
)
@
dataclass
(
frozen
=
True
,
kw_only
=
True
)
...
@@ -438,13 +429,8 @@ class BaseMultiModalField(ABC):
...
@@ -438,13 +429,8 @@ class BaseMultiModalField(ABC):
when `MultiModalKwargsItems.get_data()` is called to batch the data.
when `MultiModalKwargsItems.get_data()` is called to batch the data.
"""
"""
def
_field_factory
(
self
,
*
,
modality
:
str
,
key
:
str
):
def
_field_factory
(
self
):
f
=
partial
(
f
=
partial
(
MultiModalFieldElem
,
field
=
self
)
MultiModalFieldElem
,
modality
=
modality
,
key
=
key
,
field
=
self
,
)
# Allow passing data as positional argument
# Allow passing data as positional argument
def
factory
(
data
:
NestedTensors
)
->
MultiModalFieldElem
:
def
factory
(
data
:
NestedTensors
)
->
MultiModalFieldElem
:
...
@@ -519,7 +505,7 @@ class MultiModalBatchedField(BaseMultiModalField):
...
@@ -519,7 +505,7 @@ class MultiModalBatchedField(BaseMultiModalField):
key
:
str
,
key
:
str
,
data
:
NestedTensors
,
data
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
)
->
Sequence
[
MultiModalFieldElem
]:
field_factory
=
self
.
_field_factory
(
modality
=
modality
,
key
=
key
)
field_factory
=
self
.
_field_factory
()
return
[
field_factory
(
item
)
for
item
in
data
]
return
[
field_factory
(
item
)
for
item
in
data
]
def
_reduce_data
(
def
_reduce_data
(
...
@@ -565,7 +551,7 @@ class MultiModalFlatField(BaseMultiModalField):
...
@@ -565,7 +551,7 @@ class MultiModalFlatField(BaseMultiModalField):
key
:
str
,
key
:
str
,
data
:
NestedTensors
,
data
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
)
->
Sequence
[
MultiModalFieldElem
]:
field_factory
=
self
.
_field_factory
(
modality
=
modality
,
key
=
key
)
field_factory
=
self
.
_field_factory
()
if
not
is_list_of
(
self
.
slices
,
slice
,
check
=
"all"
):
if
not
is_list_of
(
self
.
slices
,
slice
,
check
=
"all"
):
assert
isinstance
(
data
,
torch
.
Tensor
),
(
assert
isinstance
(
data
,
torch
.
Tensor
),
(
"torch.Tensor is required for multiple slices"
"torch.Tensor is required for multiple slices"
...
@@ -664,7 +650,7 @@ class MultiModalSharedField(BaseMultiModalField):
...
@@ -664,7 +650,7 @@ class MultiModalSharedField(BaseMultiModalField):
key
:
str
,
key
:
str
,
data
:
NestedTensors
,
data
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
)
->
Sequence
[
MultiModalFieldElem
]:
field_factory
=
self
.
_field_factory
(
modality
=
modality
,
key
=
key
)
field_factory
=
self
.
_field_factory
()
return
[
field_factory
(
data
)]
*
self
.
batch_size
return
[
field_factory
(
data
)]
*
self
.
batch_size
def
_reduce_data
(
def
_reduce_data
(
...
@@ -899,37 +885,19 @@ class MultiModalFieldConfig:
...
@@ -899,37 +885,19 @@ class MultiModalFieldConfig:
class
MultiModalKwargsItem
(
UserDict
[
str
,
MultiModalFieldElem
]):
class
MultiModalKwargsItem
(
UserDict
[
str
,
MultiModalFieldElem
]):
"""
"""
A collection of
A dictionary of processed keyword arguments to pass to the model,
[`MultiModalFieldElem`][vllm.multimodal.inputs.MultiModalFieldElem]
corresponding to a single item in
corresponding to a data item in
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
"""
"""
@
staticmethod
@
staticmethod
def
dummy
(
modality
:
str
,
nbytes
:
int
=
1
):
def
dummy
(
nbytes
:
int
=
1
):
"""Convenience class for testing."""
"""Convenience class for testing."""
mm_elem
=
MultiModalFieldElem
(
mm_elem
=
MultiModalFieldElem
(
modality
=
modality
,
key
=
"dummy"
,
data
=
torch
.
empty
(
nbytes
,
dtype
=
torch
.
uint8
),
data
=
torch
.
empty
(
nbytes
,
dtype
=
torch
.
uint8
),
field
=
MultiModalSharedField
(
batch_size
=
1
),
field
=
MultiModalSharedField
(
batch_size
=
1
),
)
)
return
MultiModalKwargsItem
.
from_elems
([
mm_elem
])
return
MultiModalKwargsItem
({
"dummy"
:
mm_elem
})
@
staticmethod
def
from_elems
(
elems
:
Sequence
[
MultiModalFieldElem
]):
return
MultiModalKwargsItem
({
elem
.
key
:
elem
for
elem
in
elems
})
def
__init__
(
self
,
data
:
Mapping
[
str
,
MultiModalFieldElem
]
=
{})
->
None
:
super
().
__init__
(
data
)
modalities
=
{
elem
.
modality
for
elem
in
self
.
values
()}
assert
len
(
modalities
)
==
1
,
f
"Found different modalities=
{
modalities
}
"
self
.
_modality
=
next
(
iter
(
modalities
))
@
property
def
modality
(
self
)
->
str
:
return
self
.
_modality
def
get_data
(
self
)
->
dict
[
str
,
NestedTensors
]:
def
get_data
(
self
)
->
dict
[
str
,
NestedTensors
]:
return
{
key
:
elem
.
data
for
key
,
elem
in
self
.
items
()}
return
{
key
:
elem
.
data
for
key
,
elem
in
self
.
items
()}
...
@@ -945,9 +913,38 @@ _I = TypeVar(
...
@@ -945,9 +913,38 @@ _I = TypeVar(
class
MultiModalKwargsItems
(
UserDict
[
str
,
Sequence
[
_I
]]):
class
MultiModalKwargsItems
(
UserDict
[
str
,
Sequence
[
_I
]]):
"""
"""
A dictionary of
A dictionary of processed multi-modal inputs by modality.
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
by modality.
For example, given a processor that processes
images into `pixel_values` and `image_grid_thw`,
and audios into `input_audio_features`,
a prompt with 2 images and 1 audio will be processed
into a `MultiModalKwargsItems` with the following structure:
```python
MultiModalKwargsItems(
{
"image": [
# For the first image
MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
# For the second imgae
MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
],
"audio": [
# For the first audio
MultiModalKwargsItem({"input_audio_features": ...}),
],
}
)
```
Unlike HF processing which returns all items
in a single dictionary with batched keyword arguments,
we split up the items because some of them may already be cached.
Also, items from multiple requests may be batched together to improve throughput,
using the logic defined by the
[`BaseMultiModalField`][vllm.multimodal.inputs.BaseMultiModalField]
for each keyword argument.
"""
"""
@
staticmethod
@
staticmethod
...
@@ -967,7 +964,7 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
...
@@ -967,7 +964,7 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
elems_by_key
[
key
]
=
elems
elems_by_key
[
key
]
=
elems
keys_by_modality
[
config
.
modality
].
add
(
key
)
keys_by_modality
[
config
.
modality
].
add
(
key
)
items
=
list
[
MultiModalKwargsItem
]()
items
_by_modality
=
dict
[
str
,
list
[
MultiModalKwargsItem
]
]
()
for
modality
,
keys
in
keys_by_modality
.
items
():
for
modality
,
keys
in
keys_by_modality
.
items
():
elems_in_modality
=
{
k
:
elems_by_key
[
k
]
for
k
in
keys
}
elems_in_modality
=
{
k
:
elems_by_key
[
k
]
for
k
in
keys
}
batch_sizes
=
{
k
:
len
(
v
)
for
k
,
v
in
elems_in_modality
.
items
()}
batch_sizes
=
{
k
:
len
(
v
)
for
k
,
v
in
elems_in_modality
.
items
()}
...
@@ -979,15 +976,11 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
...
@@ -979,15 +976,11 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
)
)
batch_size
=
next
(
iter
(
batch_sizes
.
values
()))
batch_size
=
next
(
iter
(
batch_sizes
.
values
()))
for
item_idx
in
range
(
batch_size
):
items_by_modality
[
modality
]
=
[
elems
=
[
v
[
item_idx
]
for
v
in
elems_in_modality
.
values
()]
MultiModalKwargsItem
({
k
:
v
[
i
]
for
k
,
v
in
elems_in_modality
.
items
()})
items
.
append
(
MultiModalKwargsItem
.
from_elems
(
elems
))
for
i
in
range
(
batch_size
)
]
return
MultiModalKwargsItems
.
from_seq
(
items
)
@
staticmethod
def
from_seq
(
items
:
Sequence
[
MultiModalKwargsItem
]):
items_by_modality
=
full_groupby
(
items
,
key
=
lambda
x
:
x
.
modality
)
return
MultiModalKwargsItems
(
items_by_modality
)
return
MultiModalKwargsItems
(
items_by_modality
)
def
__getitem__
(
self
,
modality
:
str
)
->
Sequence
[
_I
]:
def
__getitem__
(
self
,
modality
:
str
)
->
Sequence
[
_I
]:
...
...
vllm/multimodal/utils.py
View file @
c6e7404c
...
@@ -467,7 +467,7 @@ def argsort_mm_positions(
...
@@ -467,7 +467,7 @@ def argsort_mm_positions(
def
group_mm_kwargs_by_modality
(
def
group_mm_kwargs_by_modality
(
mm_kwargs
:
list
[
MultiModalKwargsItem
],
mm_kwargs
:
list
[
tuple
[
str
,
MultiModalKwargsItem
]
]
,
*
,
*
,
device
:
torch
.
types
.
Device
=
None
,
device
:
torch
.
types
.
Device
=
None
,
pin_memory
:
bool
=
False
,
pin_memory
:
bool
=
False
,
...
@@ -485,9 +485,9 @@ def group_mm_kwargs_by_modality(
...
@@ -485,9 +485,9 @@ def group_mm_kwargs_by_modality(
"""
"""
from
vllm.multimodal.inputs
import
MultiModalKwargsItems
from
vllm.multimodal.inputs
import
MultiModalKwargsItems
for
modality
,
items
in
groupby
(
mm_kwargs
,
key
=
lambda
item
:
item
.
modality
):
for
modality
,
group
in
groupby
(
mm_kwargs
,
key
=
lambda
x
:
x
[
0
]
):
items_lst
=
list
(
items
)
items_lst
=
[
item
for
_
,
item
in
group
]
mm_kwargs_items
=
MultiModalKwargsItems
.
from_seq
(
items_lst
)
mm_kwargs_items
=
MultiModalKwargsItems
({
modality
:
items_lst
}
)
mm_kwargs_data
=
mm_kwargs_items
.
get_data
(
mm_kwargs_data
=
mm_kwargs_items
.
get_data
(
device
=
device
,
device
=
device
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
...
...
vllm/v1/serial_utils.py
View file @
c6e7404c
...
@@ -242,13 +242,11 @@ class MsgpackEncoder:
...
@@ -242,13 +242,11 @@ class MsgpackEncoder:
for
modality
,
itemlist
in
items
.
items
()
for
modality
,
itemlist
in
items
.
items
()
}
}
def
_encode_mm_item
(
self
,
item
:
MultiModalKwargsItem
)
->
list
[
dict
[
str
,
Any
]
]
:
def
_encode_mm_item
(
self
,
item
:
MultiModalKwargsItem
)
->
dict
[
str
,
Any
]:
return
[
self
.
_encode_mm_field_elem
(
elem
)
for
elem
in
item
.
value
s
()
]
return
{
key
:
self
.
_encode_mm_field_elem
(
elem
)
for
key
,
elem
in
item
.
item
s
()
}
def
_encode_mm_field_elem
(
self
,
elem
:
MultiModalFieldElem
)
->
dict
[
str
,
Any
]:
def
_encode_mm_field_elem
(
self
,
elem
:
MultiModalFieldElem
)
->
dict
[
str
,
Any
]:
return
{
return
{
"modality"
:
elem
.
modality
,
"key"
:
elem
.
key
,
"data"
:
(
"data"
:
(
None
if
elem
.
data
is
None
else
self
.
_encode_nested_tensors
(
elem
.
data
)
None
if
elem
.
data
is
None
else
self
.
_encode_nested_tensors
(
elem
.
data
)
),
),
...
@@ -383,9 +381,9 @@ class MsgpackDecoder:
...
@@ -383,9 +381,9 @@ class MsgpackDecoder:
}
}
)
)
def
_decode_mm_item
(
self
,
obj
:
list
[
Any
])
->
MultiModalKwargsItem
:
def
_decode_mm_item
(
self
,
obj
:
dict
[
str
,
Any
])
->
MultiModalKwargsItem
:
return
MultiModalKwargsItem
.
from_elems
(
return
MultiModalKwargsItem
(
[
self
.
_decode_mm_field_elem
(
v
)
for
v
in
obj
]
{
key
:
self
.
_decode_mm_field_elem
(
elem
)
for
key
,
elem
in
obj
.
items
()}
)
)
def
_decode_mm_field_elem
(
self
,
obj
:
dict
[
str
,
Any
])
->
MultiModalFieldElem
:
def
_decode_mm_field_elem
(
self
,
obj
:
dict
[
str
,
Any
])
->
MultiModalFieldElem
:
...
...
vllm/v1/worker/gpu/mm/encoder_runner.py
View file @
c6e7404c
...
@@ -43,9 +43,9 @@ class EncoderRunner:
...
@@ -43,9 +43,9 @@ class EncoderRunner:
def
prepare_mm_inputs
(
def
prepare_mm_inputs
(
self
,
self
,
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]],
scheduled_encoder_inputs
:
dict
[
str
,
list
[
int
]],
)
->
tuple
[
list
[
str
],
list
[
MultiModalKwargsItem
]]:
)
->
tuple
[
list
[
str
],
list
[
tuple
[
str
,
MultiModalKwargsItem
]]
]
:
mm_hashes
:
list
[
str
]
=
[]
mm_hashes
:
list
[
str
]
=
[]
mm_kwargs
:
list
[
MultiModalKwargsItem
]
=
[]
mm_kwargs
:
list
[
tuple
[
str
,
MultiModalKwargsItem
]
]
=
[]
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
for
req_id
,
encoder_input_ids
in
scheduled_encoder_inputs
.
items
():
mm_features
=
self
.
req_id_to_mm_features
[
req_id
]
mm_features
=
self
.
req_id_to_mm_features
[
req_id
]
for
mm_input_id
in
encoder_input_ids
:
for
mm_input_id
in
encoder_input_ids
:
...
@@ -53,7 +53,8 @@ class EncoderRunner:
...
@@ -53,7 +53,8 @@ class EncoderRunner:
if
mm_feature
.
data
is
None
:
if
mm_feature
.
data
is
None
:
continue
continue
mm_hashes
.
append
(
mm_feature
.
identifier
)
mm_hashes
.
append
(
mm_feature
.
identifier
)
mm_kwargs
.
append
(
mm_feature
.
data
)
mm_kwargs
.
append
((
mm_feature
.
modality
,
mm_feature
.
data
))
return
mm_hashes
,
mm_kwargs
return
mm_hashes
,
mm_kwargs
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -61,7 +62,7 @@ class EncoderRunner:
...
@@ -61,7 +62,7 @@ class EncoderRunner:
self
,
self
,
model
:
SupportsMultiModal
,
model
:
SupportsMultiModal
,
mm_hashes
:
list
[
str
],
mm_hashes
:
list
[
str
],
mm_kwargs
:
list
[
MultiModalKwargsItem
],
mm_kwargs
:
list
[
tuple
[
str
,
MultiModalKwargsItem
]
]
,
)
->
list
[
torch
.
Tensor
]:
)
->
list
[
torch
.
Tensor
]:
if
not
mm_hashes
:
if
not
mm_hashes
:
return
[]
return
[]
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
c6e7404c
...
@@ -1217,11 +1217,11 @@ class GPUModelRunner(
...
@@ -1217,11 +1217,11 @@ class GPUModelRunner(
if
not
scheduler_output
or
not
self
.
is_multimodal_raw_input_only_model
:
if
not
scheduler_output
or
not
self
.
is_multimodal_raw_input_only_model
:
return
{}
return
{}
mm_kwargs
=
list
[
MultiModalKwargsItem
]()
mm_kwargs
=
list
[
tuple
[
str
,
MultiModalKwargsItem
]
]
()
for
req
in
scheduler_output
.
scheduled_new_reqs
:
for
req
in
scheduler_output
.
scheduled_new_reqs
:
for
feature
in
req
.
mm_features
:
for
feature
in
req
.
mm_features
:
if
feature
.
data
is
not
None
:
if
feature
.
data
is
not
None
:
mm_kwargs
.
append
(
feature
.
data
)
mm_kwargs
.
append
(
(
feature
.
modality
,
feature
.
data
)
)
# Input all modalities at once
# Input all modalities at once
mm_kwargs_combined
:
BatchedTensorInputs
=
{}
mm_kwargs_combined
:
BatchedTensorInputs
=
{}
...
@@ -2219,7 +2219,7 @@ class GPUModelRunner(
...
@@ -2219,7 +2219,7 @@ class GPUModelRunner(
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
)
->
tuple
[
)
->
tuple
[
list
[
str
],
list
[
str
],
list
[
MultiModalKwargsItem
],
list
[
tuple
[
str
,
MultiModalKwargsItem
]
]
,
list
[
tuple
[
str
,
PlaceholderRange
]],
list
[
tuple
[
str
,
PlaceholderRange
]],
]:
]:
"""Batch multimodal inputs from scheduled encoder inputs.
"""Batch multimodal inputs from scheduled encoder inputs.
...
@@ -2239,7 +2239,7 @@ class GPUModelRunner(
...
@@ -2239,7 +2239,7 @@ class GPUModelRunner(
return
[],
[],
[]
return
[],
[],
[]
mm_hashes
=
list
[
str
]()
mm_hashes
=
list
[
str
]()
mm_kwargs
=
list
[
MultiModalKwargsItem
]()
mm_kwargs
=
list
[
tuple
[
str
,
MultiModalKwargsItem
]
]
()
# Multimodal LoRA reference info to map each multimodal item
# Multimodal LoRA reference info to map each multimodal item
# back to its request & position
# back to its request & position
mm_lora_refs
=
list
[
tuple
[
str
,
PlaceholderRange
]]()
mm_lora_refs
=
list
[
tuple
[
str
,
PlaceholderRange
]]()
...
@@ -2252,7 +2252,7 @@ class GPUModelRunner(
...
@@ -2252,7 +2252,7 @@ class GPUModelRunner(
continue
continue
mm_hashes
.
append
(
mm_feature
.
identifier
)
mm_hashes
.
append
(
mm_feature
.
identifier
)
mm_kwargs
.
append
(
mm_feature
.
data
)
mm_kwargs
.
append
(
(
mm_feature
.
modality
,
mm_feature
.
data
)
)
mm_lora_refs
.
append
((
req_id
,
mm_feature
.
mm_position
))
mm_lora_refs
.
append
((
req_id
,
mm_feature
.
mm_position
))
return
mm_hashes
,
mm_kwargs
,
mm_lora_refs
return
mm_hashes
,
mm_kwargs
,
mm_lora_refs
...
@@ -4475,12 +4475,10 @@ class GPUModelRunner(
...
@@ -4475,12 +4475,10 @@ class GPUModelRunner(
# but not read from the cache
# but not read from the cache
assert
dummy_mm_item
is
not
None
,
"Item should not already be cached"
assert
dummy_mm_item
is
not
None
,
"Item should not already be cached"
dummy_mm_items
=
[
dummy_mm_item
]
*
max_items_per_batch
return
next
(
return
next
(
mm_kwargs_group
mm_kwargs_group
for
_
,
_
,
mm_kwargs_group
in
group_mm_kwargs_by_modality
(
for
_
,
_
,
mm_kwargs_group
in
group_mm_kwargs_by_modality
(
dummy_mm_item
s
,
[(
modality
,
dummy_mm_item
)]
*
max_items_per_batch
,
device
=
self
.
device
,
device
=
self
.
device
,
pin_memory
=
self
.
pin_memory
,
pin_memory
=
self
.
pin_memory
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment