Unverified Commit 731146f6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix mixed chunked prefill in overlap mode (#2158)

parent fa271613
...@@ -50,7 +50,7 @@ jobs: ...@@ -50,7 +50,7 @@ jobs:
timeout-minutes: 25 timeout-minutes: 25
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 0 --range-end 6 python3 run_suite.py --suite minimal --range-begin 0 --range-end 7
unit-test-backend-part-2: unit-test-backend-part-2:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...@@ -67,7 +67,7 @@ jobs: ...@@ -67,7 +67,7 @@ jobs:
timeout-minutes: 25 timeout-minutes: 25
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 6 --range-end 14 python3 run_suite.py --suite minimal --range-begin 7 --range-end 14
unit-test-backend-part-3: unit-test-backend-part-3:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
......
...@@ -729,10 +729,13 @@ class ScheduleBatch: ...@@ -729,10 +729,13 @@ class ScheduleBatch:
self.input_ids = input_ids self.input_ids = input_ids
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
# For overlap scheduler, the output_ids has one step delay
delta = 0 if self.enable_overlap else -1
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens.extend( self.prefix_lens.extend(
[ [
len(r.origin_input_ids) + len(r.output_ids) - 1 len(r.origin_input_ids) + len(r.output_ids) + delta
for r in running_batch.reqs for r in running_batch.reqs
] ]
) )
......
...@@ -848,7 +848,12 @@ class Scheduler: ...@@ -848,7 +848,12 @@ class Scheduler:
new_batch.prepare_for_extend() new_batch.prepare_for_extend()
# Mixed-style chunked prefill # Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None: if (
self.is_mixed_chunk
and self.running_batch is not None
and not (new_batch.return_logprob or self.running_batch.return_logprob)
):
# TODO (lianmin): support return_logprob + mixed chunked prefill
self.running_batch.filter_batch() self.running_batch.filter_batch()
if not self.running_batch.is_empty(): if not self.running_batch.is_empty():
self.running_batch.prepare_for_decode() self.running_batch.prepare_for_decode()
...@@ -979,7 +984,10 @@ class Scheduler: ...@@ -979,7 +984,10 @@ class Scheduler:
continue continue
if self.is_mixed_chunk and self.enable_overlap and req.finished(): if self.is_mixed_chunk and self.enable_overlap and req.finished():
raise ValueError("Unhandled error!") # Free the one delayed token for the mixed decode batch
j = len(batch.out_cache_loc) - len(batch.reqs) + i
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
continue
if req.is_being_chunked <= 0: if req.is_being_chunked <= 0:
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
...@@ -992,7 +1000,6 @@ class Scheduler: ...@@ -992,7 +1000,6 @@ class Scheduler:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
if req.return_logprob: if req.return_logprob:
# TODO (lianmin): need to think the case w/ mixed chunked prefill
logprob_pt += self.add_logprob_return_values( logprob_pt += self.add_logprob_return_values(
i, req, logprob_pt, next_token_ids, logits_output i, req, logprob_pt, next_token_ids, logits_output
) )
......
...@@ -199,12 +199,6 @@ class ServerArgs: ...@@ -199,12 +199,6 @@ class ServerArgs:
"Overlap schedule is disabled." "Overlap schedule is disabled."
) )
if self.enable_mixed_chunk:
logger.info(
"Overlap schedule is disabled because mixed-style chunked prefill is enabled."
)
self.disable_overlap_schedule = True
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args # Model and port args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment