Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5d09ca57
"src/runtime/vscode:/vscode.git/clone" did not exist on "62b5f50a5ce7509eae87f1cde7265bed05cfe973"
Unverified
Commit
5d09ca57
authored
Oct 11, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 11, 2024
Browse files
Fix constrained decoding (#1634)
parent
81c33274
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
4 deletions
+22
-4
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+2
-0
test/srt/test_json_constrained.py
test/srt/test_json_constrained.py
+20
-4
No files found.
python/sglang/srt/managers/schedule_batch.py
View file @
5d09ca57
...
...
@@ -810,6 +810,8 @@ class ScheduleBatch:
self
.
sampling_info
.
regex_fsm_states
=
[
req
.
regex_fsm_state
for
req
in
self
.
reqs
]
else
:
self
.
sampling_info
.
regex_fsms
=
None
return
ModelWorkerBatch
(
forward_mode
=
self
.
forward_mode
,
...
...
test/srt/test_json_constrained.py
View file @
5d09ca57
import
json
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
import
openai
import
requests
...
...
@@ -27,13 +28,18 @@ class TestJSONConstrained(unittest.TestCase):
"required"
:
[
"name"
,
"population"
],
}
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
other_args
=
[
"--max-running-requests"
,
"10"
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
run_decode
(
self
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
n
=
1
):
def
run_decode
(
self
,
json_schema
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
n
=
1
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
...
...
@@ -43,7 +49,7 @@ class TestJSONConstrained(unittest.TestCase):
"max_new_tokens"
:
128
,
"n"
:
n
,
"stop_token_ids"
:
[
119690
],
"json_schema"
:
self
.
json_schema
,
"json_schema"
:
json_schema
,
},
"stream"
:
False
,
"return_logprob"
:
return_logprob
,
...
...
@@ -53,6 +59,10 @@ class TestJSONConstrained(unittest.TestCase):
)
print
(
json
.
dumps
(
response
.
json
()))
print
(
"="
*
100
)
if
not
json_schema
:
return
try
:
js_obj
=
json
.
loads
(
response
.
json
()[
"text"
])
except
(
TypeError
,
json
.
decoder
.
JSONDecodeError
):
...
...
@@ -61,7 +71,7 @@ class TestJSONConstrained(unittest.TestCase):
assert
isinstance
(
js_obj
[
"population"
],
int
)
def
test_json_generate
(
self
):
self
.
run_decode
()
self
.
run_decode
(
json_schema
=
self
.
json_schema
)
def
test_json_openai
(
self
):
client
=
openai
.
Client
(
api_key
=
"EMPTY"
,
base_url
=
f
"
{
self
.
base_url
}
/v1"
)
...
...
@@ -89,6 +99,12 @@ class TestJSONConstrained(unittest.TestCase):
assert
isinstance
(
js_obj
[
"name"
],
str
)
assert
isinstance
(
js_obj
[
"population"
],
int
)
def
test_mix_json_and_other
(
self
):
json_schemas
=
[
None
,
None
,
self
.
json_schema
,
self
.
json_schema
]
*
10
with
ThreadPoolExecutor
(
len
(
json_schemas
))
as
executor
:
list
(
executor
.
map
(
self
.
run_decode
,
json_schemas
))
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment