Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b0746fae
Unverified
Commit
b0746fae
authored
Mar 10, 2025
by
Chauncey
Committed by
GitHub
Mar 10, 2025
Browse files
[Frontend] support image embeds (#13955)
Signed-off-by:
chaunceyjiang
<
chaunceyjiang@gmail.com
>
parent
60a98b2d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
201 additions
and
12 deletions
+201
-12
docs/source/serving/multimodal_inputs.md
docs/source/serving/multimodal_inputs.md
+66
-1
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+103
-10
vllm/multimodal/image.py
vllm/multimodal/image.py
+19
-0
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+13
-1
No files found.
docs/source/serving/multimodal_inputs.md
View file @
b0746fae
...
...
@@ -462,4 +462,69 @@ export VLLM_AUDIO_FETCH_TIMEOUT=<timeout>
### Embedding Inputs
TBD
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape to the corresponding field of the multi-modal dictionary.
#### Image Embedding Inputs
For image embeddings, you can pass the base64-encoded tensor to the
`image_embeds`
field.
The following example demonstrates how to pass image embeddings to the OpenAI server:
```
python
image_embedding
=
torch
.
load
(...)
grid_thw
=
torch
.
load
(...)
# Required by Qwen/Qwen2-VL-2B-Instruct
buffer
=
io
.
BytesIO
()
torch
.
save
(
image_embedding
,
buffer
)
buffer
.
seek
(
0
)
binary_data
=
buffer
.
read
()
base64_image_embedding
=
base64
.
b64encode
(
binary_data
).
decode
(
'utf-8'
)
client
=
OpenAI
(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key
=
openai_api_key
,
base_url
=
openai_api_base
,
)
# Basic usage - this is equivalent to the LLaVA example for offline inference
model
=
"llava-hf/llava-1.5-7b-hf"
embeds
=
{
"type"
:
"image_embeds"
,
"image_embeds"
:
f
"
{
base64_image_embedding
}
"
}
# Pass additional parameters (available to Qwen2-VL and MiniCPM-V)
model
=
"Qwen/Qwen2-VL-2B-Instruct"
embeds
=
{
"type"
:
"image_embeds"
,
"image_embeds"
:
{
"image_embeds"
:
f
"
{
base64_image_embedding
}
"
,
# Required
"image_grid_thw"
:
f
"
{
base64_image_grid_thw
}
"
# Required by Qwen/Qwen2-VL-2B-Instruct
},
}
model
=
"openbmb/MiniCPM-V-2_6"
embeds
=
{
"type"
:
"image_embeds"
,
"image_embeds"
:
{
"image_embeds"
:
f
"
{
base64_image_embedding
}
"
,
# Required
"image_sizes"
:
f
"
{
base64_image_sizes
}
"
# Required by openbmb/MiniCPM-V-2_6
},
}
chat_completion
=
client
.
chat
.
completions
.
create
(
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant."
},
{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
,
},
embeds
,
],
},
],
model
=
model
,
)
```
:::{note}
Only one message can contain
`{"type": "image_embeds"}`
.
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g.
`image_grid_thw`
,
`image_sizes`
, etc.
:::
vllm/entrypoints/chat_utils.py
View file @
b0746fae
...
...
@@ -56,6 +56,17 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
"""The type of the content part."""
class
ChatCompletionContentPartImageEmbedsParam
(
TypedDict
,
total
=
False
):
image_embeds
:
Required
[
Union
[
str
,
dict
[
str
,
str
]]]
"""
The image embeddings. It can be either:
- A single base64 string.
- A dictionary where each value is a base64 string.
"""
type
:
Required
[
Literal
[
"image_embeds"
]]
"""The type of the content part."""
class
VideoURL
(
TypedDict
,
total
=
False
):
url
:
Required
[
str
]
"""
...
...
@@ -109,6 +120,7 @@ ChatCompletionContentPartParam: TypeAlias = Union[
ChatCompletionContentPartInputAudioParam
,
ChatCompletionContentPartVideoParam
,
ChatCompletionContentPartRefusalParam
,
CustomChatCompletionContentSimpleImageParam
,
ChatCompletionContentPartImageEmbedsParam
,
CustomChatCompletionContentSimpleAudioParam
,
CustomChatCompletionContentSimpleVideoParam
,
str
]
...
...
@@ -350,7 +362,7 @@ def resolve_chat_template_content_format(
return
detected_format
ModalityStr
=
Literal
[
"image"
,
"audio"
,
"video"
]
ModalityStr
=
Literal
[
"image"
,
"audio"
,
"video"
,
"image_embeds"
]
_T
=
TypeVar
(
"_T"
)
...
...
@@ -391,7 +403,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config
=
self
.
_model_config
.
hf_config
model_type
=
hf_config
.
model_type
if
modality
==
"image"
:
if
modality
in
[
"image"
,
"image_embeds"
]
:
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
return
f
"<|image_
{
current_count
}
|>"
...
...
@@ -470,10 +482,27 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
object
]):
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
if
self
.
_items_by_modality
:
return
dict
(
self
.
_items_by_modality
)
if
not
self
.
_items_by_modality
:
return
None
mm_inputs
=
{}
items_by_modality
=
dict
(
self
.
_items_by_modality
)
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
raise
ValueError
(
\
"Mixing raw image and embedding inputs is not allowed"
)
if
"image_embeds"
in
items_by_modality
:
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
if
len
(
image_embeds_lst
)
>
1
:
raise
ValueError
(
\
"Only one message can have {'type': 'image_embeds'}"
)
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
elif
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
elif
"audio"
in
items_by_modality
:
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
elif
"video"
in
items_by_modality
:
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
return
mm_inputs
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
return
MultiModalContentParser
(
self
)
...
...
@@ -482,13 +511,31 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
class
AsyncMultiModalItemTracker
(
BaseMultiModalItemTracker
[
Awaitable
[
object
]]):
async
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
if
self
.
_items_by_modality
:
return
{
if
not
self
.
_items_by_modality
:
return
None
mm_inputs
=
{}
items_by_modality
=
{
modality
:
await
asyncio
.
gather
(
*
items
)
for
modality
,
items
in
self
.
_items_by_modality
.
items
()
}
return
None
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
if
"image_embeds"
in
items_by_modality
:
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
if
len
(
image_embeds_lst
)
>
1
:
raise
ValueError
(
"Only one message can have {'type': 'image_embeds'}"
)
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
elif
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
elif
"audio"
in
items_by_modality
:
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
elif
"video"
in
items_by_modality
:
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
return
mm_inputs
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
return
AsyncMultiModalContentParser
(
self
)
...
...
@@ -513,6 +560,11 @@ class BaseMultiModalContentParser(ABC):
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
parse_image_embeds
(
self
,
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]])
->
None
:
raise
NotImplementedError
@
abstractmethod
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
raise
NotImplementedError
...
...
@@ -543,6 +595,21 @@ class MultiModalContentParser(BaseMultiModalContentParser):
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image
)
self
.
_add_placeholder
(
placeholder
)
def
parse_image_embeds
(
self
,
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]])
->
None
:
if
isinstance
(
image_embeds
,
dict
):
embeds
=
{
k
:
self
.
_connector
.
fetch_image_embedding
(
v
)
for
k
,
v
in
image_embeds
.
items
()
}
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
embeds
)
if
isinstance
(
image_embeds
,
str
):
embedding
=
self
.
_connector
.
fetch_image_embedding
(
image_embeds
)
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
embedding
)
self
.
_add_placeholder
(
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio
=
self
.
_connector
.
fetch_audio
(
audio_url
)
...
...
@@ -579,6 +646,25 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_coro
)
self
.
_add_placeholder
(
placeholder
)
def
parse_image_embeds
(
self
,
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]])
->
None
:
future
:
asyncio
.
Future
[
Union
[
str
,
dict
[
str
,
str
]]]
=
asyncio
.
Future
()
if
isinstance
(
image_embeds
,
dict
):
embeds
=
{
k
:
self
.
_connector
.
fetch_image_embedding
(
v
)
for
k
,
v
in
image_embeds
.
items
()
}
future
.
set_result
(
embeds
)
if
isinstance
(
image_embeds
,
str
):
embedding
=
self
.
_connector
.
\
fetch_image_embedding
(
image_embeds
)
future
.
set_result
(
embedding
)
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
future
)
self
.
_add_placeholder
(
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio_coro
=
self
.
_connector
.
fetch_audio_async
(
audio_url
)
...
...
@@ -684,6 +770,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
# No need to validate using Pydantic again
_TextParser
=
partial
(
cast
,
ChatCompletionContentPartTextParam
)
_ImageParser
=
partial
(
cast
,
ChatCompletionContentPartImageParam
)
_ImageEmbedsParser
=
partial
(
cast
,
ChatCompletionContentPartImageEmbedsParam
)
_AudioParser
=
partial
(
cast
,
ChatCompletionContentPartAudioParam
)
_InputAudioParser
=
partial
(
cast
,
ChatCompletionContentPartInputAudioParam
)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
...
...
@@ -700,6 +787,8 @@ MM_PARSER_MAP: dict[
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
""
),
"image_url"
:
lambda
part
:
_ImageParser
(
part
).
get
(
"image_url"
,
{}).
get
(
"url"
,
""
),
"image_embeds"
:
lambda
part
:
_ImageEmbedsParser
(
part
).
get
(
"image_embeds"
,
{}),
"audio_url"
:
lambda
part
:
_AudioParser
(
part
).
get
(
"audio_url"
,
{}).
get
(
"url"
,
""
),
"input_audio"
:
...
...
@@ -769,6 +858,7 @@ def _parse_chat_message_content_mm_part(
VALID_MESSAGE_CONTENT_MM_PART_TYPES
=
(
"text"
,
"refusal"
,
"image_url"
,
"image_embeds"
,
"audio_url"
,
"input_audio"
,
"video_url"
)
...
...
@@ -843,7 +933,10 @@ def _parse_chat_message_content_part(
str_content
=
cast
(
str
,
content
)
mm_parser
.
parse_image
(
str_content
)
return
{
'type'
:
'image'
}
if
wrap_dicts
else
None
if
part_type
==
"image_embeds"
:
content
=
cast
(
Union
[
str
,
dict
[
str
,
str
]],
content
)
mm_parser
.
parse_image_embeds
(
content
)
return
{
'type'
:
'image'
}
if
wrap_dicts
else
None
if
part_type
==
"audio_url"
:
str_content
=
cast
(
str
,
content
)
mm_parser
.
parse_audio
(
str_content
)
...
...
vllm/multimodal/image.py
View file @
b0746fae
...
...
@@ -134,3 +134,22 @@ class ImageMediaIO(MediaIO[Image.Image]):
data
=
buffer
.
getvalue
()
return
base64
.
b64encode
(
data
).
decode
(
'utf-8'
)
class
ImageEmbeddingMediaIO
(
MediaIO
[
torch
.
Tensor
]):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
load_bytes
(
self
,
data
:
bytes
)
->
torch
.
Tensor
:
buffer
=
BytesIO
(
data
)
return
torch
.
load
(
buffer
,
weights_only
=
True
)
def
load_base64
(
self
,
media_type
:
str
,
data
:
str
)
->
torch
.
Tensor
:
return
self
.
load_bytes
(
base64
.
b64decode
(
data
))
def
load_file
(
self
,
filepath
:
Path
)
->
torch
.
Tensor
:
return
torch
.
load
(
filepath
)
def
encode_base64
(
self
,
media
:
torch
.
Tensor
)
->
str
:
return
base64
.
b64encode
(
media
.
numpy
()).
decode
(
'utf-8'
)
vllm/multimodal/utils.py
View file @
b0746fae
...
...
@@ -7,6 +7,7 @@ from urllib.parse import ParseResult, urlparse
import
numpy
as
np
import
numpy.typing
as
npt
import
torch
from
PIL
import
Image
import
vllm.envs
as
envs
...
...
@@ -16,7 +17,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from
.audio
import
AudioMediaIO
from
.base
import
MediaIO
from
.image
import
ImageMediaIO
from
.image
import
ImageEmbeddingMediaIO
,
ImageMediaIO
from
.inputs
import
PlaceholderRange
from
.video
import
VideoMediaIO
...
...
@@ -245,6 +246,17 @@ class MediaConnector:
fetch_timeout
=
envs
.
VLLM_VIDEO_FETCH_TIMEOUT
,
)
def
fetch_image_embedding
(
self
,
data
:
str
,
)
->
torch
.
Tensor
:
"""
Load image embedding from a URL.
"""
image_embedding_io
=
ImageEmbeddingMediaIO
()
return
image_embedding_io
.
load_base64
(
""
,
data
)
global_media_connector
=
MediaConnector
()
"""The global :class:`MediaConnector` instance used by vLLM."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment