Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7329ff54
Unverified
Commit
7329ff54
authored
Mar 28, 2025
by
Russell Bryant
Committed by
GitHub
Mar 28, 2025
Browse files
[V1] Support disable_any_whtespace for guidance backend (#15584)
Signed-off-by:
Russell Bryant
<
rbryant@redhat.com
>
parent
541d1df4
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
44 additions
and
117 deletions
+44
-117
tests/entrypoints/llm/test_guided_generate.py
tests/entrypoints/llm/test_guided_generate.py
+7
-55
tests/v1/entrypoints/llm/test_struct_output_generate.py
tests/v1/entrypoints/llm/test_struct_output_generate.py
+7
-47
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+2
-1
vllm/model_executor/guided_decoding/guidance_decoding.py
vllm/model_executor/guided_decoding/guidance_decoding.py
+10
-2
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+6
-5
vllm/v1/structured_output/backend_guidance.py
vllm/v1/structured_output/backend_guidance.py
+12
-7
No files found.
tests/entrypoints/llm/test_guided_generate.py
View file @
7329ff54
...
@@ -6,7 +6,6 @@ import weakref
...
@@ -6,7 +6,6 @@ import weakref
import
jsonschema
import
jsonschema
import
pytest
import
pytest
from
pydantic
import
BaseModel
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.distributed
import
cleanup_dist_env_and_memory
from
vllm.entrypoints.llm
import
LLM
from
vllm.entrypoints.llm
import
LLM
...
@@ -15,7 +14,10 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams
...
@@ -15,7 +14,10 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME
=
"Qwen/Qwen2.5-1.5B-Instruct"
MODEL_NAME
=
"Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS
=
[
GUIDED_DECODING_BACKENDS
=
[
"outlines"
,
"lm-format-enforcer"
,
"xgrammar"
,
"guidance"
"outlines"
,
"lm-format-enforcer"
,
"xgrammar:disable-any-whitespace"
,
"guidance:disable-any-whitespace"
,
]
]
...
@@ -322,59 +324,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
...
@@ -322,59 +324,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
print
(
generated_text
)
print
(
generated_text
)
assert
generated_text
is
not
None
assert
generated_text
is
not
None
# Parse to verify it is valid JSON
if
'disable-any-whitespace'
in
guided_decoding_backend
:
parsed_json
=
json
.
loads
(
generated_text
)
assert
isinstance
(
parsed_json
,
dict
)
@
pytest
.
mark
.
skip_global_cleanup
def
test_json_with_any_whitespace_disabled
(
llm
):
class
ResponseSchema
(
BaseModel
):
clarifying_question
:
str
cost_per_serving
:
str
calories
:
str
type_dish_ids
:
str
type_meal_ids
:
str
product_ids
:
list
[
str
]
exclude_product_ids
:
list
[
str
]
allergen_ids
:
list
[
str
]
total_cooking_time
:
str
kitchen_ids
:
str
holiday_ids
:
str
# Note: Without this setting, the response is sometimes full of `\n`
# for some models. This option prevents that.
guided_decoding_backend
=
'xgrammar:disable-any-whitespace'
schema
=
ResponseSchema
.
model_json_schema
()
guided_params
=
GuidedDecodingParams
(
json
=
schema
,
backend
=
\
guided_decoding_backend
)
sampling_params
=
SamplingParams
(
max_tokens
=
2000
,
frequency_penalty
=
0
,
presence_penalty
=-
1.1
,
repetition_penalty
=
1.3
,
guided_decoding
=
guided_params
)
prompt
=
(
"<|im_start|>system
\n
You are Qwen, created by Alibaba Cloud. You"
"are a helpful assistant.<|im_end|>
\n
<|im_start|>user
\n
I want a "
"quick launch fast with $10.<|im_end|>
\n
<|im_start|>assistant
\n
"
)
outputs
=
llm
.
generate
(
prompts
=
prompt
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
assert
outputs
is
not
None
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
generated_text
=
output
.
outputs
[
0
].
text
assert
generated_text
is
not
None
assert
"
\n
"
not
in
generated_text
assert
"
\n
"
not
in
generated_text
# Parse to verify it is valid JSON
# Parse to verify it is valid JSON
parsed_json
=
json
.
loads
(
generated_text
)
parsed_json
=
json
.
loads
(
generated_text
)
assert
isinstance
(
parsed_json
,
dict
)
assert
isinstance
(
parsed_json
,
dict
)
jsonschema
.
validate
(
instance
=
parsed_json
,
schema
=
schema
)
tests/v1/entrypoints/llm/test_struct_output_generate.py
View file @
7329ff54
...
@@ -15,7 +15,9 @@ from vllm.entrypoints.llm import LLM
...
@@ -15,7 +15,9 @@ from vllm.entrypoints.llm import LLM
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
GUIDED_DECODING_BACKENDS_V1
=
[
"xgrammar"
,
"guidance"
]
GUIDED_DECODING_BACKENDS_V1
=
[
"xgrammar:disable-any-whitespace"
,
"guidance:disable-any-whitespace"
]
MODELS_TO_TEST
=
[
MODELS_TO_TEST
=
[
"Qwen/Qwen2.5-1.5B-Instruct"
,
"mistralai/Ministral-8B-Instruct-2410"
"Qwen/Qwen2.5-1.5B-Instruct"
,
"mistralai/Ministral-8B-Instruct-2410"
]
]
...
@@ -55,49 +57,7 @@ def test_guided_json_completion(
...
@@ -55,49 +57,7 @@ def test_guided_json_completion(
generated_text
=
output
.
outputs
[
0
].
text
generated_text
=
output
.
outputs
[
0
].
text
assert
generated_text
is
not
None
assert
generated_text
is
not
None
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
if
'disable-any-whitespace'
in
guided_decoding_backend
:
output_json
=
json
.
loads
(
generated_text
)
jsonschema
.
validate
(
instance
=
output_json
,
schema
=
sample_json_schema
)
@
pytest
.
mark
.
skip_global_cleanup
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS_V1
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
MODELS_TO_TEST
)
def
test_guided_json_completion_disable_any_whitespace
(
monkeypatch
:
pytest
.
MonkeyPatch
,
sample_json_schema
:
dict
[
str
,
Any
],
guided_decoding_backend
:
str
,
model_name
:
str
,
):
if
guided_decoding_backend
!=
"xgrammar"
:
pytest
.
skip
(
"disable-any-whitespace is only supported for xgrammar."
)
guided_decoding_backend
=
'xgrammar:disable-any-whitespace'
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
model_name
,
max_model_len
=
1024
,
guided_decoding_backend
=
guided_decoding_backend
)
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
sample_json_schema
))
outputs
=
llm
.
generate
(
prompts
=
[
f
"Give an example JSON for an employee profile "
f
"that fits this schema:
{
sample_json_schema
}
"
]
*
2
,
sampling_params
=
sampling_params
,
use_tqdm
=
True
)
assert
outputs
is
not
None
for
output
in
outputs
:
assert
output
is
not
None
assert
isinstance
(
output
,
RequestOutput
)
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
assert
generated_text
is
not
None
assert
"
\n
"
not
in
generated_text
assert
"
\n
"
not
in
generated_text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
output_json
=
json
.
loads
(
generated_text
)
output_json
=
json
.
loads
(
generated_text
)
...
@@ -142,7 +102,7 @@ def test_guided_json_object(
...
@@ -142,7 +102,7 @@ def test_guided_json_object(
# Parse to verify it is valid JSON
# Parse to verify it is valid JSON
parsed_json
=
json
.
loads
(
generated_text
)
parsed_json
=
json
.
loads
(
generated_text
)
allowed_types
:
tuple
[
type
,
...]
=
(
dict
,
)
allowed_types
:
tuple
[
type
,
...]
=
(
dict
,
)
if
guided_decoding_backend
==
"xgrammar"
:
if
guided_decoding_backend
.
startswith
(
"xgrammar"
)
:
# TODO - we are currently too permissive with xgrammar and
# TODO - we are currently too permissive with xgrammar and
# allow # any valid json (typically comes back as a list or
# allow # any valid json (typically comes back as a list or
# object). We can fix this by specifying a jsonschema of
# object). We can fix this by specifying a jsonschema of
...
@@ -170,7 +130,7 @@ def test_guided_json_unsupported_schema(
...
@@ -170,7 +130,7 @@ def test_guided_json_unsupported_schema(
temperature
=
1.0
,
temperature
=
1.0
,
max_tokens
=
1000
,
max_tokens
=
1000
,
guided_decoding
=
GuidedDecodingParams
(
json
=
unsupported_json_schema
))
guided_decoding
=
GuidedDecodingParams
(
json
=
unsupported_json_schema
))
if
guided_decoding_backend
==
"xgrammar"
:
if
guided_decoding_backend
.
startswith
(
"xgrammar"
)
:
with
pytest
.
raises
(
ValueError
,
with
pytest
.
raises
(
ValueError
,
match
=
"The provided JSON schema contains features "
match
=
"The provided JSON schema contains features "
"not supported by xgrammar."
):
"not supported by xgrammar."
):
...
...
vllm/engine/arg_utils.py
View file @
7329ff54
...
@@ -1561,7 +1561,8 @@ class EngineArgs:
...
@@ -1561,7 +1561,8 @@ class EngineArgs:
# Xgrammar and Guidance are supported.
# Xgrammar and Guidance are supported.
SUPPORTED_GUIDED_DECODING
=
[
SUPPORTED_GUIDED_DECODING
=
[
"xgrammar"
,
"xgrammar:disable-any-whitespace"
,
"guidance"
,
"auto"
"xgrammar"
,
"xgrammar:disable-any-whitespace"
,
"guidance"
,
"guidance:disable-any-whitespace"
,
"auto"
]
]
if
self
.
guided_decoding_backend
not
in
SUPPORTED_GUIDED_DECODING
:
if
self
.
guided_decoding_backend
not
in
SUPPORTED_GUIDED_DECODING
:
_raise_or_fallback
(
feature_name
=
"--guided-decoding-backend"
,
_raise_or_fallback
(
feature_name
=
"--guided-decoding-backend"
,
...
...
vllm/model_executor/guided_decoding/guidance_decoding.py
View file @
7329ff54
...
@@ -18,14 +18,22 @@ def get_local_guidance_guided_decoding_logits_processor(
...
@@ -18,14 +18,22 @@ def get_local_guidance_guided_decoding_logits_processor(
"""
"""
grm
=
""
grm
=
""
any_whitespace
=
'disable-any-whitespace'
not
in
\
guided_params
.
backend_options
()
if
guided_params
.
json
:
if
guided_params
.
json
:
grm
=
llguidance
.
LLMatcher
.
grammar_from_json_schema
(
grm
=
llguidance
.
LLMatcher
.
grammar_from_json_schema
(
guided_params
.
json
,
guided_params
.
json
,
overrides
=
{
"whitespace_pattern"
:
guided_params
.
whitespace_pattern
})
overrides
=
{
"whitespace_pattern"
:
guided_params
.
whitespace_pattern
},
defaults
=
{
"whitespace_flexible"
:
any_whitespace
,
})
elif
guided_params
.
json_object
:
elif
guided_params
.
json_object
:
grm
=
llguidance
.
LLMatcher
.
grammar_from_json_schema
(
grm
=
llguidance
.
LLMatcher
.
grammar_from_json_schema
(
'{"type": "object"}'
,
'{"type": "object"}'
,
overrides
=
{
"whitespace_pattern"
:
guided_params
.
whitespace_pattern
})
overrides
=
{
"whitespace_pattern"
:
guided_params
.
whitespace_pattern
},
defaults
=
{
"whitespace_flexible"
:
any_whitespace
,
})
elif
guided_params
.
regex
:
elif
guided_params
.
regex
:
grm
=
llguidance
.
grammar_from
(
"regex"
,
guided_params
.
regex
)
grm
=
llguidance
.
grammar_from
(
"regex"
,
guided_params
.
regex
)
elif
guided_params
.
choice
:
elif
guided_params
.
choice
:
...
...
vllm/v1/engine/processor.py
View file @
7329ff54
...
@@ -121,7 +121,8 @@ class Processor:
...
@@ -121,7 +121,8 @@ class Processor:
return
return
supported_backends
=
[
supported_backends
=
[
"xgrammar"
,
"xgrammar:disable-any-whitespace"
,
"guidance"
,
"auto"
"xgrammar"
,
"xgrammar:disable-any-whitespace"
,
"guidance"
,
"guidance:disable-any-whitespace"
,
"auto"
]
]
engine_level_backend
=
self
.
decoding_config
.
guided_decoding_backend
engine_level_backend
=
self
.
decoding_config
.
guided_decoding_backend
if
engine_level_backend
not
in
supported_backends
:
if
engine_level_backend
not
in
supported_backends
:
...
@@ -140,11 +141,10 @@ class Processor:
...
@@ -140,11 +141,10 @@ class Processor:
raise
ValueError
(
"Structured output is not supported on TPU."
)
raise
ValueError
(
"Structured output is not supported on TPU."
)
# Request content validation
# Request content validation
if
engine_level_backend
.
startswith
(
"xgrammar"
):
if
engine_level_backend
==
"xgrammar"
:
# xgrammar with no fallback
# xgrammar with no fallback
validate_structured_output_request_xgrammar
(
params
)
validate_structured_output_request_xgrammar
(
params
)
params
.
guided_decoding
.
backend
=
"xgrammar"
params
.
guided_decoding
.
backend
=
engine_level_backend
elif
engine_level_backend
==
"auto"
:
elif
engine_level_backend
==
"auto"
:
# "auto" is an opt-in to opinionated behavior where we try to
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
# choose a backend based on request contents. This is not the
...
@@ -158,12 +158,13 @@ class Processor:
...
@@ -158,12 +158,13 @@ class Processor:
# are not supported in xgrammar. Fall back to guidance.
# are not supported in xgrammar. Fall back to guidance.
params
.
guided_decoding
.
backend
=
"guidance"
params
.
guided_decoding
.
backend
=
"guidance"
if
params
.
guided_decoding
.
backend
==
"guidance"
:
if
engine_level_backend
.
startswith
(
"guidance"
)
:
# TODO ideally we would have the LLTokenizer here as Lark syntax
# TODO ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar
(
params
,
tokenizer
=
None
)
validate_guidance_grammar
(
params
,
tokenizer
=
None
)
params
.
guided_decoding
.
backend
=
engine_level_backend
def
process_inputs
(
def
process_inputs
(
self
,
self
,
...
...
vllm/v1/structured_output/backend_guidance.py
View file @
7329ff54
...
@@ -41,6 +41,9 @@ class GuidanceBackend(StructuredOutputBackend):
...
@@ -41,6 +41,9 @@ class GuidanceBackend(StructuredOutputBackend):
tokenizer_group
.
ping
()
tokenizer_group
.
ping
()
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
vocab_size
=
vllm_config
.
model_config
.
get_vocab_size
()
self
.
vocab_size
=
vllm_config
.
model_config
.
get_vocab_size
()
self
.
disable_any_whitespace
=
(
"disable-any-whitespace"
in
vllm_config
.
decoding_config
.
guided_decoding_backend
)
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
None
)
tokenizer
=
tokenizer_group
.
get_lora_tokenizer
(
None
)
self
.
ll_tokenizer
=
llguidance_hf
.
from_tokenizer
(
tokenizer
,
None
)
self
.
ll_tokenizer
=
llguidance_hf
.
from_tokenizer
(
tokenizer
,
None
)
...
@@ -48,7 +51,7 @@ class GuidanceBackend(StructuredOutputBackend):
...
@@ -48,7 +51,7 @@ class GuidanceBackend(StructuredOutputBackend):
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
def
compile_grammar
(
self
,
request_type
:
StructuredOutputOptions
,
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
grammar_spec
:
str
)
->
StructuredOutputGrammar
:
self
.
serialized_grammar
=
serialize_guidance_grammar
(
self
.
serialized_grammar
=
serialize_guidance_grammar
(
request_type
,
grammar_spec
)
request_type
,
grammar_spec
,
self
.
disable_any_whitespace
)
ll_matcher
=
llguidance
.
LLMatcher
(
ll_matcher
=
llguidance
.
LLMatcher
(
self
.
ll_tokenizer
,
self
.
ll_tokenizer
,
...
@@ -126,17 +129,19 @@ class GuidanceGrammar(StructuredOutputGrammar):
...
@@ -126,17 +129,19 @@ class GuidanceGrammar(StructuredOutputGrammar):
def
serialize_guidance_grammar
(
request_type
:
StructuredOutputOptions
,
def
serialize_guidance_grammar
(
request_type
:
StructuredOutputOptions
,
grammar_spec
:
str
)
->
str
:
grammar_spec
:
str
,
disable_any_whitespace
:
bool
=
False
)
->
str
:
if
request_type
==
StructuredOutputOptions
.
JSON
:
if
request_type
==
StructuredOutputOptions
.
JSON
:
# TODO: make whitespace_flexible configurable
return
llguidance
.
LLMatcher
.
grammar_from_json_schema
(
return
llguidance
.
LLMatcher
.
grammar_from_json_schema
(
grammar_spec
,
defaults
=
{
grammar_spec
,
"whitespace_flexible"
:
True
,
defaults
=
{
"whitespace_flexible"
:
not
disable_any_whitespace
,
})
})
elif
request_type
==
StructuredOutputOptions
.
JSON_OBJECT
:
elif
request_type
==
StructuredOutputOptions
.
JSON_OBJECT
:
return
llguidance
.
LLMatcher
.
grammar_from_json_schema
(
return
llguidance
.
LLMatcher
.
grammar_from_json_schema
(
'{"type": "object"}'
,
defaults
=
{
'{"type": "object"}'
,
"whitespace_flexible"
:
True
,
defaults
=
{
"whitespace_flexible"
:
not
disable_any_whitespace
,
})
})
else
:
else
:
if
request_type
==
StructuredOutputOptions
.
REGEX
:
if
request_type
==
StructuredOutputOptions
.
REGEX
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment