Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2d831c6e
Unverified
Commit
2d831c6e
authored
May 23, 2025
by
Byron Hsu
Committed by
GitHub
May 23, 2025
Browse files
[PD] Support structured output (#6560)
parent
ed0c3035
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
106 additions
and
13 deletions
+106
-13
python/sglang/srt/disaggregation/decode.py
python/sglang/srt/disaggregation/decode.py
+20
-8
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+3
-0
python/sglang/srt/disaggregation/prefill.py
python/sglang/srt/disaggregation/prefill.py
+15
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+5
-2
scripts/playground/disaggregation/cli-so.py
scripts/playground/disaggregation/cli-so.py
+34
-0
test/srt/test_disaggregation.py
test/srt/test_disaggregation.py
+29
-3
No files found.
python/sglang/srt/disaggregation/decode.py
View file @
2d831c6e
...
...
@@ -45,19 +45,16 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce
,
prepare_abort
,
)
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
ScheduleBatch
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.memory_pool
import
ReqToTokenPool
,
TokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
Req
from
sglang.srt.managers.scheduler
import
Scheduler
from
sglang.srt.server_args
import
ServerArgs
@
dataclass
...
...
@@ -531,7 +528,18 @@ class SchedulerDisaggregationDecodeMixin:
self
.
prepare_dp_attn_batch
(
batch
)
result
=
self
.
run_batch
(
batch
)
result_queue
.
append
((
batch
.
copy
(),
result
))
if
(
self
.
last_batch
is
None
)
or
(
not
self
.
last_batch_in_queue
):
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch
=
ScheduleBatch
(
reqs
=
None
,
forward_mode
=
ForwardMode
.
DUMMY_FIRST
,
next_batch_sampling_info
=
self
.
tp_worker
.
cur_sampling_info
,
)
self
.
set_next_batch_sampling_info_done
(
tmp_batch
)
last_batch_in_queue
=
True
elif
prepare_dp_attn_flag
:
batch
,
result
=
self
.
_prepare_idle_batch_and_run
(
None
,
delay_process
=
True
...
...
@@ -543,6 +551,9 @@ class SchedulerDisaggregationDecodeMixin:
# Process the results of the previous batch but skip if the last batch is extend
if
self
.
last_batch
and
self
.
last_batch_in_queue
:
tmp_batch
,
tmp_result
=
result_queue
.
popleft
()
tmp_batch
.
next_batch_sampling_info
=
(
self
.
tp_worker
.
cur_sampling_info
if
batch
else
None
)
self
.
process_batch_result
(
tmp_batch
,
tmp_result
)
if
batch
is
None
and
(
...
...
@@ -591,6 +602,9 @@ class SchedulerDisaggregationDecodeMixin:
def
get_new_prebuilt_batch
(
self
:
Scheduler
)
->
Optional
[
ScheduleBatch
]:
"""Create a schedulebatch for fake completed prefill"""
if
self
.
grammar_queue
:
self
.
move_ready_grammar_requests
()
if
len
(
self
.
waiting_queue
)
==
0
:
return
None
...
...
@@ -616,8 +630,6 @@ class SchedulerDisaggregationDecodeMixin:
self
.
waiting_queue
=
waiting_queue
if
len
(
can_run_list
)
==
0
:
return
None
# local import to avoid circular import
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
# construct a schedule batch with those requests and mark as decode
new_batch
=
ScheduleBatch
.
init_new
(
...
...
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
2d831c6e
...
...
@@ -101,6 +101,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
for
req
in
self
.
reqs
:
self
.
output_ids
.
append
(
req
.
output_ids
[
-
1
])
self
.
tree_cache
.
cache_unfinished_req
(
req
)
if
req
.
grammar
is
not
None
:
req
.
grammar
.
accept_token
(
req
.
output_ids
[
-
1
])
req
.
grammar
.
finished
=
req
.
finished
()
self
.
output_ids
=
torch
.
tensor
(
self
.
output_ids
,
device
=
self
.
device
)
# Simulate the eagle run. We add mock data to hidden states for the
...
...
python/sglang/srt/disaggregation/prefill.py
View file @
2d831c6e
...
...
@@ -43,6 +43,7 @@ from sglang.srt.disaggregation.utils import (
prepare_abort
,
)
from
sglang.srt.managers.schedule_batch
import
FINISH_LENGTH
,
Req
,
ScheduleBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
if
TYPE_CHECKING
:
from
torch.distributed
import
ProcessGroup
...
...
@@ -143,6 +144,10 @@ class PrefillBootstrapQueue:
self
.
_process_req
(
req
)
self
.
queue
.
append
(
req
)
def
extend
(
self
,
reqs
:
List
[
Req
])
->
None
:
for
req
in
reqs
:
self
.
add
(
req
)
def
_process_req
(
self
,
req
:
Req
)
->
None
:
"""
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
...
...
@@ -269,6 +274,16 @@ class SchedulerDisaggregationPrefillMixin:
result
=
self
.
run_batch
(
batch
)
self
.
result_queue
.
append
((
batch
.
copy
(),
result
))
if
self
.
last_batch
is
None
:
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch
=
ScheduleBatch
(
reqs
=
None
,
forward_mode
=
ForwardMode
.
DUMMY_FIRST
,
next_batch_sampling_info
=
self
.
tp_worker
.
cur_sampling_info
,
)
self
.
set_next_batch_sampling_info_done
(
tmp_batch
)
if
self
.
last_batch
:
tmp_batch
,
tmp_result
=
self
.
result_queue
.
popleft
()
self
.
process_batch_result_disagg_prefill
(
tmp_batch
,
tmp_result
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
2d831c6e
...
...
@@ -1065,8 +1065,11 @@ class Scheduler(
else
:
self
.
waiting_queue
.
append
(
req
)
def
_extend_requests_to_queue
(
self
,
reqs
:
List
[
Req
],
is_retracted
:
bool
=
False
):
if
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
def
_extend_requests_to_queue
(
self
,
reqs
:
List
[
Req
]):
if
self
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
self
.
disagg_prefill_bootstrap_queue
.
extend
(
reqs
)
elif
self
.
disaggregation_mode
==
DisaggregationMode
.
DECODE
:
# If this is a decode server, we put the request to the decode pending prealloc queue
self
.
disagg_decode_prealloc_queue
.
extend
(
reqs
)
else
:
self
.
waiting_queue
.
extend
(
reqs
)
...
...
scripts/playground/disaggregation/cli-so.py
0 → 100644
View file @
2d831c6e
import
json
import
requests
port
=
8000
json_schema
=
json
.
dumps
(
{
"type"
:
"object"
,
"properties"
:
{
"name"
:
{
"type"
:
"string"
,
"pattern"
:
"^[
\\
w]+$"
},
"population"
:
{
"type"
:
"integer"
},
},
"required"
:
[
"name"
,
"population"
],
}
)
# JSON
response
=
requests
.
post
(
f
"http://localhost:
{
port
}
/generate"
,
json
=
{
"text"
:
"Here is the information of the capital of France in the JSON format.
\n
"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
64
,
"json_schema"
:
json_schema
,
},
},
)
print
(
response
.
json
())
# python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --trust-remote-code --disaggregation-mode prefill --tp 2 --disaggregation-ib-device mlx5_roce0,mlx5_roce1 --speculative-algorithm EAGLE --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --host 127.0.0.1 --port 8100
test/srt/test_disaggregation.py
View file @
2d831c6e
import
json
import
os
import
subprocess
import
time
...
...
@@ -17,12 +18,9 @@ from sglang.test.test_utils import (
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_pd_server
,
run_with_timeout
,
)
# skip the test because we have different_tp test
@
unittest
.
skip
(
"skip the test because we have different_tp test"
)
class
TestDisaggregationAccuracy
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -172,6 +170,34 @@ class TestDisaggregationAccuracy(CustomTestCase):
len
(
input_logprobs
)
>
0
),
f
"input_logprobs should have at least one token, but got
{
len
(
input_logprobs
)
}
"
def
test_structured_output
(
self
):
json_schema
=
json
.
dumps
(
{
"type"
:
"object"
,
"properties"
:
{
"name"
:
{
"type"
:
"string"
,
"pattern"
:
"^[
\\
w]+$"
},
"population"
:
{
"type"
:
"integer"
},
},
"required"
:
[
"name"
,
"population"
],
}
)
# JSON
response
=
requests
.
post
(
f
"
{
self
.
lb_url
}
/generate"
,
json
=
{
"text"
:
"Here is the information of the capital of France in the JSON format.
\n
"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
64
,
"json_schema"
:
json_schema
,
},
},
)
output
=
response
.
json
()[
"text"
]
# ensure the output is a valid JSON
json
.
loads
(
output
)
class
TestDisaggregationMooncakeFailure
(
CustomTestCase
):
@
classmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment