Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f3764c26
Unverified
Commit
f3764c26
authored
Oct 07, 2025
by
cctry
Committed by
GitHub
Oct 07, 2025
Browse files
Clean match_prefix and prepare_for_extend for mem cache V2 (#11200)
parent
7ba3de0e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
112 additions
and
84 deletions
+112
-84
python/sglang/bench_one_batch.py
python/sglang/bench_one_batch.py
+0
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+110
-79
python/sglang/srt/managers/schedule_policy.py
python/sglang/srt/managers/schedule_policy.py
+1
-1
python/sglang/srt/mem_cache/chunk_cache.py
python/sglang/srt/mem_cache/chunk_cache.py
+1
-1
test/srt/test_forward_split_prefill.py
test/srt/test_forward_split_prefill.py
+0
-1
No files found.
python/sglang/bench_one_batch.py
View file @
f3764c26
...
@@ -204,7 +204,6 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
...
@@ -204,7 +204,6 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
origin_input_ids
=
tmp_input_ids
,
origin_input_ids
=
tmp_input_ids
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
)
)
req
.
prefix_indices
=
[]
req
.
fill_ids
=
req
.
origin_input_ids
req
.
fill_ids
=
req
.
origin_input_ids
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
...
@@ -248,7 +247,6 @@ def prepare_synthetic_inputs_for_latency_test(
...
@@ -248,7 +247,6 @@ def prepare_synthetic_inputs_for_latency_test(
origin_input_ids
=
list
(
input_ids
[
i
]),
origin_input_ids
=
list
(
input_ids
[
i
]),
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
)
)
req
.
prefix_indices
=
[]
req
.
fill_ids
=
req
.
origin_input_ids
req
.
fill_ids
=
req
.
origin_input_ids
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
f3764c26
...
@@ -539,7 +539,7 @@ class Req:
...
@@ -539,7 +539,7 @@ class Req:
# Prefix info
# Prefix info
# The indices to kv cache for the shared prefix.
# The indices to kv cache for the shared prefix.
self
.
prefix_indices
:
torch
.
Tensor
=
[]
self
.
prefix_indices
:
torch
.
Tensor
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
)
# Number of tokens to run prefill.
# Number of tokens to run prefill.
self
.
extend_input_len
=
0
self
.
extend_input_len
=
0
# The relative logprob_start_len in an extend batch
# The relative logprob_start_len in an extend batch
...
@@ -691,11 +691,16 @@ class Req:
...
@@ -691,11 +691,16 @@ class Req:
# Whether request reached finished condition
# Whether request reached finished condition
return
self
.
finished_reason
is
not
None
return
self
.
finished_reason
is
not
None
def
init_next_round_input
(
def
init_next_round_input
(
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
):
self
,
tree_cache
:
Optional
[
BasePrefixCache
]
=
None
,
):
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
input_len
=
len
(
self
.
fill_ids
)
# NOTE: the matched length is at most 1 less than the input length to enable logprob computation
max_prefix_len
=
input_len
-
1
if
self
.
return_logprob
:
max_prefix_len
=
min
(
max_prefix_len
,
self
.
logprob_start_len
)
max_prefix_len
=
max
(
max_prefix_len
,
0
)
token_ids
=
self
.
fill_ids
[:
max_prefix_len
]
if
tree_cache
is
not
None
:
if
tree_cache
is
not
None
:
(
(
self
.
prefix_indices
,
self
.
prefix_indices
,
...
@@ -703,31 +708,11 @@ class Req:
...
@@ -703,31 +708,11 @@ class Req:
self
.
last_host_node
,
self
.
last_host_node
,
self
.
host_hit_length
,
self
.
host_hit_length
,
)
=
tree_cache
.
match_prefix
(
)
=
tree_cache
.
match_prefix
(
key
=
RadixKey
(
key
=
RadixKey
(
token_ids
=
token_ids
,
extra_key
=
self
.
extra_key
)
token_ids
=
self
.
adjust_max_prefix_ids
(),
extra_key
=
self
.
extra_key
),
)
)
self
.
last_matched_prefix_len
=
len
(
self
.
prefix_indices
)
self
.
last_matched_prefix_len
=
len
(
self
.
prefix_indices
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
self
.
extend_input_len
=
len
(
self
.
fill_ids
)
-
len
(
self
.
prefix_indices
)
def
adjust_max_prefix_ids
(
self
):
self
.
fill_ids
=
self
.
origin_input_ids
+
self
.
output_ids
input_len
=
len
(
self
.
fill_ids
)
# FIXME: To work around some bugs in logprob computation, we need to ensure each
# request has at least one token. Later, we can relax this requirement and use `input_len`.
max_prefix_len
=
input_len
-
1
if
self
.
sampling_params
.
max_new_tokens
>
0
:
# Need at least one token to compute logits
max_prefix_len
=
min
(
max_prefix_len
,
input_len
-
1
)
if
self
.
return_logprob
:
max_prefix_len
=
min
(
max_prefix_len
,
self
.
logprob_start_len
)
max_prefix_len
=
max
(
max_prefix_len
,
0
)
return
self
.
fill_ids
[:
max_prefix_len
]
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
# Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
def
init_incremental_detokenize
(
self
):
def
init_incremental_detokenize
(
self
):
first_iter
=
self
.
surr_offset
is
None
or
self
.
read_offset
is
None
first_iter
=
self
.
surr_offset
is
None
or
self
.
read_offset
is
None
...
@@ -808,7 +793,7 @@ class Req:
...
@@ -808,7 +793,7 @@ class Req:
return
return
def
reset_for_retract
(
self
):
def
reset_for_retract
(
self
):
self
.
prefix_indices
=
[]
self
.
prefix_indices
=
torch
.
empty
((
0
,),
dtype
=
torch
.
int64
)
self
.
last_node
=
None
self
.
last_node
=
None
self
.
swa_uuid_for_lock
=
None
self
.
swa_uuid_for_lock
=
None
self
.
extend_input_len
=
0
self
.
extend_input_len
=
0
...
@@ -1124,6 +1109,47 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1124,6 +1109,47 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
else
:
else
:
return
out_cache_loc
return
out_cache_loc
def
write_cache_indices
(
self
,
req_pool_indices
:
List
[
int
],
prefix_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
extend_lens
:
List
[
int
],
out_cache_loc
:
torch
.
Tensor
,
req_pool_indices_tensor
:
torch
.
Tensor
,
prefix_lens_tensor
:
torch
.
Tensor
,
seq_lens_tensor
:
torch
.
Tensor
,
extend_lens_tensor
:
torch
.
Tensor
,
prefix_tensors
:
list
[
torch
.
Tensor
],
):
if
support_triton
(
global_server_args_dict
.
get
(
"attention_backend"
)):
prefix_pointers
=
torch
.
tensor
(
[
t
.
data_ptr
()
for
t
in
prefix_tensors
],
device
=
self
.
device
)
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
write_req_to_token_pool_triton
[(
len
(
req_pool_indices
),)](
self
.
req_to_token_pool
.
req_to_token
,
req_pool_indices_tensor
,
prefix_pointers
,
prefix_lens_tensor
,
seq_lens_tensor
,
extend_lens_tensor
,
out_cache_loc
,
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
)
else
:
pt
=
0
for
i
in
range
(
len
(
req_pool_indices
)):
self
.
req_to_token_pool
.
write
(
(
req_pool_indices
[
i
],
slice
(
0
,
prefix_lens
[
i
])),
prefix_tensors
[
i
],
)
self
.
req_to_token_pool
.
write
(
(
req_pool_indices
[
i
],
slice
(
prefix_lens
[
i
],
seq_lens
[
i
])),
out_cache_loc
[
pt
:
pt
+
extend_lens
[
i
]],
)
pt
+=
extend_lens
[
i
]
def
prepare_encoder_info_extend
(
self
,
input_ids
:
List
[
int
],
seq_lens
:
List
[
int
]):
def
prepare_encoder_info_extend
(
self
,
input_ids
:
List
[
int
],
seq_lens
:
List
[
int
]):
self
.
encoder_lens_cpu
=
[]
self
.
encoder_lens_cpu
=
[]
self
.
encoder_cached
=
[]
self
.
encoder_cached
=
[]
...
@@ -1201,10 +1227,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1201,10 +1227,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def
prepare_for_extend
(
self
):
def
prepare_for_extend
(
self
):
self
.
forward_mode
=
ForwardMode
.
EXTEND
self
.
forward_mode
=
ForwardMode
.
EXTEND
# Allocate req slots
bs
=
len
(
self
.
reqs
)
req_pool_indices
=
self
.
alloc_req_slots
(
bs
,
self
.
reqs
)
# Init tensors
# Init tensors
reqs
=
self
.
reqs
reqs
=
self
.
reqs
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
input_ids
=
[
r
.
fill_ids
[
len
(
r
.
prefix_indices
)
:]
for
r
in
reqs
]
...
@@ -1218,9 +1240,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1218,9 +1240,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
r
.
token_type_ids
for
r
in
reqs
if
r
.
token_type_ids
is
not
None
r
.
token_type_ids
for
r
in
reqs
if
r
.
token_type_ids
is
not
None
]
]
req_pool_indices_tensor
=
torch
.
tensor
(
req_pool_indices
,
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
)
input_ids_tensor
=
torch
.
tensor
(
input_ids_tensor
=
torch
.
tensor
(
list
(
chain
.
from_iterable
(
input_ids
)),
dtype
=
torch
.
int64
list
(
chain
.
from_iterable
(
input_ids
)),
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
)
).
to
(
self
.
device
,
non_blocking
=
True
)
...
@@ -1244,7 +1263,49 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1244,7 +1263,49 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_lens_tensor
=
seq_lens_tensor
-
prefix_lens_tensor
extend_lens_tensor
=
seq_lens_tensor
-
prefix_lens_tensor
# Copy prefix and do some basic check
# Allocate req slots
bs
=
len
(
self
.
reqs
)
req_pool_indices
=
self
.
alloc_req_slots
(
bs
,
self
.
reqs
)
req_pool_indices_tensor
=
torch
.
tensor
(
req_pool_indices
,
dtype
=
torch
.
int64
).
to
(
self
.
device
,
non_blocking
=
True
)
# Allocate memory
if
self
.
token_to_kv_pool_allocator
.
page_size
==
1
:
out_cache_loc
=
self
.
alloc_token_slots
(
extend_num_tokens
)
else
:
last_loc
=
[
(
r
.
prefix_indices
[
-
1
:]
if
len
(
r
.
prefix_indices
)
>
0
else
torch
.
tensor
([
-
1
],
device
=
self
.
device
)
)
for
r
in
self
.
reqs
]
out_cache_loc
=
self
.
alloc_paged_token_slots_extend
(
prefix_lens_tensor
,
prefix_lens_cpu_tensor
,
seq_lens_tensor
,
seq_lens_cpu
,
torch
.
cat
(
last_loc
),
extend_num_tokens
,
)
# Write allocated tokens to req_to_token_pool
self
.
write_cache_indices
(
req_pool_indices
,
prefix_lens
,
seq_lens
,
extend_lens
,
out_cache_loc
,
req_pool_indices_tensor
,
prefix_lens_tensor
,
seq_lens_tensor
,
extend_lens_tensor
,
[
r
.
prefix_indices
for
r
in
reqs
],
)
# Set fields
input_embeds
=
[]
input_embeds
=
[]
extend_input_logprob_token_ids
=
[]
extend_input_logprob_token_ids
=
[]
multimodal_inputs
=
[]
multimodal_inputs
=
[]
...
@@ -1254,9 +1315,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1254,9 +1315,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
assert
seq_len
-
pre_len
==
req
.
extend_input_len
assert
seq_len
-
pre_len
==
req
.
extend_input_len
if
pre_len
>
0
:
if
pre_len
>
0
:
self
.
req_to_token_pool
.
write
(
(
req
.
req_pool_idx
,
slice
(
0
,
pre_len
)),
req
.
prefix_indices
)
if
isinstance
(
self
.
tree_cache
,
SWAChunkCache
):
if
isinstance
(
self
.
tree_cache
,
SWAChunkCache
):
self
.
tree_cache
.
evict_swa
(
self
.
tree_cache
.
evict_swa
(
req
,
pre_len
,
self
.
model_config
.
attention_chunk_size
req
,
pre_len
,
self
.
model_config
.
attention_chunk_size
...
@@ -1351,25 +1409,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1351,25 +1409,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
else
:
else
:
extend_input_logprob_token_ids
=
None
extend_input_logprob_token_ids
=
None
# Allocate memory
if
self
.
token_to_kv_pool_allocator
.
page_size
==
1
:
out_cache_loc
=
self
.
alloc_token_slots
(
extend_num_tokens
)
else
:
last_loc
=
get_last_loc
(
self
.
req_to_token_pool
.
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
,
)
out_cache_loc
=
self
.
alloc_paged_token_slots_extend
(
prefix_lens_tensor
,
prefix_lens_cpu_tensor
,
seq_lens_tensor
,
seq_lens_cpu
,
last_loc
,
extend_num_tokens
,
)
# Set fields
self
.
input_ids
=
input_ids_tensor
self
.
input_ids
=
input_ids_tensor
self
.
req_pool_indices
=
req_pool_indices_tensor
self
.
req_pool_indices
=
req_pool_indices_tensor
self
.
seq_lens
=
seq_lens_tensor
self
.
seq_lens
=
seq_lens_tensor
...
@@ -1402,28 +1441,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1402,28 +1441,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self
.
extend_lens
=
extend_lens
self
.
extend_lens
=
extend_lens
self
.
extend_input_logprob_token_ids
=
extend_input_logprob_token_ids
self
.
extend_input_logprob_token_ids
=
extend_input_logprob_token_ids
# Write to req_to_token_pool
if
support_triton
(
global_server_args_dict
.
get
(
"attention_backend"
)):
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
write_req_to_token_pool_triton
[(
bs
,)](
self
.
req_to_token_pool
.
req_to_token
,
req_pool_indices_tensor
,
prefix_lens_tensor
,
seq_lens_tensor
,
extend_lens_tensor
,
out_cache_loc
,
self
.
req_to_token_pool
.
req_to_token
.
shape
[
1
],
)
else
:
pt
=
0
for
i
in
range
(
bs
):
self
.
req_to_token_pool
.
write
(
(
req_pool_indices
[
i
],
slice
(
prefix_lens
[
i
],
seq_lens
[
i
])),
out_cache_loc
[
pt
:
pt
+
extend_lens
[
i
]],
)
pt
+=
extend_lens
[
i
]
if
self
.
model_config
.
is_encoder_decoder
:
if
self
.
model_config
.
is_encoder_decoder
:
self
.
prepare_encoder_info_extend
(
input_ids
,
seq_lens
)
self
.
prepare_encoder_info_extend
(
input_ids
,
seq_lens
)
...
@@ -2024,6 +2041,7 @@ class ModelWorkerBatch:
...
@@ -2024,6 +2041,7 @@ class ModelWorkerBatch:
def
write_req_to_token_pool_triton
(
def
write_req_to_token_pool_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices
,
req_pool_indices
,
prefix_tensors
,
pre_lens
,
pre_lens
,
seq_lens
,
seq_lens
,
extend_lens
,
extend_lens
,
...
@@ -2036,6 +2054,19 @@ def write_req_to_token_pool_triton(
...
@@ -2036,6 +2054,19 @@ def write_req_to_token_pool_triton(
req_pool_index
=
tl
.
load
(
req_pool_indices
+
pid
)
req_pool_index
=
tl
.
load
(
req_pool_indices
+
pid
)
pre_len
=
tl
.
load
(
pre_lens
+
pid
)
pre_len
=
tl
.
load
(
pre_lens
+
pid
)
seq_len
=
tl
.
load
(
seq_lens
+
pid
)
seq_len
=
tl
.
load
(
seq_lens
+
pid
)
prefix_tensor
=
tl
.
load
(
prefix_tensors
+
pid
).
to
(
tl
.
pointer_type
(
tl
.
int64
))
# write prefix
num_loop
=
tl
.
cdiv
(
pre_len
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
mask
=
offset
<
pre_len
value
=
tl
.
load
(
prefix_tensor
+
offset
,
mask
=
mask
)
tl
.
store
(
req_to_token_ptr
+
req_pool_index
*
req_to_token_ptr_stride
+
offset
,
value
,
mask
=
mask
,
)
# NOTE: This can be slow for large bs
# NOTE: This can be slow for large bs
cumsum_start
=
tl
.
cast
(
0
,
tl
.
int64
)
cumsum_start
=
tl
.
cast
(
0
,
tl
.
int64
)
...
...
python/sglang/srt/managers/schedule_policy.py
View file @
f3764c26
...
@@ -174,7 +174,7 @@ class SchedulePolicy:
...
@@ -174,7 +174,7 @@ class SchedulePolicy:
self
.
waiting_queue_radix_tree
.
reset
()
self
.
waiting_queue_radix_tree
.
reset
()
for
r
in
waiting_queue
:
for
r
in
waiting_queue
:
prefix_ids
=
r
.
adjust_max_prefix
_ids
()
prefix_ids
=
r
.
origin_input_ids
+
r
.
output
_ids
extra_key
=
r
.
extra_key
extra_key
=
r
.
extra_key
# NOTE: the prefix_indices must always be aligned with last_node
# NOTE: the prefix_indices must always be aligned with last_node
...
...
python/sglang/srt/mem_cache/chunk_cache.py
View file @
f3764c26
...
@@ -60,7 +60,7 @@ class ChunkCache(BasePrefixCache):
...
@@ -60,7 +60,7 @@ class ChunkCache(BasePrefixCache):
]
]
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req
.
prefix_indices
=
kv_indices
req
.
prefix_indices
=
kv_indices
.
to
(
dtype
=
torch
.
int64
,
copy
=
True
)
def
evict
(
self
,
num_tokens
:
int
):
def
evict
(
self
,
num_tokens
:
int
):
pass
pass
...
...
test/srt/test_forward_split_prefill.py
View file @
f3764c26
...
@@ -90,7 +90,6 @@ class TestForwardSplitPrefill(CustomTestCase):
...
@@ -90,7 +90,6 @@ class TestForwardSplitPrefill(CustomTestCase):
origin_input_ids
=
list
(
input_ids
[
i
]),
origin_input_ids
=
list
(
input_ids
[
i
]),
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
)
)
req
.
prefix_indices
=
[]
req
.
fill_ids
=
req
.
origin_input_ids
req
.
fill_ids
=
req
.
origin_input_ids
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
extend_input_len
=
len
(
req
.
fill_ids
)
-
len
(
req
.
prefix_indices
)
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
req
.
logprob_start_len
=
len
(
req
.
origin_input_ids
)
-
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment