Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
855c262a
Unverified
Commit
855c262a
authored
Sep 04, 2024
by
Cyrus Leung
Committed by
GitHub
Sep 04, 2024
Browse files
[Frontend] Multimodal support in offline chat (#8098)
parent
2be8ec6e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
356 additions
and
112 deletions
+356
-112
tests/entrypoints/llm/test_generate.py
tests/entrypoints/llm/test_generate.py
+34
-0
tests/entrypoints/test_chat_utils.py
tests/entrypoints/test_chat_utils.py
+124
-40
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+159
-49
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+19
-12
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+3
-6
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+4
-3
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+10
-0
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+3
-2
No files found.
tests/entrypoints/llm/test_generate.py
View file @
855c262a
...
@@ -6,6 +6,7 @@ import pytest
...
@@ -6,6 +6,7 @@ import pytest
from
vllm
import
LLM
,
RequestOutput
,
SamplingParams
from
vllm
import
LLM
,
RequestOutput
,
SamplingParams
from
...conftest
import
cleanup
from
...conftest
import
cleanup
from
..openai.test_vision
import
TEST_IMAGE_URLS
MODEL_NAME
=
"facebook/opt-125m"
MODEL_NAME
=
"facebook/opt-125m"
...
@@ -159,3 +160,36 @@ def test_chat():
...
@@ -159,3 +160,36 @@ def test_chat():
]
]
outputs
=
llm
.
chat
(
messages
)
outputs
=
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
1
assert
len
(
outputs
)
==
1
@
pytest
.
mark
.
parametrize
(
"image_urls"
,
[[
TEST_IMAGE_URLS
[
0
],
TEST_IMAGE_URLS
[
1
]]])
def
test_chat_multi_image
(
image_urls
:
List
[
str
]):
llm
=
LLM
(
model
=
"microsoft/Phi-3.5-vision-instruct"
,
dtype
=
"bfloat16"
,
max_model_len
=
4096
,
max_num_seqs
=
5
,
enforce_eager
=
True
,
trust_remote_code
=
True
,
limit_mm_per_prompt
=
{
"image"
:
2
},
)
messages
=
[{
"role"
:
"user"
,
"content"
:
[
*
({
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
}
for
image_url
in
image_urls
),
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
},
],
}]
outputs
=
llm
.
chat
(
messages
)
assert
len
(
outputs
)
>=
0
tests/entrypoints/test_chat_utils.py
View file @
855c262a
import
warnings
import
warnings
from
typing
import
Optional
import
pytest
import
pytest
from
PIL
import
Image
from
PIL
import
Image
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.image
import
ImageAsset
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.entrypoints.chat_utils
import
parse_chat_messages
from
vllm.entrypoints.chat_utils
import
(
parse_chat_messages
,
parse_chat_messages_futures
)
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.utils
import
encode_image_base64
from
vllm.multimodal.utils
import
encode_image_base64
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
...
@@ -42,10 +45,28 @@ def image_url():
...
@@ -42,10 +45,28 @@ def image_url():
return
f
"data:image/jpeg;base64,
{
base64
}
"
return
f
"data:image/jpeg;base64,
{
base64
}
"
@
pytest
.
mark
.
asyncio
def
_assert_mm_data_is_image_input
(
async
def
test_parse_chat_messages_with_image_url
(
phi3v_model_config
,
mm_data
:
Optional
[
MultiModalDataDict
],
phi3v_tokenizer
,
image_url
):
image_count
:
int
,
conversation
,
mm_future
=
parse_chat_messages
([{
)
->
None
:
assert
mm_data
is
not
None
assert
set
(
mm_data
.
keys
())
==
{
"image"
}
image_data
=
mm_data
.
get
(
"image"
)
assert
image_data
is
not
None
if
image_count
==
1
:
assert
isinstance
(
image_data
,
Image
.
Image
)
else
:
assert
isinstance
(
image_data
,
list
)
and
len
(
image_data
)
==
image_count
def
test_parse_chat_messages_single_image
(
phi3v_model_config
,
phi3v_tokenizer
,
image_url
,
):
conversation
,
mm_data
=
parse_chat_messages
([{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[{
...
@@ -63,15 +84,42 @@ async def test_parse_chat_messages_with_image_url(phi3v_model_config,
...
@@ -63,15 +84,42 @@ async def test_parse_chat_messages_with_image_url(phi3v_model_config,
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
"<|image_1|>
\n
What's in the image?"
"content"
:
"<|image_1|>
\n
What's in the image?"
}]
}]
mm_data
=
await
mm_future
_assert_mm_data_is_image_input
(
mm_data
,
1
)
assert
set
(
mm_data
.
keys
())
==
{
"image"
}
assert
isinstance
(
mm_data
[
"image"
],
Image
.
Image
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_parse_chat_messages_multiple_images
(
phi3v_model_config
,
async
def
test_parse_chat_messages_single_image_async
(
phi3v_tokenizer
,
image_url
):
phi3v_model_config
,
conversation
,
mm_future
=
parse_chat_messages
([{
phi3v_tokenizer
,
image_url
,
):
conversation
,
mm_future
=
parse_chat_messages_futures
([{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
},
{
"type"
:
"text"
,
"text"
:
"What's in the image?"
}]
}],
phi3v_model_config
,
phi3v_tokenizer
)
assert
conversation
==
[{
"role"
:
"user"
,
"content"
:
"<|image_1|>
\n
What's in the image?"
}]
_assert_mm_data_is_image_input
(
await
mm_future
,
1
)
def
test_parse_chat_messages_multiple_images
(
phi3v_model_config
,
phi3v_tokenizer
,
image_url
,
):
conversation
,
mm_data
=
parse_chat_messages
([{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[{
...
@@ -96,15 +144,49 @@ async def test_parse_chat_messages_multiple_images(phi3v_model_config,
...
@@ -96,15 +144,49 @@ async def test_parse_chat_messages_multiple_images(phi3v_model_config,
"content"
:
"content"
:
"<|image_1|>
\n
<|image_2|>
\n
What's in these images?"
"<|image_1|>
\n
<|image_2|>
\n
What's in these images?"
}]
}]
mm_data
=
await
mm_future
_assert_mm_data_is_image_input
(
mm_data
,
2
)
assert
set
(
mm_data
.
keys
())
==
{
"image"
}
assert
len
(
mm_data
[
"image"
])
==
2
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_parse_chat_messages_placeholder_already_in_prompt
(
async
def
test_parse_chat_messages_multiple_images_async
(
phi3v_model_config
,
phi3v_tokenizer
,
image_url
):
phi3v_model_config
,
conversation
,
mm_future
=
parse_chat_messages
([{
phi3v_tokenizer
,
image_url
,
):
conversation
,
mm_future
=
parse_chat_messages_futures
([{
"role"
:
"user"
,
"content"
:
[{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}
},
{
"type"
:
"text"
,
"text"
:
"What's in these images?"
}]
}],
phi3v_model_config
,
phi3v_tokenizer
)
assert
conversation
==
[{
"role"
:
"user"
,
"content"
:
"<|image_1|>
\n
<|image_2|>
\n
What's in these images?"
}]
_assert_mm_data_is_image_input
(
await
mm_future
,
2
)
def
test_parse_chat_messages_placeholder_already_in_prompt
(
phi3v_model_config
,
phi3v_tokenizer
,
image_url
,
):
conversation
,
mm_data
=
parse_chat_messages
([{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[{
...
@@ -131,15 +213,15 @@ async def test_parse_chat_messages_placeholder_already_in_prompt(
...
@@ -131,15 +213,15 @@ async def test_parse_chat_messages_placeholder_already_in_prompt(
"content"
:
"content"
:
"What's in <|image_1|> and how does it compare to <|image_2|>?"
"What's in <|image_1|> and how does it compare to <|image_2|>?"
}]
}]
mm_data
=
await
mm_future
_assert_mm_data_is_image_input
(
mm_data
,
2
)
assert
set
(
mm_data
.
keys
())
==
{
"image"
}
assert
len
(
mm_data
[
"image"
])
==
2
@
pytest
.
mark
.
asyncio
def
test_parse_chat_messages_placeholder_one_already_in_prompt
(
async
def
test_parse_chat_messages_placeholder_one_already_in_prompt
(
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
image_url
):
phi3v_tokenizer
,
conversation
,
mm_future
=
parse_chat_messages
([{
image_url
,
):
conversation
,
mm_data
=
parse_chat_messages
([{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[{
...
@@ -167,15 +249,15 @@ async def test_parse_chat_messages_placeholder_one_already_in_prompt(
...
@@ -167,15 +249,15 @@ async def test_parse_chat_messages_placeholder_one_already_in_prompt(
"<|image_2|>
\n
What's in <|image_1|> and how does it compare to the "
"<|image_2|>
\n
What's in <|image_1|> and how does it compare to the "
"other one?"
"other one?"
}]
}]
mm_data
=
await
mm_future
_assert_mm_data_is_image_input
(
mm_data
,
2
)
assert
set
(
mm_data
.
keys
())
==
{
"image"
}
assert
len
(
mm_data
[
"image"
])
==
2
@
pytest
.
mark
.
asyncio
def
test_parse_chat_messages_multiple_images_across_messages
(
async
def
test_parse_chat_messages_multiple_images_across_messages
(
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
image_url
):
phi3v_tokenizer
,
conversation
,
mm_future
=
parse_chat_messages
([{
image_url
,
):
conversation
,
mm_data
=
parse_chat_messages
([{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[{
...
@@ -218,14 +300,14 @@ async def test_parse_chat_messages_multiple_images_across_messages(
...
@@ -218,14 +300,14 @@ async def test_parse_chat_messages_multiple_images_across_messages(
"content"
:
"<|image_2|>
\n
What about this one?"
"content"
:
"<|image_2|>
\n
What about this one?"
},
},
]
]
mm_data
=
await
mm_future
_assert_mm_data_is_image_input
(
mm_data
,
2
)
assert
set
(
mm_data
.
keys
())
==
{
"image"
}
assert
len
(
mm_data
[
"image"
])
==
2
@
pytest
.
mark
.
asyncio
def
test_parse_chat_messages_rejects_too_many_images_in_one_message
(
async
def
test_parse_chat_messages_rejects_too_many_images_in_one_message
(
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
image_url
):
phi3v_tokenizer
,
image_url
,
):
with
warnings
.
catch_warnings
():
with
warnings
.
catch_warnings
():
warnings
.
filterwarnings
(
warnings
.
filterwarnings
(
"ignore"
,
"ignore"
,
...
@@ -259,9 +341,11 @@ async def test_parse_chat_messages_rejects_too_many_images_in_one_message(
...
@@ -259,9 +341,11 @@ async def test_parse_chat_messages_rejects_too_many_images_in_one_message(
}],
phi3v_model_config
,
phi3v_tokenizer
)
}],
phi3v_model_config
,
phi3v_tokenizer
)
@
pytest
.
mark
.
asyncio
def
test_parse_chat_messages_rejects_too_many_images_across_messages
(
async
def
test_parse_chat_messages_rejects_too_many_images_across_messages
(
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
image_url
):
phi3v_tokenizer
,
image_url
,
):
with
warnings
.
catch_warnings
():
with
warnings
.
catch_warnings
():
warnings
.
filterwarnings
(
warnings
.
filterwarnings
(
"ignore"
,
"ignore"
,
...
...
vllm/entrypoints/chat_utils.py
View file @
855c262a
import
asyncio
import
asyncio
import
codecs
import
codecs
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
from
collections
import
defaultdict
from
functools
import
lru_cache
from
functools
import
lru_cache
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
(
Any
,
Awaitable
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
from
typing
import
(
Any
,
Awaitable
,
Dict
,
Generic
,
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
Union
)
Mapping
,
Optional
,
Tuple
,
TypeVar
,
Union
)
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
...
@@ -23,7 +24,8 @@ from vllm.config import ModelConfig
...
@@ -23,7 +24,8 @@ from vllm.config import ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.utils
import
(
async_get_and_parse_audio
,
from
vllm.multimodal.utils
import
(
async_get_and_parse_audio
,
async_get_and_parse_image
)
async_get_and_parse_image
,
get_and_parse_audio
,
get_and_parse_image
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -81,7 +83,11 @@ class ConversationMessage(TypedDict):
...
@@ -81,7 +83,11 @@ class ConversationMessage(TypedDict):
content
:
str
content
:
str
class
MultiModalItemTracker
:
ModalityStr
=
Literal
[
"image"
,
"audio"
]
_T
=
TypeVar
(
"_T"
)
class
BaseMultiModalItemTracker
(
ABC
,
Generic
[
_T
]):
"""
"""
Tracks multi-modal items in a given request and ensures that the number
Tracks multi-modal items in a given request and ensures that the number
of multi-modal items in a given request does not exceed the configured
of multi-modal items in a given request does not exceed the configured
...
@@ -89,37 +95,28 @@ class MultiModalItemTracker:
...
@@ -89,37 +95,28 @@ class MultiModalItemTracker:
"""
"""
def
__init__
(
self
,
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
):
def
__init__
(
self
,
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
):
super
().
__init__
()
self
.
_model_config
=
model_config
self
.
_model_config
=
model_config
self
.
_tokenizer
=
tokenizer
self
.
_tokenizer
=
tokenizer
self
.
_allowed_items
=
(
model_config
.
multimodal_config
.
limit_per_prompt
self
.
_allowed_items
=
(
model_config
.
multimodal_config
.
limit_per_prompt
if
model_config
.
multimodal_config
else
{})
if
model_config
.
multimodal_config
else
{})
self
.
_consumed_items
=
{
k
:
0
for
k
in
self
.
_allowed_items
}
self
.
_consumed_items
=
{
k
:
0
for
k
in
self
.
_allowed_items
}
self
.
_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
self
.
_items
:
List
[
_T
]
=
[]
@
staticmethod
@
staticmethod
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
_cached_token_str
(
tokenizer
:
AnyTokenizer
,
token_index
:
int
):
def
_cached_token_str
(
tokenizer
:
AnyTokenizer
,
token_index
:
int
)
->
str
:
return
tokenizer
.
decode
(
token_index
)
return
tokenizer
.
decode
(
token_index
)
def
add
(
self
,
modality
:
Literal
[
"image"
,
"audio"
],
def
_placeholder_str
(
self
,
modality
:
ModalityStr
,
mm_future
:
Awaitable
[
MultiModalDataDict
])
->
Optional
[
str
]:
current_count
:
int
)
->
Optional
[
str
]:
"""
Adds the multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count
=
self
.
_allowed_items
.
get
(
modality
,
1
)
current_count
=
self
.
_consumed_items
.
get
(
modality
,
0
)
+
1
if
current_count
>
allowed_count
:
raise
ValueError
(
f
"At most
{
allowed_count
}
{
modality
}
(s) may be provided in "
"one request."
)
self
.
_consumed_items
[
modality
]
=
current_count
self
.
_futures
.
append
(
mm_future
)
# TODO: Let user specify how to insert image tokens into prompt
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
# (similar to chat template)
model_type
=
self
.
_model_config
.
hf_config
.
model_type
hf_config
=
self
.
_model_config
.
hf_config
model_type
=
hf_config
.
model_type
if
modality
==
"image"
:
if
modality
==
"image"
:
if
model_type
==
"phi3_v"
:
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
# Workaround since this token is not defined in the tokenizer
...
@@ -130,9 +127,8 @@ class MultiModalItemTracker:
...
@@ -130,9 +127,8 @@ class MultiModalItemTracker:
# These models do not use image tokens in the prompt
# These models do not use image tokens in the prompt
return
None
return
None
if
model_type
.
startswith
(
"llava"
):
if
model_type
.
startswith
(
"llava"
):
return
MultiModalItemTracker
.
_cached_token_str
(
return
self
.
_cached_token_str
(
self
.
_tokenizer
,
self
.
_tokenizer
,
hf_config
.
image_token_index
)
self
.
_model_config
.
hf_config
.
image_token_index
)
if
model_type
in
(
"chameleon"
,
"internvl_chat"
):
if
model_type
in
(
"chameleon"
,
"internvl_chat"
):
return
"<image>"
return
"<image>"
...
@@ -145,11 +141,11 @@ class MultiModalItemTracker:
...
@@ -145,11 +141,11 @@ class MultiModalItemTracker:
raise
TypeError
(
f
"Unknown modality:
{
modality
}
"
)
raise
TypeError
(
f
"Unknown modality:
{
modality
}
"
)
@
staticmethod
@
staticmethod
async
def
_combine
(
future
s
:
List
[
Awaitable
[
MultiModalDataDict
]])
:
def
_combine
(
item
s
:
List
[
MultiModalDataDict
])
->
MultiModalDataDict
:
mm_lists
:
Mapping
[
str
,
List
[
object
]]
=
defaultdict
(
list
)
mm_lists
:
Mapping
[
str
,
List
[
object
]]
=
defaultdict
(
list
)
# Merge all the multi-modal items
# Merge all the multi-modal items
for
single_mm_data
in
(
await
asyncio
.
gather
(
*
futures
))
:
for
single_mm_data
in
items
:
for
mm_key
,
mm_item
in
single_mm_data
.
items
():
for
mm_key
,
mm_item
in
single_mm_data
.
items
():
if
isinstance
(
mm_item
,
list
):
if
isinstance
(
mm_item
,
list
):
mm_lists
[
mm_key
].
extend
(
mm_item
)
mm_lists
[
mm_key
].
extend
(
mm_item
)
...
@@ -162,9 +158,113 @@ class MultiModalItemTracker:
...
@@ -162,9 +158,113 @@ class MultiModalItemTracker:
for
mm_key
,
mm_list
in
mm_lists
.
items
()
for
mm_key
,
mm_list
in
mm_lists
.
items
()
}
}
def
all_mm_data
(
self
)
->
Optional
[
Awaitable
[
MultiModalDataDict
]]:
def
add
(
self
,
modality
:
ModalityStr
,
item
:
_T
)
->
Optional
[
str
]:
return
MultiModalItemTracker
.
_combine
(
"""
self
.
_futures
)
if
self
.
_futures
else
None
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count
=
self
.
_allowed_items
.
get
(
modality
,
1
)
current_count
=
self
.
_consumed_items
.
get
(
modality
,
0
)
+
1
if
current_count
>
allowed_count
:
raise
ValueError
(
f
"At most
{
allowed_count
}
{
modality
}
(s) may be provided in "
"one request."
)
self
.
_consumed_items
[
modality
]
=
current_count
self
.
_items
.
append
(
item
)
return
self
.
_placeholder_str
(
modality
,
current_count
)
@
abstractmethod
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
raise
NotImplementedError
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
MultiModalDataDict
]):
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
return
self
.
_combine
(
self
.
_items
)
if
self
.
_items
else
None
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
return
MultiModalContentParser
(
self
)
class
AsyncMultiModalItemTracker
(
BaseMultiModalItemTracker
[
Awaitable
[
MultiModalDataDict
]]):
async
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
if
self
.
_items
:
items
=
await
asyncio
.
gather
(
*
self
.
_items
)
return
self
.
_combine
(
items
)
return
None
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
return
AsyncMultiModalContentParser
(
self
)
class
BaseMultiModalContentParser
(
ABC
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# multimodal placeholder_string : count
self
.
_placeholder_counts
:
Dict
[
str
,
int
]
=
defaultdict
(
lambda
:
0
)
def
_add_placeholder
(
self
,
placeholder
:
Optional
[
str
]):
if
placeholder
:
self
.
_placeholder_counts
[
placeholder
]
+=
1
def
mm_placeholder_counts
(
self
)
->
Dict
[
str
,
int
]:
return
dict
(
self
.
_placeholder_counts
)
@
abstractmethod
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
raise
NotImplementedError
class
MultiModalContentParser
(
BaseMultiModalContentParser
):
def
__init__
(
self
,
tracker
:
MultiModalItemTracker
)
->
None
:
super
().
__init__
()
self
.
_tracker
=
tracker
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
image
=
get_and_parse_image
(
image_url
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image
)
self
.
_add_placeholder
(
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio
=
get_and_parse_audio
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio
)
self
.
_add_placeholder
(
placeholder
)
class
AsyncMultiModalContentParser
(
BaseMultiModalContentParser
):
def
__init__
(
self
,
tracker
:
AsyncMultiModalItemTracker
)
->
None
:
super
().
__init__
()
self
.
_tracker
=
tracker
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
image_coro
=
async_get_and_parse_image
(
image_url
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_coro
)
self
.
_add_placeholder
(
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
audio_coro
=
async_get_and_parse_audio
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio_coro
)
self
.
_add_placeholder
(
placeholder
)
def
load_chat_template
(
def
load_chat_template
(
...
@@ -197,10 +297,10 @@ def load_chat_template(
...
@@ -197,10 +297,10 @@ def load_chat_template(
# (similar to chat template)
# (similar to chat template)
def
_get_full_multimodal_text_prompt
(
placeholder_counts
:
Dict
[
str
,
int
],
def
_get_full_multimodal_text_prompt
(
placeholder_counts
:
Dict
[
str
,
int
],
text_prompt
:
str
)
->
str
:
text_prompt
:
str
)
->
str
:
"""Combine multimodal prompts for a multimodal language model"""
"""Combine multimodal prompts for a multimodal language model
.
"""
# Look through the text prompt to check for missing placeholders
# Look through the text prompt to check for missing placeholders
missing_placeholders
=
[]
missing_placeholders
:
List
[
str
]
=
[]
for
placeholder
in
placeholder_counts
:
for
placeholder
in
placeholder_counts
:
# For any existing placeholder in the text prompt, we leave it as is
# For any existing placeholder in the text prompt, we leave it as is
...
@@ -227,12 +327,11 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
...
@@ -227,12 +327,11 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
def
_parse_chat_message_content_parts
(
def
_parse_chat_message_content_parts
(
role
:
str
,
role
:
str
,
parts
:
Iterable
[
ChatCompletionContentPartParam
],
parts
:
Iterable
[
ChatCompletionContentPartParam
],
mm_tracker
:
MultiModalItemTracker
,
mm_tracker
:
Base
MultiModalItemTracker
,
)
->
List
[
ConversationMessage
]:
)
->
List
[
ConversationMessage
]:
texts
:
List
[
str
]
=
[]
texts
:
List
[
str
]
=
[]
# multimodal placeholder_string : count
mm_parser
=
mm_tracker
.
create_parser
()
mm_placeholder_counts
:
Dict
[
str
,
int
]
=
{}
for
part
in
parts
:
for
part
in
parts
:
part_type
=
part
[
"type"
]
part_type
=
part
[
"type"
]
...
@@ -247,22 +346,16 @@ def _parse_chat_message_content_parts(
...
@@ -247,22 +346,16 @@ def _parse_chat_message_content_parts(
"'image_url.detail' is currently not supported and "
"'image_url.detail' is currently not supported and "
"will be ignored."
)
"will be ignored."
)
image_coro
=
async_get_and_parse_image
(
image_url
[
"url"
])
mm_parser
.
parse_image
(
image_url
[
"url"
])
placeholder
=
mm_tracker
.
add
(
"image"
,
image_coro
)
if
placeholder
:
mm_placeholder_counts
[
placeholder
]
=
mm_placeholder_counts
.
get
(
placeholder
,
0
)
+
1
elif
part_type
==
"audio_url"
:
elif
part_type
==
"audio_url"
:
audio_url
=
_AudioParser
.
validate_python
(
part
)[
"audio_url"
]
audio_url
=
_AudioParser
.
validate_python
(
part
)[
"audio_url"
]
audio_coro
=
async_get_and_parse_audio
(
audio_url
[
"url"
])
placeholder
=
mm_tracker
.
add
(
"audio"
,
audio_coro
)
mm_parser
.
parse_audio
(
audio_url
[
"url"
])
if
placeholder
:
mm_placeholder_counts
[
placeholder
]
=
mm_placeholder_counts
.
get
(
placeholder
,
0
)
+
1
else
:
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
text_prompt
=
"
\n
"
.
join
(
texts
)
text_prompt
=
"
\n
"
.
join
(
texts
)
mm_placeholder_counts
=
mm_parser
.
mm_placeholder_counts
()
if
mm_placeholder_counts
:
if
mm_placeholder_counts
:
text_prompt
=
_get_full_multimodal_text_prompt
(
mm_placeholder_counts
,
text_prompt
=
_get_full_multimodal_text_prompt
(
mm_placeholder_counts
,
text_prompt
)
text_prompt
)
...
@@ -271,8 +364,9 @@ def _parse_chat_message_content_parts(
...
@@ -271,8 +364,9 @@ def _parse_chat_message_content_parts(
def
_parse_chat_message_content
(
def
_parse_chat_message_content
(
message
:
ChatCompletionMessageParam
,
message
:
ChatCompletionMessageParam
,
mm_tracker
:
MultiModalItemTracker
)
->
List
[
ConversationMessage
]:
mm_tracker
:
BaseMultiModalItemTracker
,
)
->
List
[
ConversationMessage
]:
role
=
message
[
"role"
]
role
=
message
[
"role"
]
content
=
message
.
get
(
"content"
)
content
=
message
.
get
(
"content"
)
...
@@ -292,7 +386,7 @@ def parse_chat_messages(
...
@@ -292,7 +386,7 @@ def parse_chat_messages(
messages
:
List
[
ChatCompletionMessageParam
],
messages
:
List
[
ChatCompletionMessageParam
],
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
)
->
Tuple
[
List
[
ConversationMessage
],
Optional
[
Awaitable
[
MultiModalDataDict
]]
]
:
)
->
Tuple
[
List
[
ConversationMessage
],
Optional
[
MultiModalDataDict
]]:
conversation
:
List
[
ConversationMessage
]
=
[]
conversation
:
List
[
ConversationMessage
]
=
[]
mm_tracker
=
MultiModalItemTracker
(
model_config
,
tokenizer
)
mm_tracker
=
MultiModalItemTracker
(
model_config
,
tokenizer
)
...
@@ -304,6 +398,22 @@ def parse_chat_messages(
...
@@ -304,6 +398,22 @@ def parse_chat_messages(
return
conversation
,
mm_tracker
.
all_mm_data
()
return
conversation
,
mm_tracker
.
all_mm_data
()
def
parse_chat_messages_futures
(
messages
:
List
[
ChatCompletionMessageParam
],
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
,
)
->
Tuple
[
List
[
ConversationMessage
],
Awaitable
[
Optional
[
MultiModalDataDict
]]]:
conversation
:
List
[
ConversationMessage
]
=
[]
mm_tracker
=
AsyncMultiModalItemTracker
(
model_config
,
tokenizer
)
for
msg
in
messages
:
sub_messages
=
_parse_chat_message_content
(
msg
,
mm_tracker
)
conversation
.
extend
(
sub_messages
)
return
conversation
,
mm_tracker
.
all_mm_data
()
def
apply_chat_template
(
def
apply_chat_template
(
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
conversation
:
List
[
ConversationMessage
],
conversation
:
List
[
ConversationMessage
],
...
...
vllm/entrypoints/llm.py
View file @
855c262a
...
@@ -23,7 +23,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
...
@@ -23,7 +23,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer
)
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
,
deprecate_kwargs
from
vllm.utils
import
Counter
,
deprecate_kwargs
,
is_list_of
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -358,15 +358,18 @@ class LLM:
...
@@ -358,15 +358,18 @@ class LLM:
add_generation_prompt
:
bool
=
True
,
add_generation_prompt
:
bool
=
True
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
"""
"""
Generate
s
responses for chat
messages
.
Generate responses for
a
chat
conversation
.
Converts the messages to prompts using the tokenizer and calls
The chat conversation is converted into a text prompt using the
the :meth:`generate` method to generate the responses.
tokenizer and calls the :meth:`generate` method to generate the
responses.
Multi-modal inputs can be passed in the same way you would pass them
to the OpenAI API.
Args:
Args:
messages: A list of messages to generate responses for. Each
messages: A single conversation represented as a list of messages.
message is a list of dictionaries with 'role' and 'content'
Each message is a dictionary with 'role' and 'content' keys.
keys.
sampling_params: The sampling parameters for text generation.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
is a single value, it is applied to every prompt. When it
...
@@ -387,21 +390,25 @@ class LLM:
...
@@ -387,21 +390,25 @@ class LLM:
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
()
model_config
=
self
.
llm_engine
.
get_model_config
()
model_config
=
self
.
llm_engine
.
get_model_config
()
conversation
s
,
_
=
parse_chat_messages
(
messages
,
model_config
,
conversation
,
mm_data
=
parse_chat_messages
(
messages
,
model_config
,
tokenizer
)
tokenizer
)
prompt
=
apply_chat_template
(
prompt
=
apply_chat_template
(
tokenizer
,
tokenizer
,
conversation
s
,
conversation
,
chat_template
=
chat_template
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
)
add_generation_prompt
=
add_generation_prompt
,
)
inputs
:
PromptInputs
inputs
:
PromptInputs
if
is
instance
(
prompt
,
list
)
and
isinstance
(
prompt
[
0
]
,
int
):
if
is
_list_of
(
prompt
,
int
):
inputs
=
TokensPrompt
(
prompt_token_ids
=
prompt
)
inputs
=
TokensPrompt
(
prompt_token_ids
=
prompt
)
else
:
else
:
inputs
=
TextPrompt
(
prompt
=
prompt
)
inputs
=
TextPrompt
(
prompt
=
prompt
)
if
mm_data
is
not
None
:
inputs
[
"multi_modal_data"
]
=
mm_data
return
self
.
generate
(
return
self
.
generate
(
inputs
,
inputs
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
855c262a
...
@@ -11,7 +11,7 @@ from vllm.engine.protocol import AsyncEngineClient
...
@@ -11,7 +11,7 @@ from vllm.engine.protocol import AsyncEngineClient
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
apply_chat_template
,
apply_chat_template
,
load_chat_template
,
load_chat_template
,
parse_chat_messages
)
parse_chat_messages
_futures
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.protocol
import
(
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionLogProb
,
ChatCompletionLogProbs
,
ChatCompletionLogProb
,
ChatCompletionLogProbs
,
...
@@ -26,7 +26,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
...
@@ -26,7 +26,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
TextTokensPrompt
)
TextTokensPrompt
)
from
vllm.inputs
import
TokensPrompt
from
vllm.inputs
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
...
@@ -94,7 +93,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -94,7 +93,7 @@ class OpenAIServingChat(OpenAIServing):
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
tokenizer
=
await
self
.
async_engine_client
.
get_tokenizer
(
lora_request
)
lora_request
)
conversation
,
mm_data_future
=
parse_chat_messages
(
conversation
,
mm_data_future
=
parse_chat_messages
_futures
(
request
.
messages
,
model_config
,
tokenizer
)
request
.
messages
,
model_config
,
tokenizer
)
tool_dicts
=
None
if
request
.
tools
is
None
else
[
tool_dicts
=
None
if
request
.
tools
is
None
else
[
...
@@ -114,10 +113,8 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -114,10 +113,8 @@ class OpenAIServingChat(OpenAIServing):
logger
.
error
(
"Error in applying chat template from request: %s"
,
e
)
logger
.
error
(
"Error in applying chat template from request: %s"
,
e
)
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
mm_data
:
Optional
[
MultiModalDataDict
]
=
None
try
:
try
:
if
mm_data_future
:
mm_data
=
await
mm_data_future
mm_data
=
await
mm_data_future
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
"Error in loading multi-modal data: %s"
,
e
)
logger
.
error
(
"Error in loading multi-modal data: %s"
,
e
)
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
...
vllm/entrypoints/openai/serving_tokenization.py
View file @
855c262a
...
@@ -4,7 +4,7 @@ from vllm.config import ModelConfig
...
@@ -4,7 +4,7 @@ from vllm.config import ModelConfig
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.chat_utils
import
(
apply_chat_template
,
from
vllm.entrypoints.chat_utils
import
(
apply_chat_template
,
load_chat_template
,
load_chat_template
,
parse_chat_messages
)
parse_chat_messages
_futures
)
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
...
@@ -65,10 +65,11 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -65,10 +65,11 @@ class OpenAIServingTokenization(OpenAIServing):
if
isinstance
(
request
,
TokenizeChatRequest
):
if
isinstance
(
request
,
TokenizeChatRequest
):
model_config
=
self
.
model_config
model_config
=
self
.
model_config
conversation
,
mm_data_future
=
parse_chat_messages
(
conversation
,
mm_data_future
=
parse_chat_messages
_futures
(
request
.
messages
,
model_config
,
tokenizer
)
request
.
messages
,
model_config
,
tokenizer
)
if
mm_data_future
:
mm_data
=
await
mm_data_future
if
mm_data
:
logger
.
warning
(
logger
.
warning
(
"Multi-modal inputs are ignored during tokenization"
)
"Multi-modal inputs are ignored during tokenization"
)
...
...
vllm/multimodal/utils.py
View file @
855c262a
...
@@ -120,6 +120,16 @@ async def async_fetch_audio(
...
@@ -120,6 +120,16 @@ async def async_fetch_audio(
return
librosa
.
load
(
BytesIO
(
audio_bytes
),
sr
=
None
)
return
librosa
.
load
(
BytesIO
(
audio_bytes
),
sr
=
None
)
def
get_and_parse_audio
(
audio_url
:
str
)
->
MultiModalDataDict
:
audio
,
sr
=
fetch_audio
(
audio_url
)
return
{
"audio"
:
(
audio
,
sr
)}
def
get_and_parse_image
(
image_url
:
str
)
->
MultiModalDataDict
:
image
=
fetch_image
(
image_url
)
return
{
"image"
:
image
}
async
def
async_get_and_parse_audio
(
audio_url
:
str
)
->
MultiModalDataDict
:
async
def
async_get_and_parse_audio
(
audio_url
:
str
)
->
MultiModalDataDict
:
audio
,
sr
=
await
async_fetch_audio
(
audio_url
)
audio
,
sr
=
await
async_fetch_audio
(
audio_url
)
return
{
"audio"
:
(
audio
,
sr
)}
return
{
"audio"
:
(
audio
,
sr
)}
...
...
vllm/transformers_utils/tokenizers/mistral.py
View file @
855c262a
...
@@ -52,12 +52,13 @@ class MistralTokenizer:
...
@@ -52,12 +52,13 @@ class MistralTokenizer:
assert
isinstance
(
self
.
tokenizer
,
assert
isinstance
(
self
.
tokenizer
,
(
Tekkenizer
,
SentencePieceTokenizer
)),
type
(
(
Tekkenizer
,
SentencePieceTokenizer
)),
type
(
self
.
tokenizer
)
self
.
tokenizer
)
self
.
_is_tekken
=
isinstance
(
self
.
tokenizer
,
Tekkenizer
)
if
self
.
_is_tekken
:
if
(
is_tekken
:
=
isinstance
(
self
.
tokenizer
,
Tekkenizer
))
:
# Make sure special tokens will not raise
# Make sure special tokens will not raise
self
.
tokenizer
.
special_token_policy
=
SpecialTokenPolicy
.
IGNORE
self
.
tokenizer
.
special_token_policy
=
SpecialTokenPolicy
.
IGNORE
self
.
_is_tekken
=
is_tekken
# the following attributes are set to fit VLLM's design
# the following attributes are set to fit VLLM's design
self
.
is_fast
=
True
self
.
is_fast
=
True
self
.
chat_template
=
True
self
.
chat_template
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment