Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d370e91
Unverified
Commit
8d370e91
authored
Dec 04, 2024
by
Michael Goin
Committed by
GitHub
Dec 05, 2024
Browse files
[Bugfix] Fallback to outlines for complex json schemas (#10899)
Signed-off-by:
mgoin
<
michael@neuralmagic.com
>
parent
7883c2bb
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
102 additions
and
0 deletions
+102
-0
tests/entrypoints/conftest.py
tests/entrypoints/conftest.py
+31
-0
tests/entrypoints/llm/test_guided_generate.py
tests/entrypoints/llm/test_guided_generate.py
+28
-0
vllm/model_executor/guided_decoding/__init__.py
vllm/model_executor/guided_decoding/__init__.py
+43
-0
No files found.
tests/entrypoints/conftest.py
View file @
8d370e91
...
...
@@ -69,6 +69,37 @@ def sample_json_schema():
}
@
pytest
.
fixture
def
sample_complex_json_schema
():
return
{
"type"
:
"object"
,
"properties"
:
{
"score"
:
{
"type"
:
"integer"
,
"minimum"
:
0
,
"maximum"
:
100
# Numeric range
},
"grade"
:
{
"type"
:
"string"
,
"pattern"
:
"^[A-D]$"
# Regex pattern
},
"email"
:
{
"type"
:
"string"
,
"pattern"
:
"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+
\\
.[a-zA-Z]{2,}$"
},
"tags"
:
{
"type"
:
"array"
,
"items"
:
{
"type"
:
"string"
,
"pattern"
:
"^[a-z]{1,10}$"
# Combining length and pattern restrictions
}
}
},
"required"
:
[
"score"
,
"grade"
,
"email"
,
"tags"
]
}
@
pytest
.
fixture
def
sample_guided_choice
():
return
[
...
...
tests/entrypoints/llm/test_guided_generate.py
View file @
8d370e91
...
...
@@ -76,6 +76,34 @@ def test_guided_json_completion(sample_json_schema, llm):
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
sample_json_schema
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_complex_json_completion
(
sample_complex_json_schema
,
llm
):
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_complex_json_schema
))
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for an assignment grade "
f
"that fits this schema:
{
sample_complex_json_schema
}
"
]
*
2
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
assert
outputs
is
not
None
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
assert
generated_text
is
not
None
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
output_json
=
json
.
loads
(
generated_text
)
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
sample_complex_json_schema
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_guided_choice_completion
(
sample_guided_choice
,
llm
):
sampling_params
=
SamplingParams
(
...
...
vllm/model_executor/guided_decoding/__init__.py
View file @
8d370e91
...
...
@@ -15,6 +15,40 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
def
has_xgrammar_unsupported_json_features
(
schema
:
dict
)
->
bool
:
"""Check if JSON schema contains features unsupported by xgrammar."""
def
check_object
(
obj
:
dict
)
->
bool
:
if
not
isinstance
(
obj
,
dict
):
return
False
# Check for pattern restrictions
if
"pattern"
in
obj
:
return
True
# Check for numeric ranges
if
obj
.
get
(
"type"
)
in
(
"integer"
,
"number"
)
and
any
(
key
in
obj
for
key
in
[
"minimum"
,
"maximum"
,
"exclusiveMinimum"
,
"exclusiveMaximum"
,
"multipleOf"
]):
return
True
# Recursively check all nested objects and arrays
for
value
in
obj
.
values
():
if
isinstance
(
value
,
dict
):
if
check_object
(
value
):
return
True
elif
isinstance
(
value
,
list
):
for
item
in
value
:
if
isinstance
(
item
,
dict
)
and
check_object
(
item
):
return
True
return
False
return
check_object
(
schema
)
def
maybe_backend_fallback
(
guided_params
:
GuidedDecodingParams
)
->
GuidedDecodingParams
:
# lm-format-enforce doesn't support grammar, fallback to xgrammar
...
...
@@ -47,6 +81,15 @@ def maybe_backend_fallback(
"Falling back to use outlines instead."
)
guided_params
.
backend
=
"outlines"
# xgrammar doesn't support some JSON schema features
elif
(
guided_params
.
json
is
not
None
and
has_xgrammar_unsupported_json_features
(
guided_params
.
json
)):
logger
.
warning
(
"xgrammar does not support advanced JSON schema features like "
"patterns or numeric ranges. "
"Falling back to use outlines instead."
)
guided_params
.
backend
=
"outlines"
return
guided_params
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment