Unverified Commit a2486eb5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix a bug with logprob streaming + chunked prefill (#2403)

parent 61dec545
...@@ -321,6 +321,8 @@ async def async_request_sglang_generate( ...@@ -321,6 +321,8 @@ async def async_request_sglang_generate(
}, },
"stream": not args.disable_stream, "stream": not args.disable_stream,
"lora_path": request_func_input.lora_name, "lora_path": request_func_input.lora_name,
"return_logprob": args.return_logprob,
"logprob_start_len": -1,
**request_func_input.extra_request_body, **request_func_input.extra_request_body,
} }
headers = {} headers = {}
...@@ -911,7 +913,7 @@ async def benchmark( ...@@ -911,7 +913,7 @@ async def benchmark(
prompt=test_prompt, prompt=test_prompt,
api_url=api_url, api_url=api_url,
prompt_len=test_prompt_len, prompt_len=test_prompt_len,
output_len=test_output_len, output_len=min(test_output_len, 32),
lora_name=lora_name, lora_name=lora_name,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
) )
...@@ -1413,6 +1415,11 @@ if __name__ == "__main__": ...@@ -1413,6 +1415,11 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Disable ignoring EOS.", help="Disable ignoring EOS.",
) )
parser.add_argument(
"--return-logprob",
action="store_true",
help="Return logprob.",
)
parser.add_argument( parser.add_argument(
"--extra-request-body", "--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}', metavar='{"key1": "value1", "key2": "value2"}',
......
...@@ -440,16 +440,11 @@ class Scheduler: ...@@ -440,16 +440,11 @@ class Scheduler:
if self.tp_rank == 0 or self.server_args.enable_dp_attention: if self.tp_rank == 0 or self.server_args.enable_dp_attention:
recv_reqs = [] recv_reqs = []
if self.last_batch is None: while True:
recv_req = self.recv_from_tokenizer.recv_pyobj() try:
recv_reqs.append(recv_req) recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
else: except zmq.ZMQError:
while True: break
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
else: else:
recv_reqs = None recv_reqs = None
...@@ -949,6 +944,7 @@ class Scheduler: ...@@ -949,6 +944,7 @@ class Scheduler:
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result): def process_batch_result_prefill(self, batch: ScheduleBatch, result):
skip_stream_req = None
if self.is_generation: if self.is_generation:
logits_output, next_token_ids, bid = result logits_output, next_token_ids, bid = result
...@@ -1005,6 +1001,10 @@ class Scheduler: ...@@ -1005,6 +1001,10 @@ class Scheduler:
else: else:
# being chunked reqs' prefill is not finished # being chunked reqs' prefill is not finished
req.is_being_chunked -= 1 req.is_being_chunked -= 1
# There is only at most one request being currently chunked.
# Because this request does not finish prefill,
# we don't want to stream the request currently being chunked.
skip_stream_req = req
if batch.next_batch_sampling_info: if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
...@@ -1034,7 +1034,7 @@ class Scheduler: ...@@ -1034,7 +1034,7 @@ class Scheduler:
# being chunked reqs' prefill is not finished # being chunked reqs' prefill is not finished
req.is_being_chunked -= 1 req.is_being_chunked -= 1
self.stream_output(batch.reqs) self.stream_output(batch.reqs, skip_stream_req)
def process_batch_result_decode(self, batch: ScheduleBatch, result): def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids, bid = result logits_output, next_token_ids, bid = result
...@@ -1179,7 +1179,7 @@ class Scheduler: ...@@ -1179,7 +1179,7 @@ class Scheduler:
return num_input_logprobs return num_input_logprobs
def stream_output(self, reqs: List[Req]): def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None):
"""Stream the output to detokenizer.""" """Stream the output to detokenizer."""
output_rids = [] output_rids = []
output_meta_info: List[dict] = [] output_meta_info: List[dict] = []
...@@ -1199,6 +1199,9 @@ class Scheduler: ...@@ -1199,6 +1199,9 @@ class Scheduler:
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0 is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
for req in reqs: for req in reqs:
if req is skip_req:
continue
# TODO(lianmin): revisit this for overlap + retract + stream # TODO(lianmin): revisit this for overlap + retract + stream
if req.finished() or ( if req.finished() or (
req.stream and (is_stream_iter or len(req.output_ids) == 1) req.stream and (is_stream_iter or len(req.output_ids) == 1)
......
...@@ -568,6 +568,7 @@ def run_bench_serving( ...@@ -568,6 +568,7 @@ def run_bench_serving(
disable_tqdm=False, disable_tqdm=False,
disable_stream=disable_stream, disable_stream=disable_stream,
disable_ignore_eos=False, disable_ignore_eos=False,
return_logprob=False,
lora_name=None, lora_name=None,
extra_request_body=None, extra_request_body=None,
profile=None, profile=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment