Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
66d617e3
Unverified
Commit
66d617e3
authored
Aug 07, 2024
by
Cyrus Leung
Committed by
GitHub
Aug 07, 2024
Browse files
[Frontend] Gracefully handle missing chat template and fix CI failure (#7238)
Co-authored-by:
Roger Wang
<
ywang@roblox.com
>
parent
7b261092
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
125 additions
and
69 deletions
+125
-69
tests/async_engine/test_chat_template.py
tests/async_engine/test_chat_template.py
+8
-13
tests/async_engine/test_openapi_server_ray.py
tests/async_engine/test_openapi_server_ray.py
+7
-3
tests/entrypoints/openai/test_oot_registration.py
tests/entrypoints/openai/test_oot_registration.py
+55
-32
tests/utils.py
tests/utils.py
+2
-2
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+34
-3
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+3
-2
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+4
-4
vllm/entrypoints/openai/serving_tokenization.py
vllm/entrypoints/openai/serving_tokenization.py
+8
-6
vllm/transformers_utils/tokenizer.py
vllm/transformers_utils/tokenizer.py
+4
-4
No files found.
tests/async_engine/test_chat_template.py
View file @
66d617e3
import
os
import
pathlib
import
pytest
from
vllm.entrypoints.chat_utils
import
load_chat_template
from
vllm.entrypoints.chat_utils
import
apply_chat_template
,
load_chat_template
from
vllm.entrypoints.openai.protocol
import
ChatCompletionRequest
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
chatml_jinja_path
=
pathlib
.
Path
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))).
parent
.
parent
/
"examples/template_chatml.jinja"
from
..utils
import
VLLM_PATH
chatml_jinja_path
=
VLLM_PATH
/
"examples/template_chatml.jinja"
assert
chatml_jinja_path
.
exists
()
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT
=
[
(
"facebook/opt-125m"
,
None
,
True
,
"Hello</s>Hi there!</s>What is the capital of</s>"
),
(
"facebook/opt-125m"
,
None
,
False
,
"Hello</s>Hi there!</s>What is the capital of</s>"
),
(
"facebook/opt-125m"
,
chatml_jinja_path
,
True
,
"""<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
...
...
@@ -93,11 +87,12 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt
=
add_generation_prompt
)
# Call the function and get the result
result
=
tokenizer
.
apply_chat_template
(
result
=
apply_chat_template
(
tokenizer
,
conversation
=
mock_request
.
messages
,
tokenize
=
False
,
chat_template
=
mock_request
.
chat_template
or
template_content
,
add_generation_prompt
=
mock_request
.
add_generation_prompt
,
chat_template
=
mock_request
.
chat_template
or
template_content
)
)
# Test assertion
assert
result
==
expected_output
,
(
...
...
tests/async_engine/test_openapi_server_ray.py
View file @
66d617e3
import
openai
# use the official client for correctness check
import
pytest
from
..utils
import
RemoteOpenAIServer
from
..utils
import
VLLM_PATH
,
RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME
=
"facebook/opt-125m"
chatml_jinja_path
=
VLLM_PATH
/
"examples/template_chatml.jinja"
assert
chatml_jinja_path
.
exists
()
@
pytest
.
fixture
(
scope
=
"module"
)
...
...
@@ -16,7 +18,9 @@ def server():
"--max-model-len"
,
"2048"
,
"--enforce-eager"
,
"--engine-use-ray"
"--engine-use-ray"
,
"--chat-template"
,
str
(
chatml_jinja_path
),
]
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
)
as
remote_server
:
...
...
@@ -83,7 +87,7 @@ async def test_single_chat_session(client: openai.AsyncOpenAI):
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
finish_reason
==
"length"
assert
chat_completion
.
usage
==
openai
.
types
.
CompletionUsage
(
completion_tokens
=
10
,
prompt_tokens
=
13
,
total_tokens
=
23
)
completion_tokens
=
10
,
prompt_tokens
=
55
,
total_tokens
=
65
)
message
=
choice
.
message
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
10
...
...
tests/entrypoints/openai/test_oot_registration.py
View file @
66d617e3
...
...
@@ -9,6 +9,11 @@ from vllm.model_executor.models.opt import OPTForCausalLM
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.utils
import
get_open_port
from
...utils
import
VLLM_PATH
,
RemoteOpenAIServer
chatml_jinja_path
=
VLLM_PATH
/
"examples/template_chatml.jinja"
assert
chatml_jinja_path
.
exists
()
class
MyOPTForCausalLM
(
OPTForCausalLM
):
...
...
@@ -21,12 +26,25 @@ class MyOPTForCausalLM(OPTForCausalLM):
return
logits
def
server_function
(
port
):
def
server_function
(
port
:
int
):
# register our dummy model
ModelRegistry
.
register_model
(
"OPTForCausalLM"
,
MyOPTForCausalLM
)
sys
.
argv
=
[
"placeholder.py"
]
+
\
(
"--model facebook/opt-125m --gpu-memory-utilization 0.10 "
f
"--dtype float32 --api-key token-abc123 --port
{
port
}
"
).
split
()
sys
.
argv
=
[
"placeholder.py"
]
+
[
"--model"
,
"facebook/opt-125m"
,
"--gpu-memory-utilization"
,
"0.10"
,
"--dtype"
,
"float32"
,
"--api-key"
,
"token-abc123"
,
"--port"
,
str
(
port
),
"--chat-template"
,
str
(
chatml_jinja_path
),
]
import
runpy
runpy
.
run_module
(
'vllm.entrypoints.openai.api_server'
,
run_name
=
'__main__'
)
...
...
@@ -36,35 +54,40 @@ def test_oot_registration_for_api_server():
ctx
=
torch
.
multiprocessing
.
get_context
()
server
=
ctx
.
Process
(
target
=
server_function
,
args
=
(
port
,
))
server
.
start
()
MAX_SERVER_START_WAIT_S
=
60
client
=
OpenAI
(
base_url
=
f
"http://localhost:
{
port
}
/v1"
,
api_key
=
"token-abc123"
,
)
now
=
time
.
time
()
while
True
:
try
:
completion
=
client
.
chat
.
completions
.
create
(
model
=
"facebook/opt-125m"
,
messages
=
[{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant."
},
{
"role"
:
"user"
,
"content"
:
"Hello!"
}],
temperature
=
0
,
)
break
except
OpenAIError
as
e
:
if
"Connection error"
in
str
(
e
):
time
.
sleep
(
3
)
if
time
.
time
()
-
now
>
MAX_SERVER_START_WAIT_S
:
raise
RuntimeError
(
"Server did not start in time"
)
from
e
else
:
raise
e
server
.
kill
()
try
:
client
=
OpenAI
(
base_url
=
f
"http://localhost:
{
port
}
/v1"
,
api_key
=
"token-abc123"
,
)
now
=
time
.
time
()
while
True
:
try
:
completion
=
client
.
chat
.
completions
.
create
(
model
=
"facebook/opt-125m"
,
messages
=
[{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant."
},
{
"role"
:
"user"
,
"content"
:
"Hello!"
}],
temperature
=
0
,
)
break
except
OpenAIError
as
e
:
if
"Connection error"
in
str
(
e
):
time
.
sleep
(
3
)
if
time
.
time
()
-
now
>
RemoteOpenAIServer
.
MAX_START_WAIT_S
:
msg
=
"Server did not start in time"
raise
RuntimeError
(
msg
)
from
e
else
:
raise
e
finally
:
server
.
terminate
()
generated_text
=
completion
.
choices
[
0
].
message
.
content
assert
generated_text
is
not
None
# make sure only the first token is generated
rest
=
generated_text
.
replace
(
"<s>"
,
""
)
assert
rest
==
""
tests/utils.py
View file @
66d617e3
...
...
@@ -50,7 +50,7 @@ VLLM_PATH = Path(__file__).parent.parent
class
RemoteOpenAIServer
:
DUMMY_API_KEY
=
"token-abc123"
# vLLM's OpenAI server does not need API key
MAX_
SERVER_
START_WAIT_S
=
120
# wait for server to start for 120 seconds
MAX_START_WAIT_S
=
120
# wait for server to start for 120 seconds
def
__init__
(
self
,
...
...
@@ -85,7 +85,7 @@ class RemoteOpenAIServer:
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stderr
)
self
.
_wait_for_server
(
url
=
self
.
url_for
(
"health"
),
timeout
=
self
.
MAX_
SERVER_
START_WAIT_S
)
timeout
=
self
.
MAX_START_WAIT_S
)
def
__enter__
(
self
):
return
self
...
...
vllm/entrypoints/chat_utils.py
View file @
66d617e3
import
codecs
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
typing
import
(
Awaitable
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
,
final
)
from
pathlib
import
Path
from
typing
import
(
Any
,
Awaitable
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
,
final
)
# yapf conflicts with isort for this block
# yapf: disable
...
...
@@ -22,6 +23,7 @@ from vllm.config import ModelConfig
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.utils
import
async_get_and_parse_image
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
...
...
@@ -69,13 +71,17 @@ class ChatMessageParseResult:
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
def
load_chat_template
(
chat_template
:
Optional
[
str
])
->
Optional
[
str
]:
def
load_chat_template
(
chat_template
:
Optional
[
Union
[
Path
,
str
]])
->
Optional
[
str
]:
if
chat_template
is
None
:
return
None
try
:
with
open
(
chat_template
,
"r"
)
as
f
:
resolved_chat_template
=
f
.
read
()
except
OSError
as
e
:
if
isinstance
(
chat_template
,
Path
):
raise
JINJA_CHARS
=
"{}
\n
"
if
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
):
msg
=
(
f
"The supplied chat template (
{
chat_template
}
) "
...
...
@@ -208,3 +214,28 @@ def parse_chat_messages(
mm_futures
.
extend
(
parse_result
.
mm_futures
)
return
conversation
,
mm_futures
def
apply_chat_template
(
tokenizer
:
AnyTokenizer
,
conversation
:
List
[
ConversationMessage
],
chat_template
:
Optional
[
str
],
*
,
tokenize
:
bool
=
False
,
# Different from HF's default
**
kwargs
:
Any
,
)
->
str
:
if
chat_template
is
None
and
tokenizer
.
chat_template
is
None
:
raise
ValueError
(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
)
prompt
=
tokenizer
.
apply_chat_template
(
conversation
=
conversation
,
chat_template
=
chat_template
,
tokenize
=
tokenize
,
**
kwargs
,
)
assert
isinstance
(
prompt
,
str
)
return
prompt
vllm/entrypoints/openai/protocol.py
View file @
66d617e3
...
...
@@ -190,8 +190,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
default
=
None
,
description
=
(
"A Jinja template to use for this conversion. "
"If this is not passed, the model's default chat template will be "
"used instead."
),
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
Field
(
default
=
None
,
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
66d617e3
...
...
@@ -10,6 +10,7 @@ from transformers import PreTrainedTokenizer
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.chat_utils
import
(
ConversationMessage
,
apply_chat_template
,
load_chat_template
,
parse_chat_messages
)
from
vllm.entrypoints.logger
import
RequestLogger
...
...
@@ -99,16 +100,15 @@ class OpenAIServingChat(OpenAIServing):
tool
.
model_dump
()
for
tool
in
request
.
tools
]
prompt
=
tokenizer
.
apply_chat_template
(
prompt
=
apply_chat_template
(
tokenizer
,
conversation
=
conversation
,
tokenize
=
Fals
e
,
chat_template
=
request
.
chat_template
or
self
.
chat_templat
e
,
add_generation_prompt
=
request
.
add_generation_prompt
,
tools
=
tool_dicts
,
documents
=
request
.
documents
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
**
(
request
.
chat_template_kwargs
or
{}),
)
assert
isinstance
(
prompt
,
str
)
except
Exception
as
e
:
logger
.
error
(
"Error in applying chat template from request: %s"
,
e
)
return
self
.
create_error_response
(
str
(
e
))
...
...
vllm/entrypoints/openai/serving_tokenization.py
View file @
66d617e3
...
...
@@ -2,7 +2,9 @@ from typing import List, Optional, Union
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.entrypoints.chat_utils
import
load_chat_template
,
parse_chat_messages
from
vllm.entrypoints.chat_utils
import
(
apply_chat_template
,
load_chat_template
,
parse_chat_messages
)
from
vllm.entrypoints.logger
import
RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
...
...
@@ -70,12 +72,12 @@ class OpenAIServingTokenization(OpenAIServing):
logger
.
warning
(
"Multi-modal inputs are ignored during tokenization"
)
prompt
=
tokenizer
.
apply_chat_template
(
add_generation_prompt
=
request
.
add_generation_prompt
,
prompt
=
apply_chat_template
(
tokenizer
,
conversation
=
conversation
,
tokenize
=
Fals
e
,
chat_template
=
self
.
chat_template
)
assert
isinstance
(
prompt
,
str
)
chat_template
=
self
.
chat_templat
e
,
add_generation_prompt
=
request
.
add_generation_prompt
,
)
else
:
prompt
=
request
.
prompt
...
...
vllm/transformers_utils/tokenizer.py
View file @
66d617e3
...
...
@@ -12,12 +12,12 @@ from vllm.lora.request import LoRARequest
from
vllm.transformers_utils.tokenizers
import
BaichuanTokenizer
from
vllm.utils
import
make_async
from
.tokenizer_group
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
def
get_cached_tokenizer
(
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
def
get_cached_tokenizer
(
tokenizer
:
AnyTokenizer
)
->
AnyTokenizer
:
"""Get tokenizer with cached properties.
This will patch the tokenizer object in place.
...
...
@@ -63,7 +63,7 @@ def get_tokenizer(
revision
:
Optional
[
str
]
=
None
,
download_dir
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]
:
)
->
AnyTokenizer
:
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope.
"""
if
VLLM_USE_MODELSCOPE
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment