abstract.py 10.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from contextlib import contextmanager
6
from dataclasses import dataclass, fields
7
8
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
                    Protocol, Set, Tuple, Type, TypeVar)
9
10
11

import torch

12
13
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    GroupShape)
14
15
from vllm.multimodal import MultiModalPlaceholderMap

16
if TYPE_CHECKING:
17
18
19
    from vllm.worker.model_runner_base import (ModelRunnerBase,
                                               ModelRunnerInputBase,
                                               ModelRunnerInputBuilderBase)
20

21

22
23
24
25
26
27
28
29
30
31
32
33
34
class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """
    # Decoder attention between previous layer Q/K/V
    DECODER = "decoder"
    # Encoder attention between previous layer Q/K/V for encoder-decoder
    ENCODER = "encoder"
    # Encoder attention between previous layer Q/K/V
    ENCODER_ONLY = "encoder_only"
    # Attention between dec. Q and enc. K/V for encoder-decoder
    ENCODER_DECODER = "encoder_decoder"
35
36


37
38
class AttentionBackend(ABC):
    """Abstract class for attention backends."""
39
40
41
42
    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False
43

44
45
46
47
48
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

49
50
51
52
53
54
55
    @staticmethod
    @abstractmethod
    def get_impl_cls() -> Type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
56
    def get_metadata_cls() -> Type["AttentionMetadata"]:
57
58
        raise NotImplementedError

59
60
61
62
63
    @staticmethod
    @abstractmethod
    def get_state_cls() -> Type["AttentionState"]:
        raise NotImplementedError

64
65
66
67
    @classmethod
    def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
        return cls.get_metadata_cls()(*args, **kwargs)

68
69
70
71
72
    @staticmethod
    @abstractmethod
    def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
        raise NotImplementedError

73
74
75
76
77
78
79
80
81
82
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        raise NotImplementedError

83
84
85
86
    @staticmethod
    def get_kv_cache_stride_order() -> Tuple[int, ...]:
        raise NotImplementedError

87
88
89
90
91
    @staticmethod
    @abstractmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
92
        src_to_dst: torch.Tensor,
93
94
95
96
97
98
99
    ) -> None:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
100
        src_to_dists: torch.Tensor,
101
102
103
    ) -> None:
        raise NotImplementedError

104
105
106
    def advance_step(self, model_input: "ModelRunnerInputBase",
                     sampled_token_ids: Optional[torch.Tensor],
                     block_size: int, num_seqs: int, num_queries: int) -> None:
107
108
        raise NotImplementedError

109
110
111
112
    @classmethod
    def full_cls_name(cls) -> tuple[str, str]:
        return (cls.__module__, cls.__qualname__)

113
114

@dataclass
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class AttentionMetadata:
    """Attention metadata for prefill and decode batched together."""
    # Total number of prefill requests.
    num_prefills: int
    # Number of prefill tokens.
    num_prefill_tokens: int
    # Number of decode tokens. Note that it is equivalent to the number of
    # decode requests.
    num_decode_tokens: int
    # (num_tokens,). The indices of the token slots that input tokens will be
    # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
    # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
    # in block 0, and 1st slot in block 1, respectively.
    slot_mapping: torch.Tensor

130
131
132
133
134
135
136
137
138
    # The index maps that relate multi-modal embeddings to the corresponding
    # placeholders.
    #
    # N.B. These aren't really related to attention and don't belong on this
    # type -- this is just a temporary solution to make them available to
    # `model_executable`.
    multi_modal_placeholder_index_maps: Optional[Dict[
        str, MultiModalPlaceholderMap.IndexMap]]

139
140
141
142
    # Enable/disable KV scales calculation. This is so that we can disable the
    # calculation until after prefill and cuda graph capture.
    enable_kv_scales_calculation: bool

143
144
145
146
147
148
149
150
151
152
153
154
155
    @property
    @abstractmethod
    def prefill_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run prefill
        attention."""
        pass

    @property
    @abstractmethod
    def decode_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run decode
        attention."""
        pass
156

157
158
159
    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
160
        """Similar to dataclasses.asdict, but avoids deepcopying."""
161
162
        if skip_fields is None:
            skip_fields = set()
163
164
165
166
        # Note that if we add dataclasses as fields, they will need
        # similar handling.
        return {
            field.name: getattr(self, field.name)
167
            for field in fields(self) if field.name not in skip_fields
168
169
170
        }


171
T = TypeVar("T", bound=AttentionMetadata)
172
173


174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class AttentionState(ABC, Generic[T]):
    """Holds attention backend-specific objects reused during the
    lifetime of the model runner."""

    @abstractmethod
    def __init__(self, runner: "ModelRunnerBase"):
        ...

    @abstractmethod
    @contextmanager
    def graph_capture(self, max_batch_size: int):
        """Context manager used when capturing CUDA graphs."""
        yield

    @abstractmethod
    def graph_clone(self, batch_size: int) -> "AttentionState[T]":
        """Clone attention state to save in CUDA graph metadata."""
        ...

    @abstractmethod
194
195
196
197
    def graph_capture_get_metadata_for_batch(
            self,
            batch_size: int,
            is_encoder_decoder_model: bool = False) -> T:
198
199
200
201
        """Get attention metadata for CUDA graph capture of batch_size."""
        ...

    @abstractmethod
202
203
204
205
    def get_graph_input_buffers(
            self,
            attn_metadata: T,
            is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
206
207
208
209
        """Get attention-specific input buffers for CUDA graph capture."""
        ...

    @abstractmethod
210
211
212
213
214
    def prepare_graph_input_buffers(
            self,
            input_buffers: Dict[str, Any],
            attn_metadata: T,
            is_encoder_decoder_model: bool = False) -> None:
215
216
217
218
219
220
221
222
223
        """In-place modify input buffers dict for CUDA graph replay."""
        ...

    @abstractmethod
    def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
        """Prepare state for forward pass."""
        ...


224
225
226
227
class AttentionMetadataBuilder(ABC, Generic[T]):
    """Abstract class for attention metadata builders."""

    @abstractmethod
228
    def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
229
230
231
232
233
234
        """Create the builder, remember some configuration and parameters."""
        raise NotImplementedError

    @abstractmethod
    def prepare(self) -> None:
        """Prepare for one batch."""
235
236
237
        raise NotImplementedError

    @abstractmethod
238
239
    def build(self, seq_lens: List[int], query_lens: List[int],
              cuda_graph_pad_size: int, batch_size: int) -> T:
240
241
242
243
        """Build attention metadata with on-device tensors."""
        raise NotImplementedError


244
245
class AttentionLayer(Protocol):

246
    _q_scale: torch.Tensor
247
248
249
250
    _k_scale: torch.Tensor
    _v_scale: torch.Tensor
    _k_scale_float: float
    _v_scale_float: float
251
    _prob_scale: torch.Tensor
252
253
254
255
256
257
258
259
260
261
262
263

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        ...


264
class AttentionImpl(ABC, Generic[T]):
265
266
267
268
269
270
271
272
273
274

    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
275
        kv_cache_dtype: str = "auto",
276
        logits_soft_cap: Optional[float] = None,
277
        attn_type: str = AttentionType.DECODER,
278
        kv_sharing_target_layer_name: Optional[str] = None,
279
280
281
282
283
284
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
285
        layer: AttentionLayer,
286
287
288
289
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
290
        attn_metadata: T,
291
        output: Optional[torch.Tensor] = None,
292
        output_scale: Optional[torch.Tensor] = None,
293
294
    ) -> torch.Tensor:
        raise NotImplementedError
295

296
    def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
297
                                     group_shape: GroupShape):
298
299
300
301
302
303
304
305
        """
        Does this attention implementation support fused output quantization.
        This is used by the AttnFusionPass to only fuse output quantization
        onto implementations that support it.

        TODO(luka) merge parameters into QuantDescriptor
        :param dtype: quantized dtype
        :param static: static or dynamic quantization
306
        :param group_shape: quant group shape.
307
308
309
310
        :return: is fusion supported for this type of quantization
        """
        return False

311
312
313
314
315
316
317
318
319
320
321
322
323

class MLAAttentionImpl(AttentionImpl[T], Generic[T]):

    @abstractmethod
    def forward(
        self,
        layer: AttentionLayer,
        hidden_states_or_cq: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: T,
        output: Optional[torch.Tensor] = None,
324
        output_scale: Optional[torch.Tensor] = None,
325
326
    ) -> torch.Tensor:
        raise NotImplementedError
327
328
329
330


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return kv_cache_dtype != "auto"