test_struct_output_generate.py 33 KB
Newer Older
1
# ruff: noqa: E501
2
# SPDX-License-Identifier: Apache-2.0
3
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
4
5

import json
6
from enum import Enum
7
from typing import Any
8
9
10

import jsonschema
import pytest
11
import regex as re
12
import torch
13
from pydantic import BaseModel
14

15
from tests.reasoning.utils import run_reasoning_extraction
16
from vllm.config import StructuredOutputsConfig
17
from vllm.distributed import cleanup_dist_env_and_memory
18
19
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
20
from vllm.platforms import current_platform
21
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
22
23
24
25
from vllm.sampling_params import (
    SamplingParams,
    StructuredOutputsParams,
)
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
SAMPLE_REGEX = (
    r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
    r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
)

# Note: Ensure this only uses attributes compatible with xgrammar
SAMPLE_JSON_SCHEMA = {
    "type": "object",
    "properties": {
        "name": {"type": "string"},
        "age": {"type": "integer"},
        "skills": {
            "type": "array",
            "items": {
                "type": "string",
            },
        },
        "grade": {
            "type": "string",
            "pattern": "^[A-D]$",  # Regex pattern
        },
        "email": {
            "type": "string",
            "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$",
        },
        "work_history": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "company": {"type": "string"},
                    "duration": {
                        "type": "number",
                        "minimum": 0.0,
                        "maximum": 100.0,  # Numeric range
                    },
                    "position": {"type": "string"},
                },
                "required": ["company", "duration", "position"],
                "additionalProperties": False,
            },
            "minItems": 0,
            "maxItems": 3,
        },
    },
    "required": ["name", "age", "skills", "grade", "email", "work_history"],
    "additionalProperties": False,
    "minProperties": 1,
    "maxProperties": 10,
}

# A schema unsupported by xgrammar
UNSUPPORTED_JSON_SCHEMA = {
    "type": "object",
    "properties": {
        "score": {
            "type": "integer",
            "multipleOf": 5,  # Numeric multiple
        },
        "tags": {
            "type": "array",
            "items": {"type": "string", "minLength": 10, "maxLength": 20},
        },
    },
    "required": ["score", "tags"],
    "additionalProperties": False,
    "patternProperties": {
        "^score$": {"type": "integer"},
    },
}

SAMPLE_STRUCTURED_OUTPUTS_CHOICES = [
    "Python",
    "Java",
    "JavaScript",
    "C++",
    "C#",
    "PHP",
    "TypeScript",
    "Ruby",
    "Swift",
    "Kotlin",
]

SAMPLE_SQL_EBNF = """
root ::= select_statement
select_statement ::= "SELECT" column "from" table "where" condition
column ::= "col_1" | "col_2"
table ::= "table_1" | "table_2"
condition ::= column "=" number
number ::= "1" | "2"
"""

SAMPLE_SQL_LARK = """
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
"""

129
130
131
132
133
134
135
136
137
138
139
140
141
NGRAM_SPEC_CONFIG = {
    "model": "[ngram]",
    "num_speculative_tokens": 5,
    "prompt_lookup_max": 5,
    "prompt_lookup_min": 1,
}

EAGLE_SPEC_CONFIG = {
    "method": "eagle",
    "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
    "num_speculative_tokens": 5,
}

142
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
143
    ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
144
145
146
147
148
    # FIXME: Since "auto" will use Mistral tokenizer and these backends do not support
    # it, we skip these tests for now.
    # ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
    # ("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto", None),
    ("mistralai/Ministral-8B-Instruct-2410", "guidance", "hf", None),
149
150
151
    pytest.param(
        "mistralai/Ministral-8B-Instruct-2410",
        "lm-format-enforcer",
152
        "hf",
153
154
155
156
157
158
159
160
161
        None,
        marks=pytest.mark.skip(
            reason=(
                "Flaky: lm-format-enforcer intermittently returns"
                "incomplete JSON."
                "See https://github.com/noamgat/lm-format-enforcer/issues/169"
            )
        ),
    ),
162
163
    ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
    ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
164
165
166
167
168
169
170
171
172
173
174
175
176
    pytest.param(
        "Qwen/Qwen2.5-1.5B-Instruct",
        "lm-format-enforcer",
        "auto",
        None,
        marks=pytest.mark.skip(
            reason=(
                "Flaky: lm-format-enforcer intermittently returns"
                "incomplete JSON."
                "See https://github.com/noamgat/lm-format-enforcer/issues/169"
            )
        ),
    ),
177
    # FIXME: This tests are flaky on CI thus disabled. Tracking in Issue #24402
178
179
    # ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
    # ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
180
181
    # ("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
    ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", NGRAM_SPEC_CONFIG),
182
    ("mistralai/Ministral-8B-Instruct-2410", "guidance", "hf", NGRAM_SPEC_CONFIG),
183
    ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG),
184
    ("meta-llama/Meta-Llama-3.1-8B-Instruct", "xgrammar", "auto", EAGLE_SPEC_CONFIG),
185
]
186
187
188
189

PARAMS_MODELS_TOKENIZER_MODE = [
    ("mistralai/Ministral-8B-Instruct-2410", "auto"),
    ("Qwen/Qwen2.5-1.5B-Instruct", "auto"),
190
]
191

192
193
194
195
platform_args = {}
if current_platform.is_rocm():
    platform_args["async_scheduling"] = False

196

197
198
199
200
201
202
203
204
205
206
207
208
209
class CarType(str, Enum):
    sedan = "sedan"
    suv = "SUV"
    truck = "Truck"
    coupe = "Coupe"


class CarDescription(BaseModel):
    brand: str
    model: str
    car_type: CarType


210
@pytest.mark.parametrize(
211
    "model_name, backend, tokenizer_mode, speculative_config",
212
213
    PARAMS_MODELS_BACKENDS_TOKENIZER_MODE,
)
214
def test_structured_output(
215
    backend: str,
216
    tokenizer_mode: str,
217
    model_name: str,
218
    speculative_config: dict[str, Any],
219
):
220
221
222
223
224
225
    sample_json_schema = SAMPLE_JSON_SCHEMA
    unsupported_json_schema = UNSUPPORTED_JSON_SCHEMA
    sample_sql_ebnf = SAMPLE_SQL_EBNF
    sample_sql_lark = SAMPLE_SQL_LARK
    sample_regex = SAMPLE_REGEX
    sample_structured_outputs_choices = SAMPLE_STRUCTURED_OUTPUTS_CHOICES
226
227
228
    if current_platform.is_tpu() and speculative_config:
        pytest.skip("TPU does not support speculative decoding")

229
230
    # Use a single LLM instance for several scenarios to
    # speed up the test suite.
231
232
233
234
235
236
237
238
239
    llm = LLM(
        model=model_name,
        enforce_eager=True,
        max_model_len=1024,
        structured_outputs_config=dict(
            backend=backend, disable_any_whitespace=backend in {"xgrammar", "guidance"}
        ),
        seed=120,
        tokenizer_mode=tokenizer_mode,
240
241
        load_format="auto" if not model_name.startswith("mistralai/") else "hf",
        config_format="auto" if not model_name.startswith("mistralai/") else "hf",
242
        speculative_config=speculative_config,
243
        **platform_args,
244
    )
245
246
247
248

    #
    # Test 1: Generate JSON output based on a provided schema
    #
249
250
    sampling_params = SamplingParams(
        temperature=1.0,
251
        max_tokens=4096,
252
253
        structured_outputs=StructuredOutputsParams(json=sample_json_schema),
    )
254

255
256
257
258
259
    prompt = (
        "Give an example JSON for an employee profile that fits this "
        "schema. Make the response as short as possible. Schema: "
        f"{sample_json_schema}"
    )
260
261
262
263
264
    outputs = llm.generate(
        [prompt] * 2,
        sampling_params=sampling_params,
        use_tqdm=True,
    )
265
266
267
268
269
270
271
272
273
274

    assert outputs is not None

    for output in outputs:
        assert output is not None
        assert isinstance(output, RequestOutput)
        prompt = output.prompt

        generated_text = output.outputs[0].text
        assert generated_text is not None
275
        if backend != "lm-format-enforcer":
276
            assert "\n" not in generated_text
277
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
278
279
280
281
282
        try:
            output_json = json.loads(generated_text)
        except json.JSONDecodeError as e:
            pytest.fail(
                f"Invalid JSON from backend={backend}: {generated_text!r}\n"
283
284
                f"Schema: {sample_json_schema}\nError: {e}"
            )
285
286
        jsonschema.validate(instance=output_json, schema=sample_json_schema)

287
288
289
    #
    # Test 2: Generate JSON object without a schema
    #
290
    if backend != "outlines":
291
292
293
294
        sampling_params = SamplingParams(
            temperature=1.0,
            max_tokens=4096,
            n=2,
295
296
            structured_outputs=StructuredOutputsParams(json_object=True),
        )
297

298
299
300
301
302
303
304
305
306
        outputs = llm.generate(
            prompts=(
                "Generate a JSON object with curly braces for a person with "
                "name and age fields for John Smith who is 31 years old. "
                "Make the response as short as possible."
            ),
            sampling_params=sampling_params,
            use_tqdm=True,
        )
307

308
309
310
311
        assert outputs is not None
        for output in outputs:
            assert output is not None
            assert isinstance(output, RequestOutput)
312

313
314
315
316
            for i in range(2):
                generated_text = output.outputs[i].text
                print(generated_text)
                assert generated_text is not None
317

318
319
320
                # Parse to verify it is a valid JSON object
                parsed_json = json.loads(generated_text)
                assert isinstance(parsed_json, dict)
321

322
323
324
    #
    # Test 3: test a jsonschema incompatible with xgrammar
    #
325
326
    sampling_params = SamplingParams(
        temperature=1.0,
327
        max_tokens=4096,
328
329
        structured_outputs=StructuredOutputsParams(json=unsupported_json_schema),
    )
330
    if backend.startswith("xgrammar"):
331
332
333
334
335
336
337
338
339
340
        with pytest.raises(
            ValueError,
            match="The provided JSON schema contains features "
            "not supported by xgrammar.",
        ):
            prompt = (
                f"Give an example JSON for an employee profile that "
                f"fits this schema: {unsupported_json_schema}. "
                f"Make the response as short as possible."
            )
341
            llm.generate(
342
                [prompt] * 2,
343
                sampling_params=sampling_params,
344
345
                use_tqdm=True,
            )
346
    else:
347
348
349
350
351
        prompt = (
            f"Give an example JSON object for a grade that "
            f"fits this schema: {unsupported_json_schema}. "
            f"Make the response as short as possible."
        )
352
353
354
355
356
        outputs = llm.generate(
            prompt,
            sampling_params=sampling_params,
            use_tqdm=True,
        )
357
358
359
360
361
362
363
364
365
366
367
        assert outputs is not None
        for output in outputs:
            assert output is not None
            assert isinstance(output, RequestOutput)
            generated_text = output.outputs[0].text
            assert generated_text is not None
            print(generated_text)

            # Parse to verify it is valid JSON
            parsed_json = json.loads(generated_text)
            assert isinstance(parsed_json, dict)
368

369
    if backend not in ["outlines", "lm-format-enforcer"]:
370
371
372
373
374
375
376
        #
        # Test 4: Generate SQL statement using EBNF grammar
        #
        sampling_params = SamplingParams(
            temperature=0.8,
            top_p=0.95,
            max_tokens=1000,
377
378
            structured_outputs=StructuredOutputsParams(grammar=sample_sql_ebnf),
        )
379
        outputs = llm.generate(
380
381
382
383
384
            (
                "Generate a sql statement that selects col_1 from "
                "table_1 where it is equal to 1. Make the response as short as "
                "possible."
            ),
385
386
387
            sampling_params=sampling_params,
            use_tqdm=True,
        )
388

389
390
391
392
393
        assert outputs is not None
        for output in outputs:
            assert output is not None
            assert isinstance(output, RequestOutput)
            prompt = output.prompt
394

395
396
            generated_text = output.outputs[0].text
            assert generated_text is not None
397

398
            # remove spaces for comparison b/c we removed them in the grammar
399
            ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
400

401
            assert generated_text.strip() == ground_truth
402

403
            print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
404

405
406
407
408
409
410
411
        #
        # Test 5: Generate SQL statement using Lark grammar
        #
        sampling_params = SamplingParams(
            temperature=0.8,
            top_p=0.95,
            max_tokens=1000,
412
413
            structured_outputs=StructuredOutputsParams(grammar=sample_sql_lark),
        )
414
        outputs = llm.generate(
415
416
417
418
419
            (
                "Generate a sql statement that selects col_1 from "
                "table_1 where it is equal to 1. Make the response as short as "
                "possible."
            ),
420
421
422
            sampling_params=sampling_params,
            use_tqdm=True,
        )
423

424
425
426
427
428
        assert outputs is not None
        for output in outputs:
            assert output is not None
            assert isinstance(output, RequestOutput)
            prompt = output.prompt
429

430
431
            generated_text = output.outputs[0].text
            assert generated_text is not None
432

433
434
            # use Lark to parse the output, and make sure it's a valid parse tree
            from lark import Lark
435

436
437
            parser = Lark(sample_sql_lark)
            parser.parse(generated_text)
438

439
            # remove spaces for comparison b/c we removed them in the grammar
440
            ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
441

442
            assert generated_text.strip() == ground_truth
443

444
            print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
445

446
447
448
449
450
451
452
        #
        # Test 6: Test invalid grammar input
        #
        sampling_params = SamplingParams(
            temperature=0.8,
            top_p=0.95,
            max_tokens=1000,
453
454
            structured_outputs=StructuredOutputsParams(grammar="not a grammar"),
        )
455
456
        with pytest.raises(ValueError, match="Failed to convert the grammar "):
            llm.generate(
457
458
459
460
461
                (
                    "Generate a sql statement that selects col_1 from "
                    "table_1 where it is equal to 1. Make the response as short "
                    "as possible."
                ),
462
463
464
                sampling_params=sampling_params,
                use_tqdm=True,
            )
465

466
467
468
    #
    # Test 7: Generate text based on a regex pattern
    #
469
470
471
    sampling_params = SamplingParams(
        temperature=0.8,
        top_p=0.95,
472
473
        structured_outputs=StructuredOutputsParams(regex=sample_regex),
    )
474

475
476
477
478
    prompt = (
        f"Give an example IPv4 address with this regex: {sample_regex}. "
        f"Make the response as short as possible."
    )
479
    outputs = llm.generate(
480
        [prompt] * 2,
481
482
483
        sampling_params=sampling_params,
        use_tqdm=True,
    )
484

485
486
487
488
489
490
491
492
493
494
    assert outputs is not None
    for output in outputs:
        assert output is not None
        assert isinstance(output, RequestOutput)
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(generated_text)
        assert generated_text is not None
        assert re.fullmatch(sample_regex, generated_text) is not None
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
495

496
497
498
    #
    # Test 8: Generate text based on a choices
    #
499
500
501
    sampling_params = SamplingParams(
        temperature=0.8,
        top_p=0.95,
502
        structured_outputs=StructuredOutputsParams(
503
504
505
            choice=sample_structured_outputs_choices
        ),
    )
506

507
    outputs = llm.generate(
508
509
510
511
        (
            "The best language for type-safe systems programming is "
            "(Make the response as short as possible.) "
        ),
512
        sampling_params=sampling_params,
513
514
        use_tqdm=True,
    )
515
516
517
518
519
520
521
522
    assert outputs is not None
    for output in outputs:
        assert output is not None
        assert isinstance(output, RequestOutput)
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(generated_text)
        assert generated_text is not None
523
        assert generated_text in sample_structured_outputs_choices
524
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
525

526
527
528
    #
    # Test 9: Generate structured output using a Pydantic model with an enum
    #
529
530
531
532
    json_schema = CarDescription.model_json_schema()
    sampling_params = SamplingParams(
        temperature=1.0,
        max_tokens=1000,
533
534
        structured_outputs=StructuredOutputsParams(json=json_schema),
    )
535
536

    outputs = llm.generate(
537
538
539
540
541
        (
            "Generate a JSON with the brand, model and car_type of the most "
            "iconic car from the 90's. Make the response as short as "
            "possible."
        ),
542
543
544
        sampling_params=sampling_params,
        use_tqdm=True,
    )
545
546
547
548
549
550
551
552
553
554
555

    assert outputs is not None

    for output in outputs:
        assert output is not None
        assert isinstance(output, RequestOutput)
        prompt = output.prompt

        generated_text = output.outputs[0].text
        assert generated_text is not None
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
556
557
558
559
560
        try:
            output_json = json.loads(generated_text)
        except json.JSONDecodeError as e:
            pytest.fail(
                f"Invalid JSON from backend={backend}: {generated_text!r}\n"
561
562
                f"Schema: {json_schema}\nError: {e}"
            )
563
        jsonschema.validate(instance=output_json, schema=json_schema)
564

565
566
567
568
569
570
571
572
573
574
575
    #
    # Test 10: Generate structured with minLength and maxLength
    #
    min_length = 50
    max_length = 50
    json_schema = {
        "type": "object",
        "properties": {
            "description": {
                "type": "string",
                "maxLength": max_length,
576
                "minLength": min_length,
577
578
            }
        },
579
        "required": ["description"],
580
        "additionalProperties": False,
581
582
583
584
    }

    sampling_params = SamplingParams(
        temperature=1.0,
585
        max_tokens=4096,
586
587
        structured_outputs=StructuredOutputsParams(json=json_schema),
    )
588

589
    outputs = llm.generate(
590
591
592
593
        (
            "Generate a description of a frog using 50 characters. "
            "Make the response as short as possible."
        ),
594
        sampling_params=sampling_params,
595
596
        use_tqdm=True,
    )
597
598
599
600
601
602
603
604
605
606
607

    assert outputs is not None

    for output in outputs:
        assert output is not None
        assert isinstance(output, RequestOutput)
        prompt = output.prompt

        generated_text = output.outputs[0].text
        assert generated_text is not None
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
608
609
610
611
612
        try:
            output_json = json.loads(generated_text)
        except json.JSONDecodeError as e:
            pytest.fail(
                f"Invalid JSON from backend={backend}: {generated_text!r}\n"
613
614
                f"Schema: {json_schema}\nError: {e}"
            )
615
616
        jsonschema.validate(instance=output_json, schema=json_schema)

617
    if backend not in ["outlines", "lm-format-enforcer"]:
618
619
620
621
        #
        # Test 11: Generate structured output using structural_tag format
        #
        structural_tag_config = {
622
623
624
625
626
627
628
629
            "type": "structural_tag",
            "structures": [
                {
                    "begin": "<function=get_weather>",
                    "schema": {
                        "type": "object",
                        "properties": {"city": {"type": "string"}},
                        "additionalProperties": False,
630
                    },
631
632
633
634
                    "end": "</function>",
                }
            ],
            "triggers": ["<function="],
635
        }
636

637
638
639
        sampling_params = SamplingParams(
            temperature=0.0,
            max_tokens=4096,
640
            structured_outputs=StructuredOutputsParams(
641
642
643
                structural_tag=json.dumps(structural_tag_config)
            ),
        )
644

645
        prompt = """
646
You have access to the following function to retrieve the weather in a city:
647

648
649
650
651
652
653
654
655
656
657
    {
        "name": "get_weather",
        "parameters": {
            "city": {
                "param_type": "string",
                "description": "The city to get the weather for",
                "required": True
            }
        }
    }
658

659
660
661
662
663
664
If a you choose to call a function ONLY reply in the following format:
<{start_tag}={function_name}>{parameters}{end_tag}
where

start_tag => `<function`
parameters => a JSON dict with the function argument name
665
            as key and function argument value as value.
666
667
668
669
670
671
672
673
674
675
676
677
678
end_tag => `</function>`

Here is an example,
<function=example_function_name>{"example_name": "example_value"}</function>

Reminder:
- Function calls MUST follow the specified format
- Required parameters MUST be specified
- Only call one function at a time
- Put the entire function call reply on one line
- Always add your sources when using search results to answer the user query

You are a helpful assistant.
679

680
681
Given the previous instructions, what is the weather in New York City? \
Make the response as short as possible.
682
683
"""

684
        # Change this once other backends support structural_tag
685
        outputs = llm.generate(prompt, sampling_params=sampling_params, use_tqdm=True)
686
        assert outputs is not None
687

688
689
690
691
692
        for output in outputs:
            assert output is not None
            assert isinstance(output, RequestOutput)
            generated_text = output.outputs[0].text
            assert generated_text is not None
693

694
            # Search for function call pattern in the response
695
            function_call_pattern = r"<function=get_weather>(.*?)</function>"
696
697
698
            matches = re.findall(function_call_pattern, generated_text)

            if not matches:
699
700
701
                print(
                    f"Warning: No function calls found in response: {generated_text!r}"
                )
702
703
704
705
706
707
708
709
710
711
                continue

            # Take the first function call if multiple are found
            json_str = matches[0]
            try:
                json_content = json.loads(json_str)
                assert "city" in json_content
                assert isinstance(json_content["city"], str)
                print(f"Found valid function call: {generated_text!r}")
            except (json.JSONDecodeError, AssertionError) as e:
712
713
714
                pytest.fail(
                    f"Invalid function call format: {generated_text!r}\nError: {str(e)}"
                )
715

716

717
@pytest.mark.parametrize(
718
    "model_name, backend, tokenizer_mode, reasoning_parser, speculative_config, async_scheduling",  # noqa: E501
719
    [
720
721
722
723
724
725
        (
            "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
            "xgrammar",
            "auto",
            "deepseek_r1",
            NGRAM_SPEC_CONFIG,
726
            False,
727
        ),
728
729
        ("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None, False),
        ("Qwen/Qwen3-1.7B", "xgrammar", "auto", "deepseek_r1", None, True),
730
731
732
    ],
)
def test_structured_output_with_reasoning_matrices(
733
    backend: str,
734
    tokenizer_mode: str,
735
736
737
    reasoning_parser: str,
    model_name: str,
    speculative_config: dict[str, Any] | None,
738
    async_scheduling: bool,
739
740
741
742
743
744
745
746
747
748
749
750
751
):
    if current_platform.is_tpu() and speculative_config:
        pytest.skip("TPU does not support speculative decoding")

    # Use a single LLM instance for several scenarios to
    # speed up the test suite.
    llm = LLM(
        model=model_name,
        # Don't use eager execution on TPUs because we want to test for no
        # recompilation at runtime
        enforce_eager=bool(not current_platform.is_tpu()),
        max_model_len=1024,
        max_num_seqs=16,
752
753
754
755
756
        structured_outputs_config=dict(
            backend=backend,
            disable_any_whitespace=backend in {"xgrammar", "guidance"},
            reasoning_parser=reasoning_parser,
        ),
757
758
        tokenizer_mode=tokenizer_mode,
        speculative_config=speculative_config,
759
        async_scheduling=async_scheduling,
760
    )
761
    tokenizer = llm.get_tokenizer()
762
    reasoner = ReasoningParserManager.get_reasoning_parser(reasoning_parser)(
763
764
        tokenizer=tokenizer
    )
765
766
767
768

    reasoning_prompt = "Solve the following math problem step-by-step, then provide the final answer as JSON object with a single key 'result'. Make sure to correct your reasoning if there are any issue should it arise.\nProblem: What is 5 * 8 + 2?"  # noqa: E501
    reasoning_schema = {
        "type": "object",
769
        "properties": {"result": {"type": "integer"}},
770
        "required": ["result"],
771
        "additionalProperties": False,
772
773
774
775
776
777
778
    }
    if "Qwen3" in model_name:
        reasoning_prompt += "<think>\n"

    sampling_params = SamplingParams(
        temperature=0.1,
        max_tokens=8192,
779
        structured_outputs=StructuredOutputsParams(json=reasoning_schema),
780
781
782
783
784
785
786
787
788
789
790
791
    )
    outputs = llm.generate(
        [reasoning_prompt],
        sampling_params=sampling_params,
        use_tqdm=True,
    )

    assert outputs is not None
    output = outputs[0]
    assert output is not None and isinstance(output, RequestOutput)
    prompt = output.prompt
    generated_text = output.outputs[0].text
792
793
    reasoning, content = run_reasoning_extraction(reasoner, [generated_text])
    print(f"Prompt: {prompt!r}\nReasoning: {reasoning!r}\nContent: {content!r}")
794

795
796
797
798
799
800
801
802
    if "Qwen3" in model_name:
        assert content is not None

    assert reasoning is not None

    if content is not None:
        output_json = json.loads(content)
        jsonschema.validate(instance=output_json, schema=reasoning_schema)
803
804


805
@pytest.mark.parametrize("model_name, tokenizer_mode", PARAMS_MODELS_TOKENIZER_MODE)
806
807
def test_structured_output_auto_mode(
    model_name: str,
808
    tokenizer_mode: str,
809
):
810
    unsupported_json_schema = UNSUPPORTED_JSON_SCHEMA
811
812
813
814
815
    llm = LLM(
        model=model_name,
        max_model_len=1024,
        structured_outputs_config=dict(backend="auto"),
        tokenizer_mode=tokenizer_mode,
816
817
        load_format="auto",
        config_format="auto",
818
    )
819
820
821
822

    sampling_params = SamplingParams(
        temperature=1.0,
        max_tokens=1000,
823
824
        structured_outputs=StructuredOutputsParams(json=unsupported_json_schema),
    )
825

826
827
828
    prompts = (
        "Give an example JSON object for a grade "
        "that fits this schema: "
829
830
        f"{unsupported_json_schema}. Make the response as short as possible."
    )
831
832
    # This would fail with the default of "xgrammar", but in "auto"
    # we will handle fallback automatically.
833
    outputs = llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True)
834
835
836
    # Make sure `auto` backend handling doesn't mess up sampling_params
    # and that we can reuse it without error.
    outputs.extend(
837
838
        llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True)
    )
839

840
841
842
843
844
845
846
847
848
849
850
    assert outputs is not None
    for output in outputs:
        assert output is not None
        assert isinstance(output, RequestOutput)
        generated_text = output.outputs[0].text
        assert generated_text is not None
        print(generated_text)

        # Parse to verify it is valid JSON
        parsed_json = json.loads(generated_text)
        assert isinstance(parsed_json, dict)
851
852


853
def test_guidance_no_additional_properties():
854
855
856
857
858
859
860
861
862
    llm = LLM(
        model="Qwen/Qwen2.5-1.5B-Instruct",
        max_model_len=1024,
        structured_outputs_config=dict(
            backend="guidance",
            disable_any_whitespace=True,
            disable_additional_properties=True,
        ),
    )
863
864

    schema = {
865
866
867
868
869
        "type": "object",
        "properties": {
            "a1": {"type": "string"},
            "a2": {"type": "string"},
            "a3": {"type": "string"},
870
        },
871
        "required": ["a1", "a2", "a3"],
872
873
874
875
876
    }

    prompt = (
        "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
        "helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
877
878
        "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20. "
        "Make the response as short as possible."
879
880
        "<|im_end|>\n<|im_start|>assistant\n"
    )
881
882

    def generate_with_backend(backend):
883
        structured_outputs_params = StructuredOutputsParams(
884
885
886
            json=schema,
            backend=backend,
            disable_any_whitespace=True,
887
888
            disable_additional_properties=True,
        )
889
        sampling_params = SamplingParams(
890
891
            temperature=0, max_tokens=256, structured_outputs=structured_outputs_params
        )
892

893
        outputs = llm.generate(prompt, sampling_params=sampling_params)
894
895
896
897
898
899
900
901
        assert outputs is not None
        generated_text = outputs[0].outputs[0].text
        assert generated_text is not None
        parsed_json = json.loads(generated_text)
        assert isinstance(parsed_json, dict)
        jsonschema.validate(instance=parsed_json, schema=schema)
        return parsed_json

902
    generated = generate_with_backend("guidance")
903
904
905
906
907
908
    assert "a1" in generated
    assert "a2" in generated
    assert "a3" in generated
    assert "a4" not in generated
    assert "a5" not in generated
    assert "a6" not in generated
909
910


911
912
913
@pytest.mark.parametrize("backend", ["guidance", "xgrammar", "outlines"])
def test_structured_output_batched_with_non_structured_outputs_requests(
    backend: str,
914
):
915
    sample_json_schema = SAMPLE_JSON_SCHEMA
916
917
918
919
920
921
922
923
    # Don't use eager execution on TPUs because we want to test for no
    # recompilation at runtime
    enforce_eager = bool(not current_platform.is_tpu())

    llm = LLM(
        model="meta-llama/Meta-Llama-3.1-8B-Instruct",
        enforce_eager=enforce_eager,
        max_model_len=1024,
924
925
926
927
        structured_outputs_config=StructuredOutputsConfig(
            backend=backend,
            disable_any_whitespace=backend in {"xgrammar", "guidance"},
        ),
928
929
    )

930
    structured_outputs_prompt = (
931
932
        "Give an example JSON for an employee profile that fits this "
        "schema. Make the response as short as possible. Schema: "
933
934
        f"{sample_json_schema}"
    )
935

936
    non_structured_outputs_prompt = "The diameter of the Earth in kilometers is "
937

938
    prompts = [structured_outputs_prompt, non_structured_outputs_prompt]
939
    sampling_params = [
940
941
942
943
944
        SamplingParams(
            temperature=1.0,
            max_tokens=400,
            structured_outputs=StructuredOutputsParams(json=sample_json_schema),
        ),
945
946
947
948
949
950
951
952
        # No max tokens, temp=0 to assert on contents
        SamplingParams(
            seed=42,
            temperature=0,
            top_p=1.0,
        ),
    ]

953
954
955
    outputs = llm.generate(
        prompts=prompts, sampling_params=sampling_params, use_tqdm=True
    )
956
957
958
959
960
961

    assert outputs is not None

    # Free memory as soon as possible as failed assertions
    # will short circuit and not free up memory
    del llm
962
    torch.accelerator.empty_cache()
963
964
965
966
967
968
969
970
971
972
973
974
    cleanup_dist_env_and_memory()

    for index, output in enumerate(outputs):
        assert output is not None
        assert isinstance(output, RequestOutput)
        prompt = output.prompt

        generated_text = output.outputs[0].text
        assert generated_text is not None
        print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}")

        if index == 0:
975
            # First prompt is structured outputs, expect valid JSON
976
977
            assert "\n" not in generated_text
            output_json = json.loads(generated_text)
978
            jsonschema.validate(instance=output_json, schema=sample_json_schema)
979
        else:
980
            # Second prompt is not structured outputs, expect valid output
981
982
983
            # Cannot assert on exact output, but we can expect it to be factual
            assert "12,742" in generated_text

984
            # non-structured outputs requests should not return a valid JSON here
985
986
            with pytest.raises(ValueError):
                output_json = json.loads(generated_text)
987
988


989
990
@pytest.mark.parametrize("backend", ["xgrammar"])
def test_structured_output_with_structural_tag(backend: str):
991
992
    llm = LLM(
        model="Qwen/Qwen2.5-1.5B-Instruct",
993
        structured_outputs_config=StructuredOutputsConfig(backend=backend),
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
    )

    structural_tag_config = {
        "type": "structural_tag",
        "format": {
            "type": "triggered_tags",
            "tags": [
                {"begin": "hello_flag", "content": {"type": "any_text"}, "end": "hello"}
            ],
            "triggers": ["hello"],
            "stop_after_first": False,
        },
    }

    sampling_params = SamplingParams(
        temperature=0.0,
        max_tokens=500,
1011
        structured_outputs=StructuredOutputsParams(
1012
1013
1014
1015
            structural_tag=json.dumps(structural_tag_config)
        ),
    )

Jiayi Yan's avatar
Jiayi Yan committed
1016
    prompt = "Hello and repeat hello 10 times, do not say anything else. Only say hello hello hello, now start"
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
    outputs = llm.generate(prompt, sampling_params=sampling_params, use_tqdm=True)
    assert outputs is not None
    for output in outputs:
        assert output is not None
        assert isinstance(output, RequestOutput)
        prompt = output.prompt
        generated_text = output.outputs[0].text
        assert generated_text is not None
        assert "hello_flag" in generated_text, (
            f"Expected 'hello_flag' to be in generated text, but got: {generated_text}"
        )