Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
069d3bd8
Unverified
Commit
069d3bd8
authored
Oct 08, 2024
by
Alex Brooks
Committed by
GitHub
Oct 08, 2024
Browse files
[Frontend] Add Early Validation For Chat Template / Tool Call Parser (#9151)
Signed-off-by:
Alex-Brooks
<
Alex.Brooks@ibm.com
>
parent
a3691b6b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
155 additions
and
72 deletions
+155
-72
tests/entrypoints/openai/test_cli_args.py
tests/entrypoints/openai/test_cli_args.py
+109
-69
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+22
-0
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+3
-1
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+15
-0
vllm/scripts.py
vllm/scripts.py
+6
-2
No files found.
tests/entrypoints/openai/test_cli_args.py
View file @
069d3bd8
import
json
import
unittest
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
import
pytest
from
vllm.entrypoints.openai.cli_args
import
(
make_arg_parser
,
validate_parsed_serve_args
)
from
vllm.entrypoints.openai.serving_engine
import
LoRAModulePath
from
vllm.utils
import
FlexibleArgumentParser
from
...utils
import
VLLM_PATH
LORA_MODULE
=
{
"name"
:
"module2"
,
"path"
:
"/path/to/module2"
,
"base_model_name"
:
"llama"
}
CHATML_JINJA_PATH
=
VLLM_PATH
/
"examples/template_chatml.jinja"
assert
CHATML_JINJA_PATH
.
exists
()
class
TestLoraParserAction
(
unittest
.
TestCase
):
@
pytest
.
fixture
def
serve_parser
():
parser
=
FlexibleArgumentParser
(
description
=
"vLLM's remote OpenAI server."
)
return
make_arg_parser
(
parser
)
def
setUp
(
self
):
# Setting up argparse parser for tests
parser
=
FlexibleArgumentParser
(
description
=
"vLLM's remote OpenAI server."
)
self
.
parser
=
make_arg_parser
(
parser
)
def
test_valid_key_value_format
(
self
):
# Test old format: name=path
args
=
self
.
parser
.
parse_args
([
'--lora-modules'
,
'module1=/path/to/module1'
,
### Tests for Lora module parsing
def
test_valid_key_value_format
(
serve_parser
):
# Test old format: name=path
args
=
serve_parser
.
parse_args
([
'--lora-modules'
,
'module1=/path/to/module1'
,
])
expected
=
[
LoRAModulePath
(
name
=
'module1'
,
path
=
'/path/to/module1'
)]
assert
args
.
lora_modules
==
expected
def
test_valid_json_format
(
serve_parser
):
# Test valid JSON format input
args
=
serve_parser
.
parse_args
([
'--lora-modules'
,
json
.
dumps
(
LORA_MODULE
),
])
expected
=
[
LoRAModulePath
(
name
=
'module2'
,
path
=
'/path/to/module2'
,
base_model_name
=
'llama'
)
]
assert
args
.
lora_modules
==
expected
def
test_invalid_json_format
(
serve_parser
):
# Test invalid JSON format input, missing closing brace
with
pytest
.
raises
(
SystemExit
):
serve_parser
.
parse_args
([
'--lora-modules'
,
'{"name": "module3", "path": "/path/to/module3"'
])
expected
=
[
LoRAModulePath
(
name
=
'module1'
,
path
=
'/path/to/module1'
)]
self
.
assertEqual
(
args
.
lora_modules
,
expected
)
def
test_valid_json_format
(
self
):
# Test valid JSON format input
args
=
self
.
parser
.
parse_args
([
def
test_invalid_type_error
(
serve_parser
):
# Test type error when values are not JSON or key=value
with
pytest
.
raises
(
SystemExit
):
serve_parser
.
parse_args
([
'--lora-modules'
,
json
.
dumps
(
LORA_MODULE
),
'invalid_format'
# This is not JSON or key=value format
])
expected
=
[
LoRAModulePath
(
name
=
'module2'
,
path
=
'/path/to/module2'
,
base_model_name
=
'llama'
)
]
self
.
assertEqual
(
args
.
lora_modules
,
expected
)
def
test_invalid_json_format
(
self
):
# Test invalid JSON format input, missing closing brace
with
self
.
assertRaises
(
SystemExit
):
self
.
parser
.
parse_args
([
'--lora-modules'
,
'{"name": "module3", "path": "/path/to/module3"'
])
def
test_invalid_type_error
(
self
):
# Test type error when values are not JSON or key=value
with
self
.
assertRaises
(
SystemExit
):
self
.
parser
.
parse_args
([
'--lora-modules'
,
'invalid_format'
# This is not JSON or key=value format
])
def
test_invalid_json_field
(
self
):
# Test valid JSON format but missing required fields
with
self
.
assertRaises
(
SystemExit
):
self
.
parser
.
parse_args
([
'--lora-modules'
,
'{"name": "module4"}'
# Missing required 'path' field
])
def
test_empty_values
(
self
):
# Test when no LoRA modules are provided
args
=
self
.
parser
.
parse_args
([
'--lora-modules'
,
''
])
self
.
assertEqual
(
args
.
lora_modules
,
[])
def
test_multiple_valid_inputs
(
self
):
# Test multiple valid inputs (both old and JSON format)
args
=
self
.
parser
.
parse_args
([
def
test_invalid_json_field
(
serve_parser
):
# Test valid JSON format but missing required fields
with
pytest
.
raises
(
SystemExit
):
serve_parser
.
parse_args
([
'--lora-modules'
,
'module1=/path/to/module1'
,
json
.
dumps
(
LORA_MODULE
),
'{"name": "module4"}'
# Missing required 'path' field
])
expected
=
[
LoRAModulePath
(
name
=
'module1'
,
path
=
'/path/to/module1'
),
LoRAModulePath
(
name
=
'module2'
,
path
=
'/path/to/module2'
,
base_model_name
=
'llama'
)
]
self
.
assertEqual
(
args
.
lora_modules
,
expected
)
if
__name__
==
'__main__'
:
unittest
.
main
()
def
test_empty_values
(
serve_parser
):
# Test when no LoRA modules are provided
args
=
serve_parser
.
parse_args
([
'--lora-modules'
,
''
])
assert
args
.
lora_modules
==
[]
def
test_multiple_valid_inputs
(
serve_parser
):
# Test multiple valid inputs (both old and JSON format)
args
=
serve_parser
.
parse_args
([
'--lora-modules'
,
'module1=/path/to/module1'
,
json
.
dumps
(
LORA_MODULE
),
])
expected
=
[
LoRAModulePath
(
name
=
'module1'
,
path
=
'/path/to/module1'
),
LoRAModulePath
(
name
=
'module2'
,
path
=
'/path/to/module2'
,
base_model_name
=
'llama'
)
]
assert
args
.
lora_modules
==
expected
### Tests for serve argument validation that run prior to loading
def
test_enable_auto_choice_passes_without_tool_call_parser
(
serve_parser
):
"""Ensure validation fails if tool choice is enabled with no call parser"""
# If we enable-auto-tool-choice, explode with no tool-call-parser
args
=
serve_parser
.
parse_args
(
args
=
[
"--enable-auto-tool-choice"
])
with
pytest
.
raises
(
TypeError
):
validate_parsed_serve_args
(
args
)
def
test_enable_auto_choice_passes_with_tool_call_parser
(
serve_parser
):
"""Ensure validation passes with tool choice enabled with a call parser"""
args
=
serve_parser
.
parse_args
(
args
=
[
"--enable-auto-tool-choice"
,
"--tool-call-parser"
,
"mistral"
,
])
validate_parsed_serve_args
(
args
)
def
test_chat_template_validation_for_happy_paths
(
serve_parser
):
"""Ensure validation passes if the chat template exists"""
args
=
serve_parser
.
parse_args
(
args
=
[
"--chat-template"
,
CHATML_JINJA_PATH
.
absolute
().
as_posix
()])
validate_parsed_serve_args
(
args
)
def
test_chat_template_validation_for_sad_paths
(
serve_parser
):
"""Ensure validation fails if the chat template doesn't exist"""
args
=
serve_parser
.
parse_args
(
args
=
[
"--chat-template"
,
"does/not/exist"
])
with
pytest
.
raises
(
ValueError
):
validate_parsed_serve_args
(
args
)
vllm/entrypoints/chat_utils.py
View file @
069d3bd8
...
...
@@ -303,6 +303,28 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self
.
_add_placeholder
(
placeholder
)
def
validate_chat_template
(
chat_template
:
Optional
[
Union
[
Path
,
str
]]):
"""Raises if the provided chat template appears invalid."""
if
chat_template
is
None
:
return
elif
isinstance
(
chat_template
,
Path
)
and
not
chat_template
.
exists
():
raise
FileNotFoundError
(
"the supplied chat template path doesn't exist"
)
elif
isinstance
(
chat_template
,
str
):
JINJA_CHARS
=
"{}
\n
"
if
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
)
and
not
Path
(
chat_template
).
exists
():
raise
ValueError
(
f
"The supplied chat template string (
{
chat_template
}
) "
f
"appears path-like, but doesn't exist!"
)
else
:
raise
TypeError
(
f
"
{
type
(
chat_template
)
}
is not a valid chat template type"
)
def
load_chat_template
(
chat_template
:
Optional
[
Union
[
Path
,
str
]])
->
Optional
[
str
]:
if
chat_template
is
None
:
...
...
vllm/entrypoints/openai/api_server.py
View file @
069d3bd8
...
...
@@ -31,7 +31,8 @@ from vllm.engine.multiprocessing.engine import run_mp_engine
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.launcher
import
serve_http
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.cli_args
import
(
make_arg_parser
,
validate_parsed_serve_args
)
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
...
...
@@ -577,5 +578,6 @@ if __name__ == "__main__":
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
parser
=
make_arg_parser
(
parser
)
args
=
parser
.
parse_args
()
validate_parsed_serve_args
(
args
)
uvloop
.
run
(
run_server
(
args
))
vllm/entrypoints/openai/cli_args.py
View file @
069d3bd8
...
...
@@ -10,6 +10,7 @@ import ssl
from
typing
import
List
,
Optional
,
Sequence
,
Union
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.entrypoints.chat_utils
import
validate_chat_template
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
PromptAdapterPath
)
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
...
...
@@ -231,6 +232,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
return
parser
def
validate_parsed_serve_args
(
args
:
argparse
.
Namespace
):
"""Quick checks for model serve args that raise prior to loading."""
if
hasattr
(
args
,
"subparser"
)
and
args
.
subparser
!=
"serve"
:
return
# Ensure that the chat template is valid; raises if it likely isn't
validate_chat_template
(
args
.
chat_template
)
# Enable auto tool needs a tool call parser to be valid
if
args
.
enable_auto_tool_choice
and
not
args
.
tool_call_parser
:
raise
TypeError
(
"Error: --enable-auto-tool-choice requires "
"--tool-call-parser"
)
def
create_parser_for_docs
()
->
FlexibleArgumentParser
:
parser_for_docs
=
FlexibleArgumentParser
(
prog
=
"-m vllm.entrypoints.openai.api_server"
)
...
...
vllm/scripts.py
View file @
069d3bd8
...
...
@@ -11,7 +11,8 @@ from openai.types.chat import ChatCompletionMessageParam
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.entrypoints.openai.api_server
import
run_server
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.cli_args
import
(
make_arg_parser
,
validate_parsed_serve_args
)
from
vllm.logger
import
init_logger
from
vllm.utils
import
FlexibleArgumentParser
...
...
@@ -142,7 +143,7 @@ def main():
env_setup
()
parser
=
FlexibleArgumentParser
(
description
=
"vLLM CLI"
)
subparsers
=
parser
.
add_subparsers
(
required
=
True
)
subparsers
=
parser
.
add_subparsers
(
required
=
True
,
dest
=
"subparser"
)
serve_parser
=
subparsers
.
add_parser
(
"serve"
,
...
...
@@ -186,6 +187,9 @@ def main():
chat_parser
.
set_defaults
(
dispatch_function
=
interactive_cli
,
command
=
"chat"
)
args
=
parser
.
parse_args
()
if
args
.
subparser
==
"serve"
:
validate_parsed_serve_args
(
args
)
# One of the sub commands should be executed.
if
hasattr
(
args
,
"dispatch_function"
):
args
.
dispatch_function
(
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment