Unverified Commit e6750d0b authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Remove unused classes in attention (#25541)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
parent 8c853050
...@@ -2,9 +2,7 @@ ...@@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionMetadata, AttentionType)
AttentionMetadataBuilder,
AttentionState, AttentionType)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
...@@ -13,7 +11,5 @@ __all__ = [ ...@@ -13,7 +11,5 @@ __all__ = [
"AttentionBackend", "AttentionBackend",
"AttentionMetadata", "AttentionMetadata",
"AttentionType", "AttentionType",
"AttentionMetadataBuilder",
"AttentionState",
"get_attn_backend", "get_attn_backend",
] ]
...@@ -2,10 +2,7 @@ ...@@ -2,10 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
from dataclasses import dataclass, fields
from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple,
Type, TypeVar)
import torch import torch
...@@ -49,18 +46,13 @@ class AttentionBackend(ABC): ...@@ -49,18 +46,13 @@ class AttentionBackend(ABC):
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError raise NotImplementedError
@staticmethod
@abstractmethod
def get_state_cls() -> Type["AttentionState"]:
raise NotImplementedError
@classmethod @classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs) return cls.get_metadata_cls()(*args, **kwargs)
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]: def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
...@@ -77,149 +69,18 @@ class AttentionBackend(ABC): ...@@ -77,149 +69,18 @@ class AttentionBackend(ABC):
def get_kv_cache_stride_order() -> Tuple[int, ...]: def get_kv_cache_stride_order() -> Tuple[int, ...]:
raise NotImplementedError raise NotImplementedError
@staticmethod
@abstractmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise NotImplementedError
@staticmethod
@abstractmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
raise NotImplementedError
@classmethod @classmethod
def full_cls_name(cls) -> tuple[str, str]: def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__) return (cls.__module__, cls.__qualname__)
@dataclass
class AttentionMetadata: class AttentionMetadata:
"""Attention metadata for prefill and decode batched together.""" pass
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool
@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
"""Return the attention metadata that's required to run prefill
attention."""
pass
@property
@abstractmethod
def decode_metadata(self) -> Optional["AttentionMetadata"]:
"""Return the attention metadata that's required to run decode
attention."""
pass
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if skip_fields is None:
skip_fields = set()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self) if field.name not in skip_fields
}
T = TypeVar("T", bound=AttentionMetadata) T = TypeVar("T", bound=AttentionMetadata)
class AttentionState(ABC, Generic[T]):
"""Holds attention backend-specific objects reused during the
lifetime of the model runner."""
@abstractmethod
def __init__(self, runner: Any):
...
@abstractmethod
@contextmanager
def graph_capture(self, max_batch_size: int):
"""Context manager used when capturing CUDA graphs."""
yield
@abstractmethod
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
"""Clone attention state to save in CUDA graph metadata."""
...
@abstractmethod
def graph_capture_get_metadata_for_batch(
self,
batch_size: int,
is_encoder_decoder_model: bool = False) -> T:
"""Get attention metadata for CUDA graph capture of batch_size."""
...
@abstractmethod
def get_graph_input_buffers(
self,
attn_metadata: T,
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
"""Get attention-specific input buffers for CUDA graph capture."""
...
@abstractmethod
def prepare_graph_input_buffers(
self,
input_buffers: Dict[str, Any],
attn_metadata: T,
is_encoder_decoder_model: bool = False) -> None:
"""In-place modify input buffers dict for CUDA graph replay."""
...
@abstractmethod
def begin_forward(self, model_input) -> None:
"""Prepare state for forward pass."""
...
class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""
@abstractmethod
def __init__(self, input_builder) -> None:
"""Create the builder, remember some configuration and parameters."""
raise NotImplementedError
@abstractmethod
def prepare(self) -> None:
"""Prepare for one batch."""
raise NotImplementedError
@abstractmethod
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError
class AttentionLayer(Protocol): class AttentionLayer(Protocol):
_q_scale: torch.Tensor _q_scale: torch.Tensor
......
This diff is collapsed.
...@@ -11,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, ...@@ -11,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionLayer,
AttentionMetadata, AttentionType, AttentionMetadata, AttentionType,
is_quantized_kv_cache) is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
...@@ -65,10 +64,6 @@ class TorchSDPABackend(AttentionBackend): ...@@ -65,10 +64,6 @@ class TorchSDPABackend(AttentionBackend):
def get_metadata_cls() -> type["AttentionMetadata"]: def get_metadata_cls() -> type["AttentionMetadata"]:
return TorchSDPAMetadata return TorchSDPAMetadata
@staticmethod
def get_state_cls() -> type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod @staticmethod
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]: def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
return TorchSDPAMetadataBuilderV1 return TorchSDPAMetadataBuilderV1
...@@ -835,16 +830,6 @@ class _PagedAttention: ...@@ -835,16 +830,6 @@ class _PagedAttention:
blocksparse_head_sliding_step, blocksparse_head_sliding_step,
) )
@staticmethod
def copy_blocks(
kv_caches: list[torch.Tensor],
src_to_dists: torch.Tensor,
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
class _IPEXPagedAttention(_PagedAttention): class _IPEXPagedAttention(_PagedAttention):
......
...@@ -8,7 +8,6 @@ import torch ...@@ -8,7 +8,6 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType) AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cdiv, next_power_of_2 from vllm.utils import cdiv, next_power_of_2
...@@ -97,10 +96,6 @@ class PallasAttentionBackend(AttentionBackend): ...@@ -97,10 +96,6 @@ class PallasAttentionBackend(AttentionBackend):
def get_metadata_cls() -> type["PallasMetadata"]: def get_metadata_cls() -> type["PallasMetadata"]:
return PallasMetadata return PallasMetadata
@staticmethod
def get_state_cls() -> type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
......
...@@ -9,7 +9,6 @@ import numpy as np ...@@ -9,7 +9,6 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadataBuilder
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
...@@ -25,7 +24,8 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata ...@@ -25,7 +24,8 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder) TreeAttentionMetadataBuilder)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
...@@ -184,8 +184,9 @@ class EagleProposer: ...@@ -184,8 +184,9 @@ class EagleProposer:
builder = (self._get_attention_metadata_builder() builder = (self._get_attention_metadata_builder()
if self.attn_metadata_builder is None else if self.attn_metadata_builder is None else
self.attn_metadata_builder) self.attn_metadata_builder)
attn_metadata = builder.build_for_drafting( attn_metadata = builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata, draft_index=0) common_attn_metadata=common_attn_metadata,
draft_index=0)
# At this moment, we assume all eagle layers belong to the same KV # At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata. # cache group, thus using the same attention metadata.
...@@ -319,7 +320,7 @@ class EagleProposer: ...@@ -319,7 +320,7 @@ class EagleProposer:
exceeds_max_model_len, PADDING_SLOT_ID) exceeds_max_model_len, PADDING_SLOT_ID)
# Rebuild attention metadata # Rebuild attention metadata
attn_metadata = builder.build_for_drafting( attn_metadata = builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1) draft_index=token_index + 1)
for layer_name in self.attn_layer_names: for layer_name in self.attn_layer_names:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment