abstract.py 7.13 KB
Newer Older
1
from abc import ABC, abstractmethod
2
from contextlib import contextmanager
3
from dataclasses import dataclass, fields
4
from enum import Enum, auto
5
6
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
                    Tuple, Type, TypeVar)
7
8
9

import torch

10
if TYPE_CHECKING:
11
12
13
    from vllm.worker.model_runner_base import (ModelRunnerBase,
                                               ModelRunnerInputBase,
                                               ModelRunnerInputBuilderBase)
14

15

16
17
class AttentionType(Enum):
    DECODER = auto()  # Decoder attention between previous layer Q/K/V
18
19
20
21
22
    ENCODER = auto(
    )  # Encoder attention between previous layer Q/K/V for encoder-decoder
    ENCODER_ONLY = auto()  # Encoder attention between previous layer Q/K/V
    ENCODER_DECODER = auto(
    )  # Attention between dec. Q and enc. K/V for encoder-decoder
23
24


25
26
27
class AttentionBackend(ABC):
    """Abstract class for attention backends."""

28
29
30
31
32
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

33
34
35
36
37
38
39
    @staticmethod
    @abstractmethod
    def get_impl_cls() -> Type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
40
    def get_metadata_cls() -> Type["AttentionMetadata"]:
41
42
        raise NotImplementedError

43
44
45
46
47
    @staticmethod
    @abstractmethod
    def get_state_cls() -> Type["AttentionState"]:
        raise NotImplementedError

48
49
50
51
    @classmethod
    def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
        return cls.get_metadata_cls()(*args, **kwargs)

52
53
54
55
56
57
58
59
60
61
    @staticmethod
    @abstractmethod
    def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
        raise NotImplementedError

    @classmethod
    def make_metadata_builder(cls, *args,
                              **kwargs) -> "AttentionMetadataBuilder":
        return cls.get_builder_cls()(*args, **kwargs)

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
77
        src_to_dst: torch.Tensor,
78
79
80
81
82
83
84
    ) -> None:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
85
        src_to_dists: torch.Tensor,
86
87
88
    ) -> None:
        raise NotImplementedError

89
90
91
    def advance_step(self, model_input: "ModelRunnerInputBase",
                     sampled_token_ids: Optional[torch.Tensor],
                     block_size: int, num_seqs: int, num_queries: int) -> None:
92
93
        raise NotImplementedError

94
95

@dataclass
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class AttentionMetadata:
    """Attention metadata for prefill and decode batched together."""
    # Total number of prefill requests.
    num_prefills: int
    # Number of prefill tokens.
    num_prefill_tokens: int
    # Number of decode tokens. Note that it is equivalent to the number of
    # decode requests.
    num_decode_tokens: int
    # (num_tokens,). The indices of the token slots that input tokens will be
    # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
    # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
    # in block 0, and 1st slot in block 1, respectively.
    slot_mapping: torch.Tensor

    @property
    @abstractmethod
    def prefill_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run prefill
        attention."""
        pass

    @property
    @abstractmethod
    def decode_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run decode
        attention."""
        pass
124

125
126
127
    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
128
        """Similar to dataclasses.asdict, but avoids deepcopying."""
129
130
        if skip_fields is None:
            skip_fields = set()
131
132
133
134
        # Note that if we add dataclasses as fields, they will need
        # similar handling.
        return {
            field.name: getattr(self, field.name)
135
            for field in fields(self) if field.name not in skip_fields
136
137
138
        }


139
T = TypeVar("T", bound=AttentionMetadata)
140
141


142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
class AttentionState(ABC, Generic[T]):
    """Holds attention backend-specific objects reused during the
    lifetime of the model runner."""

    @abstractmethod
    def __init__(self, runner: "ModelRunnerBase"):
        ...

    @abstractmethod
    @contextmanager
    def graph_capture(self, max_batch_size: int):
        """Context manager used when capturing CUDA graphs."""
        yield

    @abstractmethod
    def graph_clone(self, batch_size: int) -> "AttentionState[T]":
        """Clone attention state to save in CUDA graph metadata."""
        ...

    @abstractmethod
162
163
164
165
    def graph_capture_get_metadata_for_batch(
            self,
            batch_size: int,
            is_encoder_decoder_model: bool = False) -> T:
166
167
168
169
        """Get attention metadata for CUDA graph capture of batch_size."""
        ...

    @abstractmethod
170
171
172
173
    def get_graph_input_buffers(
            self,
            attn_metadata: T,
            is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
174
175
176
177
        """Get attention-specific input buffers for CUDA graph capture."""
        ...

    @abstractmethod
178
179
180
181
182
    def prepare_graph_input_buffers(
            self,
            input_buffers: Dict[str, Any],
            attn_metadata: T,
            is_encoder_decoder_model: bool = False) -> None:
183
184
185
186
187
188
189
190
191
        """In-place modify input buffers dict for CUDA graph replay."""
        ...

    @abstractmethod
    def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
        """Prepare state for forward pass."""
        ...


192
193
194
195
class AttentionMetadataBuilder(ABC, Generic[T]):
    """Abstract class for attention metadata builders."""

    @abstractmethod
196
    def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
197
198
199
        raise NotImplementedError

    @abstractmethod
200
201
    def build(self, seq_lens: List[int], query_lens: List[int],
              cuda_graph_pad_size: int, batch_size: int) -> T:
202
203
204
205
        """Build attention metadata with on-device tensors."""
        raise NotImplementedError


206
class AttentionImpl(ABC, Generic[T]):
207
208
209
210
211
212
213
214
215
216

    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
217
        kv_cache_dtype: str = "auto",
218
        blocksparse_params: Optional[Dict[str, Any]] = None,
219
        logits_soft_cap: Optional[float] = None,
220
221
222
223
224
225
226
227
228
229
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
230
        attn_metadata: T,
231
232
        k_scale: float = 1.0,
        v_scale: float = 1.0,
233
        attn_type: AttentionType = AttentionType.DECODER,
234
235
    ) -> torch.Tensor:
        raise NotImplementedError