Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b36afed4
Unverified
Commit
b36afed4
authored
Oct 10, 2025
by
cctry
Committed by
GitHub
Oct 10, 2025
Browse files
Separate allocation logic from scheduler (#11313)
parent
9aa4502d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
544 additions
and
398 deletions
+544
-398
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+9
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+14
-377
python/sglang/srt/mem_cache/common.py
python/sglang/srt/mem_cache/common.py
+479
-0
python/sglang/srt/speculative/eagle_info.py
python/sglang/srt/speculative/eagle_info.py
+11
-6
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+12
-8
python/sglang/srt/speculative/ngram_info.py
python/sglang/srt/speculative/ngram_info.py
+10
-5
test/srt/test_forward_split_prefill.py
test/srt/test_forward_split_prefill.py
+9
-1
No files found.
python/sglang/bench_one_batch.py
View file @
b36afed4
...
...
@@ -51,6 +51,7 @@ import logging
import
multiprocessing
import
os
import
time
from
types
import
SimpleNamespace
from
typing
import
Tuple
import
numpy
as
np
...
...
@@ -257,11 +258,18 @@ def prepare_synthetic_inputs_for_latency_test(
@
torch
.
no_grad
def
extend
(
reqs
,
model_runner
):
# Create dummy tree_cache for benchmarks (no prefix caching, just allocation)
dummy_tree_cache
=
SimpleNamespace
(
page_size
=
1
,
device
=
model_runner
.
device
,
token_to_kv_pool_allocator
=
model_runner
.
token_to_kv_pool_allocator
,
)
batch
=
ScheduleBatch
.
init_new
(
reqs
=
reqs
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
model_runner
.
token_to_kv_pool_allocator
,
tree_cache
=
Non
e
,
tree_cache
=
dummy_tree_cach
e
,
model_config
=
model_runner
.
model_config
,
enable_overlap
=
False
,
spec_algorithm
=
SpeculativeAlgorithm
.
NONE
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
b36afed4
...
...
@@ -45,8 +45,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
import
numpy
as
np
import
torch
import
triton
import
triton.language
as
tl
from
sglang.global_config
import
global_config
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
...
...
@@ -62,6 +60,7 @@ from sglang.srt.mem_cache.allocator import (
)
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
,
SWAChunkCache
from
sglang.srt.mem_cache.common
import
alloc_for_decode
,
alloc_for_extend
from
sglang.srt.mem_cache.memory_pool
import
HybridReqToTokenPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.radix_cache
import
RadixKey
from
sglang.srt.mem_cache.swa_radix_cache
import
SWARadixCache
...
...
@@ -70,7 +69,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
flatten_nested_list
,
support_triton
from
sglang.srt.utils
import
flatten_nested_list
if
TYPE_CHECKING
:
from
sglang.srt.configs.model_config
import
ModelConfig
...
...
@@ -1001,158 +1000,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def
is_empty
(
self
):
return
len
(
self
.
reqs
)
==
0
def
alloc_req_slots
(
self
,
num_reqs
:
int
,
reqs
:
Optional
[
List
[
Req
]]
=
None
):
if
isinstance
(
self
.
req_to_token_pool
,
HybridReqToTokenPool
):
req_pool_indices
=
self
.
req_to_token_pool
.
alloc
(
num_reqs
,
reqs
)
else
:
req_pool_indices
=
self
.
req_to_token_pool
.
alloc
(
num_reqs
)
if
req_pool_indices
is
None
:
raise
RuntimeError
(
"alloc_req_slots runs out of memory. "
"Please set a smaller number for `--max-running-requests`. "
f
"
{
self
.
req_to_token_pool
.
available_size
()
=
}
, "
f
"
{
num_reqs
=
}
, "
)
return
req_pool_indices
def
alloc_token_slots
(
self
,
num_tokens
:
int
,
backup_state
:
bool
=
False
):
self
.
_evict_tree_cache_if_needed
(
num_tokens
)
if
backup_state
:
state
=
self
.
token_to_kv_pool_allocator
.
backup_state
()
out_cache_loc
=
self
.
token_to_kv_pool_allocator
.
alloc
(
num_tokens
)
if
out_cache_loc
is
None
:
phase_str
=
"Prefill"
if
self
.
forward_mode
.
is_extend
()
else
"Decode"
error_msg
=
(
f
"
{
phase_str
}
out of memory. Try to lower your batch size.
\n
"
f
"Try to allocate
{
num_tokens
}
tokens.
\n
"
f
"
{
self
.
_available_and_evictable_str
()
}
"
)
logger
.
error
(
error_msg
)
if
self
.
tree_cache
is
not
None
:
self
.
tree_cache
.
pretty_print
()
raise
RuntimeError
(
error_msg
)
if
backup_state
:
return
out_cache_loc
,
state
else
:
return
out_cache_loc
def
alloc_paged_token_slots_extend
(
self
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens_cpu
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
extend_num_tokens
:
int
,
backup_state
:
bool
=
False
,
):
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens
=
(
extend_num_tokens
+
len
(
seq_lens_cpu
)
*
self
.
token_to_kv_pool_allocator
.
page_size
)
self
.
_evict_tree_cache_if_needed
(
num_tokens
)
if
backup_state
:
state
=
self
.
token_to_kv_pool_allocator
.
backup_state
()
out_cache_loc
=
self
.
token_to_kv_pool_allocator
.
alloc_extend
(
prefix_lens
,
prefix_lens_cpu
,
seq_lens
,
seq_lens_cpu
,
last_loc
,
extend_num_tokens
,
)
if
out_cache_loc
is
None
:
error_msg
=
(
f
"Prefill out of memory. Try to lower your batch size.
\n
"
f
"Try to allocate
{
extend_num_tokens
}
tokens.
\n
"
f
"
{
self
.
_available_and_evictable_str
()
}
"
)
logger
.
error
(
error_msg
)
raise
RuntimeError
(
error_msg
)
if
backup_state
:
return
out_cache_loc
,
state
else
:
return
out_cache_loc
def
alloc_paged_token_slots_decode
(
self
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
backup_state
:
bool
=
False
,
):
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens
=
len
(
seq_lens
)
*
self
.
token_to_kv_pool_allocator
.
page_size
self
.
_evict_tree_cache_if_needed
(
num_tokens
)
if
backup_state
:
state
=
self
.
token_to_kv_pool_allocator
.
backup_state
()
out_cache_loc
=
self
.
token_to_kv_pool_allocator
.
alloc_decode
(
seq_lens
,
seq_lens_cpu
,
last_loc
)
if
out_cache_loc
is
None
:
error_msg
=
(
f
"Decode out of memory. Try to lower your batch size.
\n
"
f
"Try to allocate
{
len
(
seq_lens
)
}
tokens.
\n
"
f
"
{
self
.
_available_and_evictable_str
()
}
"
)
logger
.
error
(
error_msg
)
raise
RuntimeError
(
error_msg
)
if
backup_state
:
return
out_cache_loc
,
state
else
:
return
out_cache_loc
def
write_cache_indices
(
self
,
req_pool_indices
:
List
[
int
],
prefix_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
extend_lens
:
List
[
int
],
out_cache_loc
:
torch
.
Tensor
,
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
seq_lens_tensor
:
torch
.
Tensor
,
extend_lens_tensor
:
torch
.
Tensor
,
prefix_tensors
:
list
[
torch
.
Tensor
],
):
if
support_triton
(
global_server_args_dict
.
get
(
"attention_backend"
)):
prefix_pointers
=
torch
.
tensor
(
[
t
.
data_ptr
()
for
t
in
prefix_tensors
],
device
=
self
.
device
)
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
write_req_to_token_pool_triton
[(
len
(
req_pool_indices
),)](
self
.
req_to_token_pool
.
req_to_token
,
req_pool_indices_tensor
,
prefix_pointers
,
prefix_lens_tensor
,
seq_lens_tensor
,
extend_lens_tensor
,
out_cache_loc
,
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
)
else
:
pt
=
0
for
i
in
range
(
len
(
req_pool_indices
)):
self
.
req_to_token_pool
.
write
(
(
req_pool_indices
[
i
],
slice
(
0
,
prefix_lens
[
i
])),
prefix_tensors
[
i
],
)
self
.
req_to_token_pool
.
write
(
(
req_pool_indices
[
i
],
slice
(
prefix_lens
[
i
],
seq_lens
[
i
])),
out_cache_loc
[
pt
:
pt
+
extend_lens
[
i
]],
)
pt
+=
extend_lens
[
i
]
def
prepare_encoder_info_extend
(
self
,
input_ids
:
List
[
int
],
seq_lens
:
List
[
int
]):
self
.
encoder_lens_cpu
=
[]
self
.
encoder_cached
=
[]
...
...
@@ -1253,10 +1100,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
orig_seq_lens_tensor
=
torch
.
tensor
(
orig_seq_lens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
prefix_lens_tensor
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
prefix_lens_cpu_tensor
=
torch
.
tensor
(
prefix_lens
,
dtype
=
torch
.
int64
)
token_type_ids_tensor
=
None
if
len
(
token_type_ids
)
>
0
:
...
...
@@ -1264,48 +1107,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
sum
(
token_type_ids
,
[]),
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
)
extend_lens_tensor
=
seq_lens_tensor
-
prefix_lens_tensor
# Allocate req slots
bs
=
len
(
self
.
reqs
)
req_pool_indices
=
self
.
alloc_req_slots
(
bs
,
self
.
reqs
)
req_pool_indices_tensor
=
torch
.
tensor
(
req_pool_indices
,
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
)
# Set batch fields needed by alloc_for_extend
self
.
prefix_lens
=
prefix_lens
self
.
extend_lens
=
extend_lens
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens_cpu
=
seq_lens_cpu
self
.
extend_num_tokens
=
extend_num_tokens
# Allocate memory
if
self
.
token_to_kv_pool_allocator
.
page_size
==
1
:
out_cache_loc
=
self
.
alloc_token_slots
(
extend_num_tokens
)
else
:
last_loc
=
[
(
r
.
prefix_indices
[
-
1
:]
if
len
(
r
.
prefix_indices
)
>
0
else
torch
.
tensor
([
-
1
],
device
=
self
.
device
)
)
for
r
in
self
.
reqs
]
out_cache_loc
=
self
.
alloc_paged_token_slots_extend
(
prefix_lens_tensor
,
prefix_lens_cpu_tensor
,
seq_lens_tensor
,
seq_lens_cpu
,
torch
.
cat
(
last_loc
),
extend_num_tokens
,
)
# Write allocated tokens to req_to_token_pool
self
.
write_cache_indices
(
req_pool_indices
,
prefix_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
req_pool_indices_tensor
,
prefix_lens_tensor
,
seq_lens_tensor
,
extend_lens_tensor
,
[
r
.
prefix_indices
for
r
in
reqs
],
out_cache_loc
,
req_pool_indices_tensor
,
req_pool_indices
=
alloc_for_extend
(
self
)
# Set fields
...
...
@@ -1317,12 +1128,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req
.
req_pool_idx
=
req_pool_indices
[
i
]
assert
seq_len
-
pre_len
==
req
.
extend_input_len
if
pre_len
>
0
:
if
isinstance
(
self
.
tree_cache
,
SWAChunkCache
):
self
.
tree_cache
.
evict_swa
(
req
,
pre_len
,
self
.
model_config
.
attention_chunk_size
)
# If input_embeds are available, store them
if
req
.
input_embeds
is
not
None
:
# If req.input_embeds is already a list, append its content directly
...
...
@@ -1414,8 +1219,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
input_ids
=
input_ids_tensor
self
.
req_pool_indices
=
req_pool_indices_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens_cpu
=
seq_lens_cpu
self
.
orig_seq_lens
=
orig_seq_lens_tensor
self
.
out_cache_loc
=
out_cache_loc
self
.
input_embeds
=
(
...
...
@@ -1439,9 +1242,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
token_ids_logprobs
=
[
r
.
token_ids_logprob
for
r
in
reqs
]
self
.
extend_logprob_start_lens
=
[
r
.
extend_logprob_start_len
for
r
in
reqs
]
self
.
extend_num_tokens
=
extend_num_tokens
self
.
prefix_lens
=
prefix_lens
self
.
extend_lens
=
extend_lens
self
.
extend_input_logprob_token_ids
=
extend_input_logprob_token_ids
if
self
.
model_config
.
is_encoder_decoder
:
...
...
@@ -1681,11 +1481,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
output_ids
=
None
if
self
.
model_config
.
is_encoder_decoder
:
locs
=
self
.
encoder_lens
+
self
.
seq_lens
self
.
prepare_encoder_info_decode
()
else
:
locs
=
self
.
seq_lens
.
clone
()
# Allocate memory
self
.
out_cache_loc
=
alloc_for_decode
(
self
,
token_per_req
=
1
)
# Update seq_lens after allocation
if
self
.
enable_overlap
:
# Do not use in-place operations in the overlap mode
self
.
seq_lens
=
self
.
seq_lens
+
1
...
...
@@ -1698,28 +1499,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
orig_seq_lens
.
add_
(
1
)
self
.
seq_lens_sum
+=
bs
# free memory
if
isinstance
(
self
.
tree_cache
,
SWAChunkCache
):
for
req
in
self
.
reqs
:
self
.
tree_cache
.
evict_swa
(
req
,
req
.
seqlen
-
1
,
self
.
model_config
.
attention_chunk_size
)
# Allocate memory
if
self
.
token_to_kv_pool_allocator
.
page_size
==
1
:
self
.
out_cache_loc
=
self
.
alloc_token_slots
(
bs
)
else
:
last_loc
=
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_pool_indices
,
self
.
seq_lens
-
2
]
self
.
out_cache_loc
=
self
.
alloc_paged_token_slots_decode
(
self
.
seq_lens
,
self
.
seq_lens_cpu
,
last_loc
)
self
.
req_to_token_pool
.
write
(
(
self
.
req_pool_indices
,
locs
),
self
.
out_cache_loc
.
to
(
torch
.
int32
)
)
def
filter_batch
(
self
,
chunked_req_to_exclude
:
Optional
[
Union
[
Req
,
List
[
Req
]]]
=
None
,
...
...
@@ -1940,23 +1719,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
else
:
return
self
.
token_to_kv_pool_allocator
.
available_size
()
>=
num_tokens
def
_available_and_evictable_str
(
self
)
->
str
:
if
self
.
is_hybrid
:
full_available_size
=
self
.
token_to_kv_pool_allocator
.
full_available_size
()
swa_available_size
=
self
.
token_to_kv_pool_allocator
.
swa_available_size
()
full_evictable_size
=
self
.
tree_cache
.
full_evictable_size
()
swa_evictable_size
=
self
.
tree_cache
.
swa_evictable_size
()
return
(
f
"Available full tokens:
{
full_available_size
+
full_evictable_size
}
(
{
full_available_size
=
}
+
{
full_evictable_size
=
}
)
\n
"
f
"Available swa tokens:
{
swa_available_size
+
swa_evictable_size
}
(
{
swa_available_size
=
}
+
{
swa_evictable_size
=
}
)
\n
"
f
"Full LRU list evictable size:
{
self
.
tree_cache
.
full_lru_list_evictable_size
()
}
\n
"
f
"SWA LRU list evictable size:
{
self
.
tree_cache
.
swa_lru_list_evictable_size
()
}
\n
"
)
else
:
available_size
=
self
.
token_to_kv_pool_allocator
.
available_size
()
evictable_size
=
self
.
tree_cache
.
evictable_size
()
return
f
"Available tokens:
{
available_size
+
evictable_size
}
(
{
available_size
=
}
+
{
evictable_size
=
}
)
\n
"
def
__str__
(
self
):
return
(
f
"ScheduleBatch(forward_mode=
{
self
.
forward_mode
.
name
if
self
.
forward_mode
else
'None'
}
, "
...
...
@@ -2038,128 +1800,3 @@ class ModelWorkerBatch:
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only
:
bool
=
False
@
triton
.
jit
def
write_req_to_token_pool_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices
,
prefix_tensors
,
pre_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices
+
pid
)
pre_len
=
tl
.
load
(
pre_lens
+
pid
)
seq_len
=
tl
.
load
(
seq_lens
+
pid
)
prefix_tensor
=
tl
.
load
(
prefix_tensors
+
pid
).
to
(
tl
.
pointer_type
(
tl
.
int64
))
# write prefix
num_loop
=
tl
.
cdiv
(
pre_len
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
mask
=
offset
<
pre_len
value
=
tl
.
load
(
prefix_tensor
+
offset
,
mask
=
mask
)
tl
.
store
(
req_to_token_ptr
+
req_pool_index
*
req_to_token_ptr_stride
+
offset
,
value
,
mask
=
mask
,
)
# NOTE: This can be slow for large bs
cumsum_start
=
tl
.
cast
(
0
,
tl
.
int64
)
for
i
in
range
(
pid
):
cumsum_start
+=
tl
.
load
(
extend_lens
+
i
)
num_loop
=
tl
.
cdiv
(
seq_len
-
pre_len
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
mask
=
offset
<
(
seq_len
-
pre_len
)
value
=
tl
.
load
(
out_cache_loc
+
cumsum_start
+
offset
,
mask
=
mask
)
tl
.
store
(
req_to_token_ptr
+
req_pool_index
*
req_to_token_ptr_stride
+
offset
+
pre_len
,
value
,
mask
=
mask
,
)
def
get_last_loc
(
req_to_token
:
torch
.
Tensor
,
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
(
global_server_args_dict
[
"attention_backend"
]
!=
"ascend"
and
global_server_args_dict
[
"attention_backend"
]
!=
"torch_native"
):
impl
=
get_last_loc_triton
else
:
impl
=
get_last_loc_torch
return
impl
(
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
)
def
get_last_loc_torch
(
req_to_token
:
torch
.
Tensor
,
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
torch
.
where
(
prefix_lens_tensor
>
0
,
req_to_token
[
req_pool_indices_tensor
,
prefix_lens_tensor
-
1
],
torch
.
full_like
(
prefix_lens_tensor
,
-
1
),
)
@
triton
.
jit
def
get_last_loc_kernel
(
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
,
result
,
num_tokens
,
req_to_token_stride
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
pid
*
BLOCK_SIZE
mask
=
offset
<
num_tokens
prefix_lens
=
tl
.
load
(
prefix_lens_tensor
+
offset
,
mask
=
mask
,
other
=
0
)
req_pool_indices
=
tl
.
load
(
req_pool_indices_tensor
+
offset
,
mask
=
mask
,
other
=
0
)
token_mask
=
prefix_lens
>
0
token_index
=
req_pool_indices
*
req_to_token_stride
+
(
prefix_lens
-
1
)
tokens
=
tl
.
load
(
req_to_token
+
token_index
,
mask
=
token_mask
,
other
=-
1
)
tl
.
store
(
result
+
offset
,
tokens
,
mask
=
mask
)
def
get_last_loc_triton
(
req_to_token
:
torch
.
Tensor
,
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
BLOCK_SIZE
=
256
num_tokens
=
prefix_lens_tensor
.
shape
[
0
]
result
=
torch
.
empty_like
(
prefix_lens_tensor
)
grid
=
(
triton
.
cdiv
(
num_tokens
,
BLOCK_SIZE
),)
get_last_loc_kernel
[
grid
](
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
,
result
,
num_tokens
,
req_to_token
.
stride
(
0
),
BLOCK_SIZE
,
)
return
result
python/sglang/srt/mem_cache/common.py
0 → 100644
View file @
b36afed4
from
__future__
import
annotations
import
logging
from
typing
import
TYPE_CHECKING
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.mem_cache.allocator
import
SWATokenToKVPoolAllocator
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
,
SWAChunkCache
from
sglang.srt.mem_cache.memory_pool
import
HybridReqToTokenPool
,
ReqToTokenPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
support_triton
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
Req
,
ScheduleBatch
logger
=
logging
.
getLogger
(
__name__
)
GLOBAL_SERVER_ARGS_KEYS
=
[
"attention_backend"
]
global_server_args_dict
=
{
k
:
getattr
(
ServerArgs
,
k
)
for
k
in
GLOBAL_SERVER_ARGS_KEYS
}
@
triton
.
jit
def
write_req_to_token_pool_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices
,
prefix_tensors
,
pre_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices
+
pid
)
pre_len
=
tl
.
load
(
pre_lens
+
pid
)
seq_len
=
tl
.
load
(
seq_lens
+
pid
)
prefix_tensor
=
tl
.
load
(
prefix_tensors
+
pid
).
to
(
tl
.
pointer_type
(
tl
.
int64
))
# write prefix
num_loop
=
tl
.
cdiv
(
pre_len
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
mask
=
offset
<
pre_len
value
=
tl
.
load
(
prefix_tensor
+
offset
,
mask
=
mask
)
tl
.
store
(
req_to_token_ptr
+
req_pool_index
*
req_to_token_ptr_stride
+
offset
,
value
,
mask
=
mask
,
)
# NOTE: This can be slow for large bs
cumsum_start
=
tl
.
cast
(
0
,
tl
.
int64
)
for
i
in
range
(
pid
):
cumsum_start
+=
tl
.
load
(
extend_lens
+
i
)
num_loop
=
tl
.
cdiv
(
seq_len
-
pre_len
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
mask
=
offset
<
(
seq_len
-
pre_len
)
value
=
tl
.
load
(
out_cache_loc
+
cumsum_start
+
offset
,
mask
=
mask
)
tl
.
store
(
req_to_token_ptr
+
req_pool_index
*
req_to_token_ptr_stride
+
offset
+
pre_len
,
value
,
mask
=
mask
,
)
def
write_cache_indices
(
out_cache_loc
:
torch
.
Tensor
,
req_pool_indices_tensor
:
torch
.
Tensor
,
req_pool_indices_cpu
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
prefix_lens_cpu
:
torch
.
Tensor
,
seq_lens_tensor
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
extend_lens_tensor
:
torch
.
Tensor
,
extend_lens_cpu
:
torch
.
Tensor
,
prefix_tensors
:
list
[
torch
.
Tensor
],
req_to_token_pool
:
ReqToTokenPool
,
):
if
support_triton
(
global_server_args_dict
.
get
(
"attention_backend"
)):
prefix_pointers
=
torch
.
tensor
(
[
t
.
data_ptr
()
for
t
in
prefix_tensors
],
device
=
req_to_token_pool
.
device
,
)
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
write_req_to_token_pool_triton
[(
req_pool_indices_tensor
.
shape
[
0
],)](
req_to_token_pool
.
req_to_token
,
req_pool_indices_tensor
,
prefix_pointers
,
prefix_lens_tensor
,
seq_lens_tensor
,
extend_lens_tensor
,
out_cache_loc
,
req_to_token_pool
.
req_to_token
.
shape
[
1
],
)
else
:
pt
=
0
for
i
in
range
(
req_pool_indices_cpu
.
shape
[
0
]):
req_idx
=
req_pool_indices_cpu
[
i
].
item
()
prefix_len
=
prefix_lens_cpu
[
i
].
item
()
seq_len
=
seq_lens_cpu
[
i
].
item
()
extend_len
=
extend_lens_cpu
[
i
].
item
()
req_to_token_pool
.
write
(
(
req_idx
,
slice
(
0
,
prefix_len
)),
prefix_tensors
[
i
],
)
req_to_token_pool
.
write
(
(
req_idx
,
slice
(
prefix_len
,
seq_len
)),
out_cache_loc
[
pt
:
pt
+
extend_len
],
)
pt
+=
extend_len
def
get_last_loc
(
req_to_token
:
torch
.
Tensor
,
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
(
global_server_args_dict
[
"attention_backend"
]
!=
"ascend"
and
global_server_args_dict
[
"attention_backend"
]
!=
"torch_native"
):
impl
=
get_last_loc_triton
else
:
impl
=
get_last_loc_torch
return
impl
(
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
)
def
get_last_loc_torch
(
req_to_token
:
torch
.
Tensor
,
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
return
torch
.
where
(
prefix_lens_tensor
>
0
,
req_to_token
[
req_pool_indices_tensor
,
prefix_lens_tensor
-
1
],
torch
.
full_like
(
prefix_lens_tensor
,
-
1
),
)
@
triton
.
jit
def
get_last_loc_kernel
(
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
,
result
,
num_tokens
,
req_to_token_stride
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
pid
*
BLOCK_SIZE
mask
=
offset
<
num_tokens
prefix_lens
=
tl
.
load
(
prefix_lens_tensor
+
offset
,
mask
=
mask
,
other
=
0
)
req_pool_indices
=
tl
.
load
(
req_pool_indices_tensor
+
offset
,
mask
=
mask
,
other
=
0
)
token_mask
=
prefix_lens
>
0
token_index
=
req_pool_indices
*
req_to_token_stride
+
(
prefix_lens
-
1
)
tokens
=
tl
.
load
(
req_to_token
+
token_index
,
mask
=
token_mask
,
other
=-
1
)
tl
.
store
(
result
+
offset
,
tokens
,
mask
=
mask
)
def
get_last_loc_triton
(
req_to_token
:
torch
.
Tensor
,
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
BLOCK_SIZE
=
256
num_tokens
=
prefix_lens_tensor
.
shape
[
0
]
result
=
torch
.
empty_like
(
prefix_lens_tensor
)
grid
=
(
triton
.
cdiv
(
num_tokens
,
BLOCK_SIZE
),)
get_last_loc_kernel
[
grid
](
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
,
result
,
num_tokens
,
req_to_token
.
stride
(
0
),
BLOCK_SIZE
,
)
return
result
def
alloc_token_slots
(
tree_cache
:
BasePrefixCache
,
num_tokens
:
int
,
backup_state
:
bool
=
False
,
):
allocator
=
tree_cache
.
token_to_kv_pool_allocator
evict_from_tree_cache
(
tree_cache
,
num_tokens
)
state
=
None
if
backup_state
:
state
=
allocator
.
backup_state
()
out_cache_loc
=
allocator
.
alloc
(
num_tokens
)
if
out_cache_loc
is
None
:
error_msg
=
(
f
"Out of memory. Try to lower your batch size.
\n
"
f
"Try to allocate
{
num_tokens
}
tokens.
\n
"
f
"
{
available_and_evictable_str
(
tree_cache
)
}
"
)
logger
.
error
(
error_msg
)
if
tree_cache
is
not
None
:
tree_cache
.
pretty_print
()
raise
RuntimeError
(
error_msg
)
return
(
out_cache_loc
,
state
)
if
backup_state
else
out_cache_loc
def
evict_from_tree_cache
(
tree_cache
:
BasePrefixCache
|
None
,
num_tokens
:
int
):
if
tree_cache
is
None
:
return
if
isinstance
(
tree_cache
,
(
SWAChunkCache
,
ChunkCache
)):
return
allocator
=
tree_cache
.
token_to_kv_pool_allocator
# Check if this is a hybrid allocator
if
hasattr
(
allocator
,
"full_available_size"
):
# Hybrid allocator
full_available_size
=
allocator
.
full_available_size
()
swa_available_size
=
allocator
.
swa_available_size
()
if
full_available_size
<
num_tokens
or
swa_available_size
<
num_tokens
:
full_num_tokens
=
max
(
0
,
num_tokens
-
full_available_size
)
swa_num_tokens
=
max
(
0
,
num_tokens
-
swa_available_size
)
tree_cache
.
evict
(
full_num_tokens
,
swa_num_tokens
)
else
:
# Standard allocator
if
allocator
.
available_size
()
<
num_tokens
:
tree_cache
.
evict
(
num_tokens
)
def
alloc_paged_token_slots_extend
(
tree_cache
:
BasePrefixCache
,
prefix_lens
:
torch
.
Tensor
,
prefix_lens_cpu
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
extend_num_tokens
:
int
,
backup_state
:
bool
=
False
,
):
# Over estimate the number of tokens: assume each request needs a new page.
allocator
=
tree_cache
.
token_to_kv_pool_allocator
num_tokens
=
extend_num_tokens
+
len
(
seq_lens_cpu
)
*
allocator
.
page_size
evict_from_tree_cache
(
tree_cache
,
num_tokens
)
state
=
None
if
backup_state
:
state
=
allocator
.
backup_state
()
out_cache_loc
=
allocator
.
alloc_extend
(
prefix_lens
,
prefix_lens_cpu
,
seq_lens
,
seq_lens_cpu
,
last_loc
,
extend_num_tokens
,
)
if
out_cache_loc
is
None
:
error_msg
=
(
f
"Prefill out of memory. Try to lower your batch size.
\n
"
f
"Try to allocate
{
extend_num_tokens
}
tokens.
\n
"
f
"
{
available_and_evictable_str
(
tree_cache
)
}
"
)
logger
.
error
(
error_msg
)
if
tree_cache
is
not
None
:
tree_cache
.
pretty_print
()
raise
RuntimeError
(
error_msg
)
return
(
out_cache_loc
,
state
)
if
backup_state
else
out_cache_loc
def
alloc_req_slots
(
req_to_token_pool
:
ReqToTokenPool
,
num_reqs
:
int
,
reqs
:
list
[
Req
]
|
None
,
)
->
list
[
int
]:
"""Allocate request slots from the pool."""
if
isinstance
(
req_to_token_pool
,
HybridReqToTokenPool
):
req_pool_indices
=
req_to_token_pool
.
alloc
(
num_reqs
,
reqs
)
else
:
req_pool_indices
=
req_to_token_pool
.
alloc
(
num_reqs
)
if
req_pool_indices
is
None
:
raise
RuntimeError
(
"alloc_req_slots runs out of memory. "
"Please set a smaller number for `--max-running-requests`. "
f
"
{
req_to_token_pool
.
available_size
()
=
}
, "
f
"
{
num_reqs
=
}
, "
)
return
req_pool_indices
def
alloc_for_extend
(
batch
:
ScheduleBatch
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
int
]]:
"""
Allocate KV cache for extend batch and write to req_to_token_pool.
Returns:
out_cache_loc: allocated cache locations
req_pool_indices_device: request pool indices at a device tensor
req_pool_indices: request pool indices as list
"""
# free out-of-window swa tokens
if
isinstance
(
batch
.
tree_cache
,
SWAChunkCache
):
for
req
,
pre_len
in
zip
(
batch
.
reqs
,
batch
.
prefix_lens
):
batch
.
tree_cache
.
evict_swa
(
req
,
pre_len
,
batch
.
model_config
.
attention_chunk_size
)
bs
=
len
(
batch
.
reqs
)
prefix_tensors
=
[
r
.
prefix_indices
for
r
in
batch
.
reqs
]
# Create tensors for allocation
prefix_lens_cpu
=
torch
.
tensor
(
batch
.
prefix_lens
,
dtype
=
torch
.
int64
)
extend_lens_cpu
=
torch
.
tensor
(
batch
.
extend_lens
,
dtype
=
torch
.
int64
)
prefix_lens_device
=
prefix_lens_cpu
.
to
(
batch
.
device
,
non_blocking
=
True
)
extend_lens_device
=
extend_lens_cpu
.
to
(
batch
.
device
,
non_blocking
=
True
)
# Allocate req slots
req_pool_indices
=
alloc_req_slots
(
batch
.
req_to_token_pool
,
bs
,
batch
.
reqs
)
req_pool_indices_cpu
=
torch
.
tensor
(
req_pool_indices
,
dtype
=
torch
.
int64
)
req_pool_indices_device
=
req_pool_indices_cpu
.
to
(
batch
.
device
,
non_blocking
=
True
)
# Allocate KV cache (throws exception on failure)
if
batch
.
tree_cache
.
page_size
==
1
:
out_cache_loc
=
alloc_token_slots
(
batch
.
tree_cache
,
batch
.
extend_num_tokens
)
else
:
# Paged allocation - build last_loc
last_loc
=
[
(
t
[
-
1
:]
if
len
(
t
)
>
0
else
torch
.
tensor
([
-
1
],
device
=
batch
.
tree_cache
.
device
)
)
for
t
in
prefix_tensors
]
out_cache_loc
=
alloc_paged_token_slots_extend
(
tree_cache
=
batch
.
tree_cache
,
prefix_lens
=
prefix_lens_device
,
prefix_lens_cpu
=
prefix_lens_cpu
,
seq_lens
=
batch
.
seq_lens
,
seq_lens_cpu
=
batch
.
seq_lens_cpu
,
last_loc
=
torch
.
cat
(
last_loc
),
extend_num_tokens
=
batch
.
extend_num_tokens
,
)
# Write to req_to_token_pool
write_cache_indices
(
out_cache_loc
,
req_pool_indices_device
,
req_pool_indices_cpu
,
prefix_lens_device
,
prefix_lens_cpu
,
batch
.
seq_lens
,
batch
.
seq_lens_cpu
,
extend_lens_device
,
extend_lens_cpu
,
prefix_tensors
,
batch
.
req_to_token_pool
,
)
return
out_cache_loc
,
req_pool_indices_device
,
req_pool_indices
def
alloc_paged_token_slots_decode
(
tree_cache
:
BasePrefixCache
,
seq_lens
:
torch
.
Tensor
,
seq_lens_cpu
:
torch
.
Tensor
,
last_loc
:
torch
.
Tensor
,
token_per_req
:
int
=
1
,
)
->
torch
.
Tensor
:
"""Allocate paged KV cache for decode batch."""
allocator
=
tree_cache
.
token_to_kv_pool_allocator
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens
=
len
(
seq_lens
)
*
allocator
.
page_size
evict_from_tree_cache
(
tree_cache
,
num_tokens
)
out_cache_loc
=
allocator
.
alloc_decode
(
seq_lens
,
seq_lens_cpu
,
last_loc
)
if
out_cache_loc
is
None
:
error_msg
=
(
f
"Decode out of memory. Try to lower your batch size.
\n
"
f
"Try to allocate
{
len
(
seq_lens
)
*
token_per_req
}
tokens.
\n
"
f
"
{
available_and_evictable_str
(
tree_cache
)
}
"
)
logger
.
error
(
error_msg
)
if
tree_cache
is
not
None
:
tree_cache
.
pretty_print
()
raise
RuntimeError
(
error_msg
)
return
out_cache_loc
def
alloc_for_decode
(
batch
:
ScheduleBatch
,
token_per_req
:
int
)
->
torch
.
Tensor
:
"""
Allocate KV cache for decode batch and write to req_to_token_pool.
Returns:
out_cache_loc: allocated cache locations
"""
if
isinstance
(
batch
.
tree_cache
,
SWAChunkCache
):
for
req
in
batch
.
reqs
:
batch
.
tree_cache
.
evict_swa
(
req
,
req
.
seqlen
-
1
,
batch
.
model_config
.
attention_chunk_size
)
bs
=
batch
.
seq_lens
.
shape
[
0
]
if
batch
.
tree_cache
.
page_size
==
1
:
# Non-paged allocation
out_cache_loc
=
alloc_token_slots
(
batch
.
tree_cache
,
bs
*
token_per_req
)
else
:
# Paged allocation
last_loc
=
batch
.
req_to_token_pool
.
req_to_token
[
batch
.
req_pool_indices
,
batch
.
seq_lens
-
1
]
seq_lens_next
=
batch
.
seq_lens
+
token_per_req
out_cache_loc
=
alloc_paged_token_slots_decode
(
tree_cache
=
batch
.
tree_cache
,
seq_lens
=
seq_lens_next
,
seq_lens_cpu
=
batch
.
seq_lens_cpu
+
token_per_req
,
last_loc
=
last_loc
,
token_per_req
=
token_per_req
,
)
# Write to req_to_token_pool
if
batch
.
model_config
.
is_encoder_decoder
:
locs
=
batch
.
encoder_lens
+
batch
.
seq_lens
else
:
locs
=
batch
.
seq_lens
.
clone
()
batch
.
req_to_token_pool
.
write
(
(
batch
.
req_pool_indices
,
locs
),
out_cache_loc
.
to
(
torch
.
int32
)
)
return
out_cache_loc
def
available_and_evictable_str
(
tree_cache
)
->
str
:
token_to_kv_pool_allocator
=
tree_cache
.
token_to_kv_pool_allocator
if
isinstance
(
token_to_kv_pool_allocator
,
SWATokenToKVPoolAllocator
):
full_available_size
=
token_to_kv_pool_allocator
.
full_available_size
()
swa_available_size
=
token_to_kv_pool_allocator
.
swa_available_size
()
full_evictable_size
=
tree_cache
.
full_evictable_size
()
swa_evictable_size
=
tree_cache
.
swa_evictable_size
()
return
(
f
"Available full tokens:
{
full_available_size
+
full_evictable_size
}
(
{
full_available_size
=
}
+
{
full_evictable_size
=
}
)
\n
"
f
"Available swa tokens:
{
swa_available_size
+
swa_evictable_size
}
(
{
swa_available_size
=
}
+
{
swa_evictable_size
=
}
)
\n
"
f
"Full LRU list evictable size:
{
tree_cache
.
full_lru_list_evictable_size
()
}
\n
"
f
"SWA LRU list evictable size:
{
tree_cache
.
swa_lru_list_evictable_size
()
}
\n
"
)
else
:
available_size
=
token_to_kv_pool_allocator
.
available_size
()
evictable_size
=
tree_cache
.
evictable_size
()
return
f
"Available tokens:
{
available_size
+
evictable_size
}
(
{
available_size
=
}
+
{
evictable_size
=
}
)
\n
"
python/sglang/srt/speculative/eagle_info.py
View file @
b36afed4
...
...
@@ -10,12 +10,13 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.managers.schedule_batch
import
(
ScheduleBatch
,
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.common
import
(
alloc_paged_token_slots_extend
,
alloc_token_slots
,
get_last_loc
,
global_server_args_dict
,
)
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.speculative.spec_info
import
SpecInput
,
SpecInputType
from
sglang.srt.speculative.spec_utils
import
(
...
...
@@ -100,7 +101,10 @@ class EagleVerifyInput(SpecInput):
batch
.
input_ids
=
self
.
draft_token
if
page_size
==
1
:
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
len
(
batch
.
input_ids
))
batch
.
out_cache_loc
=
alloc_token_slots
(
batch
.
tree_cache
,
len
(
batch
.
input_ids
),
)
end_offset
=
batch
.
seq_lens
+
self
.
draft_token_num
else
:
prefix_lens
=
batch
.
seq_lens
...
...
@@ -112,7 +116,8 @@ class EagleVerifyInput(SpecInput):
batch
.
req_pool_indices
,
prefix_lens
,
)
batch
.
out_cache_loc
=
batch
.
alloc_paged_token_slots_extend
(
batch
.
out_cache_loc
=
alloc_paged_token_slots_extend
(
batch
.
tree_cache
,
prefix_lens
,
prefix_lens_cpu
,
end_offset
,
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
b36afed4
...
...
@@ -14,13 +14,14 @@ from sglang.srt.distributed import (
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
get_token_ids_logprobs
,
get_top_logprobs
from
sglang.srt.managers.schedule_batch
import
(
ScheduleBatch
,
get_last_loc
,
global_server_args_dict
,
)
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.managers.scheduler
import
GenerationBatchResult
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.mem_cache.common
import
(
alloc_paged_token_slots_extend
,
alloc_token_slots
,
get_last_loc
,
)
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
...
...
@@ -541,8 +542,10 @@ class EAGLEWorker(TpModelWorker):
# [ topk 0 ] [ topk 1 ]
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
if
self
.
page_size
==
1
:
out_cache_loc
,
token_to_kv_pool_state_backup
=
batch
.
alloc_token_slots
(
num_seqs
*
self
.
speculative_num_steps
*
self
.
topk
,
backup_state
=
True
out_cache_loc
,
token_to_kv_pool_state_backup
=
alloc_token_slots
(
batch
.
tree_cache
,
num_seqs
*
self
.
speculative_num_steps
*
self
.
topk
,
backup_state
=
True
,
)
else
:
if
self
.
topk
==
1
:
...
...
@@ -601,7 +604,8 @@ class EAGLEWorker(TpModelWorker):
extend_num_tokens
=
torch
.
sum
((
seq_lens_cpu
-
prefix_lens_cpu
)).
item
()
out_cache_loc
,
token_to_kv_pool_state_backup
=
(
batch
.
alloc_paged_token_slots_extend
(
alloc_paged_token_slots_extend
(
batch
.
tree_cache
,
prefix_lens
,
prefix_lens_cpu
,
seq_lens
,
...
...
python/sglang/srt/speculative/ngram_info.py
View file @
b36afed4
...
...
@@ -16,10 +16,11 @@ import torch.nn.functional as F
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.managers.schedule_batch
import
(
ScheduleBatch
,
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
,
global_server_args_dict
from
sglang.srt.mem_cache.common
import
(
alloc_paged_token_slots_extend
,
alloc_token_slots
,
get_last_loc
,
global_server_args_dict
,
)
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.speculative.spec_info
import
SpecInput
,
SpecInputType
...
...
@@ -74,7 +75,10 @@ class NgramVerifyInput(SpecInput):
batch
.
input_ids
=
self
.
draft_token
if
page_size
==
1
:
batch
.
out_cache_loc
=
batch
.
alloc_token_slots
(
len
(
batch
.
input_ids
))
batch
.
out_cache_loc
=
alloc_token_slots
(
batch
.
tree_cache
,
len
(
batch
.
input_ids
),
)
end_offset
=
batch
.
seq_lens
+
self
.
draft_token_num
else
:
# TODO(lsyin): add prefix lens cpu here to support page size > 1
...
...
@@ -87,7 +91,8 @@ class NgramVerifyInput(SpecInput):
batch
.
req_pool_indices
,
prefix_lens
,
)
batch
.
out_cache_loc
=
batch
.
alloc_paged_token_slots_extend
(
batch
.
out_cache_loc
=
alloc_paged_token_slots_extend
(
batch
.
tree_cache
,
prefix_lens
,
prefix_lens_cpu
,
end_offset
,
...
...
test/srt/test_forward_split_prefill.py
View file @
b36afed4
...
...
@@ -8,6 +8,7 @@ python3 test_forward_split_prefill.py
"""
import
unittest
from
types
import
SimpleNamespace
import
numpy
as
np
import
torch
...
...
@@ -95,11 +96,18 @@ class TestForwardSplitPrefill(CustomTestCase):
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
reqs
.
append
(
req
)
# Create dummy tree_cache for tests (no prefix caching, just allocation)
dummy_tree_cache
=
SimpleNamespace
(
page_size
=
1
,
device
=
self
.
model_runner
.
device
,
token_to_kv_pool_allocator
=
self
.
model_runner
.
token_to_kv_pool_allocator
,
)
batch
=
ScheduleBatch
.
init_new
(
reqs
=
reqs
,
req_to_token_pool
=
self
.
model_runner
.
req_to_token_pool
,
token_to_kv_pool_allocator
=
self
.
model_runner
.
token_to_kv_pool_allocator
,
tree_cache
=
Non
e
,
tree_cache
=
dummy_tree_cach
e
,
model_config
=
self
.
model_config
,
enable_overlap
=
False
,
spec_algorithm
=
SpeculativeAlgorithm
.
NONE
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment