Unverified Commit 36078fb2 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

fix schedule bug (#1450)

parent b3710d2c
...@@ -119,31 +119,21 @@ class PrefillAdder: ...@@ -119,31 +119,21 @@ class PrefillAdder:
self.running_batch = running_batch self.running_batch = running_batch
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
self.rem_total_tokens_ = self.rem_total_tokens
self.total_tokens = rem_total_tokens
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens self.rem_chunk_tokens = rem_chunk_tokens
if self.rem_chunk_tokens is not None: if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= mixed_with_decode_tokens self.rem_chunk_tokens -= mixed_with_decode_tokens
self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens
self.req_states = None self.req_states = None
self.can_run_list = [] self.can_run_list = []
self.new_inflight_req = None self.new_inflight_req = None
self.log_hit_tokens = 0 self.log_hit_tokens = 0
self.log_input_tokens = 0 self.log_input_tokens = 0
def no_remaining_tokens(self): if running_batch is not None:
return ( # Pre-remove the tokens which will be occupied by the running requests
self.rem_total_tokens <= 0
or self.rem_input_tokens <= 0
or (
self.rem_chunk_tokens <= 0
if self.rem_chunk_tokens is not None
else False
)
)
def remove_running_tokens(self, running_batch: ScheduleBatch):
self.rem_total_tokens -= sum( self.rem_total_tokens -= sum(
[ [
min( min(
...@@ -154,18 +144,24 @@ class PrefillAdder: ...@@ -154,18 +144,24 @@ class PrefillAdder:
for r in running_batch.reqs for r in running_batch.reqs
] ]
) )
self.rem_total_tokens_ -= sum(
[ def no_remaining_tokens(self):
r.sampling_params.max_new_tokens - len(r.output_ids) return (
for r in running_batch.reqs self.rem_total_tokens <= 0
] or self.rem_input_tokens <= 0
or (
self.rem_chunk_tokens <= 0
if self.rem_chunk_tokens is not None
else False
)
or self.cur_rem_tokens <= 0
) )
def _prefill_one_req( def _prefill_one_req(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int self, prefix_len: int, extend_input_len: int, max_new_tokens: int
): ):
self.rem_total_tokens -= extend_input_len + max_new_tokens self.rem_total_tokens -= extend_input_len + max_new_tokens
self.rem_total_tokens_ -= extend_input_len + max_new_tokens self.cur_rem_tokens -= extend_input_len
self.rem_input_tokens -= extend_input_len self.rem_input_tokens -= extend_input_len
if self.rem_chunk_tokens is not None: if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= extend_input_len self.rem_chunk_tokens -= extend_input_len
...@@ -173,29 +169,7 @@ class PrefillAdder: ...@@ -173,29 +169,7 @@ class PrefillAdder:
self.log_hit_tokens += prefix_len self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len self.log_input_tokens += extend_input_len
def add_inflight_req_ignore_eos(self, req: Req):
truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
self.can_run_list.append(req)
self._prefill_one_req(
0,
req.extend_input_len,
(
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
if not truncated
else 0
),
)
# Return if chunked prefill not finished
return req if truncated else None
def add_inflight_req(self, req: Req): def add_inflight_req(self, req: Req):
if req.sampling_params.ignore_eos:
return self.add_inflight_req_ignore_eos(req)
truncated = req.extend_input_len > self.rem_chunk_tokens truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
...@@ -225,7 +199,7 @@ class PrefillAdder: ...@@ -225,7 +199,7 @@ class PrefillAdder:
self.rem_total_tokens += delta self.rem_total_tokens += delta
def add_one_req_ignore_eos(self, req: Req): def add_one_req_ignore_eos(self, req: Req):
def get_req_state(r): def add_req_state(r, insert_sort=False):
new_token_ratio = ( new_token_ratio = (
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio 1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
) )
...@@ -235,44 +209,25 @@ class PrefillAdder: ...@@ -235,44 +209,25 @@ class PrefillAdder:
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids) tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
if tokens_left > 0: if tokens_left > 0:
return (tokens_left, tokens_occupied) if not insert_sort:
self.req_states.append((tokens_left, tokens_occupied))
return None else:
for i in range(len(self.req_states)):
# Quick Check if tokens_left <= self.req_states[i][0]:
can_run = False break
if ( self.req_states.insert(i, (tokens_left, tokens_occupied))
req.extend_input_len + req.sampling_params.max_new_tokens
<= self.rem_total_tokens
):
can_run = True
if not can_run:
if self.req_states is None: if self.req_states is None:
self.req_states = [] self.req_states = []
add_req_state(req)
if self.running_batch is not None: if self.running_batch is not None:
for r in self.running_batch.reqs: for r in self.running_batch.reqs:
state = get_req_state(r) add_req_state(r)
if state is not None:
self.req_states.append(state)
for r in self.can_run_list: for r in self.can_run_list:
state = get_req_state(r) add_req_state(r)
if state is not None:
self.req_states.append(state)
state = get_req_state(req)
if state is not None:
self.req_states.append(state)
self.req_states.sort(key=lambda x: x[0]) self.req_states.sort(key=lambda x: x[0])
else: else:
state = get_req_state(req) add_req_state(req, insert_sort=True)
if state is not None:
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
if tokens_left >= state[0]:
self.req_states.insert(i, state)
break
else:
self.req_states.append(state)
tokens_freed = 0 tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
...@@ -282,7 +237,7 @@ class PrefillAdder: ...@@ -282,7 +237,7 @@ class PrefillAdder:
else tokens_left else tokens_left
) )
bs = len(self.req_states) - i bs = len(self.req_states) - i
if self.total_tokens + tokens_freed - decode_steps * bs <= 0: if self.cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
return False return False
tokens_freed += tokens_occupied tokens_freed += tokens_occupied
......
...@@ -445,9 +445,6 @@ class ModelTpServer: ...@@ -445,9 +445,6 @@ class ModelTpServer:
num_mixed_running, num_mixed_running,
) )
if self.running_batch is not None:
adder.remove_running_tokens(self.running_batch)
has_inflight = self.current_inflight_req is not None has_inflight = self.current_inflight_req is not None
if self.current_inflight_req is not None: if self.current_inflight_req is not None:
self.current_inflight_req.init_next_round_input( self.current_inflight_req.init_next_round_input(
...@@ -465,9 +462,6 @@ class ModelTpServer: ...@@ -465,9 +462,6 @@ class ModelTpServer:
) )
for req in self.waiting_queue: for req in self.waiting_queue:
if adder.no_remaining_tokens():
break
req.init_next_round_input(None if prefix_computed else self.tree_cache)
if ( if (
self.lora_paths is not None self.lora_paths is not None
and len( and len(
...@@ -478,6 +472,10 @@ class ModelTpServer: ...@@ -478,6 +472,10 @@ class ModelTpServer:
> self.max_loras_per_batch > self.max_loras_per_batch
): ):
break break
if adder.no_remaining_tokens():
break
req.init_next_round_input(None if prefix_computed else self.tree_cache)
res = adder.add_one_req(req) res = adder.add_one_req(req)
if ( if (
not res not res
...@@ -507,6 +505,11 @@ class ModelTpServer: ...@@ -507,6 +505,11 @@ class ModelTpServer:
else: else:
tree_cache_hit_rate = 0.0 tree_cache_hit_rate = 0.0
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size()
+ self.tree_cache.evictable_size()
)
if num_mixed_running > 0: if num_mixed_running > 0:
logger.info( logger.info(
f"Prefill batch" f"Prefill batch"
...@@ -515,6 +518,7 @@ class ModelTpServer: ...@@ -515,6 +518,7 @@ class ModelTpServer:
f"#new-token: {adder.log_input_tokens}, " f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, " f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
) )
else: else:
...@@ -524,6 +528,7 @@ class ModelTpServer: ...@@ -524,6 +528,7 @@ class ModelTpServer:
f"#new-token: {adder.log_input_tokens}, " f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, " f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, " f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment