# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json import weakref from enum import Enum import jsonschema import pytest import os import regex as re from pydantic import BaseModel from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...utils import models_path_prefix MODEL_NAME = os.path.join(models_path_prefix, "Qwen2.5-1.5B-Instruct") GUIDED_DECODING_BACKENDS = [ # (backend, disable_any_whitespace), ("outlines", False), ("lm-format-enforcer", False), ("xgrammar", True), ("guidance", True), ] @pytest.fixture(scope="module") def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection llm = LLM(model=MODEL_NAME, max_model_len=1024, seed=0) with llm.deprecate_legacy_api(): yield weakref.proxy(llm) del llm cleanup_dist_env_and_memory() @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", GUIDED_DECODING_BACKENDS) def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, disable_any_whitespace: bool): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams( regex=sample_regex, backend=guided_decoding_backend, disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ f"Give an example IPv4 address with this regex: {sample_regex}" ] * 2, sampling_params=sampling_params, use_tqdm=True) assert outputs is not None for output in outputs: assert output is not None assert isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text print(generated_text) assert generated_text is not None assert re.fullmatch(sample_regex, generated_text) is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", GUIDED_DECODING_BACKENDS) def test_guided_json_completion(sample_json_schema, llm, guided_decoding_backend: str, disable_any_whitespace: bool): sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams( json=sample_json_schema, backend=guided_decoding_backend, disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ f"Give an example JSON for an employee profile " f"that fits this schema: {sample_json_schema}" ] * 2, sampling_params=sampling_params, use_tqdm=True) assert outputs is not None for output in outputs: assert output is not None assert isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_json_schema) @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", GUIDED_DECODING_BACKENDS) def test_guided_complex_json_completion(sample_complex_json_schema, llm, guided_decoding_backend: str, disable_any_whitespace: bool): sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams( json=sample_complex_json_schema, backend=guided_decoding_backend, disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ f"Give an example JSON for an assignment grade " f"that fits this schema: {sample_complex_json_schema}" ] * 2, sampling_params=sampling_params, use_tqdm=True) assert outputs is not None for output in outputs: assert output is not None assert isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_complex_json_schema) @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", GUIDED_DECODING_BACKENDS) def test_guided_definition_json_completion(sample_definition_json_schema, llm, guided_decoding_backend: str, disable_any_whitespace: bool): sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams( json=sample_definition_json_schema, backend=guided_decoding_backend, disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ f"Give an example JSON for solving 8x + 7 = -23 " f"that fits this schema: {sample_definition_json_schema}" ] * 2, sampling_params=sampling_params, use_tqdm=True) assert outputs is not None for output in outputs: assert output is not None assert isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_definition_json_schema) @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", GUIDED_DECODING_BACKENDS) def test_guided_enum_json_completion(sample_enum_json_schema, llm, guided_decoding_backend: str, disable_any_whitespace: bool): sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams( json=sample_enum_json_schema, backend=guided_decoding_backend, disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ "Create a bug report JSON that fits this schema: " f"{sample_enum_json_schema}. Make it for a high priority critical bug." ] * 2, sampling_params=sampling_params, use_tqdm=True) assert outputs is not None for output in outputs: assert output is not None assert isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_enum_json_schema) # Additional assertions to verify enum values assert output_json["status"] in ["active", "inactive", "pending"] assert output_json["priority"] in ["low", "medium", "high", "critical"] assert output_json["category"]["type"] in [ "bug", "feature", "improvement" ] assert output_json["category"]["severity"] in [1, 2, 3, 4, 5] for flag in output_json["flags"]: assert flag in ["urgent", "blocked", "needs_review", "approved"] @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", GUIDED_DECODING_BACKENDS) def test_guided_choice_completion(sample_guided_choice, llm, guided_decoding_backend: str, disable_any_whitespace: bool): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams( choice=sample_guided_choice, backend=guided_decoding_backend, disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate( prompts="The best language for type-safe systems programming is ", sampling_params=sampling_params, use_tqdm=True) assert outputs is not None for output in outputs: assert output is not None assert isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text print(generated_text) assert generated_text is not None assert generated_text in sample_guided_choice print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", GUIDED_DECODING_BACKENDS) def test_guided_grammar(sample_sql_statements, llm, guided_decoding_backend: str, disable_any_whitespace: bool): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, max_tokens=1000, guided_decoding=GuidedDecodingParams( grammar=sample_sql_statements, backend=guided_decoding_backend, disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate( prompts=("Generate a sql state that select col_1 from " "table_1 where it is equals to 1"), sampling_params=sampling_params, use_tqdm=True, ) assert outputs is not None for output in outputs: assert output is not None assert isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text assert generated_text is not None # use Lark to parse the output, and make sure it's a valid parse tree from lark import Lark parser = Lark(sample_sql_statements) parser.parse(generated_text) # remove spaces for comparison b/c we removed them in the grammar ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( " ", "") assert generated_text.strip() == ground_truth print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") @pytest.mark.skip_global_cleanup def test_guided_options_request_deprecation_warning(sample_regex, llm): sampling_params = SamplingParams(temperature=0.8, top_p=0.95) with pytest.warns(DeprecationWarning, match="guided_options_request"): llm.generate(prompts="This should fail", sampling_params=sampling_params, use_tqdm=True, guided_options_request=dict(guided_regex=sample_regex)) @pytest.mark.skip_global_cleanup def test_validation_against_both_guided_decoding_options(sample_regex, llm): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams(regex=sample_regex)) with pytest.raises(ValueError, match="Cannot set both"): llm.generate(prompts="This should fail", sampling_params=sampling_params, use_tqdm=True, guided_options_request=dict(guided_regex=sample_regex)) @pytest.mark.skip_global_cleanup def test_disable_guided_decoding_fallback(sample_regex, llm): # see has_xgrammar_unsupported_json_features() unsupported_json = { "type": "object", "properties": { "example": { "type": "string", "minLength": 5 # unsupported by xgrammar } } } sampling_params = SamplingParams(temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams( json=unsupported_json, backend="xgrammar", disable_fallback=True)) with pytest.raises( ValueError, match="xgrammar does not support advanced JSON schema features " "like string length, item limits, or property bounds."): llm.generate(prompts="This should fail", sampling_params=sampling_params, use_tqdm=True) @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", GUIDED_DECODING_BACKENDS) def test_guided_json_object(llm, guided_decoding_backend: str, disable_any_whitespace: bool): sampling_params = SamplingParams( temperature=1.0, max_tokens=100, n=2, guided_decoding=GuidedDecodingParams( json_object=True, backend=guided_decoding_backend, disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate( prompts=("Generate a JSON object with curly braces for a person with " "name and age fields for John Smith who is 31 years old."), sampling_params=sampling_params, use_tqdm=True) assert outputs is not None for output in outputs: assert output is not None assert isinstance(output, RequestOutput) for i in range(2): generated_text = output.outputs[i].text print(generated_text) assert generated_text is not None if disable_any_whitespace: assert "\n" not in generated_text # Parse to verify it is valid JSON parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) class CarType(str, Enum): sedan = "sedan" suv = "SUV" truck = "Truck" coupe = "Coupe" class CarDescription(BaseModel): brand: str model: str car_type: CarType @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", GUIDED_DECODING_BACKENDS) def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str, disable_any_whitespace: bool): json_schema = CarDescription.model_json_schema() sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams( json=json_schema, backend=guided_decoding_backend, disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate( prompts="Generate a JSON with the brand, model and car_type of" "the most iconic car from the 90's", sampling_params=sampling_params, use_tqdm=True) assert outputs is not None for output in outputs: assert output is not None assert isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=json_schema) @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", GUIDED_DECODING_BACKENDS) def test_guided_number_range_json_completion(llm, guided_decoding_backend: str, disable_any_whitespace: bool): sample_output_schema = { "type": "object", "properties": { "age": { "type": "integer", "minimum": 18, "maximum": 99 }, "score": { "type": "number", "minimum": 0.0, "maximum": 100.0 }, "zipcode": { "type": "string", "pattern": r"^\d{5}(-\d{4})?$" }, }, "required": ["age", "score", "zipcode"], } sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams( json=sample_output_schema, backend=guided_decoding_backend, disable_any_whitespace=disable_any_whitespace), ) outputs = llm.generate( prompts=[ "Create a JSON object for a user with age, score, and zipcode." ] * 2, sampling_params=sampling_params, use_tqdm=True, ) assert outputs is not None for output in outputs: assert output is not None assert isinstance(output, RequestOutput) prompt = output.prompt generated_text = output.outputs[0].text assert generated_text is not None print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_output_schema) assert 18 <= output_json["age"] <= 99 assert 0.0 <= output_json["score"] <= 100.0 assert (re.fullmatch(r"^\d{5}(-\d{4})?$", output_json["zipcode"]) is not None) @pytest.mark.skip_global_cleanup def test_guidance_no_additional_properties(llm): schema = { 'type': 'object', 'properties': { 'a1': { 'type': 'string' }, 'a2': { 'type': 'string' }, 'a3': { 'type': 'string' } }, 'required': ['a1', 'a2', 'a3'], } prompt = ( "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a " "helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a " "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20" "<|im_end|>\n<|im_start|>assistant\n") def generate_with_backend(backend, disable_additional_properties): guided_params = GuidedDecodingParams( json=schema, backend=backend, disable_any_whitespace=True, disable_additional_properties=disable_additional_properties) sampling_params = SamplingParams(temperature=0, max_tokens=256, guided_decoding=guided_params) outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) assert outputs is not None generated_text = outputs[0].outputs[0].text assert generated_text is not None parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) jsonschema.validate(instance=parsed_json, schema=schema) return parsed_json base_generated = generate_with_backend("guidance", False) assert "a1" in base_generated assert "a2" in base_generated assert "a3" in base_generated # by default additional keys are generated assert "a4" in base_generated assert "a5" in base_generated assert "a6" in base_generated generated = generate_with_backend("guidance", True) assert "a1" in generated assert "a2" in generated assert "a3" in generated assert "a4" not in generated assert "a5" not in generated assert "a6" not in generated