Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7d004799
Unverified
Commit
7d004799
authored
Oct 02, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 02, 2025
Browse files
Clean up ascend allocator (#11152)
parent
083629c2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
20 deletions
+29
-20
python/sglang/srt/mem_cache/allocator.py
python/sglang/srt/mem_cache/allocator.py
+8
-2
python/sglang/srt/mem_cache/allocator_ascend.py
python/sglang/srt/mem_cache/allocator_ascend.py
+11
-15
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+10
-3
No files found.
python/sglang/srt/mem_cache/allocator.py
View file @
7d004799
...
@@ -493,7 +493,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
...
@@ -493,7 +493,11 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if
self
.
debug_mode
:
if
self
.
debug_mode
:
assert
len
(
torch
.
unique
(
out_indices
))
==
len
(
out_indices
)
assert
len
(
torch
.
unique
(
out_indices
))
==
len
(
out_indices
)
num_new_pages
=
get_num_new_pages
(
prefix_lens_cpu
,
seq_lens_cpu
,
self
.
page_size
)
num_new_pages
=
get_num_new_pages
(
seq_lens
=
seq_lens_cpu
,
page_size
=
self
.
page_size
,
prefix_lens
=
prefix_lens_cpu
,
)
if
num_new_pages
>
len
(
self
.
free_pages
):
if
num_new_pages
>
len
(
self
.
free_pages
):
return
None
return
None
...
@@ -529,7 +533,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
...
@@ -529,7 +533,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
assert
len
(
torch
.
unique
(
out_indices
))
==
len
(
out_indices
)
assert
len
(
torch
.
unique
(
out_indices
))
==
len
(
out_indices
)
num_new_pages
=
get_num_new_pages
(
num_new_pages
=
get_num_new_pages
(
seq_lens_cpu
-
1
,
seq_lens_cpu
,
self
.
page_size
,
decode
=
True
seq_lens
=
seq_lens_cpu
,
page_size
=
self
.
page_size
,
decode
=
True
,
)
)
if
num_new_pages
>
len
(
self
.
free_pages
):
if
num_new_pages
>
len
(
self
.
free_pages
):
return
None
return
None
...
...
python/sglang/srt/mem_cache/allocator_ascend.py
View file @
7d004799
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
import
torch
import
torch
from
sglang.srt.mem_cache.allocator
import
PagedTokenToKVPoolAllocator
from
sglang.srt.mem_cache.allocator
import
PagedTokenToKVPoolAllocator
from
sglang.srt.utils
import
get_num_new_pages
if
TYPE_CHECKING
:
from
sglang.srt.mem_cache.memory_pool
import
KVCache
def
alloc_extend_kernel_ascend
(
def
alloc_extend_kernel_ascend
(
...
@@ -80,13 +76,10 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
...
@@ -80,13 +76,10 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(
last_loc
+
1
)
%
self
.
page_size
==
prefix_lens
%
self
.
page_size
(
last_loc
+
1
)
%
self
.
page_size
==
prefix_lens
%
self
.
page_size
)
)
num_new_pages
=
(
num_new_pages
=
get_num_new_pages
(
(
seq_lens
=
seq_lens_cpu
,
(
seq_lens_cpu
+
self
.
page_size
-
1
)
//
self
.
page_size
page_size
=
self
.
page_size
,
-
(
prefix_lens_cpu
+
self
.
page_size
-
1
)
//
self
.
page_size
prefix_lens
=
prefix_lens_cpu
,
)
.
sum
()
.
item
()
)
)
if
self
.
need_sort
and
num_new_pages
>
len
(
self
.
free_pages
):
if
self
.
need_sort
and
num_new_pages
>
len
(
self
.
free_pages
):
self
.
merge_and_sort_free
()
self
.
merge_and_sort_free
()
...
@@ -125,9 +118,11 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
...
@@ -125,9 +118,11 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
(
last_loc
+
2
)
%
self
.
page_size
==
seq_lens
%
self
.
page_size
(
last_loc
+
2
)
%
self
.
page_size
==
seq_lens
%
self
.
page_size
)
)
need_new_pages
=
(
seq_lens
%
self
.
page_size
==
1
).
int
()
num_new_pages
=
get_num_new_pages
(
need_new_pages_cpu
=
(
seq_lens_cpu
%
self
.
page_size
==
1
).
int
()
seq_lens
=
seq_lens_cpu
,
num_new_pages
=
need_new_pages_cpu
.
sum
().
item
()
page_size
=
self
.
page_size
,
decode
=
True
,
)
if
num_new_pages
>
len
(
self
.
free_pages
):
if
num_new_pages
>
len
(
self
.
free_pages
):
self
.
merge_and_sort_free
()
self
.
merge_and_sort_free
()
...
@@ -135,6 +130,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
...
@@ -135,6 +130,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
if
num_new_pages
>
len
(
self
.
free_pages
):
if
num_new_pages
>
len
(
self
.
free_pages
):
return
None
return
None
need_new_pages
=
(
seq_lens
%
self
.
page_size
==
1
).
int
()
end_new_pages
=
torch
.
cumsum
(
need_new_pages
,
0
)
end_new_pages
=
torch
.
cumsum
(
need_new_pages
,
0
)
start_new_pages
=
end_new_pages
-
need_new_pages
start_new_pages
=
end_new_pages
-
need_new_pages
if
num_new_pages
==
0
:
if
num_new_pages
==
0
:
...
...
python/sglang/srt/utils.py
View file @
7d004799
...
@@ -3251,17 +3251,24 @@ def get_extend_input_len_swa_limit(
...
@@ -3251,17 +3251,24 @@ def get_extend_input_len_swa_limit(
def
get_num_new_pages
(
def
get_num_new_pages
(
prefix_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
page_size
:
int
,
page_size
:
int
,
prefix_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
decode
:
bool
=
False
,
decode
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Get the number of new pages for the given prefix and sequence lengths. We use cpu tensors to avoid blocking kernel launch.
Get the number of new pages for the given prefix and sequence lengths.
We use cpu tensors to avoid blocking kernel launch.
"""
"""
cpu_device
=
torch
.
device
(
"cpu"
)
cpu_device
=
torch
.
device
(
"cpu"
)
assert
prefix_lens
.
device
==
cpu_device
assert
seq_lens
.
device
==
cpu_device
assert
seq_lens
.
device
==
cpu_device
if
prefix_lens
is
None
or
decode
:
# NOTE: Special case for handling decode, which prefix lens is `seq_lens - 1`.
assert
decode
return
(
seq_lens
%
page_size
==
1
).
int
().
sum
().
item
()
assert
prefix_lens
.
device
==
cpu_device
num_pages_after
=
(
seq_lens
+
page_size
-
1
)
//
page_size
num_pages_after
=
(
seq_lens
+
page_size
-
1
)
//
page_size
num_pages_before
=
(
prefix_lens
+
page_size
-
1
)
//
page_size
num_pages_before
=
(
prefix_lens
+
page_size
-
1
)
//
page_size
num_new_pages
=
num_pages_after
-
num_pages_before
num_new_pages
=
num_pages_after
-
num_pages_before
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment