Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9935f97b
Unverified
Commit
9935f97b
authored
Aug 26, 2024
by
havetc
Committed by
GitHub
Aug 26, 2024
Browse files
[FEAT] JSON constrained support (#1125)
Co-authored-by:
Yineng Zhang
<
me@zhyncs.com
>
parent
c5fe11a8
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
147 additions
and
3 deletions
+147
-3
docs/en/sampling_params.md
docs/en/sampling_params.md
+3
-0
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+11
-2
python/sglang/srt/constrained/jump_forward.py
python/sglang/srt/constrained/jump_forward.py
+1
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+20
-1
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+2
-0
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+2
-0
python/sglang/srt/sampling/sampling_params.py
python/sglang/srt/sampling/sampling_params.py
+4
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+96
-0
No files found.
docs/en/sampling_params.md
View file @
9935f97b
...
@@ -60,6 +60,9 @@ spaces_between_special_tokens: bool = True,
...
@@ -60,6 +60,9 @@ spaces_between_special_tokens: bool = True,
regex
:
Optional
[
str
]
=
None
,
regex
:
Optional
[
str
]
=
None
,
# Do parallel sampling and return `n` outputs.
# Do parallel sampling and return `n` outputs.
n
:
int
=
1
,
n
:
int
=
1
,
# Constrains the output to follow a given JSON schema.
# `regex` and `json_schema` cannot be set at the same time.
json_schema
:
Optional
[
str
]
=
None
,
## Penalties. See [Performance Implications on Penalties] section below for more informations.
## Penalties. See [Performance Implications on Penalties] section below for more informations.
...
...
python/sglang/srt/constrained/fsm_cache.py
View file @
9935f97b
...
@@ -15,6 +15,8 @@ limitations under the License.
...
@@ -15,6 +15,8 @@ limitations under the License.
"""Cache for the compressed finite state machine."""
"""Cache for the compressed finite state machine."""
from
outlines.fsm.json_schema
import
build_regex_from_schema
from
sglang.srt.constrained
import
RegexGuide
,
TransformerTokenizer
from
sglang.srt.constrained
import
RegexGuide
,
TransformerTokenizer
from
sglang.srt.constrained.base_tool_cache
import
BaseToolCache
from
sglang.srt.constrained.base_tool_cache
import
BaseToolCache
...
@@ -26,9 +28,12 @@ class FSMCache(BaseToolCache):
...
@@ -26,9 +28,12 @@ class FSMCache(BaseToolCache):
tokenizer_args_dict
,
tokenizer_args_dict
,
enable
=
True
,
enable
=
True
,
skip_tokenizer_init
=
False
,
skip_tokenizer_init
=
False
,
json_schema_mode
=
False
,
):
):
super
().
__init__
(
enable
=
enable
)
super
().
__init__
(
enable
=
enable
)
self
.
json_schema_mode
=
json_schema_mode
if
(
if
(
skip_tokenizer_init
skip_tokenizer_init
or
tokenizer_path
.
endswith
(
".json"
)
or
tokenizer_path
.
endswith
(
".json"
)
...
@@ -72,5 +77,9 @@ class FSMCache(BaseToolCache):
...
@@ -72,5 +77,9 @@ class FSMCache(BaseToolCache):
tokenizer_path
,
**
tokenizer_args_dict
tokenizer_path
,
**
tokenizer_args_dict
)
)
def
init_value
(
self
,
regex
):
def
init_value
(
self
,
value
):
return
RegexGuide
(
regex
,
self
.
outlines_tokenizer
)
if
self
.
json_schema_mode
:
regex
=
build_regex_from_schema
(
value
)
return
RegexGuide
(
regex
,
self
.
outlines_tokenizer
),
regex
else
:
return
RegexGuide
(
value
,
self
.
outlines_tokenizer
)
python/sglang/srt/constrained/jump_forward.py
View file @
9935f97b
...
@@ -23,6 +23,7 @@ from collections import defaultdict
...
@@ -23,6 +23,7 @@ from collections import defaultdict
import
interegular
import
interegular
import
outlines.caching
import
outlines.caching
from
outlines.fsm.json_schema
import
build_regex_from_schema
from
sglang.srt.constrained
import
(
from
sglang.srt.constrained
import
(
FSMInfo
,
FSMInfo
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
9935f97b
...
@@ -268,7 +268,14 @@ class Req:
...
@@ -268,7 +268,14 @@ class Req:
all_text
=
self
.
origin_input_text
+
self
.
decoded_text
+
jump_forward_str
all_text
=
self
.
origin_input_text
+
self
.
decoded_text
+
jump_forward_str
all_ids
=
self
.
tokenizer
.
encode
(
all_text
)
all_ids
=
self
.
tokenizer
.
encode
(
all_text
)
if
not
all_ids
:
warnings
.
warn
(
"Encoded all_text resulted in empty all_ids"
)
return
False
prompt_tokens
=
len
(
self
.
origin_input_ids_unpadded
)
prompt_tokens
=
len
(
self
.
origin_input_ids_unpadded
)
if
prompt_tokens
>
len
(
all_ids
):
warnings
.
warn
(
"prompt_tokens is larger than encoded all_ids"
)
return
False
if
all_ids
[
prompt_tokens
-
1
]
!=
self
.
origin_input_ids_unpadded
[
-
1
]:
if
all_ids
[
prompt_tokens
-
1
]
!=
self
.
origin_input_ids_unpadded
[
-
1
]:
# TODO(lsyin): fix token fusion
# TODO(lsyin): fix token fusion
...
...
python/sglang/srt/managers/tp_worker.py
View file @
9935f97b
...
@@ -197,6 +197,16 @@ class ModelTpServer:
...
@@ -197,6 +197,16 @@ class ModelTpServer:
"trust_remote_code"
:
server_args
.
trust_remote_code
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
json_schema_mode
=
False
,
)
self
.
json_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
json_schema_mode
=
True
,
)
)
self
.
jump_forward_cache
=
JumpForwardCache
()
self
.
jump_forward_cache
=
JumpForwardCache
()
...
@@ -349,8 +359,17 @@ class ModelTpServer:
...
@@ -349,8 +359,17 @@ class ModelTpServer:
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
req
.
stream
=
recv_req
.
stream
# Init regex fsm fron json
if
req
.
sampling_params
.
json_schema
is
not
None
:
req
.
regex_fsm
,
computed_regex_string
=
self
.
json_fsm_cache
.
query
(
req
.
sampling_params
.
json_schema
)
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
computed_regex_string
)
# Init regex fsm
# Init regex fsm
if
req
.
sampling_params
.
regex
is
not
None
:
el
if
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
query
(
req
.
sampling_params
.
regex
)
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
query
(
req
.
sampling_params
.
regex
)
if
not
self
.
disable_regex_jump_forward
:
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
...
...
python/sglang/srt/openai_api/adapter.py
View file @
9935f97b
...
@@ -434,6 +434,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]):
...
@@ -434,6 +434,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]):
"frequency_penalty"
:
request
.
frequency_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"regex"
:
request
.
regex
,
"regex"
:
request
.
regex
,
"json_schema"
:
request
.
json_schema
,
"n"
:
request
.
n
,
"n"
:
request
.
n
,
"ignore_eos"
:
request
.
ignore_eos
,
"ignore_eos"
:
request
.
ignore_eos
,
}
}
...
@@ -802,6 +803,7 @@ def v1_chat_generate_request(
...
@@ -802,6 +803,7 @@ def v1_chat_generate_request(
"frequency_penalty"
:
request
.
frequency_penalty
,
"frequency_penalty"
:
request
.
frequency_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"repetition_penalty"
:
request
.
repetition_penalty
,
"regex"
:
request
.
regex
,
"regex"
:
request
.
regex
,
"json_schema"
:
request
.
json_schema
,
"n"
:
request
.
n
,
"n"
:
request
.
n
,
}
}
)
)
...
...
python/sglang/srt/openai_api/protocol.py
View file @
9935f97b
...
@@ -161,6 +161,7 @@ class CompletionRequest(BaseModel):
...
@@ -161,6 +161,7 @@ class CompletionRequest(BaseModel):
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
regex
:
Optional
[
str
]
=
None
regex
:
Optional
[
str
]
=
None
json_schema
:
Optional
[
str
]
=
None
ignore_eos
:
Optional
[
bool
]
=
False
ignore_eos
:
Optional
[
bool
]
=
False
min_tokens
:
Optional
[
int
]
=
0
min_tokens
:
Optional
[
int
]
=
0
repetition_penalty
:
Optional
[
float
]
=
1.0
repetition_penalty
:
Optional
[
float
]
=
1.0
...
@@ -262,6 +263,7 @@ class ChatCompletionRequest(BaseModel):
...
@@ -262,6 +263,7 @@ class ChatCompletionRequest(BaseModel):
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
regex
:
Optional
[
str
]
=
None
regex
:
Optional
[
str
]
=
None
json_schema
:
Optional
[
str
]
=
None
min_tokens
:
Optional
[
int
]
=
0
min_tokens
:
Optional
[
int
]
=
0
repetition_penalty
:
Optional
[
float
]
=
1.0
repetition_penalty
:
Optional
[
float
]
=
1.0
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
...
...
python/sglang/srt/sampling/sampling_params.py
View file @
9935f97b
...
@@ -39,6 +39,7 @@ class SamplingParams:
...
@@ -39,6 +39,7 @@ class SamplingParams:
spaces_between_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
regex
:
Optional
[
str
]
=
None
,
regex
:
Optional
[
str
]
=
None
,
n
:
int
=
1
,
n
:
int
=
1
,
json_schema
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
self
.
temperature
=
temperature
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_p
=
top_p
...
@@ -56,6 +57,7 @@ class SamplingParams:
...
@@ -56,6 +57,7 @@ class SamplingParams:
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
regex
=
regex
self
.
regex
=
regex
self
.
n
=
n
self
.
n
=
n
self
.
json_schema
=
json_schema
# Process some special cases
# Process some special cases
if
self
.
temperature
<
_SAMPLING_EPS
:
if
self
.
temperature
<
_SAMPLING_EPS
:
...
@@ -106,6 +108,8 @@ class SamplingParams:
...
@@ -106,6 +108,8 @@ class SamplingParams:
f
"min_new_tokens must be in (0, max_new_tokens(
{
self
.
max_new_tokens
}
)], got "
f
"min_new_tokens must be in (0, max_new_tokens(
{
self
.
max_new_tokens
}
)], got "
f
"
{
self
.
min_new_tokens
}
."
f
"
{
self
.
min_new_tokens
}
."
)
)
if
self
.
regex
is
not
None
and
self
.
json_schema
is
not
None
:
raise
ValueError
(
"regex and json_schema cannot be both set."
)
def
normalize
(
self
,
tokenizer
):
def
normalize
(
self
,
tokenizer
):
# Process stop strings
# Process stop strings
...
...
test/srt/run_suite.py
View file @
9935f97b
...
@@ -13,6 +13,7 @@ suites = {
...
@@ -13,6 +13,7 @@ suites = {
"test_eval_accuracy_mini.py"
,
"test_eval_accuracy_mini.py"
,
"test_large_max_new_tokens.py"
,
"test_large_max_new_tokens.py"
,
"test_openai_server.py"
,
"test_openai_server.py"
,
"test_json_constrained.py"
,
"test_skip_tokenizer_init.py"
,
"test_skip_tokenizer_init.py"
,
"test_torch_compile.py"
,
"test_torch_compile.py"
,
"test_triton_attn_backend.py"
,
"test_triton_attn_backend.py"
,
...
...
test/srt/test_json_constrained.py
0 → 100644
View file @
9935f97b
import
json
import
unittest
import
openai
import
requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestJSONConstrained
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
cls
.
json_schema
=
json
.
dumps
(
{
"type"
:
"object"
,
"properties"
:
{
"name"
:
{
"type"
:
"string"
,
"pattern"
:
"^[
\\
w]+$"
},
"population"
:
{
"type"
:
"integer"
},
},
"required"
:
[
"name"
,
"population"
],
}
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
api_key
=
cls
.
api_key
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
run_decode
(
self
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
n
=
1
):
headers
=
{
"Authorization"
:
f
"Bearer
{
self
.
api_key
}
"
}
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
if
n
==
1
else
0.5
,
"max_new_tokens"
:
128
,
"n"
:
n
,
"stop_token_ids"
:
[
119690
],
"json_schema"
:
self
.
json_schema
,
},
"stream"
:
False
,
"return_logprob"
:
return_logprob
,
"top_logprobs_num"
:
top_logprobs_num
,
"logprob_start_len"
:
0
,
},
headers
=
headers
,
)
print
(
json
.
dumps
(
response
.
json
()))
print
(
"="
*
100
)
try
:
js_obj
=
json
.
loads
(
response
.
json
()[
"text"
])
except
(
TypeError
,
json
.
decoder
.
JSONDecodeError
):
raise
assert
isinstance
(
js_obj
[
"name"
],
str
)
assert
isinstance
(
js_obj
[
"population"
],
int
)
def
test_json_generate
(
self
):
self
.
run_decode
()
def
test_json_openai
(
self
):
client
=
openai
.
Client
(
api_key
=
self
.
api_key
,
base_url
=
f
"
{
self
.
base_url
}
/v1"
)
response
=
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful AI assistant"
},
{
"role"
:
"user"
,
"content"
:
"Introduce the capital of France."
},
],
temperature
=
0
,
max_tokens
=
128
,
extra_body
=
{
"json_schema"
:
self
.
json_schema
},
)
text
=
response
.
choices
[
0
].
message
.
content
try
:
js_obj
=
json
.
loads
(
text
)
except
(
TypeError
,
json
.
decoder
.
JSONDecodeError
):
print
(
"JSONDecodeError"
,
text
)
raise
assert
isinstance
(
js_obj
[
"name"
],
str
)
assert
isinstance
(
js_obj
[
"population"
],
int
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment