Unverified Commit a2486eb5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix a bug with logprob streaming + chunked prefill (#2403)

parent 61dec545
......@@ -321,6 +321,8 @@ async def async_request_sglang_generate(
},
"stream": not args.disable_stream,
"lora_path": request_func_input.lora_name,
"return_logprob": args.return_logprob,
"logprob_start_len": -1,
**request_func_input.extra_request_body,
}
headers = {}
......@@ -911,7 +913,7 @@ async def benchmark(
prompt=test_prompt,
api_url=api_url,
prompt_len=test_prompt_len,
output_len=test_output_len,
output_len=min(test_output_len, 32),
lora_name=lora_name,
extra_request_body=extra_request_body,
)
......@@ -1413,6 +1415,11 @@ if __name__ == "__main__":
action="store_true",
help="Disable ignoring EOS.",
)
parser.add_argument(
"--return-logprob",
action="store_true",
help="Return logprob.",
)
parser.add_argument(
"--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}',
......
......@@ -440,16 +440,11 @@ class Scheduler:
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
recv_reqs = []
if self.last_batch is None:
recv_req = self.recv_from_tokenizer.recv_pyobj()
recv_reqs.append(recv_req)
else:
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
else:
recv_reqs = None
......@@ -949,6 +944,7 @@ class Scheduler:
batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
skip_stream_req = None
if self.is_generation:
logits_output, next_token_ids, bid = result
......@@ -1005,6 +1001,10 @@ class Scheduler:
else:
# being chunked reqs' prefill is not finished
req.is_being_chunked -= 1
# There is only at most one request being currently chunked.
# Because this request does not finish prefill,
# we don't want to stream the request currently being chunked.
skip_stream_req = req
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
......@@ -1034,7 +1034,7 @@ class Scheduler:
# being chunked reqs' prefill is not finished
req.is_being_chunked -= 1
self.stream_output(batch.reqs)
self.stream_output(batch.reqs, skip_stream_req)
def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids, bid = result
......@@ -1179,7 +1179,7 @@ class Scheduler:
return num_input_logprobs
def stream_output(self, reqs: List[Req]):
def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None):
"""Stream the output to detokenizer."""
output_rids = []
output_meta_info: List[dict] = []
......@@ -1199,6 +1199,9 @@ class Scheduler:
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
for req in reqs:
if req is skip_req:
continue
# TODO(lianmin): revisit this for overlap + retract + stream
if req.finished() or (
req.stream and (is_stream_iter or len(req.output_ids) == 1)
......
......@@ -568,6 +568,7 @@ def run_bench_serving(
disable_tqdm=False,
disable_stream=disable_stream,
disable_ignore_eos=False,
return_logprob=False,
lora_name=None,
extra_request_body=None,
profile=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment