"vscode:/vscode.git/clone" did not exist on "c3f8dad55cd36d96668a93d9387e5d3d05c6df75"
Unverified Commit 02bc9579 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify chunked prefill (#1667)

parent 24f3e151
...@@ -456,7 +456,12 @@ class Scheduler: ...@@ -456,7 +456,12 @@ class Scheduler:
and not self.last_batch.is_empty() and not self.last_batch.is_empty()
): ):
if self.current_inflight_req: if self.current_inflight_req:
self.last_batch.filter_batch(self.current_inflight_req) self.last_batch.filter_batch(
current_inflight_req=self.current_inflight_req
)
self.tree_cache.cache_unfinished_req(self.current_inflight_req)
# Inflight request keeps its rid but will get a new req_pool_idx.
self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx)
self.batch_is_full = False self.batch_is_full = False
if not self.last_batch.is_empty(): if not self.last_batch.is_empty():
if self.running_batch is None: if self.running_batch is None:
...@@ -728,26 +733,23 @@ class Scheduler: ...@@ -728,26 +733,23 @@ class Scheduler:
# Check finish conditions # Check finish conditions
logprob_pt = 0 logprob_pt = 0
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
if req is not self.current_inflight_req: if req.is_inflight_req > 0:
req.is_inflight_req -= 1
else:
# Inflight reqs' prefill is not finished # Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i]) req.output_ids.append(next_token_ids[i])
req.check_finished() req.check_finished()
if req.regex_fsm is not None:
req.regex_fsm_state = req.regex_fsm.get_next_state(
req.regex_fsm_state, next_token_ids[i]
)
if req.finished(): if req.finished():
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
elif not batch.decoding_reqs or req not in batch.decoding_reqs: elif not batch.decoding_reqs or req not in batch.decoding_reqs:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
if req.is_inflight_req > 0: if req.regex_fsm is not None:
# Inflight request would get a new req idx req.regex_fsm_state = req.regex_fsm.get_next_state(
req.is_inflight_req -= 1 req.regex_fsm_state, next_token_ids[i]
self.req_to_token_pool.free(req.req_pool_idx) )
if req.return_logprob: if req.return_logprob:
logprob_pt += self.add_logprob_return_values( logprob_pt += self.add_logprob_return_values(
...@@ -760,7 +762,9 @@ class Scheduler: ...@@ -760,7 +762,9 @@ class Scheduler:
# Check finish conditions # Check finish conditions
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
req.embedding = embeddings[i] req.embedding = embeddings[i]
if req is not self.current_inflight_req: if req.is_inflight_req > 0:
req.is_inflight_req -= 1
else:
# Inflight reqs' prefill is not finished # Inflight reqs' prefill is not finished
# dummy output token for embedding models # dummy output token for embedding models
req.output_ids.append(0) req.output_ids.append(0)
...@@ -771,11 +775,6 @@ class Scheduler: ...@@ -771,11 +775,6 @@ class Scheduler:
else: else:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
if req.is_inflight_req > 0:
# Inflight request would get a new req idx
req.is_inflight_req -= 1
self.req_to_token_pool.free(req.req_pool_idx)
self.stream_output(batch) self.stream_output(batch)
def process_batch_result_decode(self, batch: ScheduleBatch, result): def process_batch_result_decode(self, batch: ScheduleBatch, result):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment