Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0b2aa8a7
Unverified
Commit
0b2aa8a7
authored
Oct 02, 2025
by
Zhang Junda
Committed by
GitHub
Oct 02, 2025
Browse files
Intoduce cpu tensor as metadata to avoid blocking gpu kernel launch (#10720)
Co-authored-by:
hnyls2002
<
lsyincs@gmail.com
>
parent
609f65ba
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
114 additions
and
43 deletions
+114
-43
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
.../sglang/srt/disaggregation/decode_schedule_batch_mixin.py
+1
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+32
-10
python/sglang/srt/mem_cache/allocator.py
python/sglang/srt/mem_cache/allocator.py
+8
-20
python/sglang/srt/mem_cache/allocator_ascend.py
python/sglang/srt/mem_cache/allocator_ascend.py
+7
-3
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+20
-5
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+16
-3
python/sglang/srt/speculative/ngram_utils.py
python/sglang/srt/speculative/ngram_utils.py
+6
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+24
-0
No files found.
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
View file @
0b2aa8a7
...
@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
...
@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
req_pool_indices
,
dtype
=
torch
.
int64
,
device
=
self
.
device
req_pool_indices
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens_cpu
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
)
self
.
orig_seq_lens
=
torch
.
tensor
(
self
.
orig_seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
seq_lens
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
0b2aa8a7
...
@@ -900,6 +900,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -900,6 +900,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
token_type_ids
:
torch
.
Tensor
=
None
# shape: [b], int64
token_type_ids
:
torch
.
Tensor
=
None
# shape: [b], int64
req_pool_indices
:
torch
.
Tensor
=
None
# shape: [b], int64
req_pool_indices
:
torch
.
Tensor
=
None
# shape: [b], int64
seq_lens
:
torch
.
Tensor
=
None
# shape: [b], int64
seq_lens
:
torch
.
Tensor
=
None
# shape: [b], int64
seq_lens_cpu
:
torch
.
Tensor
=
None
# shape: [b], int64
# The output locations of the KV cache
# The output locations of the KV cache
out_cache_loc
:
torch
.
Tensor
=
None
# shape: [b], int64
out_cache_loc
:
torch
.
Tensor
=
None
# shape: [b], int64
output_ids
:
torch
.
Tensor
=
None
# shape: [b], int64
output_ids
:
torch
.
Tensor
=
None
# shape: [b], int64
...
@@ -1055,7 +1056,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1055,7 +1056,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def
alloc_paged_token_slots_extend
(
def
alloc_paged_token_slots_extend
(
self
,
self
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens_cpu
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
extend_num_tokens
:
int
,
extend_num_tokens
:
int
,
backup_state
:
bool
=
False
,
backup_state
:
bool
=
False
,
...
@@ -1063,7 +1066,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1063,7 +1066,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Over estimate the number of tokens: assume each request needs a new page.
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens
=
(
num_tokens
=
(
extend_num_tokens
extend_num_tokens
+
len
(
seq_lens
)
*
self
.
token_to_kv_pool_allocator
.
page_size
+
len
(
seq_lens
_cpu
)
*
self
.
token_to_kv_pool_allocator
.
page_size
)
)
self
.
_evict_tree_cache_if_needed
(
num_tokens
)
self
.
_evict_tree_cache_if_needed
(
num_tokens
)
...
@@ -1071,7 +1074,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1071,7 +1074,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
state
=
self
.
token_to_kv_pool_allocator
.
backup_state
()
state
=
self
.
token_to_kv_pool_allocator
.
backup_state
()
out_cache_loc
=
self
.
token_to_kv_pool_allocator
.
alloc_extend
(
out_cache_loc
=
self
.
token_to_kv_pool_allocator
.
alloc_extend
(
prefix_lens
,
seq_lens
,
last_loc
,
extend_num_tokens
prefix_lens
,
prefix_lens_cpu
,
seq_lens
,
seq_lens_cpu
,
last_loc
,
extend_num_tokens
,
)
)
if
out_cache_loc
is
None
:
if
out_cache_loc
is
None
:
error_msg
=
(
error_msg
=
(
...
@@ -1090,6 +1098,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1090,6 +1098,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def
alloc_paged_token_slots_decode
(
def
alloc_paged_token_slots_decode
(
self
,
self
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
backup_state
:
bool
=
False
,
backup_state
:
bool
=
False
,
):
):
...
@@ -1100,7 +1109,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1100,7 +1109,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if
backup_state
:
if
backup_state
:
state
=
self
.
token_to_kv_pool_allocator
.
backup_state
()
state
=
self
.
token_to_kv_pool_allocator
.
backup_state
()
out_cache_loc
=
self
.
token_to_kv_pool_allocator
.
alloc_decode
(
seq_lens
,
last_loc
)
out_cache_loc
=
self
.
token_to_kv_pool_allocator
.
alloc_decode
(
seq_lens
,
seq_lens_cpu
,
last_loc
)
if
out_cache_loc
is
None
:
if
out_cache_loc
is
None
:
error_msg
=
(
error_msg
=
(
f
"Decode out of memory. Try to lower your batch size.
\n
"
f
"Decode out of memory. Try to lower your batch size.
\n
"
...
@@ -1169,6 +1180,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1169,6 +1180,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
).
to
(
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
self
.
device
,
non_blocking
=
True
)
)
self
.
seq_lens_cpu
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
)
if
not
decoder_out_cache_loc
:
if
not
decoder_out_cache_loc
:
self
.
out_cache_loc
=
torch
.
zeros
(
0
,
dtype
=
torch
.
int64
).
to
(
self
.
out_cache_loc
=
torch
.
zeros
(
0
,
dtype
=
torch
.
int64
).
to
(
...
@@ -1217,12 +1229,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1217,12 +1229,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
).
to
(
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
self
.
device
,
non_blocking
=
True
)
)
seq_lens_cpu_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int64
)
orig_seq_lens_tensor
=
torch
.
tensor
(
orig_seq_lens
,
dtype
=
torch
.
int32
).
to
(
orig_seq_lens_tensor
=
torch
.
tensor
(
orig_seq_lens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
self
.
device
,
non_blocking
=
True
)
)
prefix_lens_tensor
=
torch
.
tensor
(
prefix_lens_tensor
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
prefix_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
)
prefix_lens_cpu_tensor
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int64
)
token_type_ids_tensor
=
None
token_type_ids_tensor
=
None
if
len
(
token_type_ids
)
>
0
:
if
len
(
token_type_ids
)
>
0
:
...
@@ -1349,13 +1363,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1349,13 +1363,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens_tensor
,
prefix_lens_tensor
,
)
)
out_cache_loc
=
self
.
alloc_paged_token_slots_extend
(
out_cache_loc
=
self
.
alloc_paged_token_slots_extend
(
prefix_lens_tensor
,
seq_lens_tensor
,
last_loc
,
extend_num_tokens
prefix_lens_tensor
,
prefix_lens_cpu_tensor
,
seq_lens_tensor
,
seq_lens_cpu_tensor
,
last_loc
,
extend_num_tokens
,
)
)
# Set fields
# Set fields
self
.
input_ids
=
input_ids_tensor
self
.
input_ids
=
input_ids_tensor
self
.
req_pool_indices
=
req_pool_indices_tensor
self
.
req_pool_indices
=
req_pool_indices_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens_cpu
=
seq_lens_cpu_tensor
self
.
orig_seq_lens
=
orig_seq_lens_tensor
self
.
orig_seq_lens
=
orig_seq_lens_tensor
self
.
out_cache_loc
=
out_cache_loc
self
.
out_cache_loc
=
out_cache_loc
self
.
input_embeds
=
(
self
.
input_embeds
=
(
...
@@ -1498,7 +1518,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1498,7 +1518,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
)
retracted_reqs
=
[]
retracted_reqs
=
[]
seq_lens_cpu
=
self
.
seq_lens
.
cpu
().
numpy
()
first_iter
=
True
first_iter
=
True
while
first_iter
or
(
while
first_iter
or
(
not
self
.
check_decode_mem
(
selected_indices
=
sorted_indices
)
not
self
.
check_decode_mem
(
selected_indices
=
sorted_indices
)
...
@@ -1548,7 +1567,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1548,7 +1567,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def
release_req
(
self
,
idx
:
int
,
remaing_req_count
:
int
,
server_args
:
ServerArgs
):
def
release_req
(
self
,
idx
:
int
,
remaing_req_count
:
int
,
server_args
:
ServerArgs
):
req
=
self
.
reqs
[
idx
]
req
=
self
.
reqs
[
idx
]
seq_lens_cpu
=
self
.
seq_lens
.
cpu
()
.
numpy
()
seq_lens_cpu
=
self
.
seq_lens
_
cpu
.
numpy
()
if
server_args
.
disaggregation_mode
==
"decode"
:
if
server_args
.
disaggregation_mode
==
"decode"
:
req
.
offload_kv_cache
(
req
.
offload_kv_cache
(
...
@@ -1592,6 +1611,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1592,6 +1611,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
forward_mode
=
ForwardMode
.
IDLE
self
.
forward_mode
=
ForwardMode
.
IDLE
self
.
input_ids
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
input_ids
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seq_lens_cpu
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
)
self
.
orig_seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
orig_seq_lens
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
out_cache_loc
=
torch
.
empty
(
0
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
req_pool_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
...
@@ -1651,10 +1671,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1651,10 +1671,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if
self
.
enable_overlap
:
if
self
.
enable_overlap
:
# Do not use in-place operations in the overlap mode
# Do not use in-place operations in the overlap mode
self
.
seq_lens
=
self
.
seq_lens
+
1
self
.
seq_lens
=
self
.
seq_lens
+
1
self
.
seq_lens_cpu
=
self
.
seq_lens_cpu
+
1
self
.
orig_seq_lens
=
self
.
orig_seq_lens
+
1
self
.
orig_seq_lens
=
self
.
orig_seq_lens
+
1
else
:
else
:
# A faster in-place version
# A faster in-place version
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens_cpu
.
add_
(
1
)
self
.
orig_seq_lens
.
add_
(
1
)
self
.
orig_seq_lens
.
add_
(
1
)
self
.
seq_lens_sum
+=
bs
self
.
seq_lens_sum
+=
bs
...
@@ -1673,7 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1673,7 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
req_pool_indices
,
self
.
seq_lens
-
2
self
.
req_pool_indices
,
self
.
seq_lens
-
2
]
]
self
.
out_cache_loc
=
self
.
alloc_paged_token_slots_decode
(
self
.
out_cache_loc
=
self
.
alloc_paged_token_slots_decode
(
self
.
seq_lens
,
last_loc
self
.
seq_lens
,
self
.
seq_lens_cpu
,
last_loc
)
)
self
.
req_to_token_pool
.
write
(
self
.
req_to_token_pool
.
write
(
...
@@ -1719,6 +1741,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1719,6 +1741,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
multimodal_inputs
=
[
self
.
multimodal_inputs
[
i
]
for
i
in
keep_indices
]
self
.
multimodal_inputs
=
[
self
.
multimodal_inputs
[
i
]
for
i
in
keep_indices
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[
keep_indices_device
]
self
.
req_pool_indices
=
self
.
req_pool_indices
[
keep_indices_device
]
self
.
seq_lens
=
self
.
seq_lens
[
keep_indices_device
]
self
.
seq_lens
=
self
.
seq_lens
[
keep_indices_device
]
self
.
seq_lens_cpu
=
self
.
seq_lens_cpu
[
keep_indices
]
self
.
orig_seq_lens
=
self
.
orig_seq_lens
[
keep_indices_device
]
self
.
orig_seq_lens
=
self
.
orig_seq_lens
[
keep_indices_device
]
self
.
out_cache_loc
=
None
self
.
out_cache_loc
=
None
self
.
seq_lens_sum
=
self
.
seq_lens
.
sum
().
item
()
self
.
seq_lens_sum
=
self
.
seq_lens
.
sum
().
item
()
...
@@ -1759,6 +1782,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1759,6 +1782,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
[
self
.
req_pool_indices
,
other
.
req_pool_indices
]
)
)
self
.
seq_lens
=
torch
.
cat
([
self
.
seq_lens
,
other
.
seq_lens
])
self
.
seq_lens
=
torch
.
cat
([
self
.
seq_lens
,
other
.
seq_lens
])
self
.
seq_lens_cpu
=
torch
.
cat
([
self
.
seq_lens_cpu
,
other
.
seq_lens_cpu
])
self
.
orig_seq_lens
=
torch
.
cat
([
self
.
orig_seq_lens
,
other
.
orig_seq_lens
])
self
.
orig_seq_lens
=
torch
.
cat
([
self
.
orig_seq_lens
,
other
.
orig_seq_lens
])
self
.
out_cache_loc
=
None
self
.
out_cache_loc
=
None
self
.
seq_lens_sum
+=
other
.
seq_lens_sum
self
.
seq_lens_sum
+=
other
.
seq_lens_sum
...
@@ -1802,9 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1802,9 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
sampling_info
.
grammars
=
None
self
.
sampling_info
.
grammars
=
None
seq_lens_cpu
=
(
seq_lens_cpu
=
(
seq_lens_cpu_cache
seq_lens_cpu_cache
if
seq_lens_cpu_cache
is
not
None
else
self
.
seq_lens_cpu
if
seq_lens_cpu_cache
is
not
None
else
self
.
seq_lens
.
cpu
()
)
)
global
bid
global
bid
...
...
python/sglang/srt/mem_cache/allocator.py
View file @
0b2aa8a7
...
@@ -27,7 +27,7 @@ import triton
...
@@ -27,7 +27,7 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.mem_cache.memory_pool
import
SWAKVPool
from
sglang.srt.mem_cache.memory_pool
import
SWAKVPool
from
sglang.srt.utils
import
get_bool_env_var
,
next_power_of_2
from
sglang.srt.utils
import
get_bool_env_var
,
get_num_new_pages
,
next_power_of_2
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.mem_cache.memory_pool
import
KVCache
from
sglang.srt.mem_cache.memory_pool
import
KVCache
...
@@ -294,7 +294,6 @@ def alloc_extend_kernel(
...
@@ -294,7 +294,6 @@ def alloc_extend_kernel(
last_loc_ptr
,
last_loc_ptr
,
free_page_ptr
,
free_page_ptr
,
out_indices
,
out_indices
,
ret_values
,
bs_upper
:
tl
.
constexpr
,
bs_upper
:
tl
.
constexpr
,
page_size
:
tl
.
constexpr
,
page_size
:
tl
.
constexpr
,
max_num_extend_tokens
:
tl
.
constexpr
,
max_num_extend_tokens
:
tl
.
constexpr
,
...
@@ -323,13 +322,6 @@ def alloc_extend_kernel(
...
@@ -323,13 +322,6 @@ def alloc_extend_kernel(
sum_num_new_pages
=
tl
.
sum
(
num_new_pages
)
sum_num_new_pages
=
tl
.
sum
(
num_new_pages
)
new_page_start_loc
=
sum_num_new_pages
-
num_page_start_loc_self
new_page_start_loc
=
sum_num_new_pages
-
num_page_start_loc_self
# Return value
if
pid
==
tl
.
num_programs
(
0
)
-
1
:
merged_value
=
(
sum_num_new_pages
.
to
(
tl
.
int64
))
<<
32
|
sum_extend_lens
.
to
(
tl
.
int64
)
tl
.
store
(
ret_values
,
merged_value
)
# Part 1: fill the old partial page
# Part 1: fill the old partial page
last_loc
=
tl
.
load
(
last_loc_ptr
+
pid
)
last_loc
=
tl
.
load
(
last_loc_ptr
+
pid
)
num_part1
=
(
num_part1
=
(
...
@@ -381,7 +373,6 @@ def alloc_decode_kernel(
...
@@ -381,7 +373,6 @@ def alloc_decode_kernel(
last_loc_ptr
,
last_loc_ptr
,
free_page_ptr
,
free_page_ptr
,
out_indices
,
out_indices
,
ret_values
,
bs_upper
:
tl
.
constexpr
,
bs_upper
:
tl
.
constexpr
,
page_size
:
tl
.
constexpr
,
page_size
:
tl
.
constexpr
,
):
):
...
@@ -404,10 +395,6 @@ def alloc_decode_kernel(
...
@@ -404,10 +395,6 @@ def alloc_decode_kernel(
sum_num_new_pages
=
tl
.
sum
(
num_new_pages
)
sum_num_new_pages
=
tl
.
sum
(
num_new_pages
)
new_page_start_loc
=
sum_num_new_pages
-
num_page_start_loc_self
new_page_start_loc
=
sum_num_new_pages
-
num_page_start_loc_self
# Return value
if
pid
==
tl
.
num_programs
(
0
)
-
1
:
tl
.
store
(
ret_values
,
sum_num_new_pages
)
if
num_page_start_loc_self
==
0
:
if
num_page_start_loc_self
==
0
:
last_loc
=
tl
.
load
(
last_loc_ptr
+
pid
)
last_loc
=
tl
.
load
(
last_loc_ptr
+
pid
)
tl
.
store
(
out_indices
+
pid
,
last_loc
+
1
)
tl
.
store
(
out_indices
+
pid
,
last_loc
+
1
)
...
@@ -438,7 +425,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
...
@@ -438,7 +425,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
super
().
__init__
(
size
,
page_size
,
dtype
,
device
,
kvcache
,
need_sort
)
super
().
__init__
(
size
,
page_size
,
dtype
,
device
,
kvcache
,
need_sort
)
self
.
num_pages
=
size
//
page_size
self
.
num_pages
=
size
//
page_size
self
.
debug_mode
=
get_bool_env_var
(
"SGLANG_DEBUG_MEMORY_POOL"
)
self
.
debug_mode
=
get_bool_env_var
(
"SGLANG_DEBUG_MEMORY_POOL"
)
self
.
ret_values
=
torch
.
empty
((),
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
seen_max_num_extend_tokens_next_power_of_2
=
1
self
.
seen_max_num_extend_tokens_next_power_of_2
=
1
self
.
clear
()
self
.
clear
()
...
@@ -468,7 +454,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
...
@@ -468,7 +454,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
def
alloc_extend
(
def
alloc_extend
(
self
,
self
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens_cpu
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
extend_num_tokens
:
int
,
extend_num_tokens
:
int
,
):
):
...
@@ -497,7 +485,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
...
@@ -497,7 +485,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
last_loc
,
last_loc
,
self
.
free_pages
,
self
.
free_pages
,
out_indices
,
out_indices
,
self
.
ret_values
,
next_power_of_2
(
bs
),
next_power_of_2
(
bs
),
self
.
page_size
,
self
.
page_size
,
self
.
seen_max_num_extend_tokens_next_power_of_2
,
self
.
seen_max_num_extend_tokens_next_power_of_2
,
...
@@ -506,8 +493,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
...
@@ -506,8 +493,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if
self
.
debug_mode
:
if
self
.
debug_mode
:
assert
len
(
torch
.
unique
(
out_indices
))
==
len
(
out_indices
)
assert
len
(
torch
.
unique
(
out_indices
))
==
len
(
out_indices
)
merged_value
=
self
.
ret_values
.
item
()
num_new_pages
=
get_num_new_pages
(
prefix_lens_cpu
,
seq_lens_cpu
,
self
.
page_size
)
num_new_pages
=
merged_value
>>
32
if
num_new_pages
>
len
(
self
.
free_pages
):
if
num_new_pages
>
len
(
self
.
free_pages
):
return
None
return
None
...
@@ -517,6 +503,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
...
@@ -517,6 +503,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
def
alloc_decode
(
def
alloc_decode
(
self
,
self
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
):
):
if
self
.
debug_mode
:
if
self
.
debug_mode
:
...
@@ -534,7 +521,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
...
@@ -534,7 +521,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
last_loc
,
last_loc
,
self
.
free_pages
,
self
.
free_pages
,
out_indices
,
out_indices
,
self
.
ret_values
,
next_power_of_2
(
bs
),
next_power_of_2
(
bs
),
self
.
page_size
,
self
.
page_size
,
)
)
...
@@ -542,7 +528,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
...
@@ -542,7 +528,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
if
self
.
debug_mode
:
if
self
.
debug_mode
:
assert
len
(
torch
.
unique
(
out_indices
))
==
len
(
out_indices
)
assert
len
(
torch
.
unique
(
out_indices
))
==
len
(
out_indices
)
num_new_pages
=
self
.
ret_values
.
item
()
num_new_pages
=
get_num_new_pages
(
seq_lens_cpu
-
1
,
seq_lens_cpu
,
self
.
page_size
,
decode
=
True
)
if
num_new_pages
>
len
(
self
.
free_pages
):
if
num_new_pages
>
len
(
self
.
free_pages
):
return
None
return
None
...
...
python/sglang/srt/mem_cache/allocator_ascend.py
View file @
0b2aa8a7
...
@@ -69,7 +69,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
...
@@ -69,7 +69,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def
alloc_extend
(
def
alloc_extend
(
self
,
self
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens_cpu
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
extend_num_tokens
:
int
,
extend_num_tokens
:
int
,
):
):
...
@@ -80,8 +82,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
...
@@ -80,8 +82,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
num_new_pages
=
(
num_new_pages
=
(
(
(
(
seq_lens
+
self
.
page_size
-
1
)
//
self
.
page_size
(
seq_lens
_cpu
+
self
.
page_size
-
1
)
//
self
.
page_size
-
(
prefix_lens
+
self
.
page_size
-
1
)
//
self
.
page_size
-
(
prefix_lens
_cpu
+
self
.
page_size
-
1
)
//
self
.
page_size
)
)
.
sum
()
.
sum
()
.
item
()
.
item
()
...
@@ -115,6 +117,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
...
@@ -115,6 +117,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def
alloc_decode
(
def
alloc_decode
(
self
,
self
,
seq_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
):
):
if
self
.
debug_mode
:
if
self
.
debug_mode
:
...
@@ -123,7 +126,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
...
@@ -123,7 +126,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
)
)
need_new_pages
=
(
seq_lens
%
self
.
page_size
==
1
).
int
()
need_new_pages
=
(
seq_lens
%
self
.
page_size
==
1
).
int
()
num_new_pages
=
need_new_pages
.
sum
().
item
()
need_new_pages_cpu
=
(
seq_lens_cpu
%
self
.
page_size
==
1
).
int
()
num_new_pages
=
need_new_pages_cpu
.
sum
().
item
()
if
num_new_pages
>
len
(
self
.
free_pages
):
if
num_new_pages
>
len
(
self
.
free_pages
):
self
.
merge_and_sort_free
()
self
.
merge_and_sort_free
()
...
...
python/sglang/srt/speculative/eagle_info.py
View file @
0b2aa8a7
...
@@ -104,14 +104,21 @@ class EagleVerifyInput(SpecInput):
...
@@ -104,14 +104,21 @@ class EagleVerifyInput(SpecInput):
end_offset
=
batch
.
seq_lens
+
self
.
draft_token_num
end_offset
=
batch
.
seq_lens
+
self
.
draft_token_num
else
:
else
:
prefix_lens
=
batch
.
seq_lens
prefix_lens
=
batch
.
seq_lens
prefix_lens_cpu
=
batch
.
seq_lens_cpu
end_offset
=
prefix_lens
+
self
.
draft_token_num
end_offset
=
prefix_lens
+
self
.
draft_token_num
end_offset_cpu
=
prefix_lens_cpu
+
self
.
draft_token_num
last_loc
=
get_last_loc
(
last_loc
=
get_last_loc
(
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_to_token_pool
.
req_to_token
,
batch
.
req_pool_indices
,
batch
.
req_pool_indices
,
prefix_lens
,
prefix_lens
,
)
)
batch
.
out_cache_loc
=
batch
.
alloc_paged_token_slots_extend
(
batch
.
out_cache_loc
=
batch
.
alloc_paged_token_slots_extend
(
prefix_lens
,
end_offset
,
last_loc
,
len
(
batch
.
input_ids
)
prefix_lens
,
prefix_lens_cpu
,
end_offset
,
end_offset_cpu
,
last_loc
,
len
(
batch
.
input_ids
),
)
)
self
.
last_loc
=
last_loc
self
.
last_loc
=
last_loc
...
@@ -380,6 +387,8 @@ class EagleVerifyInput(SpecInput):
...
@@ -380,6 +387,8 @@ class EagleVerifyInput(SpecInput):
verified_id
=
predict
[
accept_index
]
verified_id
=
predict
[
accept_index
]
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
=
torch
.
full_like
(
self
.
draft_token
,
True
,
dtype
=
torch
.
bool
)
evict_mask
[
accept_index
]
=
False
evict_mask
[
accept_index
]
=
False
accept_length_cpu
=
accept_length
.
cpu
()
accept_length_list
=
accept_length_cpu
.
tolist
()
if
page_size
==
1
:
if
page_size
==
1
:
# TODO: boolean array index leads to a device sync. Remove it.
# TODO: boolean array index leads to a device sync. Remove it.
...
@@ -456,13 +465,15 @@ class EagleVerifyInput(SpecInput):
...
@@ -456,13 +465,15 @@ class EagleVerifyInput(SpecInput):
else
:
else
:
batch
.
out_cache_loc
=
tgt_cache_loc
batch
.
out_cache_loc
=
tgt_cache_loc
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
batch
.
seq_lens_cpu
.
add_
(
accept_length_cpu
+
1
)
draft_input
=
EagleDraftInput
(
draft_input
=
EagleDraftInput
(
hidden_states
=
batch
.
spec_info
.
hidden_states
[
accept_index
],
hidden_states
=
batch
.
spec_info
.
hidden_states
[
accept_index
],
verified_id
=
verified_id
,
verified_id
=
verified_id
,
accept_length
=
accept_length
,
accept_length
=
accept_length
,
accept_length_cpu
=
accept_length
.
to
list
()
,
accept_length_cpu
=
accept_length
_
list
,
seq_lens_for_draft_extend
=
batch
.
seq_lens
,
seq_lens_for_draft_extend
=
batch
.
seq_lens
,
seq_lens_for_draft_extend_cpu
=
batch
.
seq_lens_cpu
,
req_pool_indices_for_draft_extend
=
batch
.
req_pool_indices
,
req_pool_indices_for_draft_extend
=
batch
.
req_pool_indices
,
)
)
...
@@ -485,15 +496,15 @@ class EagleVerifyInput(SpecInput):
...
@@ -485,15 +496,15 @@ class EagleVerifyInput(SpecInput):
next_power_of_2
(
bs
),
next_power_of_2
(
bs
),
)
)
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
batch
.
seq_lens
.
add_
(
accept_length
+
1
)
batch
.
seq_lens_cpu
.
add_
(
accept_length_cpu
+
1
)
accept_length_cpu
=
accept_length
.
tolist
()
if
len
(
unfinished_accept_index
)
>
0
:
if
len
(
unfinished_accept_index
)
>
0
:
unfinished_accept_index
=
torch
.
cat
(
unfinished_accept_index
)
unfinished_accept_index
=
torch
.
cat
(
unfinished_accept_index
)
unfinished_index_device
=
torch
.
tensor
(
unfinished_index_device
=
torch
.
tensor
(
unfinished_index
,
dtype
=
torch
.
int64
,
device
=
predict
.
device
unfinished_index
,
dtype
=
torch
.
int64
,
device
=
predict
.
device
)
)
draft_input_accept_length_cpu
=
[
draft_input_accept_length_cpu
=
[
accept_length_
cpu
[
i
]
for
i
in
unfinished_index
accept_length_
list
[
i
]
for
i
in
unfinished_index
]
]
if
page_size
==
1
or
self
.
topk
==
1
:
if
page_size
==
1
or
self
.
topk
==
1
:
batch
.
out_cache_loc
=
batch
.
out_cache_loc
[
unfinished_accept_index
]
batch
.
out_cache_loc
=
batch
.
out_cache_loc
[
unfinished_accept_index
]
...
@@ -508,6 +519,7 @@ class EagleVerifyInput(SpecInput):
...
@@ -508,6 +519,7 @@ class EagleVerifyInput(SpecInput):
unfinished_index_device
,
unfinished_index_device
,
batch
.
seq_lens
,
batch
.
seq_lens
,
)
)
batch
.
seq_lens_cpu
.
add_
(
accept_length_cpu
+
1
)
filter_finished_cache_loc_kernel
[(
bs
,)](
filter_finished_cache_loc_kernel
[(
bs
,)](
batch
.
out_cache_loc
,
batch
.
out_cache_loc
,
tgt_cache_loc
,
tgt_cache_loc
,
...
@@ -525,6 +537,7 @@ class EagleVerifyInput(SpecInput):
...
@@ -525,6 +537,7 @@ class EagleVerifyInput(SpecInput):
accept_length_cpu
=
draft_input_accept_length_cpu
,
accept_length_cpu
=
draft_input_accept_length_cpu
,
accept_length
=
accept_length
[
unfinished_index_device
],
accept_length
=
accept_length
[
unfinished_index_device
],
seq_lens_for_draft_extend
=
batch
.
seq_lens
[
unfinished_index_device
],
seq_lens_for_draft_extend
=
batch
.
seq_lens
[
unfinished_index_device
],
seq_lens_for_draft_extend_cpu
=
batch
.
seq_lens_cpu
[
unfinished_index
],
req_pool_indices_for_draft_extend
=
batch
.
req_pool_indices
[
req_pool_indices_for_draft_extend
=
batch
.
req_pool_indices
[
unfinished_index_device
unfinished_index_device
],
],
...
@@ -542,7 +555,7 @@ class EagleVerifyInput(SpecInput):
...
@@ -542,7 +555,7 @@ class EagleVerifyInput(SpecInput):
draft_input
=
draft_input
,
draft_input
=
draft_input
,
logits_output
=
logits_output
,
logits_output
=
logits_output
,
verified_id
=
verified_id
,
verified_id
=
verified_id
,
accept_length_per_req_cpu
=
accept_length_
cpu
,
accept_length_per_req_cpu
=
accept_length_
list
,
accepted_indices
=
accept_index
,
accepted_indices
=
accept_index
,
)
)
...
@@ -575,6 +588,7 @@ class EagleDraftInput(SpecInput):
...
@@ -575,6 +588,7 @@ class EagleDraftInput(SpecInput):
# Inputs for draft extend
# Inputs for draft extend
# shape: (b,)
# shape: (b,)
seq_lens_for_draft_extend
:
torch
.
Tensor
=
None
seq_lens_for_draft_extend
:
torch
.
Tensor
=
None
seq_lens_for_draft_extend_cpu
:
torch
.
Tensor
=
None
req_pool_indices_for_draft_extend
:
torch
.
Tensor
=
None
req_pool_indices_for_draft_extend
:
torch
.
Tensor
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
...
@@ -631,6 +645,7 @@ class EagleDraftInput(SpecInput):
...
@@ -631,6 +645,7 @@ class EagleDraftInput(SpecInput):
batch
.
extend_lens
=
[
x
+
1
for
x
in
batch
.
spec_info
.
accept_length_cpu
]
batch
.
extend_lens
=
[
x
+
1
for
x
in
batch
.
spec_info
.
accept_length_cpu
]
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
batch
.
extend_num_tokens
=
sum
(
batch
.
extend_lens
)
batch
.
seq_lens
=
batch
.
spec_info
.
seq_lens_for_draft_extend
batch
.
seq_lens
=
batch
.
spec_info
.
seq_lens_for_draft_extend
batch
.
seq_lens_cpu
=
batch
.
spec_info
.
seq_lens_for_draft_extend_cpu
batch
.
req_pool_indices
=
batch
.
spec_info
.
req_pool_indices_for_draft_extend
batch
.
req_pool_indices
=
batch
.
spec_info
.
req_pool_indices_for_draft_extend
batch
.
return_logprob
=
False
batch
.
return_logprob
=
False
batch
.
return_hidden_states
=
False
batch
.
return_hidden_states
=
False
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
0b2aa8a7
...
@@ -543,6 +543,8 @@ class EAGLEWorker(TpModelWorker):
...
@@ -543,6 +543,8 @@ class EAGLEWorker(TpModelWorker):
batch
.
seq_lens
,
batch
.
seq_lens
,
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
)
)
prefix_lens_cpu
=
batch
.
seq_lens_cpu
seq_lens_cpu
=
batch
.
seq_lens_cpu
+
self
.
speculative_num_steps
extend_num_tokens
=
num_seqs
*
self
.
speculative_num_steps
extend_num_tokens
=
num_seqs
*
self
.
speculative_num_steps
else
:
else
:
# In this case, the last partial page needs to be duplicated.
# In this case, the last partial page needs to be duplicated.
...
@@ -578,14 +580,23 @@ class EAGLEWorker(TpModelWorker):
...
@@ -578,14 +580,23 @@ class EAGLEWorker(TpModelWorker):
self
.
topk
,
self
.
topk
,
self
.
page_size
,
self
.
page_size
,
)
)
prefix_lens_cpu
=
batch
.
seq_lens_cpu
# TODO(lmzheng): remove this device sync
last_page_lens
=
prefix_lens_cpu
%
self
.
page_size
extend_num_tokens
=
torch
.
sum
(
self
.
extend_lens
).
item
()
num_new_pages_per_topk
=
(
last_page_lens
+
self
.
speculative_num_steps
+
self
.
page_size
-
1
)
//
self
.
page_size
seq_lens_cpu
=
(
prefix_lens_cpu
//
self
.
page_size
*
self
.
page_size
+
num_new_pages_per_topk
*
(
self
.
page_size
*
self
.
topk
)
)
extend_num_tokens
=
torch
.
sum
((
seq_lens_cpu
-
prefix_lens_cpu
)).
item
()
out_cache_loc
,
token_to_kv_pool_state_backup
=
(
out_cache_loc
,
token_to_kv_pool_state_backup
=
(
batch
.
alloc_paged_token_slots_extend
(
batch
.
alloc_paged_token_slots_extend
(
prefix_lens
,
prefix_lens
,
prefix_lens_cpu
,
seq_lens
,
seq_lens
,
seq_lens_cpu
,
last_loc
,
last_loc
,
extend_num_tokens
,
extend_num_tokens
,
backup_state
=
True
,
backup_state
=
True
,
...
@@ -1003,6 +1014,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -1003,6 +1014,7 @@ class EAGLEWorker(TpModelWorker):
assert
isinstance
(
batch
.
spec_info
,
EagleDraftInput
)
assert
isinstance
(
batch
.
spec_info
,
EagleDraftInput
)
# Backup fields that will be modified in-place
# Backup fields that will be modified in-place
seq_lens_backup
=
batch
.
seq_lens
.
clone
()
seq_lens_backup
=
batch
.
seq_lens
.
clone
()
seq_lens_cpu_backup
=
batch
.
seq_lens_cpu
.
clone
()
req_pool_indices_backup
=
batch
.
req_pool_indices
req_pool_indices_backup
=
batch
.
req_pool_indices
accept_length_backup
=
batch
.
spec_info
.
accept_length
accept_length_backup
=
batch
.
spec_info
.
accept_length
return_logprob_backup
=
batch
.
return_logprob
return_logprob_backup
=
batch
.
return_logprob
...
@@ -1081,6 +1093,7 @@ class EAGLEWorker(TpModelWorker):
...
@@ -1081,6 +1093,7 @@ class EAGLEWorker(TpModelWorker):
ForwardMode
.
DECODE
if
not
input_is_idle
else
ForwardMode
.
IDLE
ForwardMode
.
DECODE
if
not
input_is_idle
else
ForwardMode
.
IDLE
)
)
batch
.
seq_lens
=
seq_lens_backup
batch
.
seq_lens
=
seq_lens_backup
batch
.
seq_lens_cpu
=
seq_lens_cpu_backup
batch
.
req_pool_indices
=
req_pool_indices_backup
batch
.
req_pool_indices
=
req_pool_indices_backup
batch
.
spec_info
.
accept_length
=
accept_length_backup
batch
.
spec_info
.
accept_length
=
accept_length_backup
batch
.
return_logprob
=
return_logprob_backup
batch
.
return_logprob
=
return_logprob_backup
...
...
python/sglang/srt/speculative/ngram_utils.py
View file @
0b2aa8a7
...
@@ -77,6 +77,7 @@ class NgramVerifyInput(SpecInput):
...
@@ -77,6 +77,7 @@ class NgramVerifyInput(SpecInput):
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
len
(
batch
.
input_ids
))
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
len
(
batch
.
input_ids
))
end_offset
=
batch
.
seq_lens
+
self
.
draft_token_num
end_offset
=
batch
.
seq_lens
+
self
.
draft_token_num
else
:
else
:
# TODO(lsyin): add prefix lens cpu here to support page size > 1
prefix_lens
=
batch
.
seq_lens
prefix_lens
=
batch
.
seq_lens
end_offset
=
prefix_lens
+
self
.
draft_token_num
end_offset
=
prefix_lens
+
self
.
draft_token_num
last_loc
=
get_last_loc
(
last_loc
=
get_last_loc
(
...
@@ -405,10 +406,13 @@ class NgramVerifyInput(SpecInput):
...
@@ -405,10 +406,13 @@ class NgramVerifyInput(SpecInput):
self
.
_fill_requests
(
batch
,
logits_output
)
self
.
_fill_requests
(
batch
,
logits_output
)
self
.
_free_cache
(
batch
,
page_size
)
self
.
_free_cache
(
batch
,
page_size
)
accept_length_cpu
=
self
.
accept_length
.
cpu
()
num_accepted_tokens
=
accept_length_cpu
.
sum
().
item
()
batch
.
seq_lens
.
add_
(
self
.
accept_length
+
1
)
batch
.
seq_lens
.
add_
(
self
.
accept_length
+
1
)
batch
.
seq_lens_
sum
=
torch
.
sum
(
batch
.
seq_lens
).
item
(
)
batch
.
seq_lens_
cpu
.
add_
(
accept_length_cpu
+
1
)
return
logits_output
,
self
.
verified_id
,
self
.
accept
_length
.
sum
().
item
()
return
logits_output
,
self
.
verified_id
,
num_
accept
ed_tokens
def
filter_batch
(
self
,
new_indices
:
torch
.
Tensor
,
has_been_filtered
:
bool
=
True
):
def
filter_batch
(
self
,
new_indices
:
torch
.
Tensor
,
has_been_filtered
:
bool
=
True
):
pass
pass
...
...
python/sglang/srt/utils.py
View file @
0b2aa8a7
...
@@ -3250,6 +3250,30 @@ def get_extend_input_len_swa_limit(
...
@@ -3250,6 +3250,30 @@ def get_extend_input_len_swa_limit(
return
page_size
+
2
*
max
(
sliding_window_size
,
chunked_prefill_size
)
return
page_size
+
2
*
max
(
sliding_window_size
,
chunked_prefill_size
)
def
get_num_new_pages
(
prefix_lens
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
page_size
:
int
,
decode
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Get the number of new pages for the given prefix and sequence lengths. We use cpu tensors to avoid blocking kernel launch.
"""
cpu_device
=
torch
.
device
(
"cpu"
)
assert
prefix_lens
.
device
==
cpu_device
assert
seq_lens
.
device
==
cpu_device
num_pages_after
=
(
seq_lens
+
page_size
-
1
)
//
page_size
num_pages_before
=
(
prefix_lens
+
page_size
-
1
)
//
page_size
num_new_pages
=
num_pages_after
-
num_pages_before
extend_lens
=
seq_lens
-
prefix_lens
sum_num_new_pages
=
torch
.
sum
(
num_new_pages
).
to
(
torch
.
int64
)
if
decode
:
return
sum_num_new_pages
.
item
()
merged_value
=
(
sum_num_new_pages
)
<<
32
|
torch
.
sum
(
extend_lens
).
to
(
torch
.
int64
)
return
merged_value
.
item
()
>>
32
class
CachedKernel
:
class
CachedKernel
:
"""
"""
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment