Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6c9ba48f
Unverified
Commit
6c9ba48f
authored
Sep 29, 2024
by
danieljannai21
Committed by
GitHub
Sep 29, 2024
Browse files
[Frontend] Added support for HF's new `continue_final_message` parameter (#8942)
parent
1fb9c1b0
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
105 additions
and
31 deletions
+105
-31
tests/entrypoints/openai/test_chat_template.py
tests/entrypoints/openai/test_chat_template.py
+23
-7
tests/entrypoints/openai/test_tokenization.py
tests/entrypoints/openai/test_tokenization.py
+34
-22
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+8
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+6
-0
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+28
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+4
-2
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+2
-0
No files found.
tests/entrypoints/openai/test_chat_template.py
View file @
6c9ba48f
...
@@ -12,7 +12,7 @@ assert chatml_jinja_path.exists()
...
@@ -12,7 +12,7 @@ assert chatml_jinja_path.exists()
# Define models, templates, and their corresponding expected outputs
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT
=
[
MODEL_TEMPLATE_GENERATON_OUTPUT
=
[
(
"facebook/opt-125m"
,
chatml_jinja_path
,
True
,
"""<|im_start|>user
(
"facebook/opt-125m"
,
chatml_jinja_path
,
True
,
False
,
"""<|im_start|>user
Hello<|im_end|>
Hello<|im_end|>
<|im_start|>assistant
<|im_start|>assistant
Hi there!<|im_end|>
Hi there!<|im_end|>
...
@@ -20,12 +20,20 @@ Hi there!<|im_end|>
...
@@ -20,12 +20,20 @@ Hi there!<|im_end|>
What is the capital of<|im_end|>
What is the capital of<|im_end|>
<|im_start|>assistant
<|im_start|>assistant
"""
),
"""
),
(
"facebook/opt-125m"
,
chatml_jinja_path
,
False
,
"""<|im_start|>user
(
"facebook/opt-125m"
,
chatml_jinja_path
,
False
,
False
,
"""<|im_start|>user
Hello<|im_end|>
Hello<|im_end|>
<|im_start|>assistant
<|im_start|>assistant
Hi there!<|im_end|>
Hi there!<|im_end|>
<|im_start|>user
<|im_start|>user
What is the capital of"""
)
What is the capital of"""
),
(
"facebook/opt-125m"
,
chatml_jinja_path
,
False
,
True
,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
The capital of"""
),
]
]
TEST_MESSAGES
=
[
TEST_MESSAGES
=
[
...
@@ -42,6 +50,10 @@ TEST_MESSAGES = [
...
@@ -42,6 +50,10 @@ TEST_MESSAGES = [
'content'
:
'What is the capital of'
'content'
:
'What is the capital of'
},
},
]
]
ASSISTANT_MESSAGE_TO_CONTINUE
=
{
'role'
:
'assistant'
,
'content'
:
'The capital of'
}
def
test_load_chat_template
():
def
test_load_chat_template
():
...
@@ -73,10 +85,10 @@ def test_no_load_chat_template_literallike():
...
@@ -73,10 +85,10 @@ def test_no_load_chat_template_literallike():
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"model,template,add_generation_prompt,expected_output"
,
"model,template,add_generation_prompt,
continue_final_message,
expected_output"
,
MODEL_TEMPLATE_GENERATON_OUTPUT
)
MODEL_TEMPLATE_GENERATON_OUTPUT
)
def
test_get_gen_prompt
(
model
,
template
,
add_generation_prompt
,
def
test_get_gen_prompt
(
model
,
template
,
add_generation_prompt
,
expected_output
):
continue_final_message
,
expected_output
):
# Initialize the tokenizer
# Initialize the tokenizer
tokenizer
=
get_tokenizer
(
tokenizer_name
=
model
)
tokenizer
=
get_tokenizer
(
tokenizer_name
=
model
)
template_content
=
load_chat_template
(
chat_template
=
template
)
template_content
=
load_chat_template
(
chat_template
=
template
)
...
@@ -84,8 +96,11 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
...
@@ -84,8 +96,11 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Create a mock request object using keyword arguments
# Create a mock request object using keyword arguments
mock_request
=
ChatCompletionRequest
(
mock_request
=
ChatCompletionRequest
(
model
=
model
,
model
=
model
,
messages
=
TEST_MESSAGES
,
messages
=
TEST_MESSAGES
+
[
ASSISTANT_MESSAGE_TO_CONTINUE
]
add_generation_prompt
=
add_generation_prompt
)
if
continue_final_message
else
TEST_MESSAGES
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
)
# Call the function and get the result
# Call the function and get the result
result
=
apply_hf_chat_template
(
result
=
apply_hf_chat_template
(
...
@@ -93,6 +108,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
...
@@ -93,6 +108,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
conversation
=
mock_request
.
messages
,
conversation
=
mock_request
.
messages
,
chat_template
=
mock_request
.
chat_template
or
template_content
,
chat_template
=
mock_request
.
chat_template
or
template_content
,
add_generation_prompt
=
mock_request
.
add_generation_prompt
,
add_generation_prompt
=
mock_request
.
add_generation_prompt
,
continue_final_message
=
mock_request
.
continue_final_message
,
)
)
# Test assertion
# Test assertion
...
...
tests/entrypoints/openai/test_tokenization.py
View file @
6c9ba48f
...
@@ -104,17 +104,29 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
...
@@ -104,17 +104,29 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
"Can I ask a question? vllm1"
"content"
:
"Can I ask a question? vllm1"
}]
}]
for
continue_final
in
[
False
,
True
]:
if
add_generation
and
continue_final
:
continue
if
continue_final
:
conversation
.
append
({
"role"
:
"assistant"
,
"content"
:
"Sure,"
})
prompt
=
tokenizer
.
apply_chat_template
(
prompt
=
tokenizer
.
apply_chat_template
(
add_generation_prompt
=
add_generation
,
add_generation_prompt
=
add_generation
,
continue_final_message
=
continue_final
,
conversation
=
conversation
,
conversation
=
conversation
,
tokenize
=
False
)
tokenize
=
False
)
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
add_special
)
tokens
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
add_special
)
response
=
requests
.
post
(
base_url
+
"/tokenize"
,
response
=
requests
.
post
(
base_url
+
"/tokenize"
,
json
=
{
json
=
{
"add_generation_prompt"
:
"add_generation_prompt"
:
add_generation
,
add_generation
,
"continue_final_message"
:
continue_final
,
"add_special_tokens"
:
add_special
,
"add_special_tokens"
:
add_special
,
"messages"
:
conversation
,
"messages"
:
conversation
,
"model"
:
model_name
"model"
:
model_name
...
...
vllm/entrypoints/chat_utils.py
View file @
6c9ba48f
...
@@ -542,6 +542,14 @@ def apply_mistral_chat_template(
...
@@ -542,6 +542,14 @@ def apply_mistral_chat_template(
if
chat_template
is
not
None
:
if
chat_template
is
not
None
:
logger
.
warning
(
logger
.
warning
(
"'chat_template' cannot be overridden for mistral tokenizer."
)
"'chat_template' cannot be overridden for mistral tokenizer."
)
if
"add_generation_prompt"
in
kwargs
:
logger
.
warning
(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored."
)
if
"continue_final_message"
in
kwargs
:
logger
.
warning
(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored."
)
return
tokenizer
.
apply_chat_template
(
return
tokenizer
.
apply_chat_template
(
messages
=
messages
,
messages
=
messages
,
...
...
vllm/entrypoints/llm.py
View file @
6c9ba48f
...
@@ -501,6 +501,7 @@ class LLM:
...
@@ -501,6 +501,7 @@ class LLM:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
add_generation_prompt
:
bool
=
True
,
add_generation_prompt
:
bool
=
True
,
continue_final_message
:
bool
=
False
,
tools
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
tools
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
)
->
List
[
RequestOutput
]:
)
->
List
[
RequestOutput
]:
"""
"""
...
@@ -528,6 +529,9 @@ class LLM:
...
@@ -528,6 +529,9 @@ class LLM:
If not provided, the model's default chat template will be used.
If not provided, the model's default chat template will be used.
add_generation_prompt: If True, adds a generation template
add_generation_prompt: If True, adds a generation template
to each message.
to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be `True`
if `add_generation_prompt` is also `True`.
Returns:
Returns:
A list of ``RequestOutput`` objects containing the generated
A list of ``RequestOutput`` objects containing the generated
...
@@ -559,6 +563,7 @@ class LLM:
...
@@ -559,6 +563,7 @@ class LLM:
messages
=
msgs
,
messages
=
msgs
,
chat_template
=
chat_template
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
tools
=
tools
,
tools
=
tools
,
)
)
else
:
else
:
...
@@ -567,6 +572,7 @@ class LLM:
...
@@ -567,6 +572,7 @@ class LLM:
conversation
=
conversation
,
conversation
=
conversation
,
chat_template
=
chat_template
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
tools
=
tools
,
tools
=
tools
,
)
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
6c9ba48f
...
@@ -211,6 +211,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -211,6 +211,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"This is a parameter used by chat template in tokenizer config of the "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
"model."
),
)
)
continue_final_message
:
bool
=
Field
(
default
=
False
,
description
=
(
"If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
"This allows you to
\"
prefill
\"
part of the model's response for it. "
"Cannot be used at the same time as `add_generation_prompt`."
),
)
add_special_tokens
:
bool
=
Field
(
add_special_tokens
:
bool
=
Field
(
default
=
False
,
default
=
False
,
description
=
(
description
=
(
...
@@ -431,6 +440,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -431,6 +440,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
" of the specified `tools`"
)
" of the specified `tools`"
)
return
data
return
data
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
check_generation_prompt
(
cls
,
data
):
if
data
.
get
(
"continue_final_message"
)
and
data
.
get
(
"add_generation_prompt"
):
raise
ValueError
(
"Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True."
)
return
data
class
CompletionRequest
(
OpenAIBaseModel
):
class
CompletionRequest
(
OpenAIBaseModel
):
# Ordered by official OpenAI API documentation
# Ordered by official OpenAI API documentation
...
@@ -862,8 +880,18 @@ class TokenizeChatRequest(OpenAIBaseModel):
...
@@ -862,8 +880,18 @@ class TokenizeChatRequest(OpenAIBaseModel):
messages
:
List
[
ChatCompletionMessageParam
]
messages
:
List
[
ChatCompletionMessageParam
]
add_generation_prompt
:
bool
=
Field
(
default
=
True
)
add_generation_prompt
:
bool
=
Field
(
default
=
True
)
continue_final_message
:
bool
=
Field
(
default
=
False
)
add_special_tokens
:
bool
=
Field
(
default
=
False
)
add_special_tokens
:
bool
=
Field
(
default
=
False
)
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
check_generation_prompt
(
cls
,
data
):
if
data
.
get
(
"continue_final_message"
)
and
data
.
get
(
"add_generation_prompt"
):
raise
ValueError
(
"Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True."
)
return
data
TokenizeRequest
=
Union
[
TokenizeCompletionRequest
,
TokenizeChatRequest
]
TokenizeRequest
=
Union
[
TokenizeCompletionRequest
,
TokenizeChatRequest
]
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
6c9ba48f
...
@@ -140,6 +140,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -140,6 +140,7 @@ class OpenAIServingChat(OpenAIServing):
messages
=
request
.
messages
,
messages
=
request
.
messages
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
add_generation_prompt
=
request
.
add_generation_prompt
,
continue_final_message
=
request
.
continue_final_message
,
tools
=
tool_dicts
,
tools
=
tool_dicts
,
documents
=
request
.
documents
,
documents
=
request
.
documents
,
**
(
request
.
chat_template_kwargs
or
{}),
**
(
request
.
chat_template_kwargs
or
{}),
...
@@ -150,6 +151,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -150,6 +151,7 @@ class OpenAIServingChat(OpenAIServing):
conversation
=
conversation
,
conversation
=
conversation
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
add_generation_prompt
=
request
.
add_generation_prompt
,
continue_final_message
=
request
.
continue_final_message
,
tools
=
tool_dicts
,
tools
=
tool_dicts
,
documents
=
request
.
documents
,
documents
=
request
.
documents
,
**
(
request
.
chat_template_kwargs
or
{}),
**
(
request
.
chat_template_kwargs
or
{}),
...
@@ -361,7 +363,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -361,7 +363,7 @@ class OpenAIServingChat(OpenAIServing):
# Send response to echo the input portion of the
# Send response to echo the input portion of the
# last message
# last message
if
request
.
echo
:
if
request
.
echo
or
request
.
continue_final_message
:
last_msg_content
:
str
=
""
last_msg_content
:
str
=
""
if
conversation
and
"content"
in
conversation
[
if
conversation
and
"content"
in
conversation
[
-
1
]
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
-
1
]
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
...
@@ -716,7 +718,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -716,7 +718,7 @@ class OpenAIServingChat(OpenAIServing):
stop_reason
=
output
.
stop_reason
)
stop_reason
=
output
.
stop_reason
)
choices
.
append
(
choice_data
)
choices
.
append
(
choice_data
)
if
request
.
echo
:
if
request
.
echo
or
request
.
continue_final_message
:
last_msg_content
=
""
last_msg_content
=
""
if
conversation
and
"content"
in
conversation
[
-
1
]
and
conversation
[
if
conversation
and
"content"
in
conversation
[
-
1
]
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
-
1
].
get
(
"role"
)
==
role
:
...
...
vllm/entrypoints/openai/serving_tokenization.py
View file @
6c9ba48f
...
@@ -87,6 +87,7 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -87,6 +87,7 @@ class OpenAIServingTokenization(OpenAIServing):
messages
=
request
.
messages
,
messages
=
request
.
messages
,
chat_template
=
self
.
chat_template
,
chat_template
=
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
add_generation_prompt
=
request
.
add_generation_prompt
,
continue_final_message
=
request
.
continue_final_message
,
)
)
else
:
else
:
prompt
=
apply_hf_chat_template
(
prompt
=
apply_hf_chat_template
(
...
@@ -94,6 +95,7 @@ class OpenAIServingTokenization(OpenAIServing):
...
@@ -94,6 +95,7 @@ class OpenAIServingTokenization(OpenAIServing):
conversation
=
conversation
,
conversation
=
conversation
,
chat_template
=
self
.
chat_template
,
chat_template
=
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
add_generation_prompt
=
request
.
add_generation_prompt
,
continue_final_message
=
request
.
continue_final_message
,
)
)
else
:
else
:
prompt
=
request
.
prompt
prompt
=
request
.
prompt
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment