Unverified Commit 24f1c01e authored by leon-seidel's avatar leon-seidel Committed by GitHub
Browse files

[Bugfix][V0] XGrammar structured output supports Enum (#15878)


Signed-off-by: default avatarLeon Seidel <leon.seidel@fau.de>
parent fad6e253
......@@ -3,9 +3,11 @@
import json
import re
import weakref
from enum import Enum
import jsonschema
import pytest
from pydantic import BaseModel
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
......@@ -330,3 +332,44 @@ def test_guided_json_object(llm, guided_decoding_backend: str):
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"
class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
json_schema = CarDescription.model_json_schema()
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=json_schema,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts="Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's",
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)
\ No newline at end of file
......@@ -14,10 +14,6 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
if "pattern" in obj:
return True
# Check for enum restrictions
if "enum" in obj:
return True
# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj for key in [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment