You need to sign in or sign up before continuing.
Unverified Commit 731146f6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix mixed chunked prefill in overlap mode (#2158)

parent fa271613
...@@ -50,7 +50,7 @@ jobs: ...@@ -50,7 +50,7 @@ jobs:
timeout-minutes: 25 timeout-minutes: 25
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 0 --range-end 6 python3 run_suite.py --suite minimal --range-begin 0 --range-end 7
unit-test-backend-part-2: unit-test-backend-part-2:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...@@ -67,7 +67,7 @@ jobs: ...@@ -67,7 +67,7 @@ jobs:
timeout-minutes: 25 timeout-minutes: 25
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 6 --range-end 14 python3 run_suite.py --suite minimal --range-begin 7 --range-end 14
unit-test-backend-part-3: unit-test-backend-part-3:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
......
...@@ -729,10 +729,13 @@ class ScheduleBatch: ...@@ -729,10 +729,13 @@ class ScheduleBatch:
self.input_ids = input_ids self.input_ids = input_ids
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
# For overlap scheduler, the output_ids has one step delay
delta = 0 if self.enable_overlap else -1
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens.extend( self.prefix_lens.extend(
[ [
len(r.origin_input_ids) + len(r.output_ids) - 1 len(r.origin_input_ids) + len(r.output_ids) + delta
for r in running_batch.reqs for r in running_batch.reqs
] ]
) )
......
...@@ -848,7 +848,12 @@ class Scheduler: ...@@ -848,7 +848,12 @@ class Scheduler:
new_batch.prepare_for_extend() new_batch.prepare_for_extend()
# Mixed-style chunked prefill # Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None: if (
self.is_mixed_chunk
and self.running_batch is not None
and not (new_batch.return_logprob or self.running_batch.return_logprob)
):
# TODO (lianmin): support return_logprob + mixed chunked prefill
self.running_batch.filter_batch() self.running_batch.filter_batch()
if not self.running_batch.is_empty(): if not self.running_batch.is_empty():
self.running_batch.prepare_for_decode() self.running_batch.prepare_for_decode()
...@@ -979,7 +984,10 @@ class Scheduler: ...@@ -979,7 +984,10 @@ class Scheduler:
continue continue
if self.is_mixed_chunk and self.enable_overlap and req.finished(): if self.is_mixed_chunk and self.enable_overlap and req.finished():
raise ValueError("Unhandled error!") # Free the one delayed token for the mixed decode batch
j = len(batch.out_cache_loc) - len(batch.reqs) + i
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
continue
if req.is_being_chunked <= 0: if req.is_being_chunked <= 0:
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
...@@ -992,7 +1000,6 @@ class Scheduler: ...@@ -992,7 +1000,6 @@ class Scheduler:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
if req.return_logprob: if req.return_logprob:
# TODO (lianmin): need to think the case w/ mixed chunked prefill
logprob_pt += self.add_logprob_return_values( logprob_pt += self.add_logprob_return_values(
i, req, logprob_pt, next_token_ids, logits_output i, req, logprob_pt, next_token_ids, logits_output
) )
......
...@@ -199,12 +199,6 @@ class ServerArgs: ...@@ -199,12 +199,6 @@ class ServerArgs:
"Overlap schedule is disabled." "Overlap schedule is disabled."
) )
if self.enable_mixed_chunk:
logger.info(
"Overlap schedule is disabled because mixed-style chunked prefill is enabled."
)
self.disable_overlap_schedule = True
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args # Model and port args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment