Unverified Commit 5f12f0e7 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix chunked prefill when ignore eos (#2290)

parent d5b95cbb
...@@ -142,7 +142,7 @@ class PrefillAdder: ...@@ -142,7 +142,7 @@ class PrefillAdder:
self.req_states = None self.req_states = None
self.can_run_list = [] self.can_run_list = []
self.new_inflight_req = None self.new_being_chunked_req = None
self.log_hit_tokens = 0 self.log_hit_tokens = 0
self.log_input_tokens = 0 self.log_input_tokens = 0
...@@ -182,7 +182,7 @@ class PrefillAdder: ...@@ -182,7 +182,7 @@ class PrefillAdder:
self.log_hit_tokens += prefix_len self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len self.log_input_tokens += extend_input_len
def add_inflight_req(self, req: Req): def add_being_chunked_req(self, req: Req):
truncated = req.extend_input_len > self.rem_chunk_tokens truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
...@@ -269,10 +269,13 @@ class PrefillAdder: ...@@ -269,10 +269,13 @@ class PrefillAdder:
else: else:
# Chunked prefill # Chunked prefill
trunc_len = self.rem_chunk_tokens trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return AddReqResult.OTHER
req.extend_input_len = trunc_len req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[:trunc_len] req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req) self.can_run_list.append(req)
self.new_inflight_req = req self.new_being_chunked_req = req
self._prefill_one_req(0, trunc_len, 0) self._prefill_one_req(0, trunc_len, 0)
return self.budget_state() return self.budget_state()
...@@ -326,7 +329,7 @@ class PrefillAdder: ...@@ -326,7 +329,7 @@ class PrefillAdder:
req.extend_input_len = trunc_len req.extend_input_len = trunc_len
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
self.can_run_list.append(req) self.can_run_list.append(req)
self.new_inflight_req = req self.new_being_chunked_req = req
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0) self._prefill_one_req(prefix_len, trunc_len, 0)
......
...@@ -660,7 +660,7 @@ class Scheduler: ...@@ -660,7 +660,7 @@ class Scheduler:
self.waiting_queue.append(req) self.waiting_queue.append(req)
def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight): def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
if isinstance(self.tree_cache, RadixCache): if isinstance(self.tree_cache, RadixCache):
self.tree_cache_metrics["total"] += ( self.tree_cache_metrics["total"] += (
adder.log_input_tokens + adder.log_hit_tokens adder.log_input_tokens + adder.log_hit_tokens
...@@ -684,14 +684,14 @@ class Scheduler: ...@@ -684,14 +684,14 @@ class Scheduler:
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"#running-req: {running_bs}, " f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) + has_inflight}" f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
) )
if self.enable_metrics: if self.enable_metrics:
self.stats.num_running_reqs = running_bs self.stats.num_running_reqs = running_bs
self.stats.num_used_tokens = num_used self.stats.num_used_tokens = num_used
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2) self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
self.stats.cache_hit_rate = tree_cache_hit_rate self.stats.cache_hit_rate = tree_cache_hit_rate
self.metrics_collector.log_stats(self.stats) self.metrics_collector.log_stats(self.stats)
...@@ -752,7 +752,7 @@ class Scheduler: ...@@ -752,7 +752,7 @@ class Scheduler:
# Move the chunked request out of the batch # Move the chunked request out of the batch
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req) self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
self.tree_cache.cache_unfinished_req(self.being_chunked_req) self.tree_cache.cache_unfinished_req(self.being_chunked_req)
# Inflight request keeps its rid but will get a new req_pool_idx # being chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx) self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
self.batch_is_full = False self.batch_is_full = False
...@@ -803,10 +803,10 @@ class Scheduler: ...@@ -803,10 +803,10 @@ class Scheduler:
running_bs if self.is_mixed_chunk else 0, running_bs if self.is_mixed_chunk else 0,
) )
has_inflight = self.being_chunked_req is not None has_being_chunked = self.being_chunked_req is not None
if has_inflight: if has_being_chunked:
self.being_chunked_req.init_next_round_input() self.being_chunked_req.init_next_round_input()
self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req) self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
if self.lora_paths: if self.lora_paths:
lora_set = ( lora_set = (
...@@ -848,16 +848,16 @@ class Scheduler: ...@@ -848,16 +848,16 @@ class Scheduler:
x for x in self.waiting_queue if x not in set(can_run_list) x for x in self.waiting_queue if x not in set(can_run_list)
] ]
if adder.new_inflight_req is not None: if adder.new_being_chunked_req is not None:
assert self.being_chunked_req is None assert self.being_chunked_req is None
self.being_chunked_req = adder.new_inflight_req self.being_chunked_req = adder.new_being_chunked_req
if self.being_chunked_req: if self.being_chunked_req:
self.being_chunked_req.is_being_chunked += 1 self.being_chunked_req.is_being_chunked += 1
# Print stats # Print stats
if self.tp_rank == 0: if self.tp_rank == 0:
self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight) self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
# Create a new batch # Create a new batch
new_batch = ScheduleBatch.init_new( new_batch = ScheduleBatch.init_new(
...@@ -1030,7 +1030,7 @@ class Scheduler: ...@@ -1030,7 +1030,7 @@ class Scheduler:
if req.grammar is not None: if req.grammar is not None:
req.grammar.accept_token(next_token_id) req.grammar.accept_token(next_token_id)
else: else:
# Inflight reqs' prefill is not finished # being chunked reqs' prefill is not finished
req.is_being_chunked -= 1 req.is_being_chunked -= 1
if batch.next_batch_sampling_info: if batch.next_batch_sampling_info:
...@@ -1058,7 +1058,7 @@ class Scheduler: ...@@ -1058,7 +1058,7 @@ class Scheduler:
else: else:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
else: else:
# Inflight reqs' prefill is not finished # being chunked reqs' prefill is not finished
req.is_being_chunked -= 1 req.is_being_chunked -= 1
self.stream_output(batch.reqs) self.stream_output(batch.reqs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment