Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dc1b4a6f
Unverified
Commit
dc1b4a6f
authored
Apr 13, 2025
by
Russell Bryant
Committed by
GitHub
Apr 14, 2025
Browse files
[Core][V0] Enable regex support with xgrammar (#13228)
Signed-off-by:
Russell Bryant
<
rbryant@redhat.com
>
parent
63d2705e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
9 deletions
+25
-9
tests/entrypoints/llm/test_guided_generate.py
tests/entrypoints/llm/test_guided_generate.py
+13
-2
vllm/model_executor/guided_decoding/__init__.py
vllm/model_executor/guided_decoding/__init__.py
+2
-7
vllm/model_executor/guided_decoding/xgrammar_decoding.py
vllm/model_executor/guided_decoding/xgrammar_decoding.py
+10
-0
No files found.
tests/entrypoints/llm/test_guided_generate.py
View file @
dc1b4a6f
...
@@ -286,15 +286,26 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):
...
@@ -286,15 +286,26 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
skip_global_cleanup
def
test_disable_guided_decoding_fallback
(
sample_regex
,
llm
):
def
test_disable_guided_decoding_fallback
(
sample_regex
,
llm
):
# see has_xgrammar_unsupported_json_features()
unsupported_json
=
{
"type"
:
"object"
,
"properties"
:
{
"example"
:
{
"type"
:
"string"
,
"minLength"
:
5
# unsupported by xgrammar
}
}
}
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
top_p
=
0.95
,
guided_decoding
=
GuidedDecodingParams
(
guided_decoding
=
GuidedDecodingParams
(
regex
=
sample_regex
,
json
=
unsupported_json
,
backend
=
"xgrammar:no-fallback"
))
backend
=
"xgrammar:no-fallback"
))
with
pytest
.
raises
(
with
pytest
.
raises
(
ValueError
,
ValueError
,
match
=
"xgrammar does not support regex guided decoding"
):
match
=
"xgrammar does not support advanced JSON schema features "
"like enums, patterns or numeric ranges."
):
llm
.
generate
(
prompts
=
"This should fail"
,
llm
.
generate
(
prompts
=
"This should fail"
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
use_tqdm
=
True
)
...
...
vllm/model_executor/guided_decoding/__init__.py
View file @
dc1b4a6f
...
@@ -59,14 +59,9 @@ def maybe_backend_fallback(
...
@@ -59,14 +59,9 @@ def maybe_backend_fallback(
from
vllm.model_executor.guided_decoding.xgrammar_decoding
import
(
from
vllm.model_executor.guided_decoding.xgrammar_decoding
import
(
xgr_installed
)
xgr_installed
)
# xgrammar doesn't support regex, fallback to outlines
if
guided_params
.
regex
is
not
None
:
fallback_or_error
(
guided_params
,
"xgrammar does not support regex guided decoding."
,
"outlines"
)
# xgrammar doesn't support some JSON schema features
# xgrammar doesn't support some JSON schema features
el
if
(
guided_params
.
json
is
not
None
if
(
guided_params
.
json
is
not
None
and
and
has_xgrammar_unsupported_json_features
(
guided_params
.
json
)):
has_xgrammar_unsupported_json_features
(
guided_params
.
json
)):
fallback_or_error
(
fallback_or_error
(
guided_params
,
guided_params
,
"xgrammar does not support advanced JSON schema features like "
"xgrammar does not support advanced JSON schema features like "
...
...
vllm/model_executor/guided_decoding/xgrammar_decoding.py
View file @
dc1b4a6f
...
@@ -152,6 +152,7 @@ class GrammarConfig:
...
@@ -152,6 +152,7 @@ class GrammarConfig:
grammar_str
:
str
|
None
=
None
grammar_str
:
str
|
None
=
None
json_object
:
bool
|
None
=
None
json_object
:
bool
|
None
=
None
any_whitespace
:
bool
=
True
any_whitespace
:
bool
=
True
regex_str
:
str
|
None
=
None
max_threads
:
int
=
8
max_threads
:
int
=
8
@
classmethod
@
classmethod
...
@@ -255,6 +256,13 @@ class GrammarConfig:
...
@@ -255,6 +256,13 @@ class GrammarConfig:
max_threads
=
max_threads
,
max_threads
=
max_threads
,
tokenizer_data
=
tokenizer_data
,
tokenizer_data
=
tokenizer_data
,
)
)
elif
guided_params
.
regex
:
return
cls
(
regex_str
=
guided_params
.
regex
,
tokenizer_hash
=
tokenizer_hash
,
max_threads
=
max_threads
,
tokenizer_data
=
tokenizer_data
,
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Currently only support JSON and EBNF grammar mode for xgrammar"
"Currently only support JSON and EBNF grammar mode for xgrammar"
...
@@ -330,6 +338,8 @@ class XGrammarLogitsProcessor:
...
@@ -330,6 +338,8 @@ class XGrammarLogitsProcessor:
self
.
ctx
=
compiler
\
self
.
ctx
=
compiler
\
.
compile_json_schema
(
'{"type": "object"}'
,
.
compile_json_schema
(
'{"type": "object"}'
,
any_whitespace
=
any_whitespace
)
any_whitespace
=
any_whitespace
)
elif
self
.
config
.
regex_str
:
self
.
ctx
=
compiler
.
compile_regex
(
self
.
config
.
regex_str
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Invalid configuration for xgrammar logits processor"
)
"Invalid configuration for xgrammar logits processor"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment