Unverified Commit 4727a8af authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Remove unused reorder_batch method (#24463)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent b8f603ce
...@@ -581,7 +581,7 @@ def _generate_fake_step_update( ...@@ -581,7 +581,7 @@ def _generate_fake_step_update(
persistent_batch[:] = persistent_batch[0:condensed_batch_size] persistent_batch[:] = persistent_batch[0:condensed_batch_size]
if condensed_batch_size > 1: if condensed_batch_size > 1:
# Simulate arbitrary reorder_batch() in the kernel backend # Simulate arbitrary batch ordering in the kernel backend
# Generate a random number k of non-overlapping swap tuples # Generate a random number k of non-overlapping swap tuples
k = random.randint(0, condensed_batch_size // 2) k = random.randint(0, condensed_batch_size // 2)
idxs = list(range(condensed_batch_size)) idxs = list(range(condensed_batch_size))
......
...@@ -602,8 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -602,8 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
) )
else: else:
# Regular attention (common case). # Regular attention (common case).
# Decodes are at the front and prefills are at the back, # Decodes are at the front and prefills are at the back.
# according to reorder_batch()
num_prefills = attn_metadata.num_prefills num_prefills = attn_metadata.num_prefills
num_decodes = attn_metadata.num_decodes num_decodes = attn_metadata.num_decodes
if num_prefills > 0: if num_prefills > 0:
...@@ -925,8 +924,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -925,8 +924,7 @@ class FlashInferImpl(AttentionImpl):
stride_order = FlashInferBackend.get_kv_cache_stride_order() stride_order = FlashInferBackend.get_kv_cache_stride_order()
kv_cache_permute = kv_cache.permute(*stride_order) kv_cache_permute = kv_cache.permute(*stride_order)
# Regular attention (common case). # Regular attention (common case).
# Decodes are at the front and prefills are at the back, # Decodes are at the front and prefills are at the back.
# according to reorder_batch()
if num_prefill_tokens > 0: if num_prefill_tokens > 0:
prefill_wrapper = attn_metadata.prefill_wrapper prefill_wrapper = attn_metadata.prefill_wrapper
prefill_query = query[num_decode_tokens:] prefill_query = query[num_decode_tokens:]
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""Attention layer with FlexAttention.""" """Attention layer with FlexAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union from typing import Optional, Union
import torch import torch
import torch._dynamo.decorators import torch._dynamo.decorators
...@@ -38,10 +38,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec ...@@ -38,10 +38,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__) logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
create_block_mask_compiled = torch.compile( create_block_mask_compiled = torch.compile(
create_block_mask, fullgraph=True, mode="reduce-overhead" create_block_mask, fullgraph=True, mode="reduce-overhead"
) )
...@@ -600,11 +596,6 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat ...@@ -600,11 +596,6 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128 self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128 self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
def reorder_batch(
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
) -> bool:
return False
def build( def build(
self, self,
common_prefix_len: int, common_prefix_len: int,
......
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
import ast import ast
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional from typing import Optional
import torch import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
...@@ -20,17 +21,10 @@ from vllm.logger import init_logger ...@@ -20,17 +21,10 @@ from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills, split_decodes_and_prefills,
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -189,12 +183,7 @@ class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadat ...@@ -189,12 +183,7 @@ class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadat
device=device, device=device,
) )
def reorder_batch( self.reorder_batch_threshold = self.tree_attn_bias.shape[0]
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
) -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch, scheduler_output, decode_threshold=self.tree_attn_bias.shape[0]
)
def build( def build(
self, self,
......
...@@ -299,24 +299,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ...@@ -299,24 +299,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
""" """
raise NotImplementedError raise NotImplementedError
def reorder_batch(
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
) -> bool:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
input_batch: input batch
scheduler_output: scheduler output.
Returns:
True if the batch was modified, False otherwise.
"""
raise NotImplementedError
def build_for_cudagraph_capture( def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata self, common_attn_metadata: CommonAttentionMetadata
) -> M: ) -> M:
...@@ -828,10 +810,6 @@ def reorder_batch_to_split_decodes_and_prefills( ...@@ -828,10 +810,6 @@ def reorder_batch_to_split_decodes_and_prefills(
for i, req_id in enumerate(input_batch.req_ids): for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if it's not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens <= decode_threshold: if num_tokens <= decode_threshold:
decodes.append(i) decodes.append(i)
num_decode_tokens += num_tokens num_decode_tokens += num_tokens
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention.""" """Attention layer with XFormersAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional from typing import Optional
import torch import torch
...@@ -19,7 +19,6 @@ from vllm.logger import init_logger ...@@ -19,7 +19,6 @@ from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills, split_decodes_and_prefills,
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
...@@ -35,10 +34,6 @@ try: ...@@ -35,10 +34,6 @@ try:
except ImportError: except ImportError:
XFORMERS_AVAILABLE = False XFORMERS_AVAILABLE = False
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -223,13 +218,6 @@ class XFormersAttentionMetadataBuilder( ...@@ -223,13 +218,6 @@ class XFormersAttentionMetadataBuilder(
self._num_decodes = 0 self._num_decodes = 0
self._num_decode_tokens = 0 self._num_decode_tokens = 0
def reorder_batch(
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
) -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch, scheduler_output, decode_threshold=self.reorder_batch_threshold
)
def build( def build(
self, self,
common_prefix_len: int, common_prefix_len: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment