test_json_constrained.py 4.78 KB
Newer Older
1
2
3
4
"""
python3 -m unittest test_json_constrained.TestJSONConstrained.test_json_generate
"""

5
6
import json
import unittest
7
from concurrent.futures import ThreadPoolExecutor
8
9
10
11

import openai
import requests

12
from sglang.srt.utils import kill_process_tree
13
14
15
16
17
18
19
from sglang.test.test_utils import (
    DEFAULT_MODEL_NAME_FOR_TEST,
    DEFAULT_URL_FOR_TEST,
    popen_launch_server,
)


20
class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.json_schema = json.dumps(
            {
                "type": "object",
                "properties": {
                    "name": {"type": "string", "pattern": "^[\\w]+$"},
                    "population": {"type": "integer"},
                },
                "required": ["name", "population"],
            }
        )
35
36
37
38
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=300,
39
40
41
42
43
44
            other_args=[
                "--max-running-requests",
                "10",
                "--grammar-backend",
                "outlines",
            ],
45
        )
46
47
48

    @classmethod
    def tearDownClass(cls):
49
        kill_process_tree(cls.process.pid)
50

51
    def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
52
53
54
55
56
57
58
59
60
        response = requests.post(
            self.base_url + "/generate",
            json={
                "text": "The capital of France is",
                "sampling_params": {
                    "temperature": 0 if n == 1 else 0.5,
                    "max_new_tokens": 128,
                    "n": n,
                    "stop_token_ids": [119690],
61
                    "json_schema": json_schema,
62
63
64
65
66
67
68
                },
                "stream": False,
                "return_logprob": return_logprob,
                "top_logprobs_num": top_logprobs_num,
                "logprob_start_len": 0,
            },
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
69
70
        ret = response.json()
        print(json.dumps(ret))
71
        print("=" * 100)
72
73
74
75

        if not json_schema:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
76
        # Make sure the json output is valid
77
        try:
Lianmin Zheng's avatar
Lianmin Zheng committed
78
            js_obj = json.loads(ret["text"])
79
80
        except (TypeError, json.decoder.JSONDecodeError):
            raise
Lianmin Zheng's avatar
Lianmin Zheng committed
81
82
83
84
85

        self.assertIsInstance(js_obj["name"], str)
        self.assertIsInstance(js_obj["population"], int)

        # Make sure jump forward is triggered
86
87
88
89
90
        # NOTE: This is skipped because overlap scheduler does not support jump forward
        # self.assertGreater(
        #     ret["meta_info"]["completion_tokens"],
        #     ret["meta_info"]["completion_tokens_wo_jump_forward"],
        # )
91
92

    def test_json_generate(self):
93
        self.run_decode(json_schema=self.json_schema)
94
95

    def test_json_openai(self):
96
        client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
97
98
99
100
101
102
103
104
105

        response = client.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful AI assistant"},
                {"role": "user", "content": "Introduce the capital of France."},
            ],
            temperature=0,
            max_tokens=128,
106
107
108
109
            response_format={
                "type": "json_schema",
                "json_schema": {"name": "foo", "schema": json.loads(self.json_schema)},
            },
110
111
112
113
114
115
116
117
        )
        text = response.choices[0].message.content

        try:
            js_obj = json.loads(text)
        except (TypeError, json.decoder.JSONDecodeError):
            print("JSONDecodeError", text)
            raise
Lianmin Zheng's avatar
Lianmin Zheng committed
118
119
120

        self.assertIsInstance(js_obj["name"], str)
        self.assertIsInstance(js_obj["population"], int)
121

122
123
124
125
126
127
    def test_mix_json_and_other(self):
        json_schemas = [None, None, self.json_schema, self.json_schema] * 10

        with ThreadPoolExecutor(len(json_schemas)) as executor:
            list(executor.map(self.run_decode, json_schemas))

128

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
    @classmethod
    def setUpClass(cls):
        cls.model = DEFAULT_MODEL_NAME_FOR_TEST
        cls.base_url = DEFAULT_URL_FOR_TEST
        cls.json_schema = json.dumps(
            {
                "type": "object",
                "properties": {
                    "name": {"type": "string"},
                    "population": {"type": "integer"},
                },
                "required": ["name", "population"],
            }
        )
        cls.process = popen_launch_server(
            cls.model,
            cls.base_url,
            timeout=300,
            other_args=[
                "--max-running-requests",
                "10",
                "--grammar-backend",
                "xgrammar",
            ],
        )


157
158
if __name__ == "__main__":
    unittest.main()