Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f399182e
Unverified
Commit
f399182e
authored
Sep 02, 2025
by
Chenheli Hua
Committed by
GitHub
Sep 02, 2025
Browse files
Run ruff format on a few files. (#24075)
Signed-off-by:
Chenheli Hua
<
huachenheli@outlook.com
>
parent
1c413105
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
947 additions
and
690 deletions
+947
-690
tests/entrypoints/test_chat_utils.py
tests/entrypoints/test_chat_utils.py
+578
-452
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+228
-150
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+141
-88
No files found.
tests/entrypoints/test_chat_utils.py
View file @
f399182e
...
@@ -46,23 +46,27 @@ MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
...
@@ -46,23 +46,27 @@ MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
@
pytest
.
fixture
(
scope
=
"function"
)
@
pytest
.
fixture
(
scope
=
"function"
)
def
phi3v_model_config
():
def
phi3v_model_config
():
return
ModelConfig
(
PHI3V_MODEL_ID
,
return
ModelConfig
(
PHI3V_MODEL_ID
,
runner
=
"generate"
,
runner
=
"generate"
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
limit_mm_per_prompt
=
{
limit_mm_per_prompt
=
{
"image"
:
2
,
"image"
:
2
,
})
},
)
@
pytest
.
fixture
(
scope
=
"function"
)
@
pytest
.
fixture
(
scope
=
"function"
)
def
phi3v_model_config_mm_interleaved
():
def
phi3v_model_config_mm_interleaved
():
return
ModelConfig
(
PHI3V_MODEL_ID
,
return
ModelConfig
(
PHI3V_MODEL_ID
,
runner
=
"generate"
,
runner
=
"generate"
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
interleave_mm_strings
=
True
,
interleave_mm_strings
=
True
,
limit_mm_per_prompt
=
{
limit_mm_per_prompt
=
{
"image"
:
2
,
"image"
:
2
,
})
},
)
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
...
@@ -77,14 +81,16 @@ def phi3v_tokenizer():
...
@@ -77,14 +81,16 @@ def phi3v_tokenizer():
@
pytest
.
fixture
(
scope
=
"function"
)
@
pytest
.
fixture
(
scope
=
"function"
)
def
qwen25omni_model_config_mm_interleaved
():
def
qwen25omni_model_config_mm_interleaved
():
return
ModelConfig
(
QWEN25OMNI_MODEL_ID
,
return
ModelConfig
(
QWEN25OMNI_MODEL_ID
,
runner
=
"generate"
,
runner
=
"generate"
,
interleave_mm_strings
=
True
,
interleave_mm_strings
=
True
,
limit_mm_per_prompt
=
{
limit_mm_per_prompt
=
{
"image"
:
2
,
"image"
:
2
,
"audio"
:
1
,
"audio"
:
1
,
"video"
:
1
,
"video"
:
1
,
})
},
)
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
...
@@ -99,11 +105,13 @@ def qwen25omni_tokenizer():
...
@@ -99,11 +105,13 @@ def qwen25omni_tokenizer():
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
mllama_model_config
():
def
mllama_model_config
():
return
ModelConfig
(
MLLAMA_MODEL_ID
,
return
ModelConfig
(
MLLAMA_MODEL_ID
,
runner
=
"generate"
,
runner
=
"generate"
,
limit_mm_per_prompt
=
{
limit_mm_per_prompt
=
{
"image"
:
2
,
"image"
:
2
,
})
},
)
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
...
@@ -118,11 +126,13 @@ def mllama_tokenizer():
...
@@ -118,11 +126,13 @@ def mllama_tokenizer():
@
pytest
.
fixture
(
scope
=
"function"
)
@
pytest
.
fixture
(
scope
=
"function"
)
def
mistral_model_config
():
def
mistral_model_config
():
return
ModelConfig
(
MISTRAL_MODEL_ID
,
return
ModelConfig
(
MISTRAL_MODEL_ID
,
runner
=
"generate"
,
runner
=
"generate"
,
limit_mm_per_prompt
=
{
limit_mm_per_prompt
=
{
"image"
:
2
,
"image"
:
2
,
})
},
)
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
...
@@ -137,21 +147,21 @@ def mistral_tokenizer():
...
@@ -137,21 +147,21 @@ def mistral_tokenizer():
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
image_url
():
def
image_url
():
image
=
ImageAsset
(
'
cherry_blossom
'
)
image
=
ImageAsset
(
"
cherry_blossom
"
)
base64
=
encode_image_base64
(
image
.
pil_image
)
base64
=
encode_image_base64
(
image
.
pil_image
)
return
f
"data:image/jpeg;base64,
{
base64
}
"
return
f
"data:image/jpeg;base64,
{
base64
}
"
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
video_url
():
def
video_url
():
video
=
VideoAsset
(
'
baby_reading
'
,
1
)
video
=
VideoAsset
(
"
baby_reading
"
,
1
)
base64
=
encode_video_base64
(
video
.
np_ndarrays
)
base64
=
encode_video_base64
(
video
.
np_ndarrays
)
return
f
"data:video/jpeg;base64,
{
base64
}
"
return
f
"data:video/jpeg;base64,
{
base64
}
"
@
pytest
.
fixture
(
scope
=
"module"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
audio_url
():
def
audio_url
():
audio
=
AudioAsset
(
'
mary_had_lamb
'
)
audio
=
AudioAsset
(
"
mary_had_lamb
"
)
base64
=
encode_audio_base64
(
*
audio
.
audio_and_sample_rate
)
base64
=
encode_audio_base64
(
*
audio
.
audio_and_sample_rate
)
return
f
"data:audio/ogg;base64,
{
base64
}
"
return
f
"data:audio/ogg;base64,
{
base64
}
"
...
@@ -195,15 +205,18 @@ def test_parse_chat_messages_single_image(
...
@@ -195,15 +205,18 @@ def test_parse_chat_messages_single_image(
[{
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
}
},
{
},
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"What's in the image?"
"text"
:
"What's in the image?"
}]
},
],
}],
}],
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
phi3v_tokenizer
,
...
@@ -223,58 +236,69 @@ def test_parse_chat_messages_empty_system(
...
@@ -223,58 +236,69 @@ def test_parse_chat_messages_empty_system(
):
):
# Test string format
# Test string format
conversation
,
_
=
parse_chat_messages
(
conversation
,
_
=
parse_chat_messages
(
[{
[
{
"role"
:
"system"
,
"role"
:
"system"
,
"content"
:
""
"content"
:
""
},
{
},
{
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
[{
"content"
:
[{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"Who are you?"
"text"
:
"Who are you?"
}]
}],
}],
},
],
mistral_model_config
,
mistral_model_config
,
mistral_tokenizer
,
mistral_tokenizer
,
content_format
=
"string"
,
content_format
=
"string"
,
)
)
assert
conversation
==
[{
assert
conversation
==
[
{
"role"
:
"system"
,
"role"
:
"system"
,
"content"
:
""
"content"
:
""
},
{
},
{
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
"Who are you?"
"content"
:
"Who are you?"
}]
},
]
# Test openai format
# Test openai format
conversation
,
_
=
parse_chat_messages
(
conversation
,
_
=
parse_chat_messages
(
[{
[
{
"role"
:
"system"
,
"role"
:
"system"
,
"content"
:
""
"content"
:
""
},
{
},
{
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
[{
"content"
:
[{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"Who are you?"
"text"
:
"Who are you?"
}]
}],
}],
},
],
mistral_model_config
,
mistral_model_config
,
mistral_tokenizer
,
mistral_tokenizer
,
content_format
=
"openai"
,
content_format
=
"openai"
,
)
)
assert
conversation
==
[{
assert
conversation
==
[
{
"role"
:
"system"
,
"role"
:
"system"
,
"content"
:
[{
"content"
:
[{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
""
"text"
:
""
}]
}]
},
{
},
"role"
:
{
"user"
,
"role"
:
"user"
,
"content"
:
[{
"content"
:
[{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"Who are you?"
"text"
:
"Who are you?"
}]
}]
}]
},
]
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
...
@@ -287,15 +311,18 @@ async def test_parse_chat_messages_single_image_async(
...
@@ -287,15 +311,18 @@ async def test_parse_chat_messages_single_image_async(
[{
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
}
},
{
},
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"What's in the image?"
"text"
:
"What's in the image?"
}]
},
],
}],
}],
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
phi3v_tokenizer
,
...
@@ -318,18 +345,22 @@ def test_parse_chat_messages_multiple_images(
...
@@ -318,18 +345,22 @@ def test_parse_chat_messages_multiple_images(
[{
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
}
},
{
},
{
"type"
:
"image_pil"
,
"type"
:
"image_pil"
,
"image_pil"
:
ImageAsset
(
'cherry_blossom'
).
pil_image
"image_pil"
:
ImageAsset
(
"cherry_blossom"
).
pil_image
,
},
{
},
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"What's in these images?"
"text"
:
"What's in these images?"
}]
},
],
}],
}],
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
phi3v_tokenizer
,
...
@@ -340,7 +371,7 @@ def test_parse_chat_messages_multiple_images(
...
@@ -340,7 +371,7 @@ def test_parse_chat_messages_multiple_images(
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
"content"
:
"<|image_1|>
\n
<|image_2|>
\n
What's in these images?"
"<|image_1|>
\n
<|image_2|>
\n
What's in these images?"
,
}]
}]
_assert_mm_data_is_image_input
(
mm_data
,
2
)
_assert_mm_data_is_image_input
(
mm_data
,
2
)
...
@@ -355,18 +386,22 @@ async def test_parse_chat_messages_multiple_images_async(
...
@@ -355,18 +386,22 @@ async def test_parse_chat_messages_multiple_images_async(
[{
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
}
},
{
},
{
"type"
:
"image_pil"
,
"type"
:
"image_pil"
,
"image_pil"
:
ImageAsset
(
'cherry_blossom'
).
pil_image
"image_pil"
:
ImageAsset
(
"cherry_blossom"
).
pil_image
,
},
{
},
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"What's in these images?"
"text"
:
"What's in these images?"
}]
},
],
}],
}],
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
phi3v_tokenizer
,
...
@@ -377,7 +412,7 @@ async def test_parse_chat_messages_multiple_images_async(
...
@@ -377,7 +412,7 @@ async def test_parse_chat_messages_multiple_images_async(
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
"content"
:
"<|image_1|>
\n
<|image_2|>
\n
What's in these images?"
"<|image_1|>
\n
<|image_2|>
\n
What's in these images?"
,
}]
}]
_assert_mm_data_is_image_input
(
await
mm_future
,
2
)
_assert_mm_data_is_image_input
(
await
mm_future
,
2
)
...
@@ -391,22 +426,26 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
...
@@ -391,22 +426,26 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
[{
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
}
},
{
},
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
}
},
{
},
{
"type"
:
"type"
:
"text"
,
"text"
,
"text"
:
"text"
:
"What's in <|image_1|> and how does it compare to <|image_2|>?"
"What's in <|image_1|> and how does it compare to <|image_2|>?"
,
# noqa: E501
}]
},
],
}],
}],
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
phi3v_tokenizer
,
...
@@ -416,7 +455,7 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
...
@@ -416,7 +455,7 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
"content"
:
"What's in <|image_1|> and how does it compare to <|image_2|>?"
"What's in <|image_1|> and how does it compare to <|image_2|>?"
,
}]
}]
_assert_mm_data_is_image_input
(
mm_data
,
2
)
_assert_mm_data_is_image_input
(
mm_data
,
2
)
...
@@ -447,9 +486,9 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
...
@@ -447,9 +486,9 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
"type"
:
"type"
:
"text"
,
"text"
,
"text"
:
"text"
:
"What's in <|image_1|> and how does it compare to the other one?"
# noqa: E501
"What's in <|image_1|> and how does it compare to the other one?"
,
# noqa: E501
}
}
,
]
]
,
}],
}],
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
phi3v_tokenizer
,
...
@@ -461,7 +500,7 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
...
@@ -461,7 +500,7 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
"user"
,
"user"
,
"content"
:
"content"
:
"<|image_2|>
\n
What's in <|image_1|> and how does it compare to the "
"<|image_2|>
\n
What's in <|image_1|> and how does it compare to the "
"other one?"
"other one?"
,
}]
}]
_assert_mm_data_is_image_input
(
mm_data
,
2
)
_assert_mm_data_is_image_input
(
mm_data
,
2
)
...
@@ -472,34 +511,44 @@ def test_parse_chat_messages_multiple_images_across_messages(
...
@@ -472,34 +511,44 @@ def test_parse_chat_messages_multiple_images_across_messages(
image_url
,
image_url
,
):
):
conversation
,
mm_data
=
parse_chat_messages
(
conversation
,
mm_data
=
parse_chat_messages
(
[{
[
{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
}
},
{
},
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"What's in this image?"
"text"
:
"What's in this image?"
}]
},
},
{
],
},
{
"role"
:
"assistant"
,
"role"
:
"assistant"
,
"content"
:
"Some stuff."
"content"
:
"Some stuff."
},
{
},
{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
}
},
{
},
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"What about this one?"
"text"
:
"What about this one?"
}]
},
}],
],
},
],
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
phi3v_tokenizer
,
content_format
=
"string"
,
content_format
=
"string"
,
...
@@ -527,19 +576,23 @@ def test_parse_chat_messages_context_text_format(
...
@@ -527,19 +576,23 @@ def test_parse_chat_messages_context_text_format(
phi3v_tokenizer
,
phi3v_tokenizer
,
):
):
conversation
,
mm_data
=
parse_chat_messages
(
conversation
,
mm_data
=
parse_chat_messages
(
[{
[
{
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
[{
"content"
:
[{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"What's in this text?"
"text"
:
"What's in this text?"
}]
}],
},
{
},
{
"role"
:
"assistant"
,
"role"
:
"assistant"
,
"content"
:
"Some stuff."
"content"
:
"Some stuff."
},
{
},
{
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
"What about this one?"
"content"
:
"What about this one?"
}],
},
],
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
phi3v_tokenizer
,
content_format
=
"openai"
,
content_format
=
"openai"
,
...
@@ -551,21 +604,21 @@ def test_parse_chat_messages_context_text_format(
...
@@ -551,21 +604,21 @@ def test_parse_chat_messages_context_text_format(
"content"
:
[{
"content"
:
[{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"What's in this text?"
"text"
:
"What's in this text?"
}]
}]
,
},
},
{
{
"role"
:
"assistant"
,
"role"
:
"assistant"
,
"content"
:
[{
"content"
:
[{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"Some stuff."
"text"
:
"Some stuff."
}]
}]
,
},
},
{
{
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
[{
"content"
:
[{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"What about this one?"
"text"
:
"What about this one?"
}]
}]
,
},
},
]
]
...
@@ -578,31 +631,37 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message(
...
@@ -578,31 +631,37 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message(
with
warnings
.
catch_warnings
():
with
warnings
.
catch_warnings
():
warnings
.
filterwarnings
(
warnings
.
filterwarnings
(
"ignore"
,
"ignore"
,
message
=
"coroutine 'async_get_and_parse_image' was never awaited"
)
message
=
"coroutine 'async_get_and_parse_image' was never awaited"
,
)
with
pytest
.
raises
(
ValueError
,
match
=
"At most"
):
with
pytest
.
raises
(
ValueError
,
match
=
"At most"
):
parse_chat_messages
(
parse_chat_messages
(
[{
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
},
},
{
},
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
},
},
{
},
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
},
},
{
},
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"What's in these images?"
"text"
:
"What's in these images?"
}]
},
],
}],
}],
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
phi3v_tokenizer
,
...
@@ -618,42 +677,54 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
...
@@ -618,42 +677,54 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
with
warnings
.
catch_warnings
():
with
warnings
.
catch_warnings
():
warnings
.
filterwarnings
(
warnings
.
filterwarnings
(
"ignore"
,
"ignore"
,
message
=
"coroutine 'async_get_and_parse_image' was never awaited"
)
message
=
"coroutine 'async_get_and_parse_image' was never awaited"
,
)
with
pytest
.
raises
(
ValueError
,
match
=
"At most"
):
with
pytest
.
raises
(
ValueError
,
match
=
"At most"
):
parse_chat_messages
(
parse_chat_messages
(
[{
[
{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
},
},
{
},
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"What's in this image?"
"text"
:
"What's in this image?"
}]
},
},
{
],
},
{
"role"
:
"assistant"
,
"role"
:
"assistant"
,
"content"
:
"Some stuff."
"content"
:
"Some stuff."
},
{
},
{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
},
},
{
},
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
},
},
{
},
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"What about these two?"
"text"
:
"What about these two?"
}]
},
}],
],
},
],
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
phi3v_tokenizer
,
content_format
=
"string"
,
content_format
=
"string"
,
...
@@ -670,12 +741,14 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
...
@@ -670,12 +741,14 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[
"content"
:
[
"What's in these images?"
,
{
"What's in these images?"
,
{
"image_url"
:
image_url
"image_url"
:
image_url
},
{
},
{
"image_url"
:
image_url
"image_url"
:
image_url
}
}
,
]
]
,
}],
}],
phi3v_model_config
,
phi3v_model_config
,
phi3v_tokenizer
,
phi3v_tokenizer
,
...
@@ -686,7 +759,7 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
...
@@ -686,7 +759,7 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
"content"
:
"<|image_1|>
\n
<|image_2|>
\n
What's in these images?"
"<|image_1|>
\n
<|image_2|>
\n
What's in these images?"
,
}]
}]
_assert_mm_data_is_image_input
(
mm_data
,
2
)
_assert_mm_data_is_image_input
(
mm_data
,
2
)
...
@@ -700,26 +773,32 @@ def test_parse_chat_messages_multiple_images_interleave(
...
@@ -700,26 +773,32 @@ def test_parse_chat_messages_multiple_images_interleave(
[{
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"I need you to compare this image"
"text"
:
"I need you to compare this image"
,
},
{
},
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
}
},
{
},
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"and this one"
"text"
:
"and this one"
},
{
},
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
}
},
{
},
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"Do they have differences?"
"text"
:
"Do they have differences?"
}]
},
],
}],
}],
phi3v_model_config_mm_interleaved
,
phi3v_model_config_mm_interleaved
,
phi3v_tokenizer
,
phi3v_tokenizer
,
...
@@ -731,7 +810,7 @@ def test_parse_chat_messages_multiple_images_interleave(
...
@@ -731,7 +810,7 @@ def test_parse_chat_messages_multiple_images_interleave(
"user"
,
"user"
,
"content"
:
"content"
:
"I need you to compare this image
\n
<|image_1|>
\n
and this one
\n
<|image_2|>
\n
"
# noqa: E501
"I need you to compare this image
\n
<|image_1|>
\n
and this one
\n
<|image_2|>
\n
"
# noqa: E501
"Do they have differences?"
"Do they have differences?"
,
}]
}]
_assert_mm_data_is_image_input
(
mm_data
,
2
)
_assert_mm_data_is_image_input
(
mm_data
,
2
)
...
@@ -746,26 +825,32 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
...
@@ -746,26 +825,32 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
[{
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"I need you to compare this image"
"text"
:
"I need you to compare this image"
,
},
{
},
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
}
},
{
},
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"and this one"
"text"
:
"and this one"
},
{
},
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
}
},
{
},
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"Do they have differences?"
"text"
:
"Do they have differences?"
}]
},
],
}],
}],
phi3v_model_config_mm_interleaved
,
phi3v_model_config_mm_interleaved
,
phi3v_tokenizer
,
phi3v_tokenizer
,
...
@@ -777,7 +862,7 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
...
@@ -777,7 +862,7 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
"user"
,
"user"
,
"content"
:
"content"
:
"I need you to compare this image
\n
<|image_1|>
\n
and this one
\n
<|image_2|>
\n
"
# noqa: E501
"I need you to compare this image
\n
<|image_1|>
\n
and this one
\n
<|image_2|>
\n
"
# noqa: E501
"Do they have differences?"
"Do they have differences?"
,
}]
}]
_assert_mm_data_is_image_input
(
await
mm_data
,
2
)
_assert_mm_data_is_image_input
(
await
mm_data
,
2
)
...
@@ -788,7 +873,8 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
...
@@ -788,7 +873,8 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
image_url
,
image_url
,
):
):
conversation
,
mm_data
=
parse_chat_messages
(
conversation
,
mm_data
=
parse_chat_messages
(
[{
[
{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[
"content"
:
[
...
@@ -806,48 +892,61 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
...
@@ -806,48 +892,61 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"Be accurate."
"text"
:
"Be accurate."
},
},
]
],
},
{
},
{
"role"
:
"assistant"
,
"role"
:
"assistant"
,
"content"
:
"Some stuff."
"content"
:
"Some stuff."
},
{
},
{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"What's on this image?"
"text"
:
"What's on this image?"
},
{
},
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
}
}]
},
}],
],
},
],
phi3v_model_config_mm_interleaved
,
phi3v_model_config_mm_interleaved
,
phi3v_tokenizer
,
phi3v_tokenizer
,
content_format
=
"string"
,
content_format
=
"string"
,
)
)
assert
conversation
==
[
{
assert
conversation
==
[
"role"
:
{
"user"
,
"role"
:
"user"
,
"content"
:
"content"
:
"What's on this image?
\n
<|image_1|>
\n
Be accurate."
,
"What's on this image?
\n
<|image_1|>
\n
Be accurate."
},
},
{
{
"role"
:
"assistant"
,
"role"
:
"assistant"
,
"content"
:
"Some stuff."
"content"
:
"Some stuff."
},
{
},
{
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
"What's on this image?
\n
<|image_2|>"
"content"
:
"What's on this image?
\n
<|image_2|>"
}]
},
]
_assert_mm_data_is_image_input
(
mm_data
,
2
)
_assert_mm_data_is_image_input
(
mm_data
,
2
)
def
test_parse_chat_messages_multiple_modals_multiple_messages_interleave
(
def
test_parse_chat_messages_multiple_modals_multiple_messages_interleave
(
qwen25omni_model_config_mm_interleaved
,
qwen25omni_tokenizer
,
qwen25omni_model_config_mm_interleaved
,
image_url
,
video_url
,
audio_url
):
qwen25omni_tokenizer
,
image_url
,
video_url
,
audio_url
,
):
conversation
,
mm_data
=
parse_chat_messages
(
conversation
,
mm_data
=
parse_chat_messages
(
[{
[
{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[
"content"
:
[
...
@@ -871,52 +970,64 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
...
@@ -871,52 +970,64 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
"url"
:
audio_url
"url"
:
audio_url
}
}
},
},
]
],
},
{
},
{
"role"
:
"assistant"
,
"role"
:
"assistant"
,
"content"
:
"Some stuff."
"content"
:
"Some stuff."
},
{
},
{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"What's on this image?"
"text"
:
"What's on this image?"
},
{
},
{
"type"
:
"image_url"
,
"type"
:
"image_url"
,
"image_url"
:
{
"image_url"
:
{
"url"
:
image_url
"url"
:
image_url
}
}
},
{
},
{
"type"
:
"text"
,
"type"
:
"text"
,
"text"
:
"And what's in the video?"
"text"
:
"And what's in the video?"
},
{
},
{
"type"
:
"video_url"
,
"type"
:
"video_url"
,
"video_url"
:
{
"video_url"
:
{
"url"
:
video_url
"url"
:
video_url
}
}
}]
},
}],
],
},
],
qwen25omni_model_config_mm_interleaved
,
qwen25omni_model_config_mm_interleaved
,
qwen25omni_tokenizer
,
qwen25omni_tokenizer
,
content_format
=
"string"
,
content_format
=
"string"
,
)
)
assert
conversation
==
[{
assert
conversation
==
[
{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
"content"
:
"What's on this image?
\n
<|vision_start|><|IMAGE|><|vision_end|>
\n
"
"What's on this image?
\n
<|vision_start|><|IMAGE|><|vision_end|>
\n
"
"Now listen to this audio
\n
Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>"
"Now listen to this audio
\n
Audio 1: <|audio_bos|><|AUDIO|><|audio_eos|>"
,
# noqa: E501
},
{
},
{
"role"
:
"assistant"
,
"role"
:
"assistant"
,
"content"
:
"Some stuff."
"content"
:
"Some stuff."
},
{
},
{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
"content"
:
"What's on this image?
\n
<|vision_start|><|IMAGE|><|vision_end|>
\n
"
"What's on this image?
\n
<|vision_start|><|IMAGE|><|vision_end|>
\n
"
"And what's in the video?
\n
<|vision_start|><|VIDEO|><|vision_end|>"
"And what's in the video?
\n
<|vision_start|><|VIDEO|><|vision_end|>"
,
}]
},
]
_assert_mm_data_inputs
(
mm_data
,
{
"image"
:
2
,
"video"
:
1
,
"audio"
:
1
})
_assert_mm_data_inputs
(
mm_data
,
{
"image"
:
2
,
"video"
:
1
,
"audio"
:
1
})
...
@@ -929,7 +1040,8 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
...
@@ -929,7 +1040,8 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
with
pytest
.
raises
(
with
pytest
.
raises
(
ValueError
,
ValueError
,
match
=
r
"Found more '<|image_1|>' placeholders in input prompt "
match
=
r
"Found more '<|image_1|>' placeholders in input prompt "
"than actual multimodal data items."
):
"than actual multimodal data items."
,
):
parse_chat_messages
(
parse_chat_messages
(
[{
[{
"role"
:
"role"
:
...
@@ -952,9 +1064,9 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
...
@@ -952,9 +1064,9 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
"text"
,
"text"
,
"text"
:
"text"
:
"I need you to compare this image
\n
<|image_1|>
\n
and this one
\n
<|image_2|>
\n
"
# noqa: E501
"I need you to compare this image
\n
<|image_1|>
\n
and this one
\n
<|image_2|>
\n
"
# noqa: E501
"Do they have differences?"
"Do they have differences?"
,
},
},
]
]
,
}],
}],
phi3v_model_config_mm_interleaved
,
phi3v_model_config_mm_interleaved
,
phi3v_tokenizer
,
phi3v_tokenizer
,
...
@@ -973,12 +1085,15 @@ def test_mllama_single_image(
...
@@ -973,12 +1085,15 @@ def test_mllama_single_image(
[{
[{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
[{
"content"
:
[
'type'
:
'text'
,
{
'text'
:
'The content of this image is:'
"type"
:
"text"
,
},
{
"text"
:
"The content of this image is:"
},
{
"image_url"
:
image_url
"image_url"
:
image_url
}]
},
],
}],
}],
mllama_model_config
,
mllama_model_config
,
mllama_tokenizer
,
mllama_tokenizer
,
...
@@ -986,14 +1101,17 @@ def test_mllama_single_image(
...
@@ -986,14 +1101,17 @@ def test_mllama_single_image(
)
)
_assert_mm_data_is_image_input
(
mm_data
,
1
)
_assert_mm_data_is_image_input
(
mm_data
,
1
)
assert
conversation
==
[{
assert
conversation
==
[{
'role'
:
"role"
:
'user'
,
"user"
,
'content'
:
[{
"content"
:
[
'type'
:
'text'
,
{
'text'
:
'The content of this image is:'
"type"
:
"text"
,
},
{
"text"
:
"The content of this image is:"
'type'
:
'image'
},
}]
{
"type"
:
"image"
},
],
}]
}]
...
@@ -1009,20 +1127,20 @@ def test_mllama_interleaved_images(
...
@@ -1009,20 +1127,20 @@ def test_mllama_interleaved_images(
"user"
,
"user"
,
"content"
:
[
"content"
:
[
{
{
'
type
'
:
'
text
'
,
"
type
"
:
"
text
"
,
'
text
'
:
'
The content of the first image is:
'
"
text
"
:
"
The content of the first image is:
"
,
},
},
{
{
"image_url"
:
image_url
"image_url"
:
image_url
},
},
{
{
'
type
'
:
'
text
'
,
"
type
"
:
"
text
"
,
'
text
'
:
'
The content of the second image is:
'
"
text
"
:
"
The content of the second image is:
"
,
},
},
{
{
"image_url"
:
image_url
"image_url"
:
image_url
},
},
]
]
,
}],
}],
mllama_model_config
,
mllama_model_config
,
mllama_tokenizer
,
mllama_tokenizer
,
...
@@ -1030,19 +1148,24 @@ def test_mllama_interleaved_images(
...
@@ -1030,19 +1148,24 @@ def test_mllama_interleaved_images(
)
)
_assert_mm_data_is_image_input
(
mm_data
,
2
)
_assert_mm_data_is_image_input
(
mm_data
,
2
)
assert
conversation
==
[{
assert
conversation
==
[{
'role'
:
"role"
:
'user'
,
"user"
,
'content'
:
[{
"content"
:
[
'type'
:
'text'
,
{
'text'
:
'The content of the first image is:'
"type"
:
"text"
,
},
{
"text"
:
"The content of the first image is:"
'type'
:
'image'
},
},
{
{
'type'
:
'text'
,
"type"
:
"image"
'text'
:
'The content of the second image is:'
},
},
{
{
'type'
:
'image'
"type"
:
"text"
,
}]
"text"
:
"The content of the second image is:"
},
{
"type"
:
"image"
},
],
}]
}]
...
@@ -1053,34 +1176,36 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
...
@@ -1053,34 +1176,36 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
def
get_conversation
(
is_hf
:
bool
):
def
get_conversation
(
is_hf
:
bool
):
img_part
=
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}}
img_part
=
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_url
}}
if
is_hf
:
if
is_hf
:
img_part
=
{
'
type
'
:
'
image
'
}
img_part
=
{
"
type
"
:
"
image
"
}
return
[{
return
[{
'
role
'
:
"
role
"
:
'
user
'
,
"
user
"
,
'
content
'
:
[
"
content
"
:
[
{
{
'
type
'
:
'
text
'
,
"
type
"
:
"
text
"
,
'
text
'
:
'
The content of the first image is:
'
"
text
"
:
"
The content of the first image is:
"
,
},
},
img_part
,
img_part
,
{
{
'
type
'
:
'
text
'
,
"
type
"
:
"
text
"
,
'
text
'
:
'
The content of the second image is:
'
"
text
"
:
"
The content of the second image is:
"
,
},
},
img_part
,
img_part
,
{
{
'
type
'
:
'
text
'
,
"
type
"
:
"
text
"
,
'
text
'
:
'
What animal is in the first image?
'
"
text
"
:
"
What animal is in the first image?
"
,
},
},
]
]
,
}]
}]
# Build a config for the model
# Build a config for the model
model_config
=
ModelConfig
(
model
,
model_config
=
ModelConfig
(
model
,
runner
=
"generate"
,
runner
=
"generate"
,
limit_mm_per_prompt
=
{
limit_mm_per_prompt
=
{
"image"
:
2
,
"image"
:
2
,
})
},
)
# Build the tokenizer group and grab the underlying tokenizer
# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group
=
TokenizerGroup
(
tokenizer_group
=
TokenizerGroup
(
...
@@ -1126,7 +1251,8 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
...
@@ -1126,7 +1251,8 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
[
[
QWEN2VL_MODEL_ID
,
# tokenizer.chat_template is of type str
QWEN2VL_MODEL_ID
,
# tokenizer.chat_template is of type str
HERMES_MODEL_ID
,
# tokenizer.chat_template is of type dict
HERMES_MODEL_ID
,
# tokenizer.chat_template is of type dict
])
],
)
@
pytest
.
mark
.
parametrize
(
"use_tools"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_tools"
,
[
True
,
False
])
def
test_resolve_hf_chat_template
(
sample_json_schema
,
model
,
use_tools
):
def
test_resolve_hf_chat_template
(
sample_json_schema
,
model
,
use_tools
):
"""checks that chat_template is a dict type for HF models."""
"""checks that chat_template is a dict type for HF models."""
...
@@ -1152,14 +1278,14 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
...
@@ -1152,14 +1278,14 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
)
)
tokenizer
=
tokenizer_group
.
tokenizer
tokenizer
=
tokenizer_group
.
tokenizer
tools
=
[{
tools
=
(
[{
"type"
:
"function"
,
"type"
:
"function"
,
"function"
:
{
"function"
:
{
"name"
:
"dummy_function_name"
,
"name"
:
"dummy_function_name"
,
"description"
:
"This is a dummy function"
,
"description"
:
"This is a dummy function"
,
"parameters"
:
sample_json_schema
"parameters"
:
sample_json_schema
,
}
}
,
}]
if
use_tools
else
None
}]
if
use_tools
else
None
)
# Test detecting the tokenizer's chat_template
# Test detecting the tokenizer's chat_template
chat_template
=
resolve_hf_chat_template
(
chat_template
=
resolve_hf_chat_template
(
...
...
vllm/entrypoints/chat_utils.py
View file @
f399182e
...
@@ -103,6 +103,7 @@ class PILImage(BaseModel):
...
@@ -103,6 +103,7 @@ class PILImage(BaseModel):
"""
"""
A PIL.Image.Image object.
A PIL.Image.Image object.
"""
"""
image_pil
:
Image
.
Image
image_pil
:
Image
.
Image
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
...
@@ -115,6 +116,7 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
...
@@ -115,6 +116,7 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
"image_pil": ImageAsset('cherry_blossom').pil_image
"image_pil": ImageAsset('cherry_blossom').pil_image
}
}
"""
"""
image_pil
:
Required
[
PILImage
]
image_pil
:
Required
[
PILImage
]
...
@@ -127,6 +129,7 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
...
@@ -127,6 +129,7 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"image_url": "https://example.com/image.jpg"
"image_url": "https://example.com/image.jpg"
}
}
"""
"""
image_url
:
Required
[
str
]
image_url
:
Required
[
str
]
...
@@ -138,6 +141,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
...
@@ -138,6 +141,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
"audio_url": "https://example.com/audio.mp3"
"audio_url": "https://example.com/audio.mp3"
}
}
"""
"""
audio_url
:
Required
[
str
]
audio_url
:
Required
[
str
]
...
@@ -149,6 +153,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
...
@@ -149,6 +153,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
"video_url": "https://example.com/video.mp4"
"video_url": "https://example.com/video.mp4"
}
}
"""
"""
video_url
:
Required
[
str
]
video_url
:
Required
[
str
]
...
@@ -174,19 +179,24 @@ class CustomThinkCompletionContentParam(TypedDict, total=False):
...
@@ -174,19 +179,24 @@ class CustomThinkCompletionContentParam(TypedDict, total=False):
ChatCompletionContentPartParam
:
TypeAlias
=
Union
[
ChatCompletionContentPartParam
:
TypeAlias
=
Union
[
OpenAIChatCompletionContentPartParam
,
ChatCompletionContentPartAudioParam
,
OpenAIChatCompletionContentPartParam
,
ChatCompletionContentPartAudioParam
,
ChatCompletionContentPartInputAudioParam
,
ChatCompletionContentPartInputAudioParam
,
ChatCompletionContentPartVideoParam
,
ChatCompletionContentPartRefusalParam
,
ChatCompletionContentPartVideoParam
,
ChatCompletionContentPartRefusalParam
,
CustomChatCompletionContentPILImageParam
,
CustomChatCompletionContentPILImageParam
,
CustomChatCompletionContentSimpleImageParam
,
CustomChatCompletionContentSimpleImageParam
,
ChatCompletionContentPartImageEmbedsParam
,
ChatCompletionContentPartImageEmbedsParam
,
CustomChatCompletionContentSimpleAudioParam
,
CustomChatCompletionContentSimpleAudioParam
,
CustomChatCompletionContentSimpleVideoParam
,
str
,
CustomChatCompletionContentSimpleVideoParam
,
CustomThinkCompletionContentParam
]
str
,
CustomThinkCompletionContentParam
,
]
class
CustomChatCompletionMessageParam
(
TypedDict
,
total
=
False
):
class
CustomChatCompletionMessageParam
(
TypedDict
,
total
=
False
):
"""Enables custom roles in the Chat Completion API."""
"""Enables custom roles in the Chat Completion API."""
role
:
Required
[
str
]
role
:
Required
[
str
]
"""The role of the message's author."""
"""The role of the message's author."""
...
@@ -207,9 +217,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
...
@@ -207,9 +217,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
"""The tool calls generated by the model, such as function calls."""
"""The tool calls generated by the model, such as function calls."""
ChatCompletionMessageParam
=
Union
[
OpenAIChatCompletionMessageParam
,
ChatCompletionMessageParam
=
Union
[
OpenAIChatCompletionMessageParam
,
CustomChatCompletionMessageParam
,
CustomChatCompletionMessageParam
,
OpenAIHarmonyMessage
]
OpenAIHarmonyMessage
,
]
# TODO: Make fields ReadOnly once mypy supports it
# TODO: Make fields ReadOnly once mypy supports it
...
@@ -262,13 +274,13 @@ def _is_var_or_elems_access(
...
@@ -262,13 +274,13 @@ def _is_var_or_elems_access(
key
:
Optional
[
str
]
=
None
,
key
:
Optional
[
str
]
=
None
,
)
->
bool
:
)
->
bool
:
if
isinstance
(
node
,
jinja2
.
nodes
.
Filter
):
if
isinstance
(
node
,
jinja2
.
nodes
.
Filter
):
return
(
node
.
node
is
not
None
return
node
.
node
is
not
None
and
_is_var_or_elems_access
(
and
_is_var_or_elems_access
(
node
.
node
,
varname
,
key
)
)
node
.
node
,
varname
,
key
)
if
isinstance
(
node
,
jinja2
.
nodes
.
Test
):
if
isinstance
(
node
,
jinja2
.
nodes
.
Test
):
return
_is_var_or_elems_access
(
node
.
node
,
varname
,
key
)
return
_is_var_or_elems_access
(
node
.
node
,
varname
,
key
)
if
(
isinstance
(
node
,
jinja2
.
nodes
.
Getitem
)
if
isinstance
(
node
,
jinja2
.
nodes
.
Getitem
)
and
isinstance
(
and
isinstance
(
node
.
arg
,
jinja2
.
nodes
.
Slice
)
)
:
node
.
arg
,
jinja2
.
nodes
.
Slice
):
return
_is_var_or_elems_access
(
node
.
node
,
varname
,
key
)
return
_is_var_or_elems_access
(
node
.
node
,
varname
,
key
)
# yapf: disable
# yapf: disable
...
@@ -373,15 +385,18 @@ def resolve_mistral_chat_template(
...
@@ -373,15 +385,18 @@ def resolve_mistral_chat_template(
)
->
Optional
[
str
]:
)
->
Optional
[
str
]:
if
chat_template
is
not
None
:
if
chat_template
is
not
None
:
logger
.
warning_once
(
logger
.
warning_once
(
"'chat_template' cannot be overridden for mistral tokenizer."
)
"'chat_template' cannot be overridden for mistral tokenizer."
)
if
"add_generation_prompt"
in
kwargs
:
if
"add_generation_prompt"
in
kwargs
:
logger
.
warning_once
(
logger
.
warning_once
(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored."
)
"so it will be ignored."
)
if
"continue_final_message"
in
kwargs
:
if
"continue_final_message"
in
kwargs
:
logger
.
warning_once
(
logger
.
warning_once
(
"'continue_final_message' is not supported for mistral tokenizer, "
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored."
)
"so it will be ignored."
)
return
None
return
None
...
@@ -401,23 +416,35 @@ def resolve_hf_chat_template(
...
@@ -401,23 +416,35 @@ def resolve_hf_chat_template(
try
:
try
:
processor
=
cached_get_processor
(
processor
=
cached_get_processor
(
tokenizer
.
name_or_path
,
tokenizer
.
name_or_path
,
processor_cls
=
(
PreTrainedTokenizer
,
PreTrainedTokenizerFast
,
processor_cls
=
(
ProcessorMixin
),
PreTrainedTokenizer
,
PreTrainedTokenizerFast
,
ProcessorMixin
,
),
trust_remote_code
=
model_config
.
trust_remote_code
,
trust_remote_code
=
model_config
.
trust_remote_code
,
)
)
if
isinstance
(
processor
,
ProcessorMixin
)
and
\
if
(
hasattr
(
processor
,
'chat_template'
)
and
\
isinstance
(
processor
,
ProcessorMixin
)
processor
.
chat_template
is
not
None
:
and
hasattr
(
processor
,
"chat_template"
)
and
processor
.
chat_template
is
not
None
):
return
processor
.
chat_template
return
processor
.
chat_template
except
Exception
:
except
Exception
:
logger
.
debug
(
"Failed to load AutoProcessor chat template for %s"
,
tokenizer
.
name_or_path
,
exc_info
=
True
)
# noqa: E501
logger
.
debug
(
"Failed to load AutoProcessor chat template for %s"
,
tokenizer
.
name_or_path
,
exc_info
=
True
,
)
# noqa: E501
# 3rd priority: AutoTokenizer chat template
# 3rd priority: AutoTokenizer chat template
try
:
try
:
return
tokenizer
.
get_chat_template
(
chat_template
,
tools
=
tools
)
return
tokenizer
.
get_chat_template
(
chat_template
,
tools
=
tools
)
except
Exception
:
except
Exception
:
logger
.
debug
(
"Failed to load AutoTokenizer chat template for %s"
,
logger
.
debug
(
tokenizer
.
name_or_path
,
exc_info
=
True
)
"Failed to load AutoTokenizer chat template for %s"
,
tokenizer
.
name_or_path
,
exc_info
=
True
,
)
# 4th priority: Predefined fallbacks
# 4th priority: Predefined fallbacks
path
=
get_chat_template_fallback_path
(
path
=
get_chat_template_fallback_path
(
...
@@ -425,12 +452,16 @@ def resolve_hf_chat_template(
...
@@ -425,12 +452,16 @@ def resolve_hf_chat_template(
tokenizer_name_or_path
=
model_config
.
tokenizer
,
tokenizer_name_or_path
=
model_config
.
tokenizer
,
)
)
if
path
is
not
None
:
if
path
is
not
None
:
logger
.
info
(
"Loading chat template fallback for %s as there isn't one "
logger
.
info
(
"defined on HF Hub."
,
tokenizer
.
name_or_path
)
"Loading chat template fallback for %s as there isn't one "
"defined on HF Hub."
,
tokenizer
.
name_or_path
,
)
chat_template
=
load_chat_template
(
path
)
chat_template
=
load_chat_template
(
path
)
else
:
else
:
logger
.
debug
(
"There is no chat template fallback for %s"
,
logger
.
debug
(
tokenizer
.
name_or_path
)
"There is no chat template fallback for %s"
,
tokenizer
.
name_or_path
)
return
chat_template
return
chat_template
...
@@ -452,11 +483,17 @@ def _resolve_chat_template_content_format(
...
@@ -452,11 +483,17 @@ def _resolve_chat_template_content_format(
else
:
else
:
hf_chat_template
=
None
hf_chat_template
=
None
jinja_text
=
(
hf_chat_template
if
isinstance
(
hf_chat_template
,
str
)
jinja_text
=
(
else
load_chat_template
(
chat_template
,
is_literal
=
True
))
hf_chat_template
if
isinstance
(
hf_chat_template
,
str
)
else
load_chat_template
(
chat_template
,
is_literal
=
True
)
)
detected_format
=
(
"string"
if
jinja_text
is
None
else
detected_format
=
(
_detect_content_format
(
jinja_text
,
default
=
"string"
))
"string"
if
jinja_text
is
None
else
_detect_content_format
(
jinja_text
,
default
=
"string"
)
)
return
detected_format
return
detected_format
...
@@ -512,7 +549,6 @@ def resolve_chat_template_content_format(
...
@@ -512,7 +549,6 @@ def resolve_chat_template_content_format(
return
detected_format
return
detected_format
ModalityStr
=
Literal
[
"image"
,
"audio"
,
"video"
,
"image_embeds"
]
ModalityStr
=
Literal
[
"image"
,
"audio"
,
"video"
,
"image_embeds"
]
_T
=
TypeVar
(
"_T"
)
_T
=
TypeVar
(
"_T"
)
...
@@ -539,6 +575,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -539,6 +575,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
@
cached_property
@
cached_property
def
model_cls
(
self
)
->
type
[
SupportsMultiModal
]:
def
model_cls
(
self
)
->
type
[
SupportsMultiModal
]:
from
vllm.model_executor.model_loader
import
get_model_cls
from
vllm.model_executor.model_loader
import
get_model_cls
model_cls
=
get_model_cls
(
self
.
model_config
)
model_cls
=
get_model_cls
(
self
.
model_config
)
return
cast
(
type
[
SupportsMultiModal
],
model_cls
)
return
cast
(
type
[
SupportsMultiModal
],
model_cls
)
...
@@ -574,21 +611,22 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -574,21 +611,22 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
object
]):
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
object
]):
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
if
not
self
.
_items_by_modality
:
if
not
self
.
_items_by_modality
:
return
None
return
None
mm_inputs
=
{}
mm_inputs
=
{}
items_by_modality
=
dict
(
self
.
_items_by_modality
)
items_by_modality
=
dict
(
self
.
_items_by_modality
)
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
raise
ValueError
(
\
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
"Mixing raw image and embedding inputs is not allowed"
)
if
"image_embeds"
in
items_by_modality
:
if
"image_embeds"
in
items_by_modality
:
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
if
len
(
image_embeds_lst
)
>
1
:
if
len
(
image_embeds_lst
)
>
1
:
raise
ValueError
(
\
raise
ValueError
(
"Only one message can have {'type': 'image_embeds'}"
)
"Only one message can have {'type': 'image_embeds'}"
)
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
if
"image"
in
items_by_modality
:
if
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
...
@@ -603,7 +641,6 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
...
@@ -603,7 +641,6 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
class
AsyncMultiModalItemTracker
(
BaseMultiModalItemTracker
[
Awaitable
[
object
]]):
class
AsyncMultiModalItemTracker
(
BaseMultiModalItemTracker
[
Awaitable
[
object
]]):
async
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
async
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
if
not
self
.
_items_by_modality
:
if
not
self
.
_items_by_modality
:
return
None
return
None
...
@@ -615,13 +652,15 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
...
@@ -615,13 +652,15 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
raise
ValueError
(
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
"Mixing raw image and embedding inputs is not allowed"
)
if
"image_embeds"
in
items_by_modality
:
if
"image_embeds"
in
items_by_modality
:
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
if
len
(
image_embeds_lst
)
>
1
:
if
len
(
image_embeds_lst
)
>
1
:
raise
ValueError
(
raise
ValueError
(
"Only one message can have {'type': 'image_embeds'}"
)
"Only one message can have {'type': 'image_embeds'}"
)
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
if
"image"
in
items_by_modality
:
if
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
...
@@ -636,7 +675,6 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
...
@@ -636,7 +675,6 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
class
BaseMultiModalContentParser
(
ABC
):
class
BaseMultiModalContentParser
(
ABC
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -648,8 +686,9 @@ class BaseMultiModalContentParser(ABC):
...
@@ -648,8 +686,9 @@ class BaseMultiModalContentParser(ABC):
# }
# }
self
.
_placeholder_storage
:
dict
[
str
,
list
]
=
defaultdict
(
list
)
self
.
_placeholder_storage
:
dict
[
str
,
list
]
=
defaultdict
(
list
)
def
_add_placeholder
(
self
,
modality
:
ModalityStr
,
def
_add_placeholder
(
placeholder
:
Optional
[
str
]):
self
,
modality
:
ModalityStr
,
placeholder
:
Optional
[
str
]
):
mod_placeholder
=
MODALITY_PLACEHOLDERS_MAP
[
modality
]
mod_placeholder
=
MODALITY_PLACEHOLDERS_MAP
[
modality
]
if
placeholder
:
if
placeholder
:
self
.
_placeholder_storage
[
mod_placeholder
].
append
(
placeholder
)
self
.
_placeholder_storage
[
mod_placeholder
].
append
(
placeholder
)
...
@@ -662,8 +701,9 @@ class BaseMultiModalContentParser(ABC):
...
@@ -662,8 +701,9 @@ class BaseMultiModalContentParser(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
parse_image_embeds
(
self
,
def
parse_image_embeds
(
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]])
->
None
:
self
,
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]]
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
...
@@ -684,7 +724,6 @@ class BaseMultiModalContentParser(ABC):
...
@@ -684,7 +724,6 @@ class BaseMultiModalContentParser(ABC):
class
MultiModalContentParser
(
BaseMultiModalContentParser
):
class
MultiModalContentParser
(
BaseMultiModalContentParser
):
def
__init__
(
self
,
tracker
:
MultiModalItemTracker
)
->
None
:
def
__init__
(
self
,
tracker
:
MultiModalItemTracker
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -701,8 +740,9 @@ class MultiModalContentParser(BaseMultiModalContentParser):
...
@@ -701,8 +740,9 @@ class MultiModalContentParser(BaseMultiModalContentParser):
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
def
parse_image_embeds
(
self
,
def
parse_image_embeds
(
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]])
->
None
:
self
,
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]]
)
->
None
:
if
isinstance
(
image_embeds
,
dict
):
if
isinstance
(
image_embeds
,
dict
):
embeds
=
{
embeds
=
{
k
:
self
.
_connector
.
fetch_image_embedding
(
v
)
k
:
self
.
_connector
.
fetch_image_embedding
(
v
)
...
@@ -741,14 +781,13 @@ class MultiModalContentParser(BaseMultiModalContentParser):
...
@@ -741,14 +781,13 @@ class MultiModalContentParser(BaseMultiModalContentParser):
class
AsyncMultiModalContentParser
(
BaseMultiModalContentParser
):
class
AsyncMultiModalContentParser
(
BaseMultiModalContentParser
):
def
__init__
(
self
,
tracker
:
AsyncMultiModalItemTracker
)
->
None
:
def
__init__
(
self
,
tracker
:
AsyncMultiModalItemTracker
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
_tracker
=
tracker
self
.
_tracker
=
tracker
self
.
_connector
=
MediaConnector
(
self
.
_connector
=
MediaConnector
(
media_io_kwargs
=
self
.
_tracker
.
_model_config
.
media_io_kwargs
,
media_io_kwargs
=
self
.
_tracker
.
_model_config
.
media_io_kwargs
,
allowed_local_media_path
=
tracker
.
allowed_local_media_path
allowed_local_media_path
=
tracker
.
allowed_local_media_path
,
)
)
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
...
@@ -757,8 +796,9 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
...
@@ -757,8 +796,9 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_coro
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_coro
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
def
parse_image_embeds
(
self
,
def
parse_image_embeds
(
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]])
->
None
:
self
,
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]]
)
->
None
:
future
:
asyncio
.
Future
[
Union
[
str
,
dict
[
str
,
str
]]]
=
asyncio
.
Future
()
future
:
asyncio
.
Future
[
Union
[
str
,
dict
[
str
,
str
]]]
=
asyncio
.
Future
()
if
isinstance
(
image_embeds
,
dict
):
if
isinstance
(
image_embeds
,
dict
):
...
@@ -769,8 +809,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
...
@@ -769,8 +809,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
future
.
set_result
(
embeds
)
future
.
set_result
(
embeds
)
if
isinstance
(
image_embeds
,
str
):
if
isinstance
(
image_embeds
,
str
):
embedding
=
self
.
_connector
.
\
embedding
=
self
.
_connector
.
fetch_image_embedding
(
image_embeds
)
fetch_image_embedding
(
image_embeds
)
future
.
set_result
(
embedding
)
future
.
set_result
(
embedding
)
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
future
)
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
future
)
...
@@ -809,20 +848,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
...
@@ -809,20 +848,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
return
return
elif
isinstance
(
chat_template
,
Path
)
and
not
chat_template
.
exists
():
elif
isinstance
(
chat_template
,
Path
)
and
not
chat_template
.
exists
():
raise
FileNotFoundError
(
raise
FileNotFoundError
(
"the supplied chat template path doesn't exist"
)
"the supplied chat template path doesn't exist"
)
elif
isinstance
(
chat_template
,
str
):
elif
isinstance
(
chat_template
,
str
):
JINJA_CHARS
=
"{}
\n
"
JINJA_CHARS
=
"{}
\n
"
if
not
any
(
c
in
chat_template
if
(
for
c
in
JINJA_CHARS
)
and
not
Path
(
chat_template
).
exists
():
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
)
and
not
Path
(
chat_template
).
exists
()
):
raise
ValueError
(
raise
ValueError
(
f
"The supplied chat template string (
{
chat_template
}
) "
f
"The supplied chat template string (
{
chat_template
}
) "
f
"appears path-like, but doesn't exist!"
)
f
"appears path-like, but doesn't exist!"
)
else
:
else
:
raise
TypeError
(
raise
TypeError
(
f
"
{
type
(
chat_template
)
}
is not a valid chat template type"
)
f
"
{
type
(
chat_template
)
}
is not a valid chat template type"
)
def
_load_chat_template
(
def
_load_chat_template
(
...
@@ -835,8 +877,9 @@ def _load_chat_template(
...
@@ -835,8 +877,9 @@ def _load_chat_template(
if
is_literal
:
if
is_literal
:
if
isinstance
(
chat_template
,
Path
):
if
isinstance
(
chat_template
,
Path
):
raise
TypeError
(
"chat_template is expected to be read directly "
raise
TypeError
(
"from its value"
)
"chat_template is expected to be read directly from its value"
)
return
chat_template
return
chat_template
...
@@ -849,9 +892,11 @@ def _load_chat_template(
...
@@ -849,9 +892,11 @@ def _load_chat_template(
JINJA_CHARS
=
"{}
\n
"
JINJA_CHARS
=
"{}
\n
"
if
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
):
if
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
):
msg
=
(
f
"The supplied chat template (
{
chat_template
}
) "
msg
=
(
f
"The supplied chat template (
{
chat_template
}
) "
f
"looks like a file path, but it failed to be "
f
"looks like a file path, but it failed to be "
f
"opened. Reason:
{
e
}
"
)
f
"opened. Reason:
{
e
}
"
)
raise
ValueError
(
msg
)
from
e
raise
ValueError
(
msg
)
from
e
# If opening a file fails, set chat template to be args to
# If opening a file fails, set chat template to be args to
...
@@ -870,8 +915,9 @@ def load_chat_template(
...
@@ -870,8 +915,9 @@ def load_chat_template(
return
_cached_load_chat_template
(
chat_template
,
is_literal
=
is_literal
)
return
_cached_load_chat_template
(
chat_template
,
is_literal
=
is_literal
)
def
_get_interleaved_text_prompt
(
placeholder_storage
:
dict
[
str
,
list
],
def
_get_interleaved_text_prompt
(
texts
:
list
[
str
])
->
str
:
placeholder_storage
:
dict
[
str
,
list
],
texts
:
list
[
str
]
)
->
str
:
for
idx
,
elem
in
enumerate
(
texts
):
for
idx
,
elem
in
enumerate
(
texts
):
if
elem
in
placeholder_storage
:
if
elem
in
placeholder_storage
:
texts
[
idx
]
=
placeholder_storage
[
elem
].
pop
(
0
)
texts
[
idx
]
=
placeholder_storage
[
elem
].
pop
(
0
)
...
@@ -881,10 +927,11 @@ def _get_interleaved_text_prompt(placeholder_storage: dict[str, list],
...
@@ -881,10 +927,11 @@ def _get_interleaved_text_prompt(placeholder_storage: dict[str, list],
# TODO: Let user specify how to insert multimodal tokens into prompt
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
# (similar to chat template)
def
_get_full_multimodal_text_prompt
(
placeholder_storage
:
dict
[
str
,
list
],
def
_get_full_multimodal_text_prompt
(
placeholder_storage
:
dict
[
str
,
list
],
texts
:
list
[
str
],
texts
:
list
[
str
],
interleave_strings
:
bool
interleave_strings
:
bool
,
)
->
str
:
)
->
str
:
"""Combine multimodal prompts for a multimodal language model."""
"""Combine multimodal prompts for a multimodal language model."""
# flatten storage to make it looks like
# flatten storage to make it looks like
...
@@ -907,7 +954,6 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
...
@@ -907,7 +954,6 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
# Look through the text prompt to check for missing placeholders
# Look through the text prompt to check for missing placeholders
missing_placeholders
:
list
[
str
]
=
[]
missing_placeholders
:
list
[
str
]
=
[]
for
placeholder
in
placeholder_counts
:
for
placeholder
in
placeholder_counts
:
# For any existing placeholder in the text prompt, we leave it as is
# For any existing placeholder in the text prompt, we leave it as is
placeholder_counts
[
placeholder
]
-=
text_prompt
.
count
(
placeholder
)
placeholder_counts
[
placeholder
]
-=
text_prompt
.
count
(
placeholder
)
...
@@ -916,15 +962,18 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
...
@@ -916,15 +962,18 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
"Placeholder count is negative! "
"Placeholder count is negative! "
"Ensure that the 'interleave_strings' flag is disabled "
"Ensure that the 'interleave_strings' flag is disabled "
"(current value: %s) "
"(current value: %s) "
"when manually placing image placeholders."
,
interleave_strings
"when manually placing image placeholders."
,
interleave_strings
,
)
)
logger
.
debug
(
"Input prompt: %s"
,
text_prompt
)
logger
.
debug
(
"Input prompt: %s"
,
text_prompt
)
raise
ValueError
(
raise
ValueError
(
f
"Found more '
{
placeholder
}
' placeholders in input prompt than "
f
"Found more '
{
placeholder
}
' placeholders in input prompt than "
"actual multimodal data items."
)
"actual multimodal data items."
)
missing_placeholders
.
extend
([
placeholder
]
*
missing_placeholders
.
extend
(
placeholder_counts
[
placeholder
])
[
placeholder
]
*
placeholder_counts
[
placeholder
]
)
# NOTE: Default behaviour: we always add missing placeholders
# NOTE: Default behaviour: we always add missing placeholders
# at the front of the prompt, if interleave_strings=False
# at the front of the prompt, if interleave_strings=False
...
@@ -944,7 +993,8 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
...
@@ -944,7 +993,8 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser
=
TypeAdapter
(
ChatCompletionContentPartVideoParam
).
validate_python
_VideoParser
=
TypeAdapter
(
ChatCompletionContentPartVideoParam
).
validate_python
_ResponsesInputImageParser
=
TypeAdapter
(
_ResponsesInputImageParser
=
TypeAdapter
(
ResponseInputImageParam
).
validate_python
ResponseInputImageParam
).
validate_python
_ContentPart
:
TypeAlias
=
Union
[
str
,
dict
[
str
,
str
],
InputAudio
,
PILImage
]
_ContentPart
:
TypeAlias
=
Union
[
str
,
dict
[
str
,
str
],
InputAudio
,
PILImage
]
# Define a mapping from part types to their corresponding parsing functions.
# Define a mapping from part types to their corresponding parsing functions.
...
@@ -952,32 +1002,35 @@ MM_PARSER_MAP: dict[
...
@@ -952,32 +1002,35 @@ MM_PARSER_MAP: dict[
str
,
str
,
Callable
[[
ChatCompletionContentPartParam
],
_ContentPart
],
Callable
[[
ChatCompletionContentPartParam
],
_ContentPart
],
]
=
{
]
=
{
"text"
:
"text"
:
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
None
),
lambda
part
:
_T
ext
Parser
(
part
).
get
(
"t
ext
"
,
None
),
"thinking"
:
lambda
part
:
_T
hink
Parser
(
part
).
get
(
"t
hinking
"
,
None
),
"
thinking"
:
"
input_text"
:
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
None
),
lambda
part
:
_ThinkParser
(
part
).
get
(
"thinking"
,
None
),
"input_image"
:
lambda
part
:
_ResponsesInputImageParser
(
part
).
get
(
"i
nput_text"
:
"i
mage_url"
,
None
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
None
),
),
"i
nput_image"
:
"i
mage_url"
:
lambda
part
:
_ImageParser
(
part
)
lambda
part
:
_ResponsesInputImageParser
(
part
)
.
get
(
"image_url"
,
None
),
.
get
(
"image_url"
,
{})
"image_url"
:
.
get
(
"url"
,
None
),
lambda
part
:
_ImageParser
(
part
).
get
(
"image_url"
,
{}).
get
(
"url"
,
None
),
"image_embeds"
:
lambda
part
:
_Image
Embeds
Parser
(
part
).
get
(
"image_embeds"
:
"image_embeds"
,
None
lambda
part
:
_ImageEmbedsParser
(
part
).
get
(
"image_embeds"
,
None
),
),
"image_pil"
:
lambda
part
:
_PILImageParser
(
part
).
get
(
"image_pil"
,
None
),
"image_pil"
:
lambda
part
:
_PILImageParser
(
part
).
get
(
"image_pil"
,
None
),
"audio_url"
:
"audio_url"
:
lambda
part
:
_AudioParser
(
part
)
lambda
part
:
_AudioParser
(
part
).
get
(
"audio_url"
,
{}).
get
(
"url"
,
None
),
.
get
(
"audio_url"
,
{})
"input_audio"
:
.
get
(
"url"
,
None
),
lambda
part
:
_InputAudioParser
(
part
).
get
(
"input_audio"
,
None
),
"input_audio"
:
lambda
part
:
_InputAudioParser
(
part
).
get
(
"refusal"
:
"input_audio"
,
None
lambda
part
:
_RefusalParser
(
part
).
get
(
"refusal"
,
None
),
),
"video_url"
:
"refusal"
:
lambda
part
:
_RefusalParser
(
part
).
get
(
"refusal"
,
None
),
lambda
part
:
_VideoParser
(
part
).
get
(
"video_url"
,
{}).
get
(
"url"
,
None
),
"video_url"
:
lambda
part
:
_VideoParser
(
part
)
.
get
(
"video_url"
,
{})
.
get
(
"url"
,
None
),
}
}
def
_parse_chat_message_content_mm_part
(
def
_parse_chat_message_content_mm_part
(
part
:
ChatCompletionContentPartParam
)
->
tuple
[
str
,
_ContentPart
]:
part
:
ChatCompletionContentPartParam
,
)
->
tuple
[
str
,
_ContentPart
]:
"""
"""
Parses a given multi-modal content part based on its type.
Parses a given multi-modal content part based on its type.
...
@@ -993,7 +1046,8 @@ def _parse_chat_message_content_mm_part(
...
@@ -993,7 +1046,8 @@ def _parse_chat_message_content_mm_part(
ValueError: If the 'type' field is missing and no direct URL is found.
ValueError: If the 'type' field is missing and no direct URL is found.
"""
"""
assert
isinstance
(
assert
isinstance
(
part
,
dict
)
# This is needed to avoid mypy errors: part.get() from str
part
,
dict
)
# This is needed to avoid mypy errors: part.get() from str
part_type
=
part
.
get
(
"type"
,
None
)
part_type
=
part
.
get
(
"type"
,
None
)
if
isinstance
(
part_type
,
str
)
and
part_type
in
MM_PARSER_MAP
:
if
isinstance
(
part_type
,
str
)
and
part_type
in
MM_PARSER_MAP
:
...
@@ -1002,8 +1056,10 @@ def _parse_chat_message_content_mm_part(
...
@@ -1002,8 +1056,10 @@ def _parse_chat_message_content_mm_part(
# Special case for 'image_url.detail'
# Special case for 'image_url.detail'
# We only support 'auto', which is the default
# We only support 'auto', which is the default
if
part_type
==
"image_url"
and
part
.
get
(
"detail"
,
"auto"
)
!=
"auto"
:
if
part_type
==
"image_url"
and
part
.
get
(
"detail"
,
"auto"
)
!=
"auto"
:
logger
.
warning
(
"'image_url.detail' is currently not supported "
logger
.
warning
(
"and will be ignored."
)
"'image_url.detail' is currently not supported "
"and will be ignored."
)
return
part_type
,
content
return
part_type
,
content
...
@@ -1011,19 +1067,22 @@ def _parse_chat_message_content_mm_part(
...
@@ -1011,19 +1067,22 @@ def _parse_chat_message_content_mm_part(
# 'type' is required field by pydantic
# 'type' is required field by pydantic
if
part_type
is
None
:
if
part_type
is
None
:
if
part
.
get
(
"image_url"
)
is
not
None
:
if
part
.
get
(
"image_url"
)
is
not
None
:
image_params
=
cast
(
CustomChatCompletionContentSimpleImageParam
,
image_params
=
cast
(
part
)
CustomChatCompletionContentSimpleImageParam
,
part
)
return
"image_url"
,
image_params
.
get
(
"image_url"
,
""
)
return
"image_url"
,
image_params
.
get
(
"image_url"
,
""
)
if
part
.
get
(
"audio_url"
)
is
not
None
:
if
part
.
get
(
"audio_url"
)
is
not
None
:
audio_params
=
cast
(
CustomChatCompletionContentSimpleAudioParam
,
audio_params
=
cast
(
part
)
CustomChatCompletionContentSimpleAudioParam
,
part
)
return
"audio_url"
,
audio_params
.
get
(
"audio_url"
,
""
)
return
"audio_url"
,
audio_params
.
get
(
"audio_url"
,
""
)
if
part
.
get
(
"input_audio"
)
is
not
None
:
if
part
.
get
(
"input_audio"
)
is
not
None
:
input_audio_params
=
cast
(
dict
[
str
,
str
],
part
)
input_audio_params
=
cast
(
dict
[
str
,
str
],
part
)
return
"input_audio"
,
input_audio_params
return
"input_audio"
,
input_audio_params
if
part
.
get
(
"video_url"
)
is
not
None
:
if
part
.
get
(
"video_url"
)
is
not
None
:
video_params
=
cast
(
CustomChatCompletionContentSimpleVideoParam
,
video_params
=
cast
(
part
)
CustomChatCompletionContentSimpleVideoParam
,
part
)
return
"video_url"
,
video_params
.
get
(
"video_url"
,
""
)
return
"video_url"
,
video_params
.
get
(
"video_url"
,
""
)
# Raise an error if no 'type' or direct URL is found.
# Raise an error if no 'type' or direct URL is found.
raise
ValueError
(
"Missing 'type' field in multimodal part."
)
raise
ValueError
(
"Missing 'type' field in multimodal part."
)
...
@@ -1033,9 +1092,16 @@ def _parse_chat_message_content_mm_part(
...
@@ -1033,9 +1092,16 @@ def _parse_chat_message_content_mm_part(
return
part_type
,
"unknown part_type content"
return
part_type
,
"unknown part_type content"
VALID_MESSAGE_CONTENT_MM_PART_TYPES
=
(
"text"
,
"refusal"
,
"image_url"
,
VALID_MESSAGE_CONTENT_MM_PART_TYPES
=
(
"image_embeds"
,
"image_pil"
,
"text"
,
"audio_url"
,
"input_audio"
,
"video_url"
)
"refusal"
,
"image_url"
,
"image_embeds"
,
"image_pil"
,
"audio_url"
,
"input_audio"
,
"video_url"
,
)
def
_parse_chat_message_content_parts
(
def
_parse_chat_message_content_parts
(
...
@@ -1055,21 +1121,20 @@ def _parse_chat_message_content_parts(
...
@@ -1055,21 +1121,20 @@ def _parse_chat_message_content_parts(
part
,
part
,
mm_parser
,
mm_parser
,
wrap_dicts
=
wrap_dicts
,
wrap_dicts
=
wrap_dicts
,
interleave_strings
=
interleave_strings
interleave_strings
=
interleave_strings
,
)
)
if
parse_res
:
if
parse_res
:
content
.
append
(
parse_res
)
content
.
append
(
parse_res
)
if
wrap_dicts
:
if
wrap_dicts
:
# Parsing wraps images and texts as interleaved dictionaries
# Parsing wraps images and texts as interleaved dictionaries
return
[
ConversationMessage
(
role
=
role
,
return
[
ConversationMessage
(
role
=
role
,
content
=
content
)]
# type: ignore
content
=
content
)]
# type: ignore
texts
=
cast
(
list
[
str
],
content
)
texts
=
cast
(
list
[
str
],
content
)
mm_placeholder_storage
=
mm_parser
.
mm_placeholder_storage
()
mm_placeholder_storage
=
mm_parser
.
mm_placeholder_storage
()
if
mm_placeholder_storage
:
if
mm_placeholder_storage
:
text_prompt
=
_get_full_multimodal_text_prompt
(
mm_placeholder_storage
,
text_prompt
=
_get_full_multimodal_text_prompt
(
texts
,
mm_placeholder_storage
,
texts
,
interleave_strings
interleave_strings
)
)
else
:
else
:
text_prompt
=
"
\n
"
.
join
(
texts
)
text_prompt
=
"
\n
"
.
join
(
texts
)
...
@@ -1099,13 +1164,16 @@ def _parse_chat_message_content_part(
...
@@ -1099,13 +1164,16 @@ def _parse_chat_message_content_part(
if
part_type
in
VALID_MESSAGE_CONTENT_MM_PART_TYPES
and
content
is
None
:
if
part_type
in
VALID_MESSAGE_CONTENT_MM_PART_TYPES
and
content
is
None
:
logger
.
warning
(
logger
.
warning
(
"Skipping multimodal part '%s' (type: '%s') "
"Skipping multimodal part '%s' (type: '%s') "
"with empty / unparsable content."
,
part
,
part_type
)
"with empty / unparsable content."
,
part
,
part_type
,
)
return
None
return
None
if
part_type
in
(
"text"
,
"input_text"
,
"refusal"
,
"thinking"
):
if
part_type
in
(
"text"
,
"input_text"
,
"refusal"
,
"thinking"
):
str_content
=
cast
(
str
,
content
)
str_content
=
cast
(
str
,
content
)
if
wrap_dicts
:
if
wrap_dicts
:
return
{
'
type
'
:
'
text
'
,
'
text
'
:
str_content
}
return
{
"
type
"
:
"
text
"
,
"
text
"
:
str_content
}
else
:
else
:
return
str_content
return
str_content
...
@@ -1137,9 +1205,13 @@ def _parse_chat_message_content_part(
...
@@ -1137,9 +1205,13 @@ def _parse_chat_message_content_part(
else
:
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
return
{
'type'
:
modality
}
if
wrap_dicts
else
(
return
(
{
"type"
:
modality
}
if
wrap_dicts
else
(
MODALITY_PLACEHOLDERS_MAP
[
modality
]
if
interleave_strings
else
None
MODALITY_PLACEHOLDERS_MAP
[
modality
]
if
interleave_strings
else
None
)
)
)
# No need to validate using Pydantic again
# No need to validate using Pydantic again
...
@@ -1171,14 +1243,16 @@ def _parse_chat_message_content(
...
@@ -1171,14 +1243,16 @@ def _parse_chat_message_content(
)
)
for
result_msg
in
result
:
for
result_msg
in
result
:
if
role
==
'
assistant
'
:
if
role
==
"
assistant
"
:
parsed_msg
=
_AssistantParser
(
message
)
parsed_msg
=
_AssistantParser
(
message
)
# The 'tool_calls' is not None check ensures compatibility.
# The 'tool_calls' is not None check ensures compatibility.
# It's needed only if downstream code doesn't strictly
# It's needed only if downstream code doesn't strictly
# follow the OpenAI spec.
# follow the OpenAI spec.
if
(
"tool_calls"
in
parsed_msg
if
(
and
parsed_msg
[
"tool_calls"
]
is
not
None
):
"tool_calls"
in
parsed_msg
and
parsed_msg
[
"tool_calls"
]
is
not
None
):
result_msg
[
"tool_calls"
]
=
list
(
parsed_msg
[
"tool_calls"
])
result_msg
[
"tool_calls"
]
=
list
(
parsed_msg
[
"tool_calls"
])
elif
role
==
"tool"
:
elif
role
==
"tool"
:
parsed_msg
=
_ToolParser
(
message
)
parsed_msg
=
_ToolParser
(
message
)
...
@@ -1198,12 +1272,15 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
...
@@ -1198,12 +1272,15 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
# so, for messages that have tool_calls, parse the string (which we get
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
# from openAI format) to dict
for
message
in
messages
:
for
message
in
messages
:
if
(
message
[
"role"
]
==
"assistant"
and
"tool_calls"
in
message
if
(
and
isinstance
(
message
[
"tool_calls"
],
list
)):
message
[
"role"
]
==
"assistant"
and
"tool_calls"
in
message
and
isinstance
(
message
[
"tool_calls"
],
list
)
):
for
item
in
message
[
"tool_calls"
]:
for
item
in
message
[
"tool_calls"
]:
item
[
"function"
][
"arguments"
]
=
json
.
loads
(
item
[
"function"
][
"arguments"
]
=
json
.
loads
(
item
[
"function"
][
"arguments"
])
item
[
"function"
][
"arguments"
]
)
def
parse_chat_messages
(
def
parse_chat_messages
(
...
@@ -1224,7 +1301,7 @@ def parse_chat_messages(
...
@@ -1224,7 +1301,7 @@ def parse_chat_messages(
content_format
==
"string"
content_format
==
"string"
and
model_config
.
multimodal_config
is
not
None
and
model_config
.
multimodal_config
is
not
None
and
model_config
.
multimodal_config
.
interleave_mm_strings
and
model_config
.
multimodal_config
.
interleave_mm_strings
)
)
,
)
)
conversation
.
extend
(
sub_messages
)
conversation
.
extend
(
sub_messages
)
...
@@ -1252,7 +1329,7 @@ def parse_chat_messages_futures(
...
@@ -1252,7 +1329,7 @@ def parse_chat_messages_futures(
content_format
==
"string"
content_format
==
"string"
and
model_config
.
multimodal_config
is
not
None
and
model_config
.
multimodal_config
is
not
None
and
model_config
.
multimodal_config
.
interleave_mm_strings
and
model_config
.
multimodal_config
.
interleave_mm_strings
)
)
,
)
)
conversation
.
extend
(
sub_messages
)
conversation
.
extend
(
sub_messages
)
...
@@ -1283,10 +1360,10 @@ def apply_hf_chat_template(
...
@@ -1283,10 +1360,10 @@ def apply_hf_chat_template(
raise
ValueError
(
raise
ValueError
(
"As of transformers v4.44, default chat template is no longer "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
)
"does not define one."
)
try
:
try
:
return
tokenizer
.
apply_chat_template
(
return
tokenizer
.
apply_chat_template
(
conversation
=
conversation
,
# type: ignore[arg-type]
conversation
=
conversation
,
# type: ignore[arg-type]
tools
=
tools
,
# type: ignore[arg-type]
tools
=
tools
,
# type: ignore[arg-type]
...
@@ -1298,13 +1375,14 @@ def apply_hf_chat_template(
...
@@ -1298,13 +1375,14 @@ def apply_hf_chat_template(
# External library exceptions can sometimes occur despite the framework's
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
# internal exception management capabilities.
except
Exception
as
e
:
except
Exception
as
e
:
# Log and report any library-related exceptions for further
# Log and report any library-related exceptions for further
# investigation.
# investigation.
logger
.
exception
(
logger
.
exception
(
"An error occurred in `transformers` while applying chat template"
)
"An error occurred in `transformers` while applying chat template"
)
raise
ValueError
(
str
(
e
))
from
e
raise
ValueError
(
str
(
e
))
from
e
def
apply_mistral_chat_template
(
def
apply_mistral_chat_template
(
tokenizer
:
MistralTokenizer
,
tokenizer
:
MistralTokenizer
,
messages
:
list
[
ChatCompletionMessageParam
],
messages
:
list
[
ChatCompletionMessageParam
],
...
@@ -1337,26 +1415,26 @@ def apply_mistral_chat_template(
...
@@ -1337,26 +1415,26 @@ def apply_mistral_chat_template(
# External library exceptions can sometimes occur despite the framework's
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
# internal exception management capabilities.
except
Exception
as
e
:
except
Exception
as
e
:
# Log and report any library-related exceptions for further
# Log and report any library-related exceptions for further
# investigation.
# investigation.
logger
.
exception
(
logger
.
exception
(
"An error occurred in `mistral_common` while applying chat "
"An error occurred in `mistral_common` while applying chat
template
"
"template"
)
)
raise
ValueError
(
str
(
e
))
from
e
raise
ValueError
(
str
(
e
))
from
e
def
get_history_tool_calls_cnt
(
conversation
:
list
[
ConversationMessage
]):
def
get_history_tool_calls_cnt
(
conversation
:
list
[
ConversationMessage
]):
idx
=
0
idx
=
0
for
msg
in
conversation
:
for
msg
in
conversation
:
if
msg
[
'
role
'
]
==
'
assistant
'
:
if
msg
[
"
role
"
]
==
"
assistant
"
:
tool_calls
=
msg
.
get
(
'
tool_calls
'
)
tool_calls
=
msg
.
get
(
"
tool_calls
"
)
idx
+=
len
(
list
(
tool_calls
))
if
tool_calls
is
not
None
else
0
# noqa
idx
+=
len
(
list
(
tool_calls
))
if
tool_calls
is
not
None
else
0
# noqa
return
idx
return
idx
def
make_tool_call_id
(
id_type
:
str
=
'random'
,
func_name
=
None
,
idx
=
None
):
if
id_type
==
'kimi_k2'
:
def
make_tool_call_id
(
id_type
:
str
=
"random"
,
func_name
=
None
,
idx
=
None
):
return
f
'functions.
{
func_name
}
:
{
idx
}
'
if
id_type
==
"kimi_k2"
:
return
f
"functions.
{
func_name
}
:
{
idx
}
"
else
:
else
:
# by default return random
# by default return random
return
f
"chatcmpl-tool-
{
random_uuid
()
}
"
return
f
"chatcmpl-tool-
{
random_uuid
()
}
"
vllm/entrypoints/openai/serving_engine.py
View file @
f399182e
...
@@ -82,16 +82,26 @@ from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
...
@@ -82,16 +82,26 @@ from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
CompletionLikeRequest
=
Union
[
CompletionRequest
,
DetokenizeRequest
,
CompletionLikeRequest
=
Union
[
EmbeddingCompletionRequest
,
RerankRequest
,
CompletionRequest
,
ClassificationRequest
,
ScoreRequest
,
DetokenizeRequest
,
TokenizeCompletionRequest
]
EmbeddingCompletionRequest
,
RerankRequest
,
ClassificationRequest
,
ScoreRequest
,
TokenizeCompletionRequest
,
]
ChatLikeRequest
=
Union
[
ChatCompletionRequest
,
EmbeddingChatRequest
,
ChatLikeRequest
=
Union
[
ChatCompletionRequest
,
EmbeddingChatRequest
,
TokenizeChatRequest
]
TokenizeChatRequest
]
SpeechToTextRequest
=
Union
[
TranscriptionRequest
,
TranslationRequest
]
SpeechToTextRequest
=
Union
[
TranscriptionRequest
,
TranslationRequest
]
AnyRequest
=
Union
[
CompletionLikeRequest
,
ChatLikeRequest
,
SpeechToTextRequest
,
AnyRequest
=
Union
[
ResponsesRequest
,
IOProcessorRequest
]
CompletionLikeRequest
,
ChatLikeRequest
,
SpeechToTextRequest
,
ResponsesRequest
,
IOProcessorRequest
,
]
AnyResponse
=
Union
[
AnyResponse
=
Union
[
CompletionResponse
,
CompletionResponse
,
...
@@ -135,6 +145,7 @@ class RequestProcessingMixin(BaseModel):
...
@@ -135,6 +145,7 @@ class RequestProcessingMixin(BaseModel):
Mixin for request processing,
Mixin for request processing,
handling prompt preparation and engine input.
handling prompt preparation and engine input.
"""
"""
request_prompts
:
Optional
[
Sequence
[
RequestPrompt
]]
=
[]
request_prompts
:
Optional
[
Sequence
[
RequestPrompt
]]
=
[]
engine_prompts
:
Optional
[
Union
[
list
[
EngineTokensPrompt
],
engine_prompts
:
Optional
[
Union
[
list
[
EngineTokensPrompt
],
list
[
EngineEmbedsPrompt
]]]
=
[]
list
[
EngineEmbedsPrompt
]]]
=
[]
...
@@ -147,6 +158,7 @@ class ResponseGenerationMixin(BaseModel):
...
@@ -147,6 +158,7 @@ class ResponseGenerationMixin(BaseModel):
Mixin for response generation,
Mixin for response generation,
managing result generators and final batch results.
managing result generators and final batch results.
"""
"""
result_generator
:
Optional
[
AsyncGenerator
[
tuple
[
int
,
Union
[
result_generator
:
Optional
[
AsyncGenerator
[
tuple
[
int
,
Union
[
RequestOutput
,
PoolingRequestOutput
]],
None
]]
=
None
RequestOutput
,
PoolingRequestOutput
]],
None
]]
=
None
final_res_batch
:
list
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]
=
Field
(
final_res_batch
:
list
[
Union
[
RequestOutput
,
PoolingRequestOutput
]]
=
Field
(
...
@@ -155,8 +167,12 @@ class ResponseGenerationMixin(BaseModel):
...
@@ -155,8 +167,12 @@ class ResponseGenerationMixin(BaseModel):
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
class
ServeContext
(
RequestProcessingMixin
,
ResponseGenerationMixin
,
BaseModel
,
class
ServeContext
(
Generic
[
RequestT
]):
RequestProcessingMixin
,
ResponseGenerationMixin
,
BaseModel
,
Generic
[
RequestT
],
):
# Shared across all requests
# Shared across all requests
request
:
RequestT
request
:
RequestT
raw_request
:
Optional
[
Request
]
=
None
raw_request
:
Optional
[
Request
]
=
None
...
@@ -298,8 +314,8 @@ class OpenAIServing:
...
@@ -298,8 +314,8 @@ class OpenAIServing:
truncate_prompt_tokens
=
getattr
(
ctx
.
request
,
"truncate_prompt_tokens"
,
truncate_prompt_tokens
=
getattr
(
ctx
.
request
,
"truncate_prompt_tokens"
,
None
)
None
)
if
truncate_prompt_tokens
is
not
None
and
\
if
(
truncate_prompt_tokens
is
not
None
truncate_prompt_tokens
>
self
.
max_model_len
:
and
truncate_prompt_tokens
>
self
.
max_model_len
)
:
return
self
.
create_error_response
(
return
self
.
create_error_response
(
"truncate_prompt_tokens value is "
"truncate_prompt_tokens value is "
"greater than max_model_len."
"greater than max_model_len."
...
@@ -344,10 +360,12 @@ class OpenAIServing:
...
@@ -344,10 +360,12 @@ class OpenAIServing:
return
self
.
create_error_response
(
return
self
.
create_error_response
(
"Request prompts not available"
)
"Request prompts not available"
)
self
.
_log_inputs
(
request_id_item
,
self
.
_log_inputs
(
request_id_item
,
ctx
.
request_prompts
[
i
],
ctx
.
request_prompts
[
i
],
params
=
pooling_params
,
params
=
pooling_params
,
lora_request
=
ctx
.
lora_request
)
lora_request
=
ctx
.
lora_request
,
)
# Mypy has an existing bug related to inferring the variance of
# Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`:
# TypedDicts with `builtins.enumerate`:
...
@@ -413,7 +431,8 @@ class OpenAIServing:
...
@@ -413,7 +431,8 @@ class OpenAIServing:
self
,
self
,
message
:
str
,
message
:
str
,
err_type
:
str
=
"BadRequestError"
,
err_type
:
str
=
"BadRequestError"
,
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
)
->
ErrorResponse
:
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
,
)
->
ErrorResponse
:
if
self
.
log_error_stack
:
if
self
.
log_error_stack
:
exc_type
,
_
,
_
=
sys
.
exc_info
()
exc_type
,
_
,
_
=
sys
.
exc_info
()
if
exc_type
is
not
None
:
if
exc_type
is
not
None
:
...
@@ -427,7 +446,8 @@ class OpenAIServing:
...
@@ -427,7 +446,8 @@ class OpenAIServing:
self
,
self
,
message
:
str
,
message
:
str
,
err_type
:
str
=
"BadRequestError"
,
err_type
:
str
=
"BadRequestError"
,
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
)
->
str
:
status_code
:
HTTPStatus
=
HTTPStatus
.
BAD_REQUEST
,
)
->
str
:
json_str
=
json
.
dumps
(
json_str
=
json
.
dumps
(
self
.
create_error_response
(
message
=
message
,
self
.
create_error_response
(
message
=
message
,
err_type
=
err_type
,
err_type
=
err_type
,
...
@@ -438,25 +458,25 @@ class OpenAIServing:
...
@@ -438,25 +458,25 @@ class OpenAIServing:
self
,
self
,
request
:
AnyRequest
,
request
:
AnyRequest
,
)
->
Optional
[
ErrorResponse
]:
)
->
Optional
[
ErrorResponse
]:
error_response
=
None
error_response
=
None
if
self
.
_is_model_supported
(
request
.
model
):
if
self
.
_is_model_supported
(
request
.
model
):
return
None
return
None
if
request
.
model
in
self
.
models
.
lora_requests
:
if
request
.
model
in
self
.
models
.
lora_requests
:
return
None
return
None
if
envs
.
VLLM_ALLOW_RUNTIME_LORA_UPDATING
and
request
.
model
and
(
if
(
envs
.
VLLM_ALLOW_RUNTIME_LORA_UPDATING
and
request
.
model
and
load_result
:
=
await
self
.
models
.
resolve_lora
(
request
.
model
)):
(
load_result
:
=
await
self
.
models
.
resolve_lora
(
request
.
model
))
)
:
if
isinstance
(
load_result
,
LoRARequest
):
if
isinstance
(
load_result
,
LoRARequest
):
return
None
return
None
if
isinstance
(
load_result
,
ErrorResponse
)
and
\
if
(
isinstance
(
load_result
,
ErrorResponse
)
and
load_result
.
error
.
code
==
HTTPStatus
.
BAD_REQUEST
.
value
:
load_result
.
error
.
code
==
HTTPStatus
.
BAD_REQUEST
.
value
)
:
error_response
=
load_result
error_response
=
load_result
return
error_response
or
self
.
create_error_response
(
return
error_response
or
self
.
create_error_response
(
message
=
f
"The model `
{
request
.
model
}
` does not exist."
,
message
=
f
"The model `
{
request
.
model
}
` does not exist."
,
err_type
=
"NotFoundError"
,
err_type
=
"NotFoundError"
,
status_code
=
HTTPStatus
.
NOT_FOUND
)
status_code
=
HTTPStatus
.
NOT_FOUND
,
)
def
_get_active_default_mm_loras
(
def
_get_active_default_mm_loras
(
self
,
request
:
AnyRequest
)
->
Optional
[
LoRARequest
]:
self
,
request
:
AnyRequest
)
->
Optional
[
LoRARequest
]:
...
@@ -487,7 +507,6 @@ class OpenAIServing:
...
@@ -487,7 +507,6 @@ class OpenAIServing:
request
:
AnyRequest
,
request
:
AnyRequest
,
supports_default_mm_loras
:
bool
=
False
,
supports_default_mm_loras
:
bool
=
False
,
)
->
Optional
[
LoRARequest
]:
)
->
Optional
[
LoRARequest
]:
if
request
.
model
in
self
.
models
.
lora_requests
:
if
request
.
model
in
self
.
models
.
lora_requests
:
return
self
.
models
.
lora_requests
[
request
.
model
]
return
self
.
models
.
lora_requests
[
request
.
model
]
...
@@ -548,13 +567,15 @@ class OpenAIServing:
...
@@ -548,13 +567,15 @@ class OpenAIServing:
prompt
,
prompt
,
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
truncation
=
True
,
truncation
=
True
,
max_length
=
self
.
max_model_len
)
max_length
=
self
.
max_model_len
,
)
else
:
else
:
encoded
=
await
async_tokenizer
(
encoded
=
await
async_tokenizer
(
prompt
,
prompt
,
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
truncation
=
True
,
truncation
=
True
,
max_length
=
truncate_prompt_tokens
)
max_length
=
truncate_prompt_tokens
,
)
input_ids
=
encoded
.
input_ids
input_ids
=
encoded
.
input_ids
input_text
=
prompt
input_text
=
prompt
...
@@ -595,16 +616,22 @@ class OpenAIServing:
...
@@ -595,16 +616,22 @@ class OpenAIServing:
# Note: EmbeddingRequest, ClassificationRequest,
# Note: EmbeddingRequest, ClassificationRequest,
# and ScoreRequest doesn't have max_tokens
# and ScoreRequest doesn't have max_tokens
if
isinstance
(
request
,
if
isinstance
(
(
EmbeddingChatRequest
,
EmbeddingCompletionRequest
,
request
,
ScoreRequest
,
RerankRequest
,
ClassificationRequest
)):
(
EmbeddingChatRequest
,
EmbeddingCompletionRequest
,
ScoreRequest
,
RerankRequest
,
ClassificationRequest
,
),
):
# Note: input length can be up to the entire model context length
# Note: input length can be up to the entire model context length
# since these requests don't generate tokens.
# since these requests don't generate tokens.
if
token_num
>
self
.
max_model_len
:
if
token_num
>
self
.
max_model_len
:
operations
:
dict
[
type
[
AnyRequest
],
str
]
=
{
operations
:
dict
[
type
[
AnyRequest
],
str
]
=
{
ScoreRequest
:
"score"
,
ScoreRequest
:
"score"
,
ClassificationRequest
:
"classification"
ClassificationRequest
:
"classification"
,
}
}
operation
=
operations
.
get
(
type
(
request
),
operation
=
operations
.
get
(
type
(
request
),
"embedding generation"
)
"embedding generation"
)
...
@@ -618,8 +645,11 @@ class OpenAIServing:
...
@@ -618,8 +645,11 @@ class OpenAIServing:
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
# and does not require model context length validation
if
isinstance
(
request
,
(
TokenizeCompletionRequest
,
TokenizeChatRequest
,
if
isinstance
(
DetokenizeRequest
)):
request
,
(
TokenizeCompletionRequest
,
TokenizeChatRequest
,
DetokenizeRequest
),
):
return
TextTokensPrompt
(
prompt
=
input_text
,
return
TextTokensPrompt
(
prompt
=
input_text
,
prompt_token_ids
=
input_ids
)
prompt_token_ids
=
input_ids
)
...
@@ -639,8 +669,8 @@ class OpenAIServing:
...
@@ -639,8 +669,8 @@ class OpenAIServing:
f
"
{
token_num
}
input tokens. Please reduce the length of "
f
"
{
token_num
}
input tokens. Please reduce the length of "
"the input messages."
)
"the input messages."
)
if
max_tokens
is
not
None
and
\
if
(
max_tokens
is
not
None
token_num
+
max_tokens
>
self
.
max_model_len
:
and
token_num
+
max_tokens
>
self
.
max_model_len
)
:
raise
ValueError
(
raise
ValueError
(
"'max_tokens' or 'max_completion_tokens' is too large: "
"'max_tokens' or 'max_completion_tokens' is too large: "
f
"
{
max_tokens
}
. This model's maximum context length is "
f
"
{
max_tokens
}
. This model's maximum context length is "
...
@@ -745,13 +775,14 @@ class OpenAIServing:
...
@@ -745,13 +775,14 @@ class OpenAIServing:
tasks
=
[]
tasks
=
[]
for
prompt_input
in
batch_inputs
:
for
prompt_input
in
batch_inputs
:
if
prompt_input
[
"is_tokens"
]
is
False
:
if
prompt_input
[
"is_tokens"
]
is
False
:
assert
tokenizer
is
not
None
,
\
assert
tokenizer
is
not
None
,
(
"Tokenizer is required for text prompts"
"Tokenizer is required for text prompts"
)
task
=
self
.
_normalize_prompt_text_to_input
(
task
=
self
.
_normalize_prompt_text_to_input
(
request
,
request
,
prompt_input
[
"content"
],
prompt_input
[
"content"
],
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
add_special_tokens
=
add_special_tokens
)
add_special_tokens
=
add_special_tokens
,
)
else
:
else
:
task
=
self
.
_normalize_prompt_tokens_to_input
(
task
=
self
.
_normalize_prompt_tokens_to_input
(
request
,
prompt_input
[
"content"
],
tokenizer
=
tokenizer
)
request
,
prompt_input
[
"content"
],
tokenizer
=
tokenizer
)
...
@@ -766,9 +797,14 @@ class OpenAIServing:
...
@@ -766,9 +797,14 @@ class OpenAIServing:
@
overload
@
overload
async
def
_preprocess_completion
(
async
def
_preprocess_completion
(
self
,
self
,
request
:
Union
[
DetokenizeRequest
,
EmbeddingCompletionRequest
,
request
:
Union
[
RerankRequest
,
ClassificationRequest
,
ScoreRequest
,
DetokenizeRequest
,
TokenizeCompletionRequest
],
EmbeddingCompletionRequest
,
RerankRequest
,
ClassificationRequest
,
ScoreRequest
,
TokenizeCompletionRequest
,
],
tokenizer
:
Optional
[
AnyTokenizer
],
tokenizer
:
Optional
[
AnyTokenizer
],
input_or_inputs
:
Union
[
str
,
list
[
str
],
list
[
int
],
list
[
list
[
int
]]],
input_or_inputs
:
Union
[
str
,
list
[
str
],
list
[
int
],
list
[
list
[
int
]]],
add_special_tokens
:
bool
=
...,
add_special_tokens
:
bool
=
...,
...
@@ -783,8 +819,10 @@ class OpenAIServing:
...
@@ -783,8 +819,10 @@ class OpenAIServing:
input_or_inputs
:
Optional
[
Union
[
str
,
list
[
str
],
list
[
int
],
input_or_inputs
:
Optional
[
Union
[
str
,
list
[
str
],
list
[
int
],
list
[
list
[
int
]]]],
list
[
list
[
int
]]]],
add_special_tokens
:
bool
=
...,
add_special_tokens
:
bool
=
...,
)
->
tuple
[
list
[
Union
[
TextTokensPrompt
,
EmbedsPrompt
]],
list
[
Union
[
)
->
tuple
[
EngineTokensPrompt
,
EngineEmbedsPrompt
]]]:
list
[
Union
[
TextTokensPrompt
,
EmbedsPrompt
]],
list
[
Union
[
EngineTokensPrompt
,
EngineEmbedsPrompt
]],
]:
...
...
async
def
_preprocess_completion
(
async
def
_preprocess_completion
(
...
@@ -794,17 +832,23 @@ class OpenAIServing:
...
@@ -794,17 +832,23 @@ class OpenAIServing:
input_or_inputs
:
Optional
[
Union
[
str
,
list
[
str
],
list
[
int
],
input_or_inputs
:
Optional
[
Union
[
str
,
list
[
str
],
list
[
int
],
list
[
list
[
int
]]]],
list
[
list
[
int
]]]],
add_special_tokens
:
bool
=
True
,
add_special_tokens
:
bool
=
True
,
)
->
tuple
[
Union
[
list
[
TextTokensPrompt
],
list
[
Union
[
)
->
tuple
[
TextTokensPrompt
,
EmbedsPrompt
]]],
Union
[
Union
[
list
[
TextTokensPrompt
],
list
[
Union
[
TextTokensPrompt
,
list
[
EngineTokensPrompt
],
list
[
Union
[
EngineTokensPrompt
,
EmbedsPrompt
]]],
EngineEmbedsPrompt
]]]]:
Union
[
if
not
isinstance
(
request
,
list
[
EngineTokensPrompt
],
CompletionRequest
)
and
input_or_inputs
is
None
:
list
[
Union
[
EngineTokensPrompt
,
EngineEmbedsPrompt
]],
],
]:
if
(
not
isinstance
(
request
,
CompletionRequest
)
and
input_or_inputs
is
None
):
raise
ValueError
(
raise
ValueError
(
"Prompt embeds with non-completion requests is not"
"Prompt embeds with non-completion requests is not"
" currently supported."
)
" currently supported."
)
(
request_prompts_text
,
request_prompts_embeds
(
request_prompts_text
,
request_prompts_embeds
,
)
=
await
self
.
_tokenize_prompt_input_or_inputs_async
(
)
=
await
self
.
_tokenize_prompt_input_or_inputs_async
(
request
,
request
,
tokenizer
,
tokenizer
,
...
@@ -817,9 +861,9 @@ class OpenAIServing:
...
@@ -817,9 +861,9 @@ class OpenAIServing:
prompt_token_ids
=
request_prompt_text
[
"prompt_token_ids"
])
prompt_token_ids
=
request_prompt_text
[
"prompt_token_ids"
])
for
request_prompt_text
in
request_prompts_text
for
request_prompt_text
in
request_prompts_text
]
]
cache_salt
=
request
.
cache_salt
if
(
cache_salt
=
(
request
.
cache_salt
if
hasattr
(
request
,
"cache_salt"
)
(
hasattr
(
request
,
"cache_salt"
)
and
request
.
cache_salt
is
not
None
)
else
None
and
request
.
cache_salt
is
not
None
)
else
None
)
if
cache_salt
:
if
cache_salt
:
for
prompt_text
in
engine_prompts_text
:
for
prompt_text
in
engine_prompts_text
:
prompt_text
[
"cache_salt"
]
=
cache_salt
prompt_text
[
"cache_salt"
]
=
cache_salt
...
@@ -831,8 +875,8 @@ class OpenAIServing:
...
@@ -831,8 +875,8 @@ class OpenAIServing:
# non-completion requests and if we don't add the overload here,
# non-completion requests and if we don't add the overload here,
# everywhere this function is used outside of serving_completion will
# everywhere this function is used outside of serving_completion will
# need logic asserting that only text prompts are in the request.
# need logic asserting that only text prompts are in the request.
if
not
isinstance
(
request
,
if
(
not
isinstance
(
request
,
CompletionRequest
)
CompletionRequest
)
and
input_or_inputs
is
not
None
:
and
input_or_inputs
is
not
None
)
:
return
request_prompts_text
,
engine_prompts_text
return
request_prompts_text
,
engine_prompts_text
engine_prompts_embeds
=
[
engine_prompts_embeds
=
[
...
@@ -862,8 +906,11 @@ class OpenAIServing:
...
@@ -862,8 +906,11 @@ class OpenAIServing:
chat_template_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
chat_template_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tool_parser
:
Optional
[
Callable
[[
AnyTokenizer
],
ToolParser
]]
=
None
,
tool_parser
:
Optional
[
Callable
[[
AnyTokenizer
],
ToolParser
]]
=
None
,
add_special_tokens
:
bool
=
False
,
add_special_tokens
:
bool
=
False
,
)
->
tuple
[
list
[
ConversationMessage
],
Sequence
[
RequestPrompt
],
)
->
tuple
[
list
[
EngineTokensPrompt
]]:
list
[
ConversationMessage
],
Sequence
[
RequestPrompt
],
list
[
EngineTokensPrompt
],
]:
model_config
=
self
.
model_config
model_config
=
self
.
model_config
resolved_content_format
=
resolve_chat_template_content_format
(
resolved_content_format
=
resolve_chat_template_content_format
(
...
@@ -925,8 +972,8 @@ class OpenAIServing:
...
@@ -925,8 +972,8 @@ class OpenAIServing:
if
tokenizer
is
None
:
if
tokenizer
is
None
:
assert
isinstance
(
request_prompt
,
str
),
(
assert
isinstance
(
request_prompt
,
str
),
(
"Prompt has to be a string"
,
\
"Prompt has to be a string"
,
"when the tokenizer is not initialised"
"when the tokenizer is not initialised"
,
)
)
prompt_inputs
=
TextTokensPrompt
(
prompt
=
request_prompt
,
prompt_inputs
=
TextTokensPrompt
(
prompt
=
request_prompt
,
prompt_token_ids
=
[
1
])
prompt_token_ids
=
[
1
])
...
@@ -943,7 +990,8 @@ class OpenAIServing:
...
@@ -943,7 +990,8 @@ class OpenAIServing:
"Prompt has to be either a string or a list of token ids"
)
"Prompt has to be either a string or a list of token ids"
)
prompt_inputs
=
TextTokensPrompt
(
prompt_inputs
=
TextTokensPrompt
(
prompt
=
tokenizer
.
decode
(
request_prompt
),
prompt
=
tokenizer
.
decode
(
request_prompt
),
prompt_token_ids
=
request_prompt
)
prompt_token_ids
=
request_prompt
,
)
engine_prompt
=
EngineTokensPrompt
(
engine_prompt
=
EngineTokensPrompt
(
prompt_token_ids
=
prompt_inputs
[
"prompt_token_ids"
])
prompt_token_ids
=
prompt_inputs
[
"prompt_token_ids"
])
...
@@ -1007,22 +1055,23 @@ class OpenAIServing:
...
@@ -1007,22 +1055,23 @@ class OpenAIServing:
prompt_token_ids
=
prompt_token_ids
)
prompt_token_ids
=
prompt_token_ids
)
request_prompt
=
prompt_token_ids
request_prompt
=
prompt_token_ids
# Update the sampling params.
# Update the sampling params.
sampling_params
.
max_tokens
=
(
self
.
max_model_len
-
sampling_params
.
max_tokens
=
self
.
max_model_len
-
len
(
len
(
prompt_token_ids
)
)
prompt_token_ids
)
# OPTIMIZATION
# OPTIMIZATION
priority
=
orig_priority
-
1
priority
=
orig_priority
-
1
@
staticmethod
@
staticmethod
def
_load_prompt_embeds
(
def
_load_prompt_embeds
(
prompt_embeds
:
Optional
[
Union
[
bytes
,
list
[
bytes
]]],
prompt_embeds
:
Optional
[
Union
[
bytes
,
list
[
bytes
]]],
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
,
)
->
list
[
EmbedsPrompt
]:
)
->
list
[
EmbedsPrompt
]:
def
_load_and_validate_embed
(
embed
:
bytes
)
->
EmbedsPrompt
:
def
_load_and_validate_embed
(
embed
:
bytes
)
->
EmbedsPrompt
:
tensor
=
torch
.
load
(
io
.
BytesIO
(
tensor
=
torch
.
load
(
pybase64
.
b64decode
(
embed
,
validate
=
True
)),
io
.
BytesIO
(
pybase64
.
b64decode
(
embed
,
validate
=
True
)),
weights_only
=
True
,
weights_only
=
True
,
map_location
=
torch
.
device
(
"cpu"
))
map_location
=
torch
.
device
(
"cpu"
),
)
assert
isinstance
(
tensor
,
torch
.
Tensor
)
and
tensor
.
dtype
in
(
assert
isinstance
(
tensor
,
torch
.
Tensor
)
and
tensor
.
dtype
in
(
torch
.
float32
,
torch
.
float32
,
torch
.
bfloat16
,
torch
.
bfloat16
,
...
@@ -1061,7 +1110,7 @@ class OpenAIServing:
...
@@ -1061,7 +1110,7 @@ class OpenAIServing:
prompt
=
inputs
prompt
=
inputs
elif
isinstance
(
inputs
,
list
):
elif
isinstance
(
inputs
,
list
):
prompt_token_ids
=
inputs
prompt_token_ids
=
inputs
elif
'
prompt_embeds
'
in
inputs
:
elif
"
prompt_embeds
"
in
inputs
:
prompt_embeds
=
inputs
.
get
(
"prompt_embeds"
)
prompt_embeds
=
inputs
.
get
(
"prompt_embeds"
)
else
:
else
:
prompt
=
inputs
[
"prompt"
]
prompt
=
inputs
[
"prompt"
]
...
@@ -1101,10 +1150,12 @@ class OpenAIServing:
...
@@ -1101,10 +1150,12 @@ class OpenAIServing:
return
raw_request
.
headers
.
get
(
"X-Request-Id"
,
default
)
return
raw_request
.
headers
.
get
(
"X-Request-Id"
,
default
)
@
staticmethod
@
staticmethod
def
_get_decoded_token
(
logprob
:
Logprob
,
def
_get_decoded_token
(
logprob
:
Logprob
,
token_id
:
int
,
token_id
:
int
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
return_as_token_id
:
bool
=
False
)
->
str
:
return_as_token_id
:
bool
=
False
,
)
->
str
:
if
return_as_token_id
:
if
return_as_token_id
:
return
f
"token_id:
{
token_id
}
"
return
f
"token_id:
{
token_id
}
"
...
@@ -1117,9 +1168,11 @@ class OpenAIServing:
...
@@ -1117,9 +1168,11 @@ class OpenAIServing:
return
True
return
True
return
self
.
models
.
is_base_model
(
model_name
)
return
self
.
models
.
is_base_model
(
model_name
)
def
_get_model_name
(
self
,
def
_get_model_name
(
self
,
model_name
:
Optional
[
str
]
=
None
,
model_name
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
str
:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
str
:
if
lora_request
:
if
lora_request
:
return
lora_request
.
lora_name
return
lora_request
.
lora_name
if
not
model_name
:
if
not
model_name
:
...
@@ -1129,7 +1182,7 @@ class OpenAIServing:
...
@@ -1129,7 +1182,7 @@ class OpenAIServing:
def
clamp_prompt_logprobs
(
def
clamp_prompt_logprobs
(
prompt_logprobs
:
Union
[
PromptLogprobs
,
prompt_logprobs
:
Union
[
PromptLogprobs
,
None
])
->
Union
[
PromptLogprobs
,
None
]:
None
]
,
)
->
Union
[
PromptLogprobs
,
None
]:
if
prompt_logprobs
is
None
:
if
prompt_logprobs
is
None
:
return
prompt_logprobs
return
prompt_logprobs
...
@@ -1137,6 +1190,6 @@ def clamp_prompt_logprobs(
...
@@ -1137,6 +1190,6 @@ def clamp_prompt_logprobs(
if
logprob_dict
is
None
:
if
logprob_dict
is
None
:
continue
continue
for
logprob_values
in
logprob_dict
.
values
():
for
logprob_values
in
logprob_dict
.
values
():
if
logprob_values
.
logprob
==
float
(
'
-inf
'
):
if
logprob_values
.
logprob
==
float
(
"
-inf
"
):
logprob_values
.
logprob
=
-
9999.0
logprob_values
.
logprob
=
-
9999.0
return
prompt_logprobs
return
prompt_logprobs
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment