Unverified Commit ab4a83b2 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Optimize schedule (#1339)

parent 62f15eea
...@@ -108,18 +108,24 @@ class PrefillAdder: ...@@ -108,18 +108,24 @@ class PrefillAdder:
def __init__( def __init__(
self, self,
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
running_batch: ScheduleBatch,
new_token_ratio: float,
rem_total_tokens: int, rem_total_tokens: int,
rem_input_tokens: int, rem_input_tokens: int,
rem_chunk_tokens: Optional[int], rem_chunk_tokens: Optional[int],
mixed_with_decode_tokens: int = 0, mixed_with_decode_tokens: int = 0,
): ):
self.tree_cache = tree_cache self.tree_cache = tree_cache
self.running_batch = running_batch
self.new_token_ratio = new_token_ratio
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
self.total_tokens = rem_total_tokens
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens self.rem_chunk_tokens = rem_chunk_tokens
if self.rem_chunk_tokens is not None: if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= mixed_with_decode_tokens self.rem_chunk_tokens -= mixed_with_decode_tokens
self.req_states = None
self.can_run_list = [] self.can_run_list = []
self.new_inflight_req = None self.new_inflight_req = None
self.log_hit_tokens = 0 self.log_hit_tokens = 0
...@@ -136,16 +142,14 @@ class PrefillAdder: ...@@ -136,16 +142,14 @@ class PrefillAdder:
) )
) )
def remove_running_tokens( def remove_running_tokens(self, running_batch: ScheduleBatch):
self, running_batch: ScheduleBatch, new_token_ratio: float
):
self.rem_total_tokens -= sum( self.rem_total_tokens -= sum(
[ [
min( min(
(r.sampling_params.max_new_tokens - len(r.output_ids)), (r.sampling_params.max_new_tokens - len(r.output_ids)),
CLIP_MAX_NEW_TOKENS, CLIP_MAX_NEW_TOKENS,
) )
* new_token_ratio * self.new_token_ratio
for r in running_batch.reqs for r in running_batch.reqs
] ]
) )
...@@ -161,7 +165,29 @@ class PrefillAdder: ...@@ -161,7 +165,29 @@ class PrefillAdder:
self.log_hit_tokens += prefix_len self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len self.log_input_tokens += extend_input_len
def add_inflight_req_ignore_eos(self, req: Req):
truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
self.can_run_list.append(req)
self._prefill_one_req(
0,
req.extend_input_len,
(
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
if not truncated
else 0
),
)
# Return if chunked prefill not finished
return req if truncated else None
def add_inflight_req(self, req: Req): def add_inflight_req(self, req: Req):
if req.sampling_params.ignore_eos:
return self.add_inflight_req_ignore_eos(req)
truncated = req.extend_input_len > self.rem_chunk_tokens truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
...@@ -190,7 +216,81 @@ class PrefillAdder: ...@@ -190,7 +216,81 @@ class PrefillAdder:
delta = self.tree_cache.dec_lock_ref(last_node) delta = self.tree_cache.dec_lock_ref(last_node)
self.rem_total_tokens += delta self.rem_total_tokens += delta
def add_one_req_ignore_eos(self, req: Req):
def get_req_state(r):
new_token_ratio = (
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
)
tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len(
r.output_ids
)
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
if tokens_left > 0:
return (tokens_left, tokens_occupied)
return None
if self.req_states is None:
self.req_states = []
if self.running_batch is not None:
for r in self.running_batch.reqs:
state = get_req_state(r)
if state is not None:
self.req_states.append(state)
for r in self.can_run_list:
state = get_req_state(r)
if state is not None:
self.req_states.append(state)
state = get_req_state(req)
if state is not None:
self.req_states.append(state)
self.req_states.sort(key=lambda x: x[0])
else:
state = get_req_state(req)
if state is not None:
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
if tokens_left >= state[0]:
self.req_states.insert(i, state)
break
else:
self.req_states.append(state)
tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
decode_steps = (
self.req_states[i + 1][0]
if i + 1 < len(self.req_states)
else tokens_left
)
bs = len(self.req_states) - i
if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
return False
tokens_freed += tokens_occupied
if req.extend_input_len <= self.rem_chunk_tokens:
self.can_run_list.append(req)
self._prefill_one_req(
0,
req.extend_input_len,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
)
else:
# Chunked prefill
trunc_len = self.rem_chunk_tokens
req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req)
self.new_inflight_req = req
self._prefill_one_req(0, trunc_len, 0)
return True
def add_one_req(self, req: Req): def add_one_req(self, req: Req):
if req.sampling_params.ignore_eos and self.tree_cache.disable:
return self.add_one_req_ignore_eos(req)
total_tokens = req.extend_input_len + min( total_tokens = req.extend_input_len + min(
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
) )
...@@ -233,4 +333,4 @@ class PrefillAdder: ...@@ -233,4 +333,4 @@ class PrefillAdder:
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0) self._prefill_one_req(prefix_len, trunc_len, 0)
return True return True and not self.no_remaining_tokens()
...@@ -221,6 +221,7 @@ class ModelTpServer: ...@@ -221,6 +221,7 @@ class ModelTpServer:
) )
self.new_token_ratio = self.min_new_token_ratio self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.do_not_get_new_batch = False
def exposed_step(self, recv_reqs: List): def exposed_step(self, recv_reqs: List):
try: try:
...@@ -253,7 +254,13 @@ class ModelTpServer: ...@@ -253,7 +254,13 @@ class ModelTpServer:
@torch.inference_mode() @torch.inference_mode()
def forward_step(self): def forward_step(self):
new_batch = self.get_new_prefill_batch() if self.current_inflight_req is not None:
self.do_not_get_new_batch = False
new_batch = (
self.get_new_prefill_batch() if not self.do_not_get_new_batch else None
)
self.do_not_get_new_batch = False
if new_batch is not None: if new_batch is not None:
# Run a new prefill batch # Run a new prefill batch
...@@ -409,6 +416,8 @@ class ModelTpServer: ...@@ -409,6 +416,8 @@ class ModelTpServer:
adder = PrefillAdder( adder = PrefillAdder(
self.tree_cache, self.tree_cache,
self.running_batch,
self.new_token_ratio,
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
self.max_prefill_tokens, self.max_prefill_tokens,
self.chunked_prefill_size, self.chunked_prefill_size,
...@@ -416,7 +425,7 @@ class ModelTpServer: ...@@ -416,7 +425,7 @@ class ModelTpServer:
) )
if self.running_batch is not None: if self.running_batch is not None:
adder.remove_running_tokens(self.running_batch, self.new_token_ratio) adder.remove_running_tokens(self.running_batch)
has_inflight = self.current_inflight_req is not None has_inflight = self.current_inflight_req is not None
if self.current_inflight_req is not None: if self.current_inflight_req is not None:
...@@ -428,11 +437,12 @@ class ModelTpServer: ...@@ -428,11 +437,12 @@ class ModelTpServer:
) )
for req in self.waiting_queue: for req in self.waiting_queue:
if adder.no_remaining_tokens():
break
req.init_next_round_input(None if prefix_computed else self.tree_cache) req.init_next_round_input(None if prefix_computed else self.tree_cache)
res = adder.add_one_req(req) res = adder.add_one_req(req)
if ( if (
not res not res
or adder.no_remaining_tokens()
or running_bs + len(adder.can_run_list) >= self.max_running_requests or running_bs + len(adder.can_run_list) >= self.max_running_requests
): ):
break break
...@@ -700,6 +710,7 @@ class ModelTpServer: ...@@ -700,6 +710,7 @@ class ModelTpServer:
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
# Check finish condition # Check finish condition
has_finished = False
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id) req.output_ids.append(next_token_id)
...@@ -712,6 +723,7 @@ class ModelTpServer: ...@@ -712,6 +723,7 @@ class ModelTpServer:
if req.finished(): if req.finished():
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
has_finished = True
if req.return_logprob: if req.return_logprob:
req.output_token_logprobs.append( req.output_token_logprobs.append(
...@@ -720,6 +732,9 @@ class ModelTpServer: ...@@ -720,6 +732,9 @@ class ModelTpServer:
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
if not has_finished:
self.do_not_get_new_batch = True
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
def handle_finished_requests(self, batch: ScheduleBatch): def handle_finished_requests(self, batch: ScheduleBatch):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment