Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b09806e2
Unverified
Commit
b09806e2
authored
Dec 13, 2025
by
Cyrus Leung
Committed by
GitHub
Dec 13, 2025
Browse files
[Bugfix] Dictionary MM embeddings for online chat (#30507)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
fdc135d7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
193 additions
and
44 deletions
+193
-44
tests/entrypoints/test_chat_utils.py
tests/entrypoints/test_chat_utils.py
+105
-5
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+68
-29
vllm/v1/engine/input_processor.py
vllm/v1/engine/input_processor.py
+20
-10
No files found.
tests/entrypoints/test_chat_utils.py
View file @
b09806e2
...
...
@@ -796,9 +796,13 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
"content"
:
"<|image_1|>
\n
What's in this image?"
,
}
]
assert
mm_data
is
not
None
assert
"image"
in
mm_data
assert
mm_data
[
"image"
]
is
None
assert
isinstance
(
mm_data
[
"image"
],
list
)
assert
len
(
mm_data
[
"image"
])
==
1
assert
mm_data
[
"image"
][
0
]
is
None
_assert_mm_uuids
(
mm_uuids
,
1
,
expected_uuids
=
[
uuid
])
...
...
@@ -825,10 +829,11 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
# Should have audio in mm_data as None (UUID provided)
assert
mm_data
is
not
None
assert
"audio"
in
mm_data
assert
mm_data
[
"audio"
]
is
None
assert
isinstance
(
mm_data
[
"audio"
],
list
)
assert
len
(
mm_data
[
"audio"
])
==
1
assert
mm_data
[
"audio"
][
0
]
is
None
# UUID should be recorded
assert
mm_uuids
is
not
None
assert
"audio"
in
mm_uuids
_assert_mm_uuids
(
mm_uuids
,
1
,
modality
=
"audio"
,
expected_uuids
=
[
uuid
])
...
...
@@ -1121,10 +1126,105 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
mm_data
=
await
mm_future
assert
mm_data
is
not
None
assert
"image"
in
mm_data
assert
mm_data
[
"image"
]
is
None
assert
isinstance
(
mm_data
[
"image"
],
list
)
assert
len
(
mm_data
[
"image"
])
==
1
assert
mm_data
[
"image"
][
0
]
is
None
_assert_mm_uuids
(
mm_uuids
,
1
,
expected_uuids
=
[
uuid
])
def
test_parse_chat_messages_empty_dict_image_embeds
(
phi3v_model_config_image_embeds
,
):
"""Test that empty dictionary for image_embeds is handled without errors."""
conversation
,
mm_data
,
mm_uuids
=
parse_chat_messages
(
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_embeds"
,
"image_embeds"
:
{}},
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
},
],
}
],
phi3v_model_config_image_embeds
,
content_format
=
"string"
,
)
# Verify conversation structure
assert
conversation
==
[
{
"role"
:
"user"
,
"content"
:
"<|image_1|>
\n
What's in this image?"
,
}
]
# Verify mm_data contains an empty dictionary of embeddings
assert
mm_data
is
not
None
assert
"image"
in
mm_data
assert
isinstance
(
mm_data
[
"image"
],
dict
)
assert
len
(
mm_data
[
"image"
])
==
0
# Verify UUIDs (None since we didn't provide any)
_assert_mm_uuids
(
mm_uuids
,
1
,
expected_uuids
=
[
None
])
def
test_parse_chat_messages_multiple_dict_image_embeds
(
phi3v_model_config_image_embeds
,
):
"""Test that multiple dictionaries for image_embeds is handled without errors."""
# Create two sample image embedding tensors
batch_size
=
2
image_embedding_1
=
torch
.
randn
(
batch_size
,
256
,
1024
)
image_embedding_2
=
torch
.
randn
(
batch_size
,
3
)
conversation
,
mm_data
,
mm_uuids
=
parse_chat_messages
(
[
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"image_embeds"
,
"image_embeds"
:
{
"image_embedding_1"
:
tensor2base64
(
p
),
"image_embedding_2"
:
tensor2base64
(
i
),
},
}
for
p
,
i
in
zip
(
image_embedding_1
,
image_embedding_2
)
]
+
[
{
"type"
:
"text"
,
"text"
:
"Describe these two images."
},
],
}
],
phi3v_model_config_image_embeds
,
content_format
=
"string"
,
)
# Verify conversation structure
assert
conversation
==
[
{
"role"
:
"user"
,
"content"
:
"<|image_1|>
\n
<|image_2|>
\n
Describe these two images."
,
}
]
# Verify mm_data contains a dictionary of multi-embeddings
assert
mm_data
is
not
None
assert
"image"
in
mm_data
assert
isinstance
(
mm_data
[
"image"
],
dict
)
assert
len
(
mm_data
[
"image"
])
==
batch_size
# Verify each embedding has the correct shape
assert
isinstance
(
mm_data
[
"image"
][
"image_embedding_1"
],
torch
.
Tensor
)
assert
mm_data
[
"image"
][
"image_embedding_1"
].
shape
==
image_embedding_1
.
shape
assert
isinstance
(
mm_data
[
"image"
][
"image_embedding_2"
],
torch
.
Tensor
)
assert
mm_data
[
"image"
][
"image_embedding_2"
].
shape
==
image_embedding_2
.
shape
# Verify UUIDs (None since we didn't provide any)
_assert_mm_uuids
(
mm_uuids
,
batch_size
,
expected_uuids
=
[
None
,
None
])
@
pytest
.
mark
.
asyncio
async
def
test_parse_chat_messages_multiple_images_async
(
phi3v_model_config
,
...
...
vllm/entrypoints/chat_utils.py
View file @
b09806e2
...
...
@@ -9,7 +9,7 @@ from collections import Counter, defaultdict, deque
from
collections.abc
import
Awaitable
,
Callable
,
Iterable
from
functools
import
cached_property
,
lru_cache
,
partial
from
pathlib
import
Path
from
typing
import
Any
,
Generic
,
Literal
,
TypeAlias
,
TypeVar
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
Literal
,
TypeAlias
,
TypeVar
,
cast
import
jinja2
import
jinja2.ext
...
...
@@ -53,7 +53,14 @@ from vllm.tokenizers import MistralTokenizer, TokenizerLike
from
vllm.transformers_utils.chat_templates
import
get_chat_template_fallback_path
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.utils
import
random_uuid
from
vllm.utils.collection_utils
import
is_list_of
from
vllm.utils.func_utils
import
supports_kw
from
vllm.utils.import_utils
import
LazyLoader
if
TYPE_CHECKING
:
import
torch
else
:
torch
=
LazyLoader
(
"torch"
,
globals
(),
"torch"
)
logger
=
init_logger
(
__name__
)
...
...
@@ -620,6 +627,44 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
_T
=
TypeVar
(
"_T"
)
def
_extract_embeds
(
tensors
:
list
[
torch
.
Tensor
]):
if
len
(
tensors
)
==
0
:
return
tensors
if
len
(
tensors
)
==
1
:
tensors
[
0
].
_is_single_item
=
True
# type: ignore
return
tensors
[
0
]
# To keep backwards compatibility for single item input
first_shape
=
tensors
[
0
].
shape
if
all
(
t
.
shape
==
first_shape
for
t
in
tensors
):
return
torch
.
stack
(
tensors
)
return
tensors
def
_get_embeds_data
(
items_by_modality
:
dict
[
str
,
list
[
Any
]],
modality
:
str
):
embeds_key
=
f
"
{
modality
}
_embeds"
embeds
=
items_by_modality
[
embeds_key
]
if
len
(
embeds
)
==
0
:
return
embeds
if
is_list_of
(
embeds
,
torch
.
Tensor
):
return
_extract_embeds
(
embeds
)
if
is_list_of
(
embeds
,
dict
):
if
not
embeds
:
return
{}
first_keys
=
set
(
embeds
[
0
].
keys
())
if
any
(
set
(
item
.
keys
())
!=
first_keys
for
item
in
embeds
[
1
:]):
raise
ValueError
(
"All dictionaries in the list of embeddings must have the same keys."
)
return
{
k
:
_extract_embeds
([
item
[
k
]
for
item
in
embeds
])
for
k
in
first_keys
}
return
embeds
class
BaseMultiModalItemTracker
(
ABC
,
Generic
[
_T
]):
"""
Tracks multi-modal items in a given request and ensures that the number
...
...
@@ -688,11 +733,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def
all_mm_uuids
(
self
)
->
MultiModalUUIDDict
|
None
:
if
not
self
.
_items_by_modality
:
return
None
mm_uuids
=
{}
uuids_by_modality
=
dict
(
self
.
_uuids_by_modality
)
if
"image"
in
uuids_by_modality
and
"image_embeds"
in
uuids_by_modality
:
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
if
"audio"
in
uuids_by_modality
and
"audio_embeds"
in
uuids_by_modality
:
raise
ValueError
(
"Mixing raw audio and embedding inputs is not allowed"
)
mm_uuids
=
{}
if
"image_embeds"
in
uuids_by_modality
:
mm_uuids
[
"image"
]
=
uuids_by_modality
[
"image_embeds"
]
if
"image"
in
uuids_by_modality
:
...
...
@@ -703,6 +751,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
mm_uuids
[
"audio"
]
=
uuids_by_modality
[
"audio"
]
# UUIDs of audios
if
"video"
in
uuids_by_modality
:
mm_uuids
[
"video"
]
=
uuids_by_modality
[
"video"
]
# UUIDs of videos
return
mm_uuids
@
abstractmethod
...
...
@@ -714,29 +763,25 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def
all_mm_data
(
self
)
->
MultiModalDataDict
|
None
:
if
not
self
.
_items_by_modality
:
return
None
mm_inputs
=
{}
items_by_modality
=
dict
(
self
.
_items_by_modality
)
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
if
"audio"
in
items_by_modality
and
"audio_embeds"
in
items_by_modality
:
raise
ValueError
(
"Mixing raw audio and embedding inputs is not allowed"
)
mm_inputs
=
{}
if
"image_embeds"
in
items_by_modality
:
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
mm_inputs
[
"image"
]
=
(
image_embeds_lst
if
len
(
image_embeds_lst
)
!=
1
else
image_embeds_lst
[
0
]
)
mm_inputs
[
"image"
]
=
_get_embeds_data
(
items_by_modality
,
"image"
)
if
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
if
"audio_embeds"
in
items_by_modality
:
audio_embeds_lst
=
items_by_modality
[
"audio_embeds"
]
mm_inputs
[
"audio"
]
=
(
audio_embeds_lst
if
len
(
audio_embeds_lst
)
!=
1
else
audio_embeds_lst
[
0
]
)
mm_inputs
[
"audio"
]
=
_get_embeds_data
(
items_by_modality
,
"audio"
)
if
"audio"
in
items_by_modality
:
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
if
"video"
in
items_by_modality
:
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
return
mm_inputs
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
...
...
@@ -747,38 +792,32 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async
def
all_mm_data
(
self
)
->
MultiModalDataDict
|
None
:
if
not
self
.
_items_by_modality
:
return
None
mm_inputs
=
{}
items_by_modality
=
{}
for
modality
,
items
in
self
.
_items_by_modality
.
items
():
coros
=
[]
for
item
in
items
:
if
item
is
not
None
:
coros
.
append
(
item
)
else
:
coros
.
append
(
asyncio
.
sleep
(
0
))
items_by_modality
[
modality
]
=
await
asyncio
.
gather
(
*
coros
)
coros_by_modality
=
{
modality
:
[
item
or
asyncio
.
sleep
(
0
)
for
item
in
items
]
for
modality
,
items
in
self
.
_items_by_modality
.
items
()
}
items_by_modality
:
dict
[
str
,
list
[
object
|
None
]]
=
{
modality
:
await
asyncio
.
gather
(
*
coros
)
for
modality
,
coros
in
coros_by_modality
.
items
()
}
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
if
"audio"
in
items_by_modality
and
"audio_embeds"
in
items_by_modality
:
raise
ValueError
(
"Mixing raw audio and embedding inputs is not allowed"
)
mm_inputs
=
{}
if
"image_embeds"
in
items_by_modality
:
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
mm_inputs
[
"image"
]
=
(
image_embeds_lst
if
len
(
image_embeds_lst
)
!=
1
else
image_embeds_lst
[
0
]
)
mm_inputs
[
"image"
]
=
_get_embeds_data
(
items_by_modality
,
"image"
)
if
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
if
"audio_embeds"
in
items_by_modality
:
audio_embeds_lst
=
items_by_modality
[
"audio_embeds"
]
mm_inputs
[
"audio"
]
=
(
audio_embeds_lst
if
len
(
audio_embeds_lst
)
!=
1
else
audio_embeds_lst
[
0
]
)
mm_inputs
[
"audio"
]
=
_get_embeds_data
(
items_by_modality
,
"audio"
)
if
"audio"
in
items_by_modality
:
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
if
"video"
in
items_by_modality
:
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
return
mm_inputs
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
...
...
vllm/v1/engine/input_processor.py
View file @
b09806e2
...
...
@@ -188,29 +188,39 @@ class InputProcessor:
def
_validate_single_prompt
(
single_prompt
:
dict
|
str
)
->
None
:
if
not
isinstance
(
single_prompt
,
dict
):
return
mm_data
=
single_prompt
.
get
(
"multi_modal_data"
)
mm_uuids
=
single_prompt
.
get
(
"multi_modal_uuids"
)
if
not
mm_data
or
not
mm_uuids
:
return
import
torch
def
_get_len
(
items
:
object
):
if
isinstance
(
items
,
dict
):
# Embedding inputs
return
_get_len
(
next
(
iter
(
items
.
values
())))
if
items
else
1
if
isinstance
(
items
,
list
):
return
len
(
items
)
if
isinstance
(
items
,
torch
.
Tensor
):
# To keep backwards compatibility for single item embedding input
return
1
if
getattr
(
items
,
"_is_single_item"
,
False
)
else
len
(
items
)
return
1
for
modality
,
items
in
mm_data
.
items
():
if
modality
in
mm_uuids
:
data_len
=
len
(
items
)
if
isinstance
(
items
,
list
)
else
1
uuid_len
=
(
len
(
mm_uuids
[
modality
])
if
isinstance
(
mm_uuids
[
modality
],
list
)
else
1
)
data_len
=
_get_len
(
items
)
uuid_len
=
_get_len
(
mm_uuids
[
modality
])
if
uuid_len
!=
data_len
:
raise
ValueError
(
f
"multi_modal_uuids for modality
'
{
modality
}
'
"
f
"multi_modal_uuids for modality
{
modality
!
r
}
"
"must have same length as data: got "
f
"
{
uuid_len
}
uuids vs "
f
"
{
data_len
}
items."
f
"
{
uuid_len
}
uuids vs
{
data_len
}
items."
)
else
:
raise
ValueError
(
f
"multi_modal_uuids for modality
'
{
modality
}
'
must "
f
"multi_modal_uuids for modality
{
modality
!
r
}
must "
"be provided if multi_modal_data is provided."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment