Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b0746fae
Unverified
Commit
b0746fae
authored
Mar 10, 2025
by
Chauncey
Committed by
GitHub
Mar 10, 2025
Browse files
[Frontend] support image embeds (#13955)
Signed-off-by:
chaunceyjiang
<
chaunceyjiang@gmail.com
>
parent
60a98b2d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
201 additions
and
12 deletions
+201
-12
docs/source/serving/multimodal_inputs.md
docs/source/serving/multimodal_inputs.md
+66
-1
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+103
-10
vllm/multimodal/image.py
vllm/multimodal/image.py
+19
-0
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+13
-1
No files found.
docs/source/serving/multimodal_inputs.md
View file @
b0746fae
...
@@ -462,4 +462,69 @@ export VLLM_AUDIO_FETCH_TIMEOUT=<timeout>
...
@@ -462,4 +462,69 @@ export VLLM_AUDIO_FETCH_TIMEOUT=<timeout>
### Embedding Inputs
### Embedding Inputs
TBD
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape to the corresponding field of the multi-modal dictionary.
#### Image Embedding Inputs
For image embeddings, you can pass the base64-encoded tensor to the
`image_embeds`
field.
The following example demonstrates how to pass image embeddings to the OpenAI server:
```
python
image_embedding
=
torch
.
load
(...)
grid_thw
=
torch
.
load
(...)
# Required by Qwen/Qwen2-VL-2B-Instruct
buffer
=
io
.
BytesIO
()
torch
.
save
(
image_embedding
,
buffer
)
buffer
.
seek
(
0
)
binary_data
=
buffer
.
read
()
base64_image_embedding
=
base64
.
b64encode
(
binary_data
).
decode
(
'utf-8'
)
client
=
OpenAI
(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key
=
openai_api_key
,
base_url
=
openai_api_base
,
)
# Basic usage - this is equivalent to the LLaVA example for offline inference
model
=
"llava-hf/llava-1.5-7b-hf"
embeds
=
{
"type"
:
"image_embeds"
,
"image_embeds"
:
f
"
{
base64_image_embedding
}
"
}
# Pass additional parameters (available to Qwen2-VL and MiniCPM-V)
model
=
"Qwen/Qwen2-VL-2B-Instruct"
embeds
=
{
"type"
:
"image_embeds"
,
"image_embeds"
:
{
"image_embeds"
:
f
"
{
base64_image_embedding
}
"
,
# Required
"image_grid_thw"
:
f
"
{
base64_image_grid_thw
}
"
# Required by Qwen/Qwen2-VL-2B-Instruct
},
}
model
=
"openbmb/MiniCPM-V-2_6"
embeds
=
{
"type"
:
"image_embeds"
,
"image_embeds"
:
{
"image_embeds"
:
f
"
{
base64_image_embedding
}
"
,
# Required
"image_sizes"
:
f
"
{
base64_image_sizes
}
"
# Required by openbmb/MiniCPM-V-2_6
},
}
chat_completion
=
client
.
chat
.
completions
.
create
(
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant."
},
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
,
},
embeds
,
],
},
],
model
=
model
,
)
```
:::{note}
Only one message can contain
`{"type": "image_embeds"}`
.
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g.
`image_grid_thw`
,
`image_sizes`
, etc.
:::
vllm/entrypoints/chat_utils.py
View file @
b0746fae
...
@@ -56,6 +56,17 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
...
@@ -56,6 +56,17 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
"""The type of the content part."""
"""The type of the content part."""
class
ChatCompletionContentPartImageEmbedsParam
(
TypedDict
,
total
=
False
):
image_embeds
:
Required
[
Union
[
str
,
dict
[
str
,
str
]]]
"""
The image embeddings. It can be either:
- A single base64 string.
- A dictionary where each value is a base64 string.
"""
type
:
Required
[
Literal
[
"image_embeds"
]]
"""The type of the content part."""
class
VideoURL
(
TypedDict
,
total
=
False
):
class
VideoURL
(
TypedDict
,
total
=
False
):
url
:
Required
[
str
]
url
:
Required
[
str
]
"""
"""
...
@@ -109,6 +120,7 @@ ChatCompletionContentPartParam: TypeAlias = Union[
...
@@ -109,6 +120,7 @@ ChatCompletionContentPartParam: TypeAlias = Union[
ChatCompletionContentPartInputAudioParam
,
ChatCompletionContentPartInputAudioParam
,
ChatCompletionContentPartVideoParam
,
ChatCompletionContentPartRefusalParam
,
ChatCompletionContentPartVideoParam
,
ChatCompletionContentPartRefusalParam
,
CustomChatCompletionContentSimpleImageParam
,
CustomChatCompletionContentSimpleImageParam
,
ChatCompletionContentPartImageEmbedsParam
,
CustomChatCompletionContentSimpleAudioParam
,
CustomChatCompletionContentSimpleAudioParam
,
CustomChatCompletionContentSimpleVideoParam
,
str
]
CustomChatCompletionContentSimpleVideoParam
,
str
]
...
@@ -350,7 +362,7 @@ def resolve_chat_template_content_format(
...
@@ -350,7 +362,7 @@ def resolve_chat_template_content_format(
return
detected_format
return
detected_format
ModalityStr
=
Literal
[
"image"
,
"audio"
,
"video"
]
ModalityStr
=
Literal
[
"image"
,
"audio"
,
"video"
,
"image_embeds"
]
_T
=
TypeVar
(
"_T"
)
_T
=
TypeVar
(
"_T"
)
...
@@ -391,7 +403,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -391,7 +403,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config
=
self
.
_model_config
.
hf_config
hf_config
=
self
.
_model_config
.
hf_config
model_type
=
hf_config
.
model_type
model_type
=
hf_config
.
model_type
if
modality
==
"image"
:
if
modality
in
[
"image"
,
"image_embeds"
]
:
if
model_type
==
"phi3_v"
:
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
# Workaround since this token is not defined in the tokenizer
return
f
"<|image_
{
current_count
}
|>"
return
f
"<|image_
{
current_count
}
|>"
...
@@ -470,10 +482,27 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -470,10 +482,27 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
object
]):
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
object
]):
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
if
self
.
_items_by_modality
:
if
not
self
.
_items_by_modality
:
return
dict
(
self
.
_items_by_modality
)
return
None
mm_inputs
=
{}
return
None
items_by_modality
=
dict
(
self
.
_items_by_modality
)
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
raise
ValueError
(
\
"Mixing raw image and embedding inputs is not allowed"
)
if
"image_embeds"
in
items_by_modality
:
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
if
len
(
image_embeds_lst
)
>
1
:
raise
ValueError
(
\
"Only one message can have {'type': 'image_embeds'}"
)
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
elif
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
elif
"audio"
in
items_by_modality
:
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
elif
"video"
in
items_by_modality
:
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
return
mm_inputs
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
return
MultiModalContentParser
(
self
)
return
MultiModalContentParser
(
self
)
...
@@ -482,13 +511,31 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
...
@@ -482,13 +511,31 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
class
AsyncMultiModalItemTracker
(
BaseMultiModalItemTracker
[
Awaitable
[
object
]]):
class
AsyncMultiModalItemTracker
(
BaseMultiModalItemTracker
[
Awaitable
[
object
]]):
async
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
async
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
if
self
.
_items_by_modality
:
if
not
self
.
_items_by_modality
:
return
{
return
None
mm_inputs
=
{}
items_by_modality
=
{
modality
:
await
asyncio
.
gather
(
*
items
)
modality
:
await
asyncio
.
gather
(
*
items
)
for
modality
,
items
in
self
.
_items_by_modality
.
items
()
for
modality
,
items
in
self
.
_items_by_modality
.
items
()
}
}
return
None
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
if
"image_embeds"
in
items_by_modality
:
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
if
len
(
image_embeds_lst
)
>
1
:
raise
ValueError
(
"Only one message can have {'type': 'image_embeds'}"
)
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
elif
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
elif
"audio"
in
items_by_modality
:
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
elif
"video"
in
items_by_modality
:
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
return
mm_inputs
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
return
AsyncMultiModalContentParser
(
self
)
return
AsyncMultiModalContentParser
(
self
)
...
@@ -513,6 +560,11 @@ class BaseMultiModalContentParser(ABC):
...
@@ -513,6 +560,11 @@ class BaseMultiModalContentParser(ABC):
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
parse_image_embeds
(
self
,
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]])
->
None
:
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -543,6 +595,21 @@ class MultiModalContentParser(BaseMultiModalContentParser):
...
@@ -543,6 +595,21 @@ class MultiModalContentParser(BaseMultiModalContentParser):
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
placeholder
)
def
parse_image_embeds
(
self
,
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]])
->
None
:
if
isinstance
(
image_embeds
,
dict
):
embeds
=
{
k
:
self
.
_connector
.
fetch_image_embedding
(
v
)
for
k
,
v
in
image_embeds
.
items
()
}
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
embeds
)
if
isinstance
(
image_embeds
,
str
):
embedding
=
self
.
_connector
.
fetch_image_embedding
(
image_embeds
)
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
embedding
)
self
.
_add_placeholder
(
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio
=
self
.
_connector
.
fetch_audio
(
audio_url
)
audio
=
self
.
_connector
.
fetch_audio
(
audio_url
)
...
@@ -579,6 +646,25 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
...
@@ -579,6 +646,25 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_coro
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_coro
)
self
.
_add_placeholder
(
placeholder
)
self
.
_add_placeholder
(
placeholder
)
def
parse_image_embeds
(
self
,
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]])
->
None
:
future
:
asyncio
.
Future
[
Union
[
str
,
dict
[
str
,
str
]]]
=
asyncio
.
Future
()
if
isinstance
(
image_embeds
,
dict
):
embeds
=
{
k
:
self
.
_connector
.
fetch_image_embedding
(
v
)
for
k
,
v
in
image_embeds
.
items
()
}
future
.
set_result
(
embeds
)
if
isinstance
(
image_embeds
,
str
):
embedding
=
self
.
_connector
.
\
fetch_image_embedding
(
image_embeds
)
future
.
set_result
(
embedding
)
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
future
)
self
.
_add_placeholder
(
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio_coro
=
self
.
_connector
.
fetch_audio_async
(
audio_url
)
audio_coro
=
self
.
_connector
.
fetch_audio_async
(
audio_url
)
...
@@ -684,6 +770,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
...
@@ -684,6 +770,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
# No need to validate using Pydantic again
# No need to validate using Pydantic again
_TextParser
=
partial
(
cast
,
ChatCompletionContentPartTextParam
)
_TextParser
=
partial
(
cast
,
ChatCompletionContentPartTextParam
)
_ImageParser
=
partial
(
cast
,
ChatCompletionContentPartImageParam
)
_ImageParser
=
partial
(
cast
,
ChatCompletionContentPartImageParam
)
_ImageEmbedsParser
=
partial
(
cast
,
ChatCompletionContentPartImageEmbedsParam
)
_AudioParser
=
partial
(
cast
,
ChatCompletionContentPartAudioParam
)
_AudioParser
=
partial
(
cast
,
ChatCompletionContentPartAudioParam
)
_InputAudioParser
=
partial
(
cast
,
ChatCompletionContentPartInputAudioParam
)
_InputAudioParser
=
partial
(
cast
,
ChatCompletionContentPartInputAudioParam
)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
...
@@ -700,6 +787,8 @@ MM_PARSER_MAP: dict[
...
@@ -700,6 +787,8 @@ MM_PARSER_MAP: dict[
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
""
),
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
""
),
"image_url"
:
"image_url"
:
lambda
part
:
_ImageParser
(
part
).
get
(
"image_url"
,
{}).
get
(
"url"
,
""
),
lambda
part
:
_ImageParser
(
part
).
get
(
"image_url"
,
{}).
get
(
"url"
,
""
),
"image_embeds"
:
lambda
part
:
_ImageEmbedsParser
(
part
).
get
(
"image_embeds"
,
{}),
"audio_url"
:
"audio_url"
:
lambda
part
:
_AudioParser
(
part
).
get
(
"audio_url"
,
{}).
get
(
"url"
,
""
),
lambda
part
:
_AudioParser
(
part
).
get
(
"audio_url"
,
{}).
get
(
"url"
,
""
),
"input_audio"
:
"input_audio"
:
...
@@ -769,6 +858,7 @@ def _parse_chat_message_content_mm_part(
...
@@ -769,6 +858,7 @@ def _parse_chat_message_content_mm_part(
VALID_MESSAGE_CONTENT_MM_PART_TYPES
=
(
"text"
,
"refusal"
,
"image_url"
,
VALID_MESSAGE_CONTENT_MM_PART_TYPES
=
(
"text"
,
"refusal"
,
"image_url"
,
"image_embeds"
,
"audio_url"
,
"input_audio"
,
"video_url"
)
"audio_url"
,
"input_audio"
,
"video_url"
)
...
@@ -843,7 +933,10 @@ def _parse_chat_message_content_part(
...
@@ -843,7 +933,10 @@ def _parse_chat_message_content_part(
str_content
=
cast
(
str
,
content
)
str_content
=
cast
(
str
,
content
)
mm_parser
.
parse_image
(
str_content
)
mm_parser
.
parse_image
(
str_content
)
return
{
'type'
:
'image'
}
if
wrap_dicts
else
None
return
{
'type'
:
'image'
}
if
wrap_dicts
else
None
if
part_type
==
"image_embeds"
:
content
=
cast
(
Union
[
str
,
dict
[
str
,
str
]],
content
)
mm_parser
.
parse_image_embeds
(
content
)
return
{
'type'
:
'image'
}
if
wrap_dicts
else
None
if
part_type
==
"audio_url"
:
if
part_type
==
"audio_url"
:
str_content
=
cast
(
str
,
content
)
str_content
=
cast
(
str
,
content
)
mm_parser
.
parse_audio
(
str_content
)
mm_parser
.
parse_audio
(
str_content
)
...
...
vllm/multimodal/image.py
View file @
b0746fae
...
@@ -134,3 +134,22 @@ class ImageMediaIO(MediaIO[Image.Image]):
...
@@ -134,3 +134,22 @@ class ImageMediaIO(MediaIO[Image.Image]):
data
=
buffer
.
getvalue
()
data
=
buffer
.
getvalue
()
return
base64
.
b64encode
(
data
).
decode
(
'utf-8'
)
return
base64
.
b64encode
(
data
).
decode
(
'utf-8'
)
class
ImageEmbeddingMediaIO
(
MediaIO
[
torch
.
Tensor
]):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
load_bytes
(
self
,
data
:
bytes
)
->
torch
.
Tensor
:
buffer
=
BytesIO
(
data
)
return
torch
.
load
(
buffer
,
weights_only
=
True
)
def
load_base64
(
self
,
media_type
:
str
,
data
:
str
)
->
torch
.
Tensor
:
return
self
.
load_bytes
(
base64
.
b64decode
(
data
))
def
load_file
(
self
,
filepath
:
Path
)
->
torch
.
Tensor
:
return
torch
.
load
(
filepath
)
def
encode_base64
(
self
,
media
:
torch
.
Tensor
)
->
str
:
return
base64
.
b64encode
(
media
.
numpy
()).
decode
(
'utf-8'
)
vllm/multimodal/utils.py
View file @
b0746fae
...
@@ -7,6 +7,7 @@ from urllib.parse import ParseResult, urlparse
...
@@ -7,6 +7,7 @@ from urllib.parse import ParseResult, urlparse
import
numpy
as
np
import
numpy
as
np
import
numpy.typing
as
npt
import
numpy.typing
as
npt
import
torch
from
PIL
import
Image
from
PIL
import
Image
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -16,7 +17,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
...
@@ -16,7 +17,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from
.audio
import
AudioMediaIO
from
.audio
import
AudioMediaIO
from
.base
import
MediaIO
from
.base
import
MediaIO
from
.image
import
ImageMediaIO
from
.image
import
ImageEmbeddingMediaIO
,
ImageMediaIO
from
.inputs
import
PlaceholderRange
from
.inputs
import
PlaceholderRange
from
.video
import
VideoMediaIO
from
.video
import
VideoMediaIO
...
@@ -245,6 +246,17 @@ class MediaConnector:
...
@@ -245,6 +246,17 @@ class MediaConnector:
fetch_timeout
=
envs
.
VLLM_VIDEO_FETCH_TIMEOUT
,
fetch_timeout
=
envs
.
VLLM_VIDEO_FETCH_TIMEOUT
,
)
)
def
fetch_image_embedding
(
self
,
data
:
str
,
)
->
torch
.
Tensor
:
"""
Load image embedding from a URL.
"""
image_embedding_io
=
ImageEmbeddingMediaIO
()
return
image_embedding_io
.
load_base64
(
""
,
data
)
global_media_connector
=
MediaConnector
()
global_media_connector
=
MediaConnector
()
"""The global :class:`MediaConnector` instance used by vLLM."""
"""The global :class:`MediaConnector` instance used by vLLM."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment