Unverified Commit 05bea688 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix some online scheduling delay (#1345)

parent ab4a83b2
...@@ -119,6 +119,7 @@ class PrefillAdder: ...@@ -119,6 +119,7 @@ class PrefillAdder:
self.running_batch = running_batch self.running_batch = running_batch
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
self.rem_total_tokens_ = self.rem_total_tokens
self.total_tokens = rem_total_tokens self.total_tokens = rem_total_tokens
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens self.rem_chunk_tokens = rem_chunk_tokens
...@@ -153,11 +154,18 @@ class PrefillAdder: ...@@ -153,11 +154,18 @@ class PrefillAdder:
for r in running_batch.reqs for r in running_batch.reqs
] ]
) )
self.rem_total_tokens_ -= sum(
[
r.sampling_params.max_new_tokens - len(r.output_ids)
for r in running_batch.reqs
]
)
def _prefill_one_req( def _prefill_one_req(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int self, prefix_len: int, extend_input_len: int, max_new_tokens: int
): ):
self.rem_total_tokens -= extend_input_len + max_new_tokens self.rem_total_tokens -= extend_input_len + max_new_tokens
self.rem_total_tokens_ -= extend_input_len + max_new_tokens
self.rem_input_tokens -= extend_input_len self.rem_input_tokens -= extend_input_len
if self.rem_chunk_tokens is not None: if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= extend_input_len self.rem_chunk_tokens -= extend_input_len
...@@ -231,6 +239,15 @@ class PrefillAdder: ...@@ -231,6 +239,15 @@ class PrefillAdder:
return None return None
# Quick Check
can_run = False
if (
req.extend_input_len + req.sampling_params.max_new_tokens
<= self.rem_total_tokens
):
can_run = True
if not can_run:
if self.req_states is None: if self.req_states is None:
self.req_states = [] self.req_states = []
if self.running_batch is not None: if self.running_batch is not None:
......
...@@ -231,6 +231,7 @@ class ModelTpServer: ...@@ -231,6 +231,7 @@ class ModelTpServer:
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
): ):
self.handle_generate_request(recv_req) self.handle_generate_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, FlushCacheReq): elif isinstance(recv_req, FlushCacheReq):
self.flush_cache() self.flush_cache()
elif isinstance(recv_req, AbortReq): elif isinstance(recv_req, AbortReq):
...@@ -254,12 +255,10 @@ class ModelTpServer: ...@@ -254,12 +255,10 @@ class ModelTpServer:
@torch.inference_mode() @torch.inference_mode()
def forward_step(self): def forward_step(self):
if self.current_inflight_req is not None: if self.do_not_get_new_batch and self.current_inflight_req is None:
self.do_not_get_new_batch = False new_batch = None
else:
new_batch = ( new_batch = self.get_new_prefill_batch()
self.get_new_prefill_batch() if not self.do_not_get_new_batch else None
)
self.do_not_get_new_batch = False self.do_not_get_new_batch = False
if new_batch is not None: if new_batch is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment