abstract.py 5.56 KB
Newer Older
1
2
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
3
from enum import Enum, auto
4
5
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
                    Tuple, Type, TypeVar)
6
7
8

import torch

9
10
11
12
if TYPE_CHECKING:
    from vllm.sequence import SequenceGroupMetadata
    from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase

13

14
15
16
17
18
19
class AttentionType(Enum):
    DECODER = auto()  # Decoder attention between previous layer Q/K/V
    ENCODER = auto()  # Encoder attention between previous layer Q/K/V
    ENCODER_DECODER = auto()  # Attention between dec. Q and enc. K/V


20
21
22
class AttentionBackend(ABC):
    """Abstract class for attention backends."""

23
24
25
26
27
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

28
29
30
31
32
33
34
    @staticmethod
    @abstractmethod
    def get_impl_cls() -> Type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
35
    def get_metadata_cls() -> Type["AttentionMetadata"]:
36
37
        raise NotImplementedError

38
39
40
41
    @classmethod
    def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
        return cls.get_metadata_cls()(*args, **kwargs)

42
43
44
45
46
47
48
49
50
51
    @staticmethod
    @abstractmethod
    def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
        raise NotImplementedError

    @classmethod
    def make_metadata_builder(cls, *args,
                              **kwargs) -> "AttentionMetadataBuilder":
        return cls.get_builder_cls()(*args, **kwargs)

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
67
        src_to_dst: torch.Tensor,
68
69
70
71
72
73
74
    ) -> None:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
75
        src_to_dists: torch.Tensor,
76
77
78
79
80
    ) -> None:
        raise NotImplementedError


@dataclass
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
class AttentionMetadata:
    """Attention metadata for prefill and decode batched together."""
    # Total number of prefill requests.
    num_prefills: int
    # Number of prefill tokens.
    num_prefill_tokens: int
    # Number of decode tokens. Note that it is equivalent to the number of
    # decode requests.
    num_decode_tokens: int
    # (num_tokens,). The indices of the token slots that input tokens will be
    # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
    # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
    # in block 0, and 1st slot in block 1, respectively.
    slot_mapping: torch.Tensor

    @property
    @abstractmethod
    def prefill_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run prefill
        attention."""
        pass

    @property
    @abstractmethod
    def decode_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run decode
        attention."""
        pass
109

110
111
112
    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
113
        """Similar to dataclasses.asdict, but avoids deepcopying."""
114
115
        if skip_fields is None:
            skip_fields = set()
116
117
118
119
        # Note that if we add dataclasses as fields, they will need
        # similar handling.
        return {
            field.name: getattr(self, field.name)
120
            for field in fields(self) if field.name not in skip_fields
121
122
123
        }


124
T = TypeVar("T", bound=AttentionMetadata)
125
126


127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class AttentionMetadataBuilder(ABC, Generic[T]):
    """Abstract class for attention metadata builders."""

    @abstractmethod
    def __init__(self, input_builder) -> None:
        raise NotImplementedError

    @abstractmethod
    def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata",
                      token_lens: List[int], seq_lens: List[int],
                      curr_seq_lens: List[int], query_lens: List[int],
                      context_lens: List[int],
                      curr_sliding_window_blocks: List[int],
                      prefix_cache_hit: bool, chunked_prefill_enabled: bool):
        """Add a sequence group to the metadata and update
        corresponding fields (in Python objects).
        """
        raise NotImplementedError

    @abstractmethod
    def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int],
              query_lens: List[int], cuda_graph_pad_size: int,
              batch_size: int) -> T:
        """Build attention metadata with on-device tensors."""
        raise NotImplementedError


154
class AttentionImpl(ABC, Generic[T]):
155
156
157
158
159
160
161
162
163
164

    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
165
        kv_cache_dtype: str = "auto",
166
        blocksparse_params: Optional[Dict[str, Any]] = None,
167
168
169
170
171
172
173
174
175
176
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
177
        attn_metadata: T,
178
179
        k_scale: float = 1.0,
        v_scale: float = 1.0,
180
        attn_type: AttentionType = AttentionType.DECODER,
181
182
    ) -> torch.Tensor:
        raise NotImplementedError