Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a30482f0
Unverified
Commit
a30482f0
authored
Dec 18, 2024
by
Michael Goin
Committed by
GitHub
Dec 19, 2024
Browse files
[CI] Expand test_guided_generate to test all backends (#11313)
Signed-off-by:
mgoin
<
michael@neuralmagic.com
>
parent
17ca9642
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
129 additions
and
51 deletions
+129
-51
tests/entrypoints/llm/test_guided_generate.py
tests/entrypoints/llm/test_guided_generate.py
+69
-43
tests/model_executor/test_guided_processors.py
tests/model_executor/test_guided_processors.py
+2
-2
vllm/model_executor/guided_decoding/__init__.py
vllm/model_executor/guided_decoding/__init__.py
+58
-6
No files found.
tests/entrypoints/llm/test_guided_generate.py
View file @
a30482f0
...
...
@@ -10,7 +10,8 @@ from vllm.entrypoints.llm import LLM
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
MODEL_NAME
=
"HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME
=
"Qwen/Qwen2.5-7B-Instruct"
GUIDED_DECODING_BACKENDS
=
[
"outlines"
,
"lm-format-enforcer"
,
"xgrammar"
]
@
pytest
.
fixture
(
scope
=
"module"
)
...
...
@@ -26,11 +27,13 @@ def llm():
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_regex
(
sample_regex
,
llm
):
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
def
test_guided_regex
(
sample_regex
,
llm
,
guided_decoding_backend
:
str
):
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
guided_decoding
=
GuidedDecodingParams
(
regex
=
sample_regex
))
guided_decoding
=
GuidedDecodingParams
(
regex
=
sample_regex
,
backend
=
guided_decoding_backend
))
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example IPv4 address with this regex:
{
sample_regex
}
"
]
*
2
,
...
...
@@ -50,11 +53,14 @@ def test_guided_regex(sample_regex, llm):
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_json_completion
(
sample_json_schema
,
llm
):
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
def
test_guided_json_completion
(
sample_json_schema
,
llm
,
guided_decoding_backend
:
str
):
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_json_schema
))
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_json_schema
,
backend
=
guided_decoding_backend
))
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for an employee profile "
f
"that fits this schema:
{
sample_json_schema
}
"
...
...
@@ -77,11 +83,14 @@ def test_guided_json_completion(sample_json_schema, llm):
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_complex_json_completion
(
sample_complex_json_schema
,
llm
):
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
def
test_guided_complex_json_completion
(
sample_complex_json_schema
,
llm
,
guided_decoding_backend
:
str
):
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_complex_json_schema
))
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_complex_json_schema
,
backend
=
guided_decoding_backend
))
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for an assignment grade "
f
"that fits this schema:
{
sample_complex_json_schema
}
"
...
...
@@ -105,11 +114,14 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm):
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_definition_json_completion
(
sample_definition_json_schema
,
llm
):
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
def
test_guided_definition_json_completion
(
sample_definition_json_schema
,
llm
,
guided_decoding_backend
:
str
):
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_definition_json_schema
))
json
=
sample_definition_json_schema
,
backend
=
guided_decoding_backend
))
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for solving 8x + 7 = -23 "
f
"that fits this schema:
{
sample_definition_json_schema
}
"
...
...
@@ -133,11 +145,14 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm):
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_choice_completion
(
sample_guided_choice
,
llm
):
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
def
test_guided_choice_completion
(
sample_guided_choice
,
llm
,
guided_decoding_backend
:
str
):
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
guided_decoding
=
GuidedDecodingParams
(
choice
=
sample_guided_choice
))
guided_decoding
=
GuidedDecodingParams
(
choice
=
sample_guided_choice
,
backend
=
guided_decoding_backend
))
outputs
=
llm
.
generate
(
prompts
=
"The best language for type-safe systems programming is "
,
sampling_params
=
sampling_params
,
...
...
@@ -156,13 +171,20 @@ def test_guided_choice_completion(sample_guided_choice, llm):
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_grammar
(
sample_sql_statements
,
llm
):
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
def
test_guided_grammar
(
sample_sql_statements
,
llm
,
guided_decoding_backend
:
str
):
if
guided_decoding_backend
==
"outlines"
:
pytest
.
skip
(
"Outlines backend fails in this test case with:
\n
"
"AttributeError: Error in model execution: 'ParserConf' "
"object has no attribute 'deterministic'"
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
grammar
=
sample_sql_statements
))
guided_decoding
=
GuidedDecodingParams
(
grammar
=
sample_sql_statements
,
backend
=
guided_decoding_backend
))
outputs
=
llm
.
generate
(
prompts
=
(
"Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"
),
...
...
@@ -218,15 +240,18 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_json_object
(
llm
):
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
def
test_guided_json_object
(
llm
,
guided_decoding_backend
:
str
):
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
100
,
guided_decoding
=
GuidedDecodingParams
(
json_object
=
True
))
n
=
2
,
guided_decoding
=
GuidedDecodingParams
(
json_object
=
True
,
backend
=
guided_decoding_backend
))
outputs
=
llm
.
generate
(
prompts
=
(
"Generate a JSON object
describing
a person with
name
"
"and age for John Smith who is 31 years old."
),
prompts
=
(
"Generate a JSON object
with curly braces for
a person with "
"
name
and age
fields
for John Smith who is 31 years old."
),
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
...
...
@@ -235,7 +260,8 @@ def test_guided_json_object(llm):
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
generated_text
=
output
.
outputs
[
0
].
text
for
i
in
range
(
2
):
generated_text
=
output
.
outputs
[
i
].
text
print
(
generated_text
)
assert
generated_text
is
not
None
...
...
tests/model_executor/test_guided_processors.py
View file @
a30482f0
...
...
@@ -13,6 +13,7 @@ from vllm.model_executor.guided_decoding.outlines_logits_processors import (
from
vllm.sampling_params
import
GuidedDecodingParams
MODEL_NAME
=
'HuggingFaceH4/zephyr-7b-beta'
GUIDED_DECODING_BACKENDS
=
[
"outlines"
,
"lm-format-enforcer"
,
"xgrammar"
]
def
test_guided_logits_processors
(
sample_regex
,
sample_json_schema
):
...
...
@@ -42,8 +43,7 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"outlines"
,
"lm-format-enforcer"
,
"xgrammar"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
GUIDED_DECODING_BACKENDS
)
@
pytest
.
mark
.
parametrize
(
"is_local"
,
[
True
,
False
])
async
def
test_guided_logits_processor_black_box
(
backend
:
str
,
is_local
:
bool
,
sample_regex
,
...
...
vllm/model_executor/guided_decoding/__init__.py
View file @
a30482f0
...
...
@@ -49,16 +49,61 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
return
check_object
(
schema
)
def
has_lmf_unsupported_json_features
(
schema
:
dict
)
->
bool
:
"""
Check if JSON schema contains features unsupported
by lm_format_enforcer.
Known issues:
- Regex patterns:
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"""
def
check_object
(
obj
:
dict
)
->
bool
:
if
not
isinstance
(
obj
,
dict
):
return
False
# Check for pattern restrictions
if
"pattern"
in
obj
:
return
True
# Recursively check all nested objects and arrays
for
value
in
obj
.
values
():
if
isinstance
(
value
,
dict
):
if
check_object
(
value
):
return
True
elif
isinstance
(
value
,
list
):
for
item
in
value
:
if
isinstance
(
item
,
dict
)
and
check_object
(
item
):
return
True
return
False
return
check_object
(
schema
)
def
maybe_backend_fallback
(
guided_params
:
GuidedDecodingParams
)
->
GuidedDecodingParams
:
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if
(
guided_params
.
backend
==
"lm-format-enforcer"
and
guided_params
.
grammar
is
not
None
)
:
if
guided_params
.
backend
==
"lm-format-enforcer"
:
if
guided_params
.
grammar
is
not
None
:
logger
.
warning
(
"lm-format-enforcer does not support grammar guided decoding. "
"Falling back to use xgrammar instead."
)
guided_params
.
backend
=
"xgrammar"
# lm-format-enforcer doesn't support some JSON schema features
elif
(
guided_params
.
json
is
not
None
and
has_lmf_unsupported_json_features
(
guided_params
.
json
)):
logger
.
warning
(
"lm-format-enforcer does not support advanced JSON schema "
"features like patterns or numeric ranges. "
"Falling back to use outlines instead."
)
guided_params
.
backend
=
"outlines"
if
guided_params
.
backend
==
"xgrammar"
:
# xgrammar only has x86 wheels for linux, fallback to outlines
if
current_platform
.
get_cpu_architecture
()
is
not
CpuArchEnum
.
X86
:
...
...
@@ -82,6 +127,13 @@ def maybe_backend_fallback(
"Falling back to use outlines instead."
)
guided_params
.
backend
=
"outlines"
if
(
guided_params
.
backend
==
"outlines"
and
guided_params
.
json_object
is
not
None
):
# outlines doesn't support json_object, fallback to xgrammar
logger
.
warning
(
"outlines does not support json_object. "
"Falling back to use xgrammar instead."
)
guided_params
.
backend
=
"xgrammar"
return
guided_params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment