Unverified Commit 2d831c6e authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Support structured output (#6560)

parent ed0c3035
......@@ -45,19 +45,16 @@ from sglang.srt.disaggregation.utils import (
poll_and_all_reduce,
prepare_abort,
)
from sglang.srt.managers.schedule_batch import FINISH_ABORT
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.model_executor.forward_batch_info import ForwardMode
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.server_args import ServerArgs
@dataclass
......@@ -531,7 +528,18 @@ class SchedulerDisaggregationDecodeMixin:
self.prepare_dp_attn_batch(batch)
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))
if (self.last_batch is None) or (not self.last_batch_in_queue):
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.set_next_batch_sampling_info_done(tmp_batch)
last_batch_in_queue = True
elif prepare_dp_attn_flag:
batch, result = self._prepare_idle_batch_and_run(
None, delay_process=True
......@@ -543,6 +551,9 @@ class SchedulerDisaggregationDecodeMixin:
# Process the results of the previous batch but skip if the last batch is extend
if self.last_batch and self.last_batch_in_queue:
tmp_batch, tmp_result = result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result)
if batch is None and (
......@@ -591,6 +602,9 @@ class SchedulerDisaggregationDecodeMixin:
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
"""Create a schedulebatch for fake completed prefill"""
if self.grammar_queue:
self.move_ready_grammar_requests()
if len(self.waiting_queue) == 0:
return None
......@@ -616,8 +630,6 @@ class SchedulerDisaggregationDecodeMixin:
self.waiting_queue = waiting_queue
if len(can_run_list) == 0:
return None
# local import to avoid circular import
from sglang.srt.managers.schedule_batch import ScheduleBatch
# construct a schedule batch with those requests and mark as decode
new_batch = ScheduleBatch.init_new(
......
......@@ -101,6 +101,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
for req in self.reqs:
self.output_ids.append(req.output_ids[-1])
self.tree_cache.cache_unfinished_req(req)
if req.grammar is not None:
req.grammar.accept_token(req.output_ids[-1])
req.grammar.finished = req.finished()
self.output_ids = torch.tensor(self.output_ids, device=self.device)
# Simulate the eagle run. We add mock data to hidden states for the
......
......@@ -43,6 +43,7 @@ from sglang.srt.disaggregation.utils import (
prepare_abort,
)
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
......@@ -143,6 +144,10 @@ class PrefillBootstrapQueue:
self._process_req(req)
self.queue.append(req)
def extend(self, reqs: List[Req]) -> None:
for req in reqs:
self.add(req)
def _process_req(self, req: Req) -> None:
"""
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
......@@ -269,6 +274,16 @@ class SchedulerDisaggregationPrefillMixin:
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
if self.last_batch is None:
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.set_next_batch_sampling_info_done(tmp_batch)
if self.last_batch:
tmp_batch, tmp_result = self.result_queue.popleft()
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
......
......@@ -1065,8 +1065,11 @@ class Scheduler(
else:
self.waiting_queue.append(req)
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
if self.disaggregation_mode == DisaggregationMode.DECODE:
def _extend_requests_to_queue(self, reqs: List[Req]):
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.disagg_prefill_bootstrap_queue.extend(reqs)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
# If this is a decode server, we put the request to the decode pending prealloc queue
self.disagg_decode_prealloc_queue.extend(reqs)
else:
self.waiting_queue.extend(reqs)
......
import json
import requests
port = 8000
json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
# JSON
response = requests.post(
f"http://localhost:{port}/generate",
json={
"text": "Here is the information of the capital of France in the JSON format.\n",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 64,
"json_schema": json_schema,
},
},
)
print(response.json())
# python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --trust-remote-code --disaggregation-mode prefill --tp 2 --disaggregation-ib-device mlx5_roce0,mlx5_roce1 --speculative-algorithm EAGLE --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --host 127.0.0.1 --port 8100
import json
import os
import subprocess
import time
......@@ -17,12 +18,9 @@ from sglang.test.test_utils import (
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_pd_server,
run_with_timeout,
)
# skip the test because we have different_tp test
@unittest.skip("skip the test because we have different_tp test")
class TestDisaggregationAccuracy(CustomTestCase):
@classmethod
def setUpClass(cls):
......@@ -172,6 +170,34 @@ class TestDisaggregationAccuracy(CustomTestCase):
len(input_logprobs) > 0
), f"input_logprobs should have at least one token, but got {len(input_logprobs)}"
def test_structured_output(self):
json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
# JSON
response = requests.post(
f"{self.lb_url}/generate",
json={
"text": "Here is the information of the capital of France in the JSON format.\n",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 64,
"json_schema": json_schema,
},
},
)
output = response.json()["text"]
# ensure the output is a valid JSON
json.loads(output)
class TestDisaggregationMooncakeFailure(CustomTestCase):
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment