Unverified Commit 9935f97b authored by havetc's avatar havetc Committed by GitHub
Browse files

[FEAT] JSON constrained support (#1125)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent c5fe11a8
...@@ -60,6 +60,9 @@ spaces_between_special_tokens: bool = True, ...@@ -60,6 +60,9 @@ spaces_between_special_tokens: bool = True,
regex: Optional[str] = None, regex: Optional[str] = None,
# Do parallel sampling and return `n` outputs. # Do parallel sampling and return `n` outputs.
n: int = 1, n: int = 1,
# Constrains the output to follow a given JSON schema.
# `regex` and `json_schema` cannot be set at the same time.
json_schema: Optional[str] = None,
## Penalties. See [Performance Implications on Penalties] section below for more informations. ## Penalties. See [Performance Implications on Penalties] section below for more informations.
......
...@@ -15,6 +15,8 @@ limitations under the License. ...@@ -15,6 +15,8 @@ limitations under the License.
"""Cache for the compressed finite state machine.""" """Cache for the compressed finite state machine."""
from outlines.fsm.json_schema import build_regex_from_schema
from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_tool_cache import BaseToolCache from sglang.srt.constrained.base_tool_cache import BaseToolCache
...@@ -26,9 +28,12 @@ class FSMCache(BaseToolCache): ...@@ -26,9 +28,12 @@ class FSMCache(BaseToolCache):
tokenizer_args_dict, tokenizer_args_dict,
enable=True, enable=True,
skip_tokenizer_init=False, skip_tokenizer_init=False,
json_schema_mode=False,
): ):
super().__init__(enable=enable) super().__init__(enable=enable)
self.json_schema_mode = json_schema_mode
if ( if (
skip_tokenizer_init skip_tokenizer_init
or tokenizer_path.endswith(".json") or tokenizer_path.endswith(".json")
...@@ -72,5 +77,9 @@ class FSMCache(BaseToolCache): ...@@ -72,5 +77,9 @@ class FSMCache(BaseToolCache):
tokenizer_path, **tokenizer_args_dict tokenizer_path, **tokenizer_args_dict
) )
def init_value(self, regex): def init_value(self, value):
return RegexGuide(regex, self.outlines_tokenizer) if self.json_schema_mode:
regex = build_regex_from_schema(value)
return RegexGuide(regex, self.outlines_tokenizer), regex
else:
return RegexGuide(value, self.outlines_tokenizer)
...@@ -23,6 +23,7 @@ from collections import defaultdict ...@@ -23,6 +23,7 @@ from collections import defaultdict
import interegular import interegular
import outlines.caching import outlines.caching
from outlines.fsm.json_schema import build_regex_from_schema
from sglang.srt.constrained import ( from sglang.srt.constrained import (
FSMInfo, FSMInfo,
......
...@@ -268,7 +268,14 @@ class Req: ...@@ -268,7 +268,14 @@ class Req:
all_text = self.origin_input_text + self.decoded_text + jump_forward_str all_text = self.origin_input_text + self.decoded_text + jump_forward_str
all_ids = self.tokenizer.encode(all_text) all_ids = self.tokenizer.encode(all_text)
if not all_ids:
warnings.warn("Encoded all_text resulted in empty all_ids")
return False
prompt_tokens = len(self.origin_input_ids_unpadded) prompt_tokens = len(self.origin_input_ids_unpadded)
if prompt_tokens > len(all_ids):
warnings.warn("prompt_tokens is larger than encoded all_ids")
return False
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]: if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
# TODO(lsyin): fix token fusion # TODO(lsyin): fix token fusion
......
...@@ -197,6 +197,16 @@ class ModelTpServer: ...@@ -197,6 +197,16 @@ class ModelTpServer:
"trust_remote_code": server_args.trust_remote_code, "trust_remote_code": server_args.trust_remote_code,
}, },
skip_tokenizer_init=server_args.skip_tokenizer_init, skip_tokenizer_init=server_args.skip_tokenizer_init,
json_schema_mode=False,
)
self.json_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
json_schema_mode=True,
) )
self.jump_forward_cache = JumpForwardCache() self.jump_forward_cache = JumpForwardCache()
...@@ -349,8 +359,17 @@ class ModelTpServer: ...@@ -349,8 +359,17 @@ class ModelTpServer:
req.top_logprobs_num = recv_req.top_logprobs_num req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream req.stream = recv_req.stream
# Init regex fsm fron json
if req.sampling_params.json_schema is not None:
req.regex_fsm, computed_regex_string = self.json_fsm_cache.query(
req.sampling_params.json_schema
)
if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query(
computed_regex_string
)
# Init regex fsm # Init regex fsm
if req.sampling_params.regex is not None: elif req.sampling_params.regex is not None:
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
if not self.disable_regex_jump_forward: if not self.disable_regex_jump_forward:
req.jump_forward_map = self.jump_forward_cache.query( req.jump_forward_map = self.jump_forward_cache.query(
......
...@@ -434,6 +434,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]): ...@@ -434,6 +434,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]):
"frequency_penalty": request.frequency_penalty, "frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty, "repetition_penalty": request.repetition_penalty,
"regex": request.regex, "regex": request.regex,
"json_schema": request.json_schema,
"n": request.n, "n": request.n,
"ignore_eos": request.ignore_eos, "ignore_eos": request.ignore_eos,
} }
...@@ -802,6 +803,7 @@ def v1_chat_generate_request( ...@@ -802,6 +803,7 @@ def v1_chat_generate_request(
"frequency_penalty": request.frequency_penalty, "frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty, "repetition_penalty": request.repetition_penalty,
"regex": request.regex, "regex": request.regex,
"json_schema": request.json_schema,
"n": request.n, "n": request.n,
} }
) )
......
...@@ -161,6 +161,7 @@ class CompletionRequest(BaseModel): ...@@ -161,6 +161,7 @@ class CompletionRequest(BaseModel):
# Extra parameters for SRT backend only and will be ignored by OpenAI models. # Extra parameters for SRT backend only and will be ignored by OpenAI models.
regex: Optional[str] = None regex: Optional[str] = None
json_schema: Optional[str] = None
ignore_eos: Optional[bool] = False ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0 min_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
...@@ -262,6 +263,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -262,6 +263,7 @@ class ChatCompletionRequest(BaseModel):
# Extra parameters for SRT backend only and will be ignored by OpenAI models. # Extra parameters for SRT backend only and will be ignored by OpenAI models.
regex: Optional[str] = None regex: Optional[str] = None
json_schema: Optional[str] = None
min_tokens: Optional[int] = 0 min_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0 repetition_penalty: Optional[float] = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
......
...@@ -39,6 +39,7 @@ class SamplingParams: ...@@ -39,6 +39,7 @@ class SamplingParams:
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True,
regex: Optional[str] = None, regex: Optional[str] = None,
n: int = 1, n: int = 1,
json_schema: Optional[str] = None,
) -> None: ) -> None:
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
...@@ -56,6 +57,7 @@ class SamplingParams: ...@@ -56,6 +57,7 @@ class SamplingParams:
self.spaces_between_special_tokens = spaces_between_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens
self.regex = regex self.regex = regex
self.n = n self.n = n
self.json_schema = json_schema
# Process some special cases # Process some special cases
if self.temperature < _SAMPLING_EPS: if self.temperature < _SAMPLING_EPS:
...@@ -106,6 +108,8 @@ class SamplingParams: ...@@ -106,6 +108,8 @@ class SamplingParams:
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got " f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
f"{self.min_new_tokens}." f"{self.min_new_tokens}."
) )
if self.regex is not None and self.json_schema is not None:
raise ValueError("regex and json_schema cannot be both set.")
def normalize(self, tokenizer): def normalize(self, tokenizer):
# Process stop strings # Process stop strings
......
...@@ -13,6 +13,7 @@ suites = { ...@@ -13,6 +13,7 @@ suites = {
"test_eval_accuracy_mini.py", "test_eval_accuracy_mini.py",
"test_large_max_new_tokens.py", "test_large_max_new_tokens.py",
"test_openai_server.py", "test_openai_server.py",
"test_json_constrained.py",
"test_skip_tokenizer_init.py", "test_skip_tokenizer_init.py",
"test_torch_compile.py", "test_torch_compile.py",
"test_triton_attn_backend.py", "test_triton_attn_backend.py",
......
import json
import unittest
import openai
import requests
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestJSONConstrained(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 128,
"n": n,
"stop_token_ids": [119690],
"json_schema": self.json_schema,
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
headers=headers,
)
print(json.dumps(response.json()))
print("=" * 100)
try:
js_obj = json.loads(response.json()["text"])
except (TypeError, json.decoder.JSONDecodeError):
raise
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
def test_json_generate(self):
self.run_decode()
def test_json_openai(self):
client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1")
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=128,
extra_body={"json_schema": self.json_schema},
)
text = response.choices[0].message.content
try:
js_obj = json.loads(text)
except (TypeError, json.decoder.JSONDecodeError):
print("JSONDecodeError", text)
raise
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment