Unverified Commit a169b9f8 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Fix oom error for large page size (#4913)


Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent 4a63bc32
...@@ -814,11 +814,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -814,11 +814,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
last_loc: torch.Tensor, last_loc: torch.Tensor,
backup_state: bool = False, backup_state: bool = False,
): ):
if ( if self.tree_cache is not None:
self.token_to_kv_pool_allocator.available_size() if (
< len(seq_lens) * self.token_to_kv_pool_allocator.page_size self.token_to_kv_pool_allocator.available_size()
): < len(seq_lens) * self.token_to_kv_pool_allocator.page_size
if self.tree_cache is not None: ):
self.tree_cache.evict( self.tree_cache.evict(
len(seq_lens) * self.token_to_kv_pool_allocator.page_size, len(seq_lens) * self.token_to_kv_pool_allocator.page_size,
) )
...@@ -1116,17 +1116,25 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1116,17 +1116,25 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# TODO (lianmin): Revisit this. It should be seq_len - 1 # TODO (lianmin): Revisit this. It should be seq_len - 1
self.extend_logprob_start_lens.extend([0] * running_bs) self.extend_logprob_start_lens.extend([0] * running_bs)
def check_decode_mem(self, buf_multiplier=1): def new_page_count_next_decode(self):
bs = len(self.reqs) * buf_multiplier page_size = self.token_to_kv_pool_allocator.page_size
if self.token_to_kv_pool_allocator.available_size() >= bs: if page_size == 1:
return True return len(self.reqs)
return sum(1 for req in self.reqs if req.seqlen % page_size == 0)
self.tree_cache.evict(bs) def check_decode_mem(self, buf_multiplier=1):
tokens_required = (
self.new_page_count_next_decode()
* buf_multiplier
* self.token_to_kv_pool_allocator.page_size
)
if self.token_to_kv_pool_allocator.available_size() >= bs: if self.token_to_kv_pool_allocator.available_size() >= tokens_required:
return True return True
return False self.tree_cache.evict(tokens_required)
return self.token_to_kv_pool_allocator.available_size() >= tokens_required
def retract_decode(self, server_args: ServerArgs): def retract_decode(self, server_args: ServerArgs):
"""Retract the decoding requests when there is not enough memory.""" """Retract the decoding requests when there is not enough memory."""
......
...@@ -144,7 +144,7 @@ class TestEAGLEEngine(CustomTestCase): ...@@ -144,7 +144,7 @@ class TestEAGLEEngine(CustomTestCase):
if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST: if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST:
self.assertGreater(acc_length, 3.6) self.assertGreater(acc_length, 3.6)
else: else:
self.assertGreater(acc_length, 2.6) self.assertGreater(acc_length, 2.5)
class TestEAGLEEngineTokenMap(TestEAGLEEngine): class TestEAGLEEngineTokenMap(TestEAGLEEngine):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment