"vscode:/vscode.git/clone" did not exist on "978ac410973b38114cb514730aa05fb579926440"
Unverified Commit 25549433 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix prefill OOM due to wrong token calculation when page > 1 (#7397)

parent d6dddc19
......@@ -55,6 +55,9 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
)
IGNORE_EOS_RESERVE_TOKENS = 1
class CacheAwarePolicy(Enum):
"""Scheduling policies that are aware of the tree cache."""
......@@ -293,6 +296,7 @@ class PrefillAdder:
self.can_run_list = []
self.new_chunked_req = None
self.log_hit_tokens = 0
# TODO(lsyin): report the real input tokens excluding page alignment
self.log_input_tokens = 0
if running_batch is not None:
......@@ -323,6 +327,9 @@ class PrefillAdder:
- self.cur_rem_token_offset
)
def ceil_paged_tokens(self, tokens: int) -> int:
return -(-tokens // self.page_size) * self.page_size
def budget_state(self):
if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
return AddReqResult.NO_TOKEN
......@@ -334,9 +341,12 @@ class PrefillAdder:
return AddReqResult.CONTINUE
def _prefill_one_req(
def _update_prefill_budget(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
):
# TODO(lsyin): check this workaround logic, which only ensures the prefill will not out of memory, and may be too conservative
extend_input_len = self.ceil_paged_tokens(extend_input_len)
self.rem_total_token_offset += extend_input_len + max_new_tokens
self.cur_rem_token_offset += extend_input_len
self.rem_input_tokens -= extend_input_len
......@@ -351,7 +361,7 @@ class PrefillAdder:
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
self.can_run_list.append(req)
self._prefill_one_req(
self._update_prefill_budget(
0,
req.extend_input_len,
(
......@@ -373,6 +383,12 @@ class PrefillAdder:
self.tree_cache.dec_lock_ref(last_node)
def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
# Early exit if no enough tokens for the input tokens
if self.ceil_paged_tokens(req.extend_input_len) > min(
self.cur_rem_tokens, self.rem_total_tokens
):
return AddReqResult.NO_TOKEN
def add_req_state(r, insert_sort=False):
new_token_ratio = (
1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
......@@ -382,15 +398,17 @@ class PrefillAdder:
)
tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
if tokens_left > 0:
if not insert_sort:
self.req_states.append((tokens_left, tokens_occupied))
else:
i = 0
for i in range(len(self.req_states)):
if tokens_left <= self.req_states[i][0]:
break
self.req_states.insert(i, (tokens_left, tokens_occupied))
if tokens_left <= 0:
return
if not insert_sort:
self.req_states.append((tokens_left, tokens_occupied))
else:
i = 0
for i in range(len(self.req_states)):
if tokens_left <= self.req_states[i][0]:
break
self.req_states.insert(i, (tokens_left, tokens_occupied))
if self.req_states is None:
self.req_states = []
......@@ -407,13 +425,11 @@ class PrefillAdder:
cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
tokens_freed = 0
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
decode_steps = (
self.req_states[i + 1][0]
if i + 1 < len(self.req_states)
else tokens_left
)
# tokens_left gives a reservative calculation as the last token is not stored
bs = len(self.req_states) - i
if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
# reserve tokens for corner cases
if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
return AddReqResult.NO_TOKEN
tokens_freed += tokens_occupied
......@@ -423,7 +439,7 @@ class PrefillAdder:
):
# Non-chunked prefill
self.can_run_list.append(req)
self._prefill_one_req(
self._update_prefill_budget(
0,
req.extend_input_len,
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
......@@ -439,7 +455,7 @@ class PrefillAdder:
req.fill_ids = req.fill_ids[:trunc_len]
self.can_run_list.append(req)
self.new_chunked_req = req
self._prefill_one_req(0, trunc_len, 0)
self._update_prefill_budget(0, trunc_len, 0)
return self.budget_state()
......@@ -453,7 +469,7 @@ class PrefillAdder:
# adjusting the input_tokens based on host_hit_length and page_size
real_input_tokens = req.extend_input_len - req.host_hit_length
real_input_tokens = -(-real_input_tokens // self.page_size) * self.page_size
real_input_tokens = self.ceil_paged_tokens(real_input_tokens)
prefix_len = len(req.prefix_indices)
if total_tokens >= self.rem_total_tokens:
......@@ -475,7 +491,7 @@ class PrefillAdder:
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
prefix_len = len(req.prefix_indices)
input_tokens = -(-req.extend_input_len // self.page_size) * self.page_size
input_tokens = self.ceil_paged_tokens(req.extend_input_len)
if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
return AddReqResult.OTHER
......@@ -484,7 +500,7 @@ class PrefillAdder:
# Non-chunked prefill
self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(
self._update_prefill_budget(
prefix_len,
input_tokens,
min(
......@@ -505,6 +521,6 @@ class PrefillAdder:
self.can_run_list.append(req)
self.new_chunked_req = req
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)
self._update_prefill_budget(prefix_len, trunc_len, 0)
return self.budget_state()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment