"..._static/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "ebbd5f643d3006c601183e6f5a111611663754c5"
Unverified Commit 05bea688 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix some online scheduling delay (#1345)

parent ab4a83b2
...@@ -119,6 +119,7 @@ class PrefillAdder: ...@@ -119,6 +119,7 @@ class PrefillAdder:
self.running_batch = running_batch self.running_batch = running_batch
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
self.rem_total_tokens_ = self.rem_total_tokens
self.total_tokens = rem_total_tokens self.total_tokens = rem_total_tokens
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens self.rem_chunk_tokens = rem_chunk_tokens
...@@ -153,11 +154,18 @@ class PrefillAdder: ...@@ -153,11 +154,18 @@ class PrefillAdder:
for r in running_batch.reqs for r in running_batch.reqs
] ]
) )
self.rem_total_tokens_ -= sum(
[
r.sampling_params.max_new_tokens - len(r.output_ids)
for r in running_batch.reqs
]
)
def _prefill_one_req( def _prefill_one_req(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int self, prefix_len: int, extend_input_len: int, max_new_tokens: int
): ):
self.rem_total_tokens -= extend_input_len + max_new_tokens self.rem_total_tokens -= extend_input_len + max_new_tokens
self.rem_total_tokens_ -= extend_input_len + max_new_tokens
self.rem_input_tokens -= extend_input_len self.rem_input_tokens -= extend_input_len
if self.rem_chunk_tokens is not None: if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= extend_input_len self.rem_chunk_tokens -= extend_input_len
...@@ -231,43 +239,52 @@ class PrefillAdder: ...@@ -231,43 +239,52 @@ class PrefillAdder:
return None return None
if self.req_states is None: # Quick Check
self.req_states = [] can_run = False
if self.running_batch is not None: if (
for r in self.running_batch.reqs: req.extend_input_len + req.sampling_params.max_new_tokens
<= self.rem_total_tokens
):
can_run = True
if not can_run:
if self.req_states is None:
self.req_states = []
if self.running_batch is not None:
for r in self.running_batch.reqs:
state = get_req_state(r)
if state is not None:
self.req_states.append(state)
for r in self.can_run_list:
state = get_req_state(r) state = get_req_state(r)
if state is not None: if state is not None:
self.req_states.append(state) self.req_states.append(state)
for r in self.can_run_list: state = get_req_state(req)
state = get_req_state(r)
if state is not None: if state is not None:
self.req_states.append(state) self.req_states.append(state)
state = get_req_state(req)
if state is not None:
self.req_states.append(state)
self.req_states.sort(key=lambda x: x[0]) self.req_states.sort(key=lambda x: x[0])
else: else:
state = get_req_state(req) state = get_req_state(req)
if state is not None: if state is not None:
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
if tokens_left >= state[0]: if tokens_left >= state[0]:
self.req_states.insert(i, state) self.req_states.insert(i, state)
break break
else: else:
self.req_states.append(state) self.req_states.append(state)
tokens_freed = 0 tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
decode_steps = ( decode_steps = (
self.req_states[i + 1][0] self.req_states[i + 1][0]
if i + 1 < len(self.req_states) if i + 1 < len(self.req_states)
else tokens_left else tokens_left
) )
bs = len(self.req_states) - i bs = len(self.req_states) - i
if self.total_tokens + tokens_freed - decode_steps * bs <= 0: if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
return False return False
tokens_freed += tokens_occupied tokens_freed += tokens_occupied
if req.extend_input_len <= self.rem_chunk_tokens: if req.extend_input_len <= self.rem_chunk_tokens:
self.can_run_list.append(req) self.can_run_list.append(req)
......
...@@ -231,6 +231,7 @@ class ModelTpServer: ...@@ -231,6 +231,7 @@ class ModelTpServer:
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
): ):
self.handle_generate_request(recv_req) self.handle_generate_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, FlushCacheReq): elif isinstance(recv_req, FlushCacheReq):
self.flush_cache() self.flush_cache()
elif isinstance(recv_req, AbortReq): elif isinstance(recv_req, AbortReq):
...@@ -254,12 +255,10 @@ class ModelTpServer: ...@@ -254,12 +255,10 @@ class ModelTpServer:
@torch.inference_mode() @torch.inference_mode()
def forward_step(self): def forward_step(self):
if self.current_inflight_req is not None: if self.do_not_get_new_batch and self.current_inflight_req is None:
self.do_not_get_new_batch = False new_batch = None
else:
new_batch = ( new_batch = self.get_new_prefill_batch()
self.get_new_prefill_batch() if not self.do_not_get_new_batch else None
)
self.do_not_get_new_batch = False self.do_not_get_new_batch = False
if new_batch is not None: if new_batch is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment