abstract.py 10 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from contextlib import contextmanager
6
from dataclasses import dataclass, fields
7
8
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
                    Protocol, Set, Tuple, Type, TypeVar)
9
10
11

import torch

12
13
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    GroupShape)
14
15
from vllm.multimodal import MultiModalPlaceholderMap

16
if TYPE_CHECKING:
17
18
19
    from vllm.worker.model_runner_base import (ModelRunnerBase,
                                               ModelRunnerInputBase,
                                               ModelRunnerInputBuilderBase)
20

21

22
23
24
25
26
27
28
29
30
31
32
33
34
class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """
    # Decoder attention between previous layer Q/K/V
    DECODER = "decoder"
    # Encoder attention between previous layer Q/K/V for encoder-decoder
    ENCODER = "encoder"
    # Encoder attention between previous layer Q/K/V
    ENCODER_ONLY = "encoder_only"
    # Attention between dec. Q and enc. K/V for encoder-decoder
    ENCODER_DECODER = "encoder_decoder"
35
36


37
38
class AttentionBackend(ABC):
    """Abstract class for attention backends."""
39
40
41
42
    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False
43

44
45
46
47
48
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

49
50
51
52
53
54
55
    @staticmethod
    @abstractmethod
    def get_impl_cls() -> Type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
56
    def get_metadata_cls() -> Type["AttentionMetadata"]:
57
58
        raise NotImplementedError

59
60
61
62
63
    @staticmethod
    @abstractmethod
    def get_state_cls() -> Type["AttentionState"]:
        raise NotImplementedError

64
65
66
67
    @classmethod
    def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
        return cls.get_metadata_cls()(*args, **kwargs)

68
69
70
71
72
    @staticmethod
    @abstractmethod
    def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
        raise NotImplementedError

73
74
75
76
77
78
79
80
81
82
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        raise NotImplementedError

83
84
85
86
    @staticmethod
    def get_kv_cache_stride_order() -> Tuple[int, ...]:
        raise NotImplementedError

87
88
89
90
91
    @staticmethod
    @abstractmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
92
        src_to_dst: torch.Tensor,
93
94
95
96
97
98
99
    ) -> None:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
100
        src_to_dists: torch.Tensor,
101
102
103
    ) -> None:
        raise NotImplementedError

104
105
106
    def advance_step(self, model_input: "ModelRunnerInputBase",
                     sampled_token_ids: Optional[torch.Tensor],
                     block_size: int, num_seqs: int, num_queries: int) -> None:
107
108
        raise NotImplementedError

109
110

@dataclass
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class AttentionMetadata:
    """Attention metadata for prefill and decode batched together."""
    # Total number of prefill requests.
    num_prefills: int
    # Number of prefill tokens.
    num_prefill_tokens: int
    # Number of decode tokens. Note that it is equivalent to the number of
    # decode requests.
    num_decode_tokens: int
    # (num_tokens,). The indices of the token slots that input tokens will be
    # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
    # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
    # in block 0, and 1st slot in block 1, respectively.
    slot_mapping: torch.Tensor

126
127
128
129
130
131
132
133
134
    # The index maps that relate multi-modal embeddings to the corresponding
    # placeholders.
    #
    # N.B. These aren't really related to attention and don't belong on this
    # type -- this is just a temporary solution to make them available to
    # `model_executable`.
    multi_modal_placeholder_index_maps: Optional[Dict[
        str, MultiModalPlaceholderMap.IndexMap]]

135
136
137
138
    # Enable/disable KV scales calculation. This is so that we can disable the
    # calculation until after prefill and cuda graph capture.
    enable_kv_scales_calculation: bool

139
140
141
142
143
144
145
146
147
148
149
150
151
    @property
    @abstractmethod
    def prefill_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run prefill
        attention."""
        pass

    @property
    @abstractmethod
    def decode_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run decode
        attention."""
        pass
152

153
154
155
    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
156
        """Similar to dataclasses.asdict, but avoids deepcopying."""
157
158
        if skip_fields is None:
            skip_fields = set()
159
160
161
162
        # Note that if we add dataclasses as fields, they will need
        # similar handling.
        return {
            field.name: getattr(self, field.name)
163
            for field in fields(self) if field.name not in skip_fields
164
165
166
        }


167
T = TypeVar("T", bound=AttentionMetadata)
168
169


170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
class AttentionState(ABC, Generic[T]):
    """Holds attention backend-specific objects reused during the
    lifetime of the model runner."""

    @abstractmethod
    def __init__(self, runner: "ModelRunnerBase"):
        ...

    @abstractmethod
    @contextmanager
    def graph_capture(self, max_batch_size: int):
        """Context manager used when capturing CUDA graphs."""
        yield

    @abstractmethod
    def graph_clone(self, batch_size: int) -> "AttentionState[T]":
        """Clone attention state to save in CUDA graph metadata."""
        ...

    @abstractmethod
190
191
192
193
    def graph_capture_get_metadata_for_batch(
            self,
            batch_size: int,
            is_encoder_decoder_model: bool = False) -> T:
194
195
196
197
        """Get attention metadata for CUDA graph capture of batch_size."""
        ...

    @abstractmethod
198
199
200
201
    def get_graph_input_buffers(
            self,
            attn_metadata: T,
            is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
202
203
204
205
        """Get attention-specific input buffers for CUDA graph capture."""
        ...

    @abstractmethod
206
207
208
209
210
    def prepare_graph_input_buffers(
            self,
            input_buffers: Dict[str, Any],
            attn_metadata: T,
            is_encoder_decoder_model: bool = False) -> None:
211
212
213
214
215
216
217
218
219
        """In-place modify input buffers dict for CUDA graph replay."""
        ...

    @abstractmethod
    def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
        """Prepare state for forward pass."""
        ...


220
221
222
223
class AttentionMetadataBuilder(ABC, Generic[T]):
    """Abstract class for attention metadata builders."""

    @abstractmethod
224
    def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
225
226
227
228
229
230
        """Create the builder, remember some configuration and parameters."""
        raise NotImplementedError

    @abstractmethod
    def prepare(self) -> None:
        """Prepare for one batch."""
231
232
233
        raise NotImplementedError

    @abstractmethod
234
235
    def build(self, seq_lens: List[int], query_lens: List[int],
              cuda_graph_pad_size: int, batch_size: int) -> T:
236
237
238
239
        """Build attention metadata with on-device tensors."""
        raise NotImplementedError


240
241
class AttentionLayer(Protocol):

242
    _q_scale: torch.Tensor
243
244
245
246
    _k_scale: torch.Tensor
    _v_scale: torch.Tensor
    _k_scale_float: float
    _v_scale_float: float
247
    _prob_scale: torch.Tensor
248
249
250
251
252
253
254
255
256
257
258
259

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        ...


260
class AttentionImpl(ABC, Generic[T]):
261
262
263
264
265
266
267
268
269
270

    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
271
        kv_cache_dtype: str = "auto",
272
        logits_soft_cap: Optional[float] = None,
273
        attn_type: str = AttentionType.DECODER,
274
        kv_sharing_target_layer_name: Optional[str] = None,
275
276
277
278
279
280
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
281
        layer: AttentionLayer,
282
283
284
285
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
286
        attn_metadata: T,
287
        output: Optional[torch.Tensor] = None,
288
        output_scale: Optional[torch.Tensor] = None,
289
290
    ) -> torch.Tensor:
        raise NotImplementedError
291

292
    def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
293
                                     group_shape: GroupShape):
294
295
296
297
298
299
300
301
        """
        Does this attention implementation support fused output quantization.
        This is used by the AttnFusionPass to only fuse output quantization
        onto implementations that support it.

        TODO(luka) merge parameters into QuantDescriptor
        :param dtype: quantized dtype
        :param static: static or dynamic quantization
302
        :param group_shape: quant group shape.
303
304
305
306
        :return: is fusion supported for this type of quantization
        """
        return False

307
308
309
310
311
312
313
314
315
316
317
318
319

class MLAAttentionImpl(AttentionImpl[T], Generic[T]):

    @abstractmethod
    def forward(
        self,
        layer: AttentionLayer,
        hidden_states_or_cq: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: T,
        output: Optional[torch.Tensor] = None,
320
        output_scale: Optional[torch.Tensor] = None,
321
322
    ) -> torch.Tensor:
        raise NotImplementedError
323
324
325
326


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return kv_cache_dtype != "auto"