"src/vscode:/vscode.git/clone" did not exist on "fb248b678a7c9b4c065f9daf87ecf5f6fa29ac45"
Unverified Commit 4c0bb411 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Further fix memory pool leak error (#9298)

parent 968e1818
......@@ -434,15 +434,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
device: str,
kvcache: KVCache,
need_sort: bool,
max_num_extend_tokens: int,
):
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size
self.max_num_extend_tokens_next_power_of_2 = next_power_of_2(
max_num_extend_tokens
)
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
self.seen_max_num_extend_tokens_next_power_of_2 = 1
self.clear()
def alloc(self, need_size: int):
......@@ -480,17 +477,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
self.seen_max_num_extend_tokens_next_power_of_2 = max(
self.seen_max_num_extend_tokens_next_power_of_2,
next_power_of_2(extend_num_tokens),
)
bs = len(prefix_lens)
if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
self.free_pages
):
self.merge_and_sort_free()
assert self.max_num_extend_tokens_next_power_of_2 >= extend_num_tokens, (
f"{self.max_num_extend_tokens_next_power_of_2=} >= {extend_num_tokens=} does not hold. "
f"If this happens in PD, consider letting chunked_prefill_size in D be as large as in P"
)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
......@@ -503,7 +500,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self.ret_values,
next_power_of_2(bs),
self.page_size,
self.max_num_extend_tokens_next_power_of_2,
self.seen_max_num_extend_tokens_next_power_of_2,
)
if self.debug_mode:
......
......@@ -1353,11 +1353,6 @@ class ModelRunner:
# Initialize token_to_kv_pool_allocator
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
max_num_extend_tokens = (
self.server_args.chunked_prefill_size
if self.server_args.chunked_prefill_size > 0
else self.server_args.max_prefill_tokens
)
if self.token_to_kv_pool_allocator is None:
if self.server_args.attention_backend == "ascend":
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
......@@ -1396,7 +1391,6 @@ class ModelRunner:
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
max_num_extend_tokens=max_num_extend_tokens,
)
else:
assert self.is_draft_worker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment