Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
88ad9ec6
Unverified
Commit
88ad9ec6
authored
Apr 29, 2025
by
Cyrus Leung
Committed by
GitHub
Apr 29, 2025
Browse files
[Frontend] Support `chat_template_kwargs` in `LLM.chat` (#17356)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
40896bdf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
24 deletions
+106
-24
tests/entrypoints/llm/test_chat.py
tests/entrypoints/llm/test_chat.py
+93
-16
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+13
-8
No files found.
tests/entrypoints/llm/test_chat.py
View file @
88ad9ec6
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
weakref
import
pytest
import
pytest
from
vllm
import
LLM
from
vllm
import
LLM
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
..openai.test_vision
import
TEST_IMAGE_URLS
from
..openai.test_vision
import
TEST_IMAGE_URLS
def
test_chat
():
@
pytest
.
fixture
(
scope
=
"function"
)
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
)
def
text_llm
():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
enforce_eager
=
True
,
seed
=
0
)
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
del
llm
cleanup_dist_env_and_memory
()
def
test_chat
(
text_llm
):
prompt1
=
"Explain the concept of entropy."
prompt1
=
"Explain the concept of entropy."
messages
=
[
messages
=
[
{
{
...
@@ -21,13 +37,11 @@ def test_chat():
...
@@ -21,13 +37,11 @@ def test_chat():
"content"
:
prompt1
"content"
:
prompt1
},
},
]
]
outputs
=
llm
.
chat
(
messages
)
outputs
=
text_
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
1
assert
len
(
outputs
)
==
1
def
test_multi_chat
():
def
test_multi_chat
(
text_llm
):
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
)
prompt1
=
"Explain the concept of entropy."
prompt1
=
"Explain the concept of entropy."
prompt2
=
"Explain what among us is."
prompt2
=
"Explain what among us is."
...
@@ -55,13 +69,14 @@ def test_multi_chat():
...
@@ -55,13 +69,14 @@ def test_multi_chat():
messages
=
[
conversation1
,
conversation2
]
messages
=
[
conversation1
,
conversation2
]
outputs
=
llm
.
chat
(
messages
)
outputs
=
text_
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
2
assert
len
(
outputs
)
==
2
@
pytest
.
mark
.
parametrize
(
"image_urls"
,
@
pytest
.
fixture
(
scope
=
"function"
)
[[
TEST_IMAGE_URLS
[
0
],
TEST_IMAGE_URLS
[
1
]]])
def
vision_llm
():
def
test_chat_multi_image
(
image_urls
:
list
[
str
]):
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm
=
LLM
(
llm
=
LLM
(
model
=
"microsoft/Phi-3.5-vision-instruct"
,
model
=
"microsoft/Phi-3.5-vision-instruct"
,
max_model_len
=
4096
,
max_model_len
=
4096
,
...
@@ -69,8 +84,20 @@ def test_chat_multi_image(image_urls: list[str]):
...
@@ -69,8 +84,20 @@ def test_chat_multi_image(image_urls: list[str]):
enforce_eager
=
True
,
enforce_eager
=
True
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
limit_mm_per_prompt
=
{
"image"
:
2
},
limit_mm_per_prompt
=
{
"image"
:
2
},
seed
=
0
,
)
)
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
del
llm
cleanup_dist_env_and_memory
()
@
pytest
.
mark
.
parametrize
(
"image_urls"
,
[[
TEST_IMAGE_URLS
[
0
],
TEST_IMAGE_URLS
[
1
]]])
def
test_chat_multi_image
(
vision_llm
,
image_urls
:
list
[
str
]):
messages
=
[{
messages
=
[{
"role"
:
"role"
:
"user"
,
"user"
,
...
@@ -87,16 +114,15 @@ def test_chat_multi_image(image_urls: list[str]):
...
@@ -87,16 +114,15 @@ def test_chat_multi_image(image_urls: list[str]):
},
},
],
],
}]
}]
outputs
=
llm
.
chat
(
messages
)
outputs
=
vision_
llm
.
chat
(
messages
)
assert
len
(
outputs
)
>=
0
assert
len
(
outputs
)
>=
0
def
test_llm_chat_tokenization_no_double_bos
():
def
test_llm_chat_tokenization_no_double_bos
(
text_llm
):
"""
"""
LLM.chat() should not add special tokens when using chat templates.
LLM.chat() should not add special tokens when using chat templates.
Check we get a single BOS token for llama chat.
Check we get a single BOS token for llama chat.
"""
"""
llm
=
LLM
(
model
=
"meta-llama/Llama-3.2-1B-Instruct"
,
enforce_eager
=
True
)
messages
=
[
messages
=
[
{
{
"role"
:
"system"
,
"role"
:
"system"
,
...
@@ -107,13 +133,64 @@ def test_llm_chat_tokenization_no_double_bos():
...
@@ -107,13 +133,64 @@ def test_llm_chat_tokenization_no_double_bos():
"content"
:
"Hello!"
"content"
:
"Hello!"
},
},
]
]
outputs
=
llm
.
chat
(
messages
)
outputs
=
text_
llm
.
chat
(
messages
)
assert
len
(
outputs
)
==
1
assert
len
(
outputs
)
==
1
prompt_token_ids
=
getattr
(
outputs
[
0
],
"prompt_token_ids"
,
None
)
prompt_token_ids
=
outputs
[
0
].
prompt_token_ids
assert
prompt_token_ids
is
not
None
assert
prompt_token_ids
is
not
None
bos_token
=
llm
.
get_tokenizer
().
bos_token_id
bos_token
=
text_
llm
.
get_tokenizer
().
bos_token_id
# Ensure we have a single BOS
# Ensure we have a single BOS
assert
prompt_token_ids
[
0
]
==
bos_token
assert
prompt_token_ids
[
0
]
==
bos_token
assert
prompt_token_ids
[
1
]
!=
bos_token
,
"Double BOS"
assert
prompt_token_ids
[
1
]
!=
bos_token
,
"Double BOS"
@
pytest
.
fixture
(
scope
=
"function"
)
def
thinking_llm
():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm
=
LLM
(
model
=
"Qwen/Qwen3-0.6B"
,
max_model_len
=
4096
,
enforce_eager
=
True
,
seed
=
0
,
)
with
llm
.
deprecate_legacy_api
():
yield
weakref
.
proxy
(
llm
)
del
llm
cleanup_dist_env_and_memory
()
@
pytest
.
mark
.
parametrize
(
"enable_thinking"
,
[
True
,
False
])
def
test_chat_extra_kwargs
(
thinking_llm
,
enable_thinking
):
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
"What is 1+1?"
},
]
outputs
=
thinking_llm
.
chat
(
messages
,
chat_template_kwargs
=
{
"enable_thinking"
:
enable_thinking
},
)
assert
len
(
outputs
)
==
1
prompt_token_ids
=
outputs
[
0
].
prompt_token_ids
assert
prompt_token_ids
is
not
None
think_id
=
thinking_llm
.
get_tokenizer
().
get_vocab
()[
"<think>"
]
if
enable_thinking
:
assert
think_id
not
in
prompt_token_ids
else
:
# The chat template includes dummy thinking process
assert
think_id
in
prompt_token_ids
vllm/entrypoints/llm.py
View file @
88ad9ec6
...
@@ -656,6 +656,7 @@ class LLM:
...
@@ -656,6 +656,7 @@ class LLM:
add_generation_prompt
:
bool
=
True
,
add_generation_prompt
:
bool
=
True
,
continue_final_message
:
bool
=
False
,
continue_final_message
:
bool
=
False
,
tools
:
Optional
[
list
[
dict
[
str
,
Any
]]]
=
None
,
tools
:
Optional
[
list
[
dict
[
str
,
Any
]]]
=
None
,
chat_template_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
list
[
RequestOutput
]:
)
->
list
[
RequestOutput
]:
"""
"""
...
@@ -696,6 +697,8 @@ class LLM:
...
@@ -696,6 +697,8 @@ class LLM:
continue_final_message: If True, continues the final message in
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be
the conversation instead of starting a new one. Cannot be
``True`` if ``add_generation_prompt`` is also ``True``.
``True`` if ``add_generation_prompt`` is also ``True``.
chat_template_kwargs: Additional kwargs to pass to the chat
template.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
chat request. Only used for offline requests.
...
@@ -726,6 +729,14 @@ class LLM:
...
@@ -726,6 +729,14 @@ class LLM:
trust_remote_code
=
model_config
.
trust_remote_code
,
trust_remote_code
=
model_config
.
trust_remote_code
,
)
)
_chat_template_kwargs
:
dict
[
str
,
Any
]
=
dict
(
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
tools
=
tools
,
)
_chat_template_kwargs
.
update
(
chat_template_kwargs
or
{})
prompts
:
list
[
Union
[
TokensPrompt
,
TextPrompt
]]
=
[]
prompts
:
list
[
Union
[
TokensPrompt
,
TextPrompt
]]
=
[]
for
msgs
in
list_of_messages
:
for
msgs
in
list_of_messages
:
...
@@ -743,20 +754,14 @@ class LLM:
...
@@ -743,20 +754,14 @@ class LLM:
prompt_token_ids
=
apply_mistral_chat_template
(
prompt_token_ids
=
apply_mistral_chat_template
(
tokenizer
,
tokenizer
,
messages
=
msgs
,
messages
=
msgs
,
chat_template
=
chat_template
,
**
_chat_template_kwargs
,
tools
=
tools
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
)
)
else
:
else
:
prompt_str
=
apply_hf_chat_template
(
prompt_str
=
apply_hf_chat_template
(
tokenizer
,
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
,
trust_remote_code
=
model_config
.
trust_remote_code
,
conversation
=
conversation
,
conversation
=
conversation
,
chat_template
=
chat_template
,
**
_chat_template_kwargs
,
tools
=
tools
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
)
)
# Special tokens are already included in chat templates so
# Special tokens are already included in chat templates so
# should not be added by the tokenizer in this case.
# should not be added by the tokenizer in this case.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment