Unverified Commit 3b00ff91 authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Bugfix][v1] xgrammar structured output supports Enum. (#15594)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
parent 91276c57
......@@ -4,10 +4,12 @@ from __future__ import annotations
import json
import re
from enum import Enum
from typing import Any
import jsonschema
import pytest
from pydantic import BaseModel
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
......@@ -390,3 +392,54 @@ def test_guided_choice_completion(
assert generated_text is not None
assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"
class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_completion_with_enum(
monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
json_schema = CarDescription.model_json_schema()
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=json_schema))
outputs = llm.generate(
prompts="Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's",
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)
......@@ -26,10 +26,6 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
if "pattern" in obj:
return True
# Check for enum restrictions
if "enum" in obj:
return True
# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment