Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4727a8af
Unverified
Commit
4727a8af
authored
Oct 06, 2025
by
Matthew Bonanni
Committed by
GitHub
Oct 06, 2025
Browse files
[Attention] Remove unused reorder_batch method (#24463)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
b8f603ce
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
8 additions
and
64 deletions
+8
-64
tests/v1/logits_processors/test_correctness.py
tests/v1/logits_processors/test_correctness.py
+1
-1
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+2
-4
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+1
-10
vllm/v1/attention/backends/tree_attn.py
vllm/v1/attention/backends/tree_attn.py
+3
-14
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+0
-22
vllm/v1/attention/backends/xformers.py
vllm/v1/attention/backends/xformers.py
+1
-13
No files found.
tests/v1/logits_processors/test_correctness.py
View file @
4727a8af
...
...
@@ -581,7 +581,7 @@ def _generate_fake_step_update(
persistent_batch
[:]
=
persistent_batch
[
0
:
condensed_batch_size
]
if
condensed_batch_size
>
1
:
# Simulate arbitrary
reorder_batch()
in the kernel backend
# Simulate arbitrary
batch ordering
in the kernel backend
# Generate a random number k of non-overlapping swap tuples
k
=
random
.
randint
(
0
,
condensed_batch_size
//
2
)
idxs
=
list
(
range
(
condensed_batch_size
))
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
4727a8af
...
...
@@ -602,8 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
else
:
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
# Decodes are at the front and prefills are at the back.
num_prefills
=
attn_metadata
.
num_prefills
num_decodes
=
attn_metadata
.
num_decodes
if
num_prefills
>
0
:
...
...
@@ -925,8 +924,7 @@ class FlashInferImpl(AttentionImpl):
stride_order
=
FlashInferBackend
.
get_kv_cache_stride_order
()
kv_cache_permute
=
kv_cache
.
permute
(
*
stride_order
)
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
# Decodes are at the front and prefills are at the back.
if
num_prefill_tokens
>
0
:
prefill_wrapper
=
attn_metadata
.
prefill_wrapper
prefill_query
=
query
[
num_decode_tokens
:]
...
...
vllm/v1/attention/backends/flex_attention.py
View file @
4727a8af
...
...
@@ -3,7 +3,7 @@
"""Attention layer with FlexAttention."""
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
Optional
,
Union
import
torch
import
torch._dynamo.decorators
...
...
@@ -38,10 +38,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger
=
init_logger
(
__name__
)
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
create_block_mask_compiled
=
torch
.
compile
(
create_block_mask
,
fullgraph
=
True
,
mode
=
"reduce-overhead"
)
...
...
@@ -600,11 +596,6 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
self
.
q_block_size
:
int
=
16
if
is_torch_equal_or_newer
(
"2.9.0.dev0"
)
else
128
self
.
kv_block_size
:
int
=
16
if
is_torch_equal_or_newer
(
"2.9.0.dev0"
)
else
128
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
return
False
def
build
(
self
,
common_prefix_len
:
int
,
...
...
vllm/v1/attention/backends/tree_attn.py
View file @
4727a8af
...
...
@@ -4,10 +4,11 @@
import
ast
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
...
...
@@ -20,17 +21,10 @@ from vllm.logger import init_logger
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
reorder_batch_to_split_decodes_and_prefills
,
split_decodes_and_prefills
,
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm
import
_custom_ops
as
ops
logger
=
init_logger
(
__name__
)
...
...
@@ -189,12 +183,7 @@ class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadat
device
=
device
,
)
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
return
reorder_batch_to_split_decodes_and_prefills
(
input_batch
,
scheduler_output
,
decode_threshold
=
self
.
tree_attn_bias
.
shape
[
0
]
)
self
.
reorder_batch_threshold
=
self
.
tree_attn_bias
.
shape
[
0
]
def
build
(
self
,
...
...
vllm/v1/attention/backends/utils.py
View file @
4727a8af
...
...
@@ -299,24 +299,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
"""
raise
NotImplementedError
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
input_batch: input batch
scheduler_output: scheduler output.
Returns:
True if the batch was modified, False otherwise.
"""
raise
NotImplementedError
def
build_for_cudagraph_capture
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
M
:
...
...
@@ -828,10 +810,6 @@ def reorder_batch_to_split_decodes_and_prefills(
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
):
num_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
# for now treat 1 scheduled token as "decode" even if it's not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if
num_tokens
<=
decode_threshold
:
decodes
.
append
(
i
)
num_decode_tokens
+=
num_tokens
...
...
vllm/v1/attention/backends/xformers.py
View file @
4727a8af
...
...
@@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention."""
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
Optional
import
torch
...
...
@@ -19,7 +19,6 @@ from vllm.logger import init_logger
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
reorder_batch_to_split_decodes_and_prefills
,
split_decodes_and_prefills
,
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
...
@@ -35,10 +34,6 @@ try:
except
ImportError
:
XFORMERS_AVAILABLE
=
False
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm
import
_custom_ops
as
ops
logger
=
init_logger
(
__name__
)
...
...
@@ -223,13 +218,6 @@ class XFormersAttentionMetadataBuilder(
self
.
_num_decodes
=
0
self
.
_num_decode_tokens
=
0
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
return
reorder_batch_to_split_decodes_and_prefills
(
input_batch
,
scheduler_output
,
decode_threshold
=
self
.
reorder_batch_threshold
)
def
build
(
self
,
common_prefix_len
:
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment