Unverified Commit 25549433 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix prefill OOM due to wrong token calculation when page > 1 (#7397)

parent d6dddc19
...@@ -55,6 +55,9 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int( ...@@ -55,6 +55,9 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
) )
IGNORE_EOS_RESERVE_TOKENS = 1
class CacheAwarePolicy(Enum): class CacheAwarePolicy(Enum):
"""Scheduling policies that are aware of the tree cache.""" """Scheduling policies that are aware of the tree cache."""
...@@ -293,6 +296,7 @@ class PrefillAdder: ...@@ -293,6 +296,7 @@ class PrefillAdder:
self.can_run_list = [] self.can_run_list = []
self.new_chunked_req = None self.new_chunked_req = None
self.log_hit_tokens = 0 self.log_hit_tokens = 0
# TODO(lsyin): report the real input tokens excluding page alignment
self.log_input_tokens = 0 self.log_input_tokens = 0
if running_batch is not None: if running_batch is not None:
...@@ -323,6 +327,9 @@ class PrefillAdder: ...@@ -323,6 +327,9 @@ class PrefillAdder:
- self.cur_rem_token_offset - self.cur_rem_token_offset
) )
def ceil_paged_tokens(self, tokens: int) -> int:
return -(-tokens // self.page_size) * self.page_size
def budget_state(self): def budget_state(self):
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0: if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
return AddReqResult.NO_TOKEN return AddReqResult.NO_TOKEN
...@@ -334,9 +341,12 @@ class PrefillAdder: ...@@ -334,9 +341,12 @@ class PrefillAdder:
return AddReqResult.CONTINUE return AddReqResult.CONTINUE
def _prefill_one_req( def _update_prefill_budget(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int self, prefix_len: int, extend_input_len: int, max_new_tokens: int
): ):
# TODO(lsyin): check this workaround logic, which only ensures the prefill will not out of memory, and may be too conservative
extend_input_len = self.ceil_paged_tokens(extend_input_len)
self.rem_total_token_offset += extend_input_len + max_new_tokens self.rem_total_token_offset += extend_input_len + max_new_tokens
self.cur_rem_token_offset += extend_input_len self.cur_rem_token_offset += extend_input_len
self.rem_input_tokens -= extend_input_len self.rem_input_tokens -= extend_input_len
...@@ -351,7 +361,7 @@ class PrefillAdder: ...@@ -351,7 +361,7 @@ class PrefillAdder:
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
self.can_run_list.append(req) self.can_run_list.append(req)
self._prefill_one_req( self._update_prefill_budget(
0, 0,
req.extend_input_len, req.extend_input_len,
( (
...@@ -373,6 +383,12 @@ class PrefillAdder: ...@@ -373,6 +383,12 @@ class PrefillAdder:
self.tree_cache.dec_lock_ref(last_node) self.tree_cache.dec_lock_ref(last_node)
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool): def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
# Early exit if no enough tokens for the input tokens
if self.ceil_paged_tokens(req.extend_input_len) > min(
self.cur_rem_tokens, self.rem_total_tokens
):
return AddReqResult.NO_TOKEN
def add_req_state(r, insert_sort=False): def add_req_state(r, insert_sort=False):
new_token_ratio = ( new_token_ratio = (
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio 1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
...@@ -382,15 +398,17 @@ class PrefillAdder: ...@@ -382,15 +398,17 @@ class PrefillAdder:
) )
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids) tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
if tokens_left > 0: if tokens_left <= 0:
if not insert_sort: return
self.req_states.append((tokens_left, tokens_occupied))
else: if not insert_sort:
i = 0 self.req_states.append((tokens_left, tokens_occupied))
for i in range(len(self.req_states)): else:
if tokens_left <= self.req_states[i][0]: i = 0
break for i in range(len(self.req_states)):
self.req_states.insert(i, (tokens_left, tokens_occupied)) if tokens_left <= self.req_states[i][0]:
break
self.req_states.insert(i, (tokens_left, tokens_occupied))
if self.req_states is None: if self.req_states is None:
self.req_states = [] self.req_states = []
...@@ -407,13 +425,11 @@ class PrefillAdder: ...@@ -407,13 +425,11 @@ class PrefillAdder:
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids) cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
tokens_freed = 0 tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
decode_steps = ( # tokens_left gives a reservative calculation as the last token is not stored
self.req_states[i + 1][0]
if i + 1 < len(self.req_states)
else tokens_left
)
bs = len(self.req_states) - i bs = len(self.req_states) - i
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0: min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
# reserve tokens for corner cases
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
return AddReqResult.NO_TOKEN return AddReqResult.NO_TOKEN
tokens_freed += tokens_occupied tokens_freed += tokens_occupied
...@@ -423,7 +439,7 @@ class PrefillAdder: ...@@ -423,7 +439,7 @@ class PrefillAdder:
): ):
# Non-chunked prefill # Non-chunked prefill
self.can_run_list.append(req) self.can_run_list.append(req)
self._prefill_one_req( self._update_prefill_budget(
0, 0,
req.extend_input_len, req.extend_input_len,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION), min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
...@@ -439,7 +455,7 @@ class PrefillAdder: ...@@ -439,7 +455,7 @@ class PrefillAdder:
req.fill_ids = req.fill_ids[:trunc_len] req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req) self.can_run_list.append(req)
self.new_chunked_req = req self.new_chunked_req = req
self._prefill_one_req(0, trunc_len, 0) self._update_prefill_budget(0, trunc_len, 0)
return self.budget_state() return self.budget_state()
...@@ -453,7 +469,7 @@ class PrefillAdder: ...@@ -453,7 +469,7 @@ class PrefillAdder:
# adjusting the input_tokens based on host_hit_length and page_size # adjusting the input_tokens based on host_hit_length and page_size
real_input_tokens = req.extend_input_len - req.host_hit_length real_input_tokens = req.extend_input_len - req.host_hit_length
real_input_tokens = -(-real_input_tokens // self.page_size) * self.page_size real_input_tokens = self.ceil_paged_tokens(real_input_tokens)
prefix_len = len(req.prefix_indices) prefix_len = len(req.prefix_indices)
if total_tokens >= self.rem_total_tokens: if total_tokens >= self.rem_total_tokens:
...@@ -475,7 +491,7 @@ class PrefillAdder: ...@@ -475,7 +491,7 @@ class PrefillAdder:
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
prefix_len = len(req.prefix_indices) prefix_len = len(req.prefix_indices)
input_tokens = -(-req.extend_input_len // self.page_size) * self.page_size input_tokens = self.ceil_paged_tokens(req.extend_input_len)
if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0: if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
return AddReqResult.OTHER return AddReqResult.OTHER
...@@ -484,7 +500,7 @@ class PrefillAdder: ...@@ -484,7 +500,7 @@ class PrefillAdder:
# Non-chunked prefill # Non-chunked prefill
self.can_run_list.append(req) self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req( self._update_prefill_budget(
prefix_len, prefix_len,
input_tokens, input_tokens,
min( min(
...@@ -505,6 +521,6 @@ class PrefillAdder: ...@@ -505,6 +521,6 @@ class PrefillAdder:
self.can_run_list.append(req) self.can_run_list.append(req)
self.new_chunked_req = req self.new_chunked_req = req
self.tree_cache.inc_lock_ref(req.last_node) self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0) self._update_prefill_budget(prefix_len, trunc_len, 0)
return self.budget_state() return self.budget_state()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment