Unverified Commit 3b99f23c authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

[Bugfix] Retract not releasing enough memory when page size > 1 (#9989)

parent ee0b3c5b
...@@ -1371,21 +1371,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1371,21 +1371,28 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# TODO (lianmin): Revisit this. It should be seq_len - 1 # TODO (lianmin): Revisit this. It should be seq_len - 1
self.extend_logprob_start_lens.extend([0] * running_bs) self.extend_logprob_start_lens.extend([0] * running_bs)
def new_page_count_next_decode(self): def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
page_size = self.token_to_kv_pool_allocator.page_size page_size = self.token_to_kv_pool_allocator.page_size
requests = (
self.reqs
if selected_indices is None
else [self.reqs[i] for i in selected_indices]
)
if page_size == 1: if page_size == 1:
return len(self.reqs) return len(requests)
# In the decoding phase, the length of a request's KV cache should be # In the decoding phase, the length of a request's KV cache should be
# the total length of the request minus 1 # the total length of the request minus 1
return ( return (
sum(1 for req in self.reqs if req.seqlen % page_size == 0) sum(1 for req in requests if req.seqlen % page_size == 0)
if self.enable_overlap if self.enable_overlap
else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0) else sum(1 for req in requests if (req.seqlen - 1) % page_size == 0)
) )
def check_decode_mem(self, buf_multiplier=1): def check_decode_mem(
self, buf_multiplier=1, selected_indices: Optional[List[int]] = None
):
num_tokens = ( num_tokens = (
self.new_page_count_next_decode() self.new_page_count_next_decode(selected_indices)
* buf_multiplier * buf_multiplier
* self.token_to_kv_pool_allocator.page_size * self.token_to_kv_pool_allocator.page_size
) )
...@@ -1411,34 +1418,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1411,34 +1418,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
reverse=True, reverse=True,
) )
def get_required_tokens(num_reqs: int):
headroom_for_spec_decode = 0
if server_args.speculative_algorithm:
headroom_for_spec_decode += (
num_reqs
* server_args.speculative_eagle_topk
* server_args.speculative_num_steps
+ num_reqs * server_args.speculative_num_draft_tokens
)
return (
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
)
def _get_available_size():
if self.is_hybrid:
return min(
self.token_to_kv_pool_allocator.full_available_size(),
self.token_to_kv_pool_allocator.swa_available_size(),
)
else:
return self.token_to_kv_pool_allocator.available_size()
retracted_reqs = [] retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy() seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True first_iter = True
while ( while first_iter or (
_get_available_size() < get_required_tokens(len(sorted_indices)) not self.check_decode_mem(selected_indices=sorted_indices)
or first_iter
): ):
if len(sorted_indices) == 1: if len(sorted_indices) == 1:
# Corner case: only one request left # Corner case: only one request left
...@@ -1492,10 +1476,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1492,10 +1476,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
else: else:
self.tree_cache.dec_lock_ref(req.last_node) self.tree_cache.dec_lock_ref(req.last_node)
# NOTE(lsyin): we should use the newly evictable memory instantly.
num_tokens = len(sorted_indices) * global_config.retract_decode_steps
self._evict_tree_cache_if_needed(num_tokens)
req.reset_for_retract() req.reset_for_retract()
if len(retracted_reqs) == 0: if len(retracted_reqs) == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment