Unverified Commit e6750d0b authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[V0 Deprecation] Remove unused classes in attention (#25541)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
parent 8c853050
......@@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
AttentionMetadata, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
......@@ -13,7 +11,5 @@ __all__ = [
"AttentionBackend",
"AttentionMetadata",
"AttentionType",
"AttentionMetadataBuilder",
"AttentionState",
"get_attn_backend",
]
......@@ -2,10 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple,
Type, TypeVar)
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
import torch
......@@ -49,18 +46,13 @@ class AttentionBackend(ABC):
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_state_cls() -> Type["AttentionState"]:
raise NotImplementedError
@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)
@staticmethod
@abstractmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
def get_builder_cls(): # -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError
@staticmethod
......@@ -77,149 +69,18 @@ class AttentionBackend(ABC):
def get_kv_cache_stride_order() -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod
@abstractmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
raise NotImplementedError
@staticmethod
@abstractmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
raise NotImplementedError
@classmethod
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)
@dataclass
class AttentionMetadata:
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool
@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
"""Return the attention metadata that's required to run prefill
attention."""
pass
@property
@abstractmethod
def decode_metadata(self) -> Optional["AttentionMetadata"]:
"""Return the attention metadata that's required to run decode
attention."""
pass
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if skip_fields is None:
skip_fields = set()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self) if field.name not in skip_fields
}
T = TypeVar("T", bound=AttentionMetadata)
class AttentionState(ABC, Generic[T]):
"""Holds attention backend-specific objects reused during the
lifetime of the model runner."""
@abstractmethod
def __init__(self, runner: Any):
...
@abstractmethod
@contextmanager
def graph_capture(self, max_batch_size: int):
"""Context manager used when capturing CUDA graphs."""
yield
@abstractmethod
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
"""Clone attention state to save in CUDA graph metadata."""
...
@abstractmethod
def graph_capture_get_metadata_for_batch(
self,
batch_size: int,
is_encoder_decoder_model: bool = False) -> T:
"""Get attention metadata for CUDA graph capture of batch_size."""
...
@abstractmethod
def get_graph_input_buffers(
self,
attn_metadata: T,
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
"""Get attention-specific input buffers for CUDA graph capture."""
...
@abstractmethod
def prepare_graph_input_buffers(
self,
input_buffers: Dict[str, Any],
attn_metadata: T,
is_encoder_decoder_model: bool = False) -> None:
"""In-place modify input buffers dict for CUDA graph replay."""
...
@abstractmethod
def begin_forward(self, model_input) -> None:
"""Prepare state for forward pass."""
...
class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""
@abstractmethod
def __init__(self, input_builder) -> None:
"""Create the builder, remember some configuration and parameters."""
raise NotImplementedError
@abstractmethod
def prepare(self) -> None:
"""Prepare for one batch."""
raise NotImplementedError
@abstractmethod
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError
class AttentionLayer(Protocol):
_q_scale: torch.Tensor
......
This diff is collapsed.
......@@ -11,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
......@@ -65,10 +64,6 @@ class TorchSDPABackend(AttentionBackend):
def get_metadata_cls() -> type["AttentionMetadata"]:
return TorchSDPAMetadata
@staticmethod
def get_state_cls() -> type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_builder_cls() -> type["TorchSDPAMetadataBuilderV1"]:
return TorchSDPAMetadataBuilderV1
......@@ -835,16 +830,6 @@ class _PagedAttention:
blocksparse_head_sliding_step,
)
@staticmethod
def copy_blocks(
kv_caches: list[torch.Tensor],
src_to_dists: torch.Tensor,
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
class _IPEXPagedAttention(_PagedAttention):
......
......@@ -8,7 +8,6 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import cdiv, next_power_of_2
......@@ -97,10 +96,6 @@ class PallasAttentionBackend(AttentionBackend):
def get_metadata_cls() -> type["PallasMetadata"]:
return PallasMetadata
@staticmethod
def get_state_cls() -> type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
......
......@@ -9,7 +9,6 @@ import numpy as np
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadataBuilder
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
......@@ -25,7 +24,8 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
......@@ -184,8 +184,9 @@ class EagleProposer:
builder = (self._get_attention_metadata_builder()
if self.attn_metadata_builder is None else
self.attn_metadata_builder)
attn_metadata = builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0)
attn_metadata = builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata,
draft_index=0)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
......@@ -319,7 +320,7 @@ class EagleProposer:
exceeds_max_model_len, PADDING_SLOT_ID)
# Rebuild attention metadata
attn_metadata = builder.build_for_drafting(
attn_metadata = builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1)
for layer_name in self.attn_layer_names:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment