Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4c0bb411
Unverified
Commit
4c0bb411
authored
Aug 18, 2025
by
fzyzcjy
Committed by
GitHub
Aug 18, 2025
Browse files
Further fix memory pool leak error (#9298)
parent
968e1818
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
16 deletions
+7
-16
python/sglang/srt/mem_cache/allocator.py
python/sglang/srt/mem_cache/allocator.py
+7
-10
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-6
No files found.
python/sglang/srt/mem_cache/allocator.py
View file @
4c0bb411
...
@@ -434,15 +434,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
...
@@ -434,15 +434,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
device
:
str
,
device
:
str
,
kvcache
:
KVCache
,
kvcache
:
KVCache
,
need_sort
:
bool
,
need_sort
:
bool
,
max_num_extend_tokens
:
int
,
):
):
super
().
__init__
(
size
,
page_size
,
dtype
,
device
,
kvcache
,
need_sort
)
super
().
__init__
(
size
,
page_size
,
dtype
,
device
,
kvcache
,
need_sort
)
self
.
num_pages
=
size
//
page_size
self
.
num_pages
=
size
//
page_size
self
.
max_num_extend_tokens_next_power_of_2
=
next_power_of_2
(
max_num_extend_tokens
)
self
.
debug_mode
=
get_bool_env_var
(
"SGLANG_DEBUG_MEMORY_POOL"
)
self
.
debug_mode
=
get_bool_env_var
(
"SGLANG_DEBUG_MEMORY_POOL"
)
self
.
ret_values
=
torch
.
empty
((),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
ret_values
=
torch
.
empty
((),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seen_max_num_extend_tokens_next_power_of_2
=
1
self
.
clear
()
self
.
clear
()
def
alloc
(
self
,
need_size
:
int
):
def
alloc
(
self
,
need_size
:
int
):
...
@@ -480,17 +477,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
...
@@ -480,17 +477,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
(
last_loc
+
1
)
%
self
.
page_size
==
prefix_lens
%
self
.
page_size
(
last_loc
+
1
)
%
self
.
page_size
==
prefix_lens
%
self
.
page_size
)
)
self
.
seen_max_num_extend_tokens_next_power_of_2
=
max
(
self
.
seen_max_num_extend_tokens_next_power_of_2
,
next_power_of_2
(
extend_num_tokens
),
)
bs
=
len
(
prefix_lens
)
bs
=
len
(
prefix_lens
)
if
self
.
need_sort
and
extend_num_tokens
//
self
.
page_size
+
bs
+
1
>
len
(
if
self
.
need_sort
and
extend_num_tokens
//
self
.
page_size
+
bs
+
1
>
len
(
self
.
free_pages
self
.
free_pages
):
):
self
.
merge_and_sort_free
()
self
.
merge_and_sort_free
()
assert
self
.
max_num_extend_tokens_next_power_of_2
>=
extend_num_tokens
,
(
f
"
{
self
.
max_num_extend_tokens_next_power_of_2
=
}
>=
{
extend_num_tokens
=
}
does not hold. "
f
"If this happens in PD, consider letting chunked_prefill_size in D be as large as in P"
)
out_indices
=
torch
.
empty
(
out_indices
=
torch
.
empty
(
(
extend_num_tokens
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
(
extend_num_tokens
,),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
)
...
@@ -503,7 +500,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
...
@@ -503,7 +500,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self
.
ret_values
,
self
.
ret_values
,
next_power_of_2
(
bs
),
next_power_of_2
(
bs
),
self
.
page_size
,
self
.
page_size
,
self
.
max_num_extend_tokens_next_power_of_2
,
self
.
seen_
max_num_extend_tokens_next_power_of_2
,
)
)
if
self
.
debug_mode
:
if
self
.
debug_mode
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
4c0bb411
...
@@ -1353,11 +1353,6 @@ class ModelRunner:
...
@@ -1353,11 +1353,6 @@ class ModelRunner:
# Initialize token_to_kv_pool_allocator
# Initialize token_to_kv_pool_allocator
need_sort
=
self
.
server_args
.
disaggregation_mode
in
(
"decode"
,
"prefill"
)
need_sort
=
self
.
server_args
.
disaggregation_mode
in
(
"decode"
,
"prefill"
)
max_num_extend_tokens
=
(
self
.
server_args
.
chunked_prefill_size
if
self
.
server_args
.
chunked_prefill_size
>
0
else
self
.
server_args
.
max_prefill_tokens
)
if
self
.
token_to_kv_pool_allocator
is
None
:
if
self
.
token_to_kv_pool_allocator
is
None
:
if
self
.
server_args
.
attention_backend
==
"ascend"
:
if
self
.
server_args
.
attention_backend
==
"ascend"
:
self
.
token_to_kv_pool_allocator
=
AscendPagedTokenToKVPoolAllocator
(
self
.
token_to_kv_pool_allocator
=
AscendPagedTokenToKVPoolAllocator
(
...
@@ -1396,7 +1391,6 @@ class ModelRunner:
...
@@ -1396,7 +1391,6 @@ class ModelRunner:
device
=
self
.
device
,
device
=
self
.
device
,
kvcache
=
self
.
token_to_kv_pool
,
kvcache
=
self
.
token_to_kv_pool
,
need_sort
=
need_sort
,
need_sort
=
need_sort
,
max_num_extend_tokens
=
max_num_extend_tokens
,
)
)
else
:
else
:
assert
self
.
is_draft_worker
assert
self
.
is_draft_worker
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment