Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
141e6a05
Unverified
Commit
141e6a05
authored
Oct 29, 2025
by
Lucas Wilkinson
Committed by
GitHub
Oct 28, 2025
Browse files
[Misc] Make reorder batch also separate extends (#27367)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
130aa8cb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
164 additions
and
45 deletions
+164
-45
tests/v1/attention/test_batch_reordering.py
tests/v1/attention/test_batch_reordering.py
+111
-0
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+53
-45
No files found.
tests/v1/attention/test_batch_reordering.py
0 → 100644
View file @
141e6a05
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
numpy
as
np
import
pytest
from
vllm.v1.attention.backends.utils
import
reorder_batch_to_split_decodes_and_prefills
class
MockInputBatch
:
def
__init__
(
self
,
req_ids
,
num_computed_tokens_cpu
):
self
.
req_ids
=
req_ids
self
.
num_computed_tokens_cpu
=
num_computed_tokens_cpu
def
swap_states
(
self
,
i
,
j
):
self
.
req_ids
[
i
],
self
.
req_ids
[
j
]
=
self
.
req_ids
[
j
],
self
.
req_ids
[
i
]
self
.
num_computed_tokens_cpu
[
i
],
self
.
num_computed_tokens_cpu
[
j
]
=
(
self
.
num_computed_tokens_cpu
[
j
],
self
.
num_computed_tokens_cpu
[
i
],
)
class
MockSchedulerOutput
:
def
__init__
(
self
,
num_scheduled_tokens
):
self
.
num_scheduled_tokens
=
num_scheduled_tokens
@
dataclass
class
ReorderTestCase
:
requests
:
list
[
tuple
[
int
,
int
]]
# (num_scheduled_tokens, num_computed_tokens)
expected_order
:
list
[
int
]
expected_modified
:
bool
decode_threshold
:
int
=
1
# Test cases for batch reordering
REORDER_TEST_CASES
=
{
"all_decodes"
:
ReorderTestCase
(
requests
=
[(
1
,
10
),
(
1
,
20
),
(
1
,
30
)],
expected_order
=
[
0
,
1
,
2
],
expected_modified
=
False
,
),
"all_prefills"
:
ReorderTestCase
(
requests
=
[(
100
,
100
),
(
200
,
200
),
(
300
,
300
)],
expected_order
=
[
0
,
1
,
2
],
expected_modified
=
False
,
),
"mixed_interleaved"
:
ReorderTestCase
(
requests
=
[(
100
,
100
),
(
1
,
10
),
(
200
,
200
),
(
1
,
20
)],
expected_order
=
[
3
,
1
,
2
,
0
],
# Only swap 0↔3, keep 1 and 2 in place
expected_modified
=
True
,
),
"already_ordered"
:
ReorderTestCase
(
requests
=
[(
1
,
10
),
(
1
,
20
),
(
100
,
100
),
(
200
,
200
)],
expected_order
=
[
0
,
1
,
2
,
3
],
expected_modified
=
False
,
),
"single_request"
:
ReorderTestCase
(
requests
=
[(
1
,
10
)],
expected_order
=
[
0
],
expected_modified
=
False
,
),
"higher_threshold"
:
ReorderTestCase
(
requests
=
[(
2
,
10
),
(
3
,
20
),
(
5
,
30
),
(
6
,
40
)],
expected_order
=
[
0
,
1
,
2
,
3
],
expected_modified
=
False
,
decode_threshold
=
4
,
),
"decodes_at_end"
:
ReorderTestCase
(
requests
=
[(
100
,
100
),
(
200
,
200
),
(
1
,
10
),
(
1
,
20
)],
expected_order
=
[
2
,
3
,
0
,
1
],
expected_modified
=
True
,
),
"decode_extend_prefill"
:
ReorderTestCase
(
requests
=
[(
100
,
100
),
(
10
,
50
),
(
1
,
10
)],
expected_order
=
[
2
,
1
,
0
],
expected_modified
=
True
,
),
"extend_prefill_only"
:
ReorderTestCase
(
requests
=
[(
100
,
100
),
(
10
,
50
),
(
200
,
200
),
(
20
,
75
)],
expected_order
=
[
3
,
1
,
2
,
0
],
# Only swap 0↔3, keep 1 and 2 in place
expected_modified
=
True
,
),
}
@
pytest
.
mark
.
parametrize
(
"test_case"
,
REORDER_TEST_CASES
.
values
(),
ids
=
REORDER_TEST_CASES
.
keys
()
)
def
test_reorder_batch_to_split_decodes_and_prefills
(
test_case
:
ReorderTestCase
):
req_ids
=
[
f
"r
{
i
}
"
for
i
in
range
(
len
(
test_case
.
requests
))]
num_computed_tokens
=
np
.
array
([
r
[
1
]
for
r
in
test_case
.
requests
],
dtype
=
np
.
int32
)
num_scheduled_tokens
=
{
f
"r
{
i
}
"
:
r
[
0
]
for
i
,
r
in
enumerate
(
test_case
.
requests
)}
input_batch
=
MockInputBatch
(
req_ids
,
num_computed_tokens
)
scheduler_output
=
MockSchedulerOutput
(
num_scheduled_tokens
)
modified
=
reorder_batch_to_split_decodes_and_prefills
(
input_batch
,
scheduler_output
,
decode_threshold
=
test_case
.
decode_threshold
)
expected_req_ids
=
[
f
"r
{
i
}
"
for
i
in
test_case
.
expected_order
]
assert
modified
==
test_case
.
expected_modified
,
(
f
"Expected modified=
{
test_case
.
expected_modified
}
, got
{
modified
}
"
)
assert
input_batch
.
req_ids
==
expected_req_ids
,
(
f
"Expected order
{
expected_req_ids
}
, got
{
input_batch
.
req_ids
}
"
)
vllm/v1/attention/backends/utils.py
View file @
141e6a05
...
@@ -795,51 +795,59 @@ def reorder_batch_to_split_decodes_and_prefills(
...
@@ -795,51 +795,59 @@ def reorder_batch_to_split_decodes_and_prefills(
Returns:
Returns:
True if the batch was modified, False otherwise.
True if the batch was modified, False otherwise.
"""
"""
# We now want to reorder the batch so that the "decode" requests are at
# We now want to reorder the batch into decode → extend → prefill order
# the front and the "prefill" requests are at the back using the least
# where:
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
# decode: request with num_scheduled_tokens <= decode_threshold
# requests where attention is likely memory-bound and "prefill" to mean
# extend: non-decode request with existing context
# requests where attention is likely compute-bound, TODO(lucas): figure out
# prefill: non-decode request with no existing context
# a better naming here)
# NOTE for now we loosely use "decode" to mean requests where attention is
decodes
=
[]
# likely memory-bound and "prefill" to mean requests where attention is
prefills
=
[]
# likely compute-bound,
num_decode_tokens
=
0
num_reqs
=
len
(
input_batch
.
req_ids
)
num_prefill_tokens
=
0
num_scheduled_tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
id
]
for
id
in
input_batch
.
req_ids
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
]
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
num_scheduled_tokens_np
=
np
.
array
(
num_scheduled_tokens
)
if
num_tokens
<=
decode_threshold
:
num_computed_tokens_np
=
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
]
decodes
.
append
(
i
)
num_decode_tokens
+=
num_tokens
is_decode
=
num_scheduled_tokens_np
<=
decode_threshold
else
:
is_extend
=
(
~
is_decode
)
&
(
num_computed_tokens_np
>
num_scheduled_tokens_np
)
prefills
.
append
(
i
)
is_prefill
=
(
~
is_decode
)
&
(
num_computed_tokens_np
==
num_scheduled_tokens_np
)
num_prefill_tokens
+=
num_tokens
# Desired order: decode → extend → prefill
# We hope that this is fairly minimal since decodes
req_regions
=
np
.
zeros
(
is_decode
.
shape
,
dtype
=
np
.
int32
)
# 0 = decode by default
# should be around for a number of iterations so hopefully they are
req_regions
[
is_extend
]
=
1
# relatively stationary (and new request are generally appended to the
req_regions
[
is_prefill
]
=
2
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
num_decodes
=
int
(
is_decode
.
sum
())
# the prefills in ascending order. We swap decodes from the "back"
num_extends
=
int
(
is_extend
.
sum
())
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
target_regions
=
np
.
zeros
(
num_reqs
,
dtype
=
np
.
int32
)
# `decodes` and `prefills` are already in ascending order just based on
target_regions
[
num_decodes
:
num_decodes
+
num_extends
]
=
1
# the above loop
target_regions
[
num_decodes
+
num_extends
:]
=
2
num_decodes
=
len
(
decodes
)
num_prefills
=
len
(
prefills
)
needs_swap
=
req_regions
!=
target_regions
modified_batch
=
False
if
not
needs_swap
.
any
():
for
i
in
range
(
1
,
min
(
num_decodes
,
num_prefills
)
+
1
):
return
False
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
# Extract indices that need swapping and sort by target region
decode_idx
=
decodes
[
num_decodes
-
i
]
swap_indices
=
np
.
where
(
needs_swap
)[
0
]
if
decode_idx
<
num_decodes
:
sorted_order
=
np
.
argsort
(
req_regions
[
needs_swap
],
kind
=
"stable"
)
break
dest_indices
=
swap_indices
[
sorted_order
]
input_batch
.
swap_states
(
prefills
[
i
-
1
],
decode_idx
)
src_dest_map
=
{
int
(
src
):
int
(
dst
)
for
src
,
dst
in
zip
(
swap_indices
,
dest_indices
)}
modified_batch
=
True
for
src
in
src_dest_map
:
return
modified_batch
dst
=
src_dest_map
[
src
]
while
src
!=
dst
:
input_batch
.
swap_states
(
src
,
dst
)
# Mark dst as done by updating its destination to itself
next_dst
=
src_dest_map
.
get
(
dst
,
dst
)
src_dest_map
[
dst
]
=
dst
dst
=
next_dst
return
True
def
reshape_query_for_spec_decode
(
query
:
torch
.
Tensor
,
batch_size
:
int
)
->
torch
.
Tensor
:
def
reshape_query_for_spec_decode
(
query
:
torch
.
Tensor
,
batch_size
:
int
)
->
torch
.
Tensor
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment