import json import unittest import openai import requests from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_URL_FOR_TEST, popen_launch_server, ) class TestJSONConstrained(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.api_key = "sk-123456" cls.json_schema = json.dumps( { "type": "object", "properties": { "name": {"type": "string", "pattern": "^[\\w]+$"}, "population": {"type": "integer"}, }, "required": ["name", "population"], } ) cls.process = popen_launch_server( cls.model, cls.base_url, timeout=300, api_key=cls.api_key ) @classmethod def tearDownClass(cls): kill_child_process(cls.process.pid) def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): headers = {"Authorization": f"Bearer {self.api_key}"} response = requests.post( self.base_url + "/generate", json={ "text": "The capital of France is", "sampling_params": { "temperature": 0 if n == 1 else 0.5, "max_new_tokens": 128, "n": n, "stop_token_ids": [119690], "json_schema": self.json_schema, }, "stream": False, "return_logprob": return_logprob, "top_logprobs_num": top_logprobs_num, "logprob_start_len": 0, }, headers=headers, ) print(json.dumps(response.json())) print("=" * 100) try: js_obj = json.loads(response.json()["text"]) except (TypeError, json.decoder.JSONDecodeError): raise assert isinstance(js_obj["name"], str) assert isinstance(js_obj["population"], int) def test_json_generate(self): self.run_decode() def test_json_openai(self): client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") response = client.chat.completions.create( model=self.model, messages=[ {"role": "system", "content": "You are a helpful AI assistant"}, {"role": "user", "content": "Introduce the capital of France."}, ], temperature=0, max_tokens=128, response_format={ "type": "json_schema", "json_schema": {"name": "foo", "schema": json.loads(self.json_schema)}, }, ) text = response.choices[0].message.content try: js_obj = json.loads(text) except (TypeError, json.decoder.JSONDecodeError): print("JSONDecodeError", text) raise assert isinstance(js_obj["name"], str) assert isinstance(js_obj["population"], int) if __name__ == "__main__": unittest.main()