Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b5d70751
Unverified
Commit
b5d70751
authored
Oct 30, 2025
by
Lucas Wilkinson
Committed by
GitHub
Oct 29, 2025
Browse files
[BugFix] Reordering extend logic fix (#27739)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
b8c48c5d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
8 deletions
+23
-8
tests/v1/attention/test_batch_reordering.py
tests/v1/attention/test_batch_reordering.py
+18
-3
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+5
-5
No files found.
tests/v1/attention/test_batch_reordering.py
View file @
b5d70751
...
@@ -53,7 +53,7 @@ REORDER_TEST_CASES = {
...
@@ -53,7 +53,7 @@ REORDER_TEST_CASES = {
expected_modified
=
True
,
expected_modified
=
True
,
),
),
"already_ordered"
:
ReorderTestCase
(
"already_ordered"
:
ReorderTestCase
(
requests
=
[(
1
,
10
),
(
1
,
20
),
(
100
,
100
),
(
200
,
20
0
)],
requests
=
[(
1
,
10
),
(
1
,
20
),
(
100
,
100
),
(
200
,
0
)],
expected_order
=
[
0
,
1
,
2
,
3
],
expected_order
=
[
0
,
1
,
2
,
3
],
expected_modified
=
False
,
expected_modified
=
False
,
),
),
...
@@ -74,15 +74,30 @@ REORDER_TEST_CASES = {
...
@@ -74,15 +74,30 @@ REORDER_TEST_CASES = {
expected_modified
=
True
,
expected_modified
=
True
,
),
),
"decode_extend_prefill"
:
ReorderTestCase
(
"decode_extend_prefill"
:
ReorderTestCase
(
requests
=
[(
100
,
10
0
),
(
10
,
50
),
(
1
,
10
)],
requests
=
[(
100
,
0
),
(
10
,
50
),
(
1
,
10
)],
expected_order
=
[
2
,
1
,
0
],
expected_order
=
[
2
,
1
,
0
],
expected_modified
=
True
,
expected_modified
=
True
,
),
),
"extend_prefill_only"
:
ReorderTestCase
(
"extend_prefill_only"
:
ReorderTestCase
(
requests
=
[(
100
,
10
0
),
(
10
,
50
),
(
200
,
20
0
),
(
20
,
75
)],
requests
=
[(
100
,
0
),
(
10
,
50
),
(
200
,
0
),
(
20
,
75
)],
expected_order
=
[
3
,
1
,
2
,
0
],
# Only swap 0↔3, keep 1 and 2 in place
expected_order
=
[
3
,
1
,
2
,
0
],
# Only swap 0↔3, keep 1 and 2 in place
expected_modified
=
True
,
expected_modified
=
True
,
),
),
"complicated_mixed_interleaved"
:
ReorderTestCase
(
requests
=
[
(
1
,
20
),
(
1
,
50
),
(
374
,
0
),
(
300
,
20
),
(
1
,
20
),
(
256
,
0
),
(
1
,
5
),
(
27
,
0
),
(
1
,
4
),
],
expected_order
=
[
0
,
1
,
6
,
8
,
4
,
3
,
2
,
7
,
5
],
expected_modified
=
True
,
),
}
}
...
...
vllm/v1/attention/backends/utils.py
View file @
b5d70751
...
@@ -811,8 +811,8 @@ def reorder_batch_to_split_decodes_and_prefills(
...
@@ -811,8 +811,8 @@ def reorder_batch_to_split_decodes_and_prefills(
num_computed_tokens_np
=
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
]
num_computed_tokens_np
=
input_batch
.
num_computed_tokens_cpu
[:
num_reqs
]
is_decode
=
num_scheduled_tokens_np
<=
decode_threshold
is_decode
=
num_scheduled_tokens_np
<=
decode_threshold
is_extend
=
(
~
is_decode
)
&
(
num_computed_tokens_np
>
num_scheduled_tokens_np
)
is_extend
=
(
~
is_decode
)
&
(
num_computed_tokens_np
>
0
)
is_prefill
=
(
~
is_decode
)
&
(
num_computed_tokens_np
==
num_scheduled_tokens_np
)
is_prefill
=
(
~
is_decode
)
&
(
num_computed_tokens_np
==
0
)
# Desired order: decode → extend → prefill
# Desired order: decode → extend → prefill
req_regions
=
np
.
zeros
(
is_decode
.
shape
,
dtype
=
np
.
int32
)
# 0 = decode by default
req_regions
=
np
.
zeros
(
is_decode
.
shape
,
dtype
=
np
.
int32
)
# 0 = decode by default
...
@@ -832,11 +832,11 @@ def reorder_batch_to_split_decodes_and_prefills(
...
@@ -832,11 +832,11 @@ def reorder_batch_to_split_decodes_and_prefills(
return
False
return
False
# Extract indices that need swapping and sort by target region
# Extract indices that need swapping and sort by target region
swap
_indices
=
np
.
where
(
needs_swap
)[
0
]
orig
_indices
=
np
.
where
(
needs_swap
)[
0
]
sorted_order
=
np
.
argsort
(
req_regions
[
needs_swap
],
kind
=
"stable"
)
sorted_order
=
np
.
argsort
(
req_regions
[
needs_swap
],
kind
=
"stable"
)
dest
_indices
=
swap
_indices
[
sorted_order
]
src
_indices
=
orig
_indices
[
sorted_order
]
src_dest_map
=
{
int
(
src
):
int
(
dst
)
for
src
,
dst
in
zip
(
s
wap
_indices
,
dest
_indices
)}
src_dest_map
=
{
int
(
src
):
int
(
dst
)
for
src
,
dst
in
zip
(
s
rc
_indices
,
orig
_indices
)}
for
src
in
src_dest_map
:
for
src
in
src_dest_map
:
dst
=
src_dest_map
[
src
]
dst
=
src_dest_map
[
src
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment