Unverified Commit 3b99f23c authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

[Bugfix] Retract not releasing enough memory when page size > 1 (#9989)

parent ee0b3c5b
......@@ -1371,21 +1371,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# TODO (lianmin): Revisit this. It should be seq_len - 1
self.extend_logprob_start_lens.extend([0] * running_bs)
def new_page_count_next_decode(self):
def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
page_size = self.token_to_kv_pool_allocator.page_size
requests = (
self.reqs
if selected_indices is None
else [self.reqs[i] for i in selected_indices]
)
if page_size == 1:
return len(self.reqs)
return len(requests)
# In the decoding phase, the length of a request's KV cache should be
# the total length of the request minus 1
return (
sum(1 for req in self.reqs if req.seqlen % page_size == 0)
sum(1 for req in requests if req.seqlen % page_size == 0)
if self.enable_overlap
else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
)
def check_decode_mem(self, buf_multiplier=1):
def check_decode_mem(
self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
):
num_tokens = (
self.new_page_count_next_decode()
self.new_page_count_next_decode(selected_indices)
* buf_multiplier
* self.token_to_kv_pool_allocator.page_size
)
......@@ -1411,34 +1418,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
reverse=True,
)
def get_required_tokens(num_reqs: int):
headroom_for_spec_decode = 0
if server_args.speculative_algorithm:
headroom_for_spec_decode += (
num_reqs
* server_args.speculative_eagle_topk
* server_args.speculative_num_steps
+ num_reqs * server_args.speculative_num_draft_tokens
)
return (
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
)
def _get_available_size():
if self.is_hybrid:
return min(
self.token_to_kv_pool_allocator.full_available_size(),
self.token_to_kv_pool_allocator.swa_available_size(),
)
else:
return self.token_to_kv_pool_allocator.available_size()
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
while (
_get_available_size() < get_required_tokens(len(sorted_indices))
or first_iter
while first_iter or (
not self.check_decode_mem(selected_indices=sorted_indices)
):
if len(sorted_indices) == 1:
# Corner case: only one request left
......@@ -1492,10 +1476,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
else:
self.tree_cache.dec_lock_ref(req.last_node)
# NOTE(lsyin): we should use the newly evictable memory instantly.
num_tokens = len(sorted_indices) * global_config.retract_decode_steps
self._evict_tree_cache_if_needed(num_tokens)
req.reset_for_retract()
if len(retracted_reqs) == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment