Unverified Commit 4c0bb411 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Further fix memory pool leak error (#9298)

parent 968e1818
...@@ -434,15 +434,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -434,15 +434,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
device: str, device: str,
kvcache: KVCache, kvcache: KVCache,
need_sort: bool, need_sort: bool,
max_num_extend_tokens: int,
): ):
super().__init__(size, page_size, dtype, device, kvcache, need_sort) super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size self.num_pages = size // page_size
self.max_num_extend_tokens_next_power_of_2 = next_power_of_2(
max_num_extend_tokens
)
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL") self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device) self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
self.seen_max_num_extend_tokens_next_power_of_2 = 1
self.clear() self.clear()
def alloc(self, need_size: int): def alloc(self, need_size: int):
...@@ -480,17 +477,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -480,17 +477,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size (last_loc + 1) % self.page_size == prefix_lens % self.page_size
) )
self.seen_max_num_extend_tokens_next_power_of_2 = max(
self.seen_max_num_extend_tokens_next_power_of_2,
next_power_of_2(extend_num_tokens),
)
bs = len(prefix_lens) bs = len(prefix_lens)
if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len( if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
self.free_pages self.free_pages
): ):
self.merge_and_sort_free() self.merge_and_sort_free()
assert self.max_num_extend_tokens_next_power_of_2 >= extend_num_tokens, (
f"{self.max_num_extend_tokens_next_power_of_2=} >= {extend_num_tokens=} does not hold. "
f"If this happens in PD, consider letting chunked_prefill_size in D be as large as in P"
)
out_indices = torch.empty( out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device (extend_num_tokens,), dtype=torch.int64, device=self.device
) )
...@@ -503,7 +500,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): ...@@ -503,7 +500,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self.ret_values, self.ret_values,
next_power_of_2(bs), next_power_of_2(bs),
self.page_size, self.page_size,
self.max_num_extend_tokens_next_power_of_2, self.seen_max_num_extend_tokens_next_power_of_2,
) )
if self.debug_mode: if self.debug_mode:
......
...@@ -1353,11 +1353,6 @@ class ModelRunner: ...@@ -1353,11 +1353,6 @@ class ModelRunner:
# Initialize token_to_kv_pool_allocator # Initialize token_to_kv_pool_allocator
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill") need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
max_num_extend_tokens = (
self.server_args.chunked_prefill_size
if self.server_args.chunked_prefill_size > 0
else self.server_args.max_prefill_tokens
)
if self.token_to_kv_pool_allocator is None: if self.token_to_kv_pool_allocator is None:
if self.server_args.attention_backend == "ascend": if self.server_args.attention_backend == "ascend":
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
...@@ -1396,7 +1391,6 @@ class ModelRunner: ...@@ -1396,7 +1391,6 @@ class ModelRunner:
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, kvcache=self.token_to_kv_pool,
need_sort=need_sort, need_sort=need_sort,
max_num_extend_tokens=max_num_extend_tokens,
) )
else: else:
assert self.is_draft_worker assert self.is_draft_worker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment