Unverified Commit 24f1c01e authored by leon-seidel's avatar leon-seidel Committed by GitHub
Browse files

[Bugfix][V0] XGrammar structured output supports Enum (#15878)


Signed-off-by: default avatarLeon Seidel <leon.seidel@fau.de>
parent fad6e253
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
import json import json
import re import re
import weakref import weakref
from enum import Enum
import jsonschema import jsonschema
import pytest import pytest
from pydantic import BaseModel
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
...@@ -330,3 +332,44 @@ def test_guided_json_object(llm, guided_decoding_backend: str): ...@@ -330,3 +332,44 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
# Parse to verify it is valid JSON # Parse to verify it is valid JSON
parsed_json = json.loads(generated_text) parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict) assert isinstance(parsed_json, dict)
class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"
class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
json_schema = CarDescription.model_json_schema()
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=json_schema,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts="Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's",
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)
\ No newline at end of file
...@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool: ...@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
if "pattern" in obj: if "pattern" in obj:
return True return True
# Check for enum restrictions
if "enum" in obj:
return True
# Check for numeric ranges # Check for numeric ranges
if obj.get("type") in ("integer", "number") and any( if obj.get("type") in ("integer", "number") and any(
key in obj for key in [ key in obj for key in [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment