Unverified Commit 4727a8af authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Remove unused reorder_batch method (#24463)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent b8f603ce
......@@ -581,7 +581,7 @@ def _generate_fake_step_update(
persistent_batch[:] = persistent_batch[0:condensed_batch_size]
if condensed_batch_size > 1:
# Simulate arbitrary reorder_batch() in the kernel backend
# Simulate arbitrary batch ordering in the kernel backend
# Generate a random number k of non-overlapping swap tuples
k = random.randint(0, condensed_batch_size // 2)
idxs = list(range(condensed_batch_size))
......
......@@ -602,8 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
else:
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
# Decodes are at the front and prefills are at the back.
num_prefills = attn_metadata.num_prefills
num_decodes = attn_metadata.num_decodes
if num_prefills > 0:
......@@ -925,8 +924,7 @@ class FlashInferImpl(AttentionImpl):
stride_order = FlashInferBackend.get_kv_cache_stride_order()
kv_cache_permute = kv_cache.permute(*stride_order)
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
# Decodes are at the front and prefills are at the back.
if num_prefill_tokens > 0:
prefill_wrapper = attn_metadata.prefill_wrapper
prefill_query = query[num_decode_tokens:]
......
......@@ -3,7 +3,7 @@
"""Attention layer with FlexAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
from typing import Optional, Union
import torch
import torch._dynamo.decorators
......@@ -38,10 +38,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
create_block_mask_compiled = torch.compile(
create_block_mask, fullgraph=True, mode="reduce-overhead"
)
......@@ -600,11 +596,6 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
def reorder_batch(
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
) -> bool:
return False
def build(
self,
common_prefix_len: int,
......
......@@ -4,10 +4,11 @@
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
......@@ -20,17 +21,10 @@ from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__)
......@@ -189,12 +183,7 @@ class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadat
device=device,
)
def reorder_batch(
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
) -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch, scheduler_output, decode_threshold=self.tree_attn_bias.shape[0]
)
self.reorder_batch_threshold = self.tree_attn_bias.shape[0]
def build(
self,
......
......@@ -299,24 +299,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
"""
raise NotImplementedError
def reorder_batch(
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
) -> bool:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
input_batch: input batch
scheduler_output: scheduler output.
Returns:
True if the batch was modified, False otherwise.
"""
raise NotImplementedError
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> M:
......@@ -828,10 +810,6 @@ def reorder_batch_to_split_decodes_and_prefills(
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if it's not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens <= decode_threshold:
decodes.append(i)
num_decode_tokens += num_tokens
......
......@@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import Optional
import torch
......@@ -19,7 +19,6 @@ from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
......@@ -35,10 +34,6 @@ try:
except ImportError:
XFORMERS_AVAILABLE = False
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__)
......@@ -223,13 +218,6 @@ class XFormersAttentionMetadataBuilder(
self._num_decodes = 0
self._num_decode_tokens = 0
def reorder_batch(
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
) -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch, scheduler_output, decode_threshold=self.reorder_batch_threshold
)
def build(
self,
common_prefix_len: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment