"docker/diffusers-pytorch-xformers-cuda/Dockerfile" did not exist on "bc9a8cef6f258aafcd43ef64ac161218a7eae43a"
Unverified Commit 05bea688 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix some online scheduling delay (#1345)

parent ab4a83b2
......@@ -119,6 +119,7 @@ class PrefillAdder:
self.running_batch = running_batch
self.new_token_ratio = new_token_ratio
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
self.rem_total_tokens_ = self.rem_total_tokens
self.total_tokens = rem_total_tokens
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens
......@@ -153,11 +154,18 @@ class PrefillAdder:
for r in running_batch.reqs
]
)
self.rem_total_tokens_ -= sum(
[
r.sampling_params.max_new_tokens - len(r.output_ids)
for r in running_batch.reqs
]
)
def _prefill_one_req(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
):
self.rem_total_tokens -= extend_input_len + max_new_tokens
self.rem_total_tokens_ -= extend_input_len + max_new_tokens
self.rem_input_tokens -= extend_input_len
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= extend_input_len
......@@ -231,6 +239,15 @@ class PrefillAdder:
return None
# Quick Check
can_run = False
if (
req.extend_input_len + req.sampling_params.max_new_tokens
<= self.rem_total_tokens
):
can_run = True
if not can_run:
if self.req_states is None:
self.req_states = []
if self.running_batch is not None:
......
......@@ -231,6 +231,7 @@ class ModelTpServer:
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
):
self.handle_generate_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
......@@ -254,12 +255,10 @@ class ModelTpServer:
@torch.inference_mode()
def forward_step(self):
if self.current_inflight_req is not None:
self.do_not_get_new_batch = False
new_batch = (
self.get_new_prefill_batch() if not self.do_not_get_new_batch else None
)
if self.do_not_get_new_batch and self.current_inflight_req is None:
new_batch = None
else:
new_batch = self.get_new_prefill_batch()
self.do_not_get_new_batch = False
if new_batch is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment