abstract.py 10.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from contextlib import contextmanager
6
from dataclasses import dataclass, fields
7
8
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
                    Protocol, Set, Tuple, Type, TypeVar)
9
10
11

import torch

12
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
13
14
from vllm.multimodal import MultiModalPlaceholderMap

15
if TYPE_CHECKING:
16
17
18
    from vllm.worker.model_runner_base import (ModelRunnerBase,
                                               ModelRunnerInputBase,
                                               ModelRunnerInputBuilderBase)
19

20

21
22
23
24
25
26
class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """
    DECODER = "decoder"
27
    """Decoder attention between previous layer Q/K/V."""
28
    ENCODER = "encoder"
29
    """Encoder attention between previous layer Q/K/V for encoder-decoder."""
30
    ENCODER_ONLY = "encoder_only"
31
    """Encoder attention between previous layer Q/K/V."""
32
    ENCODER_DECODER = "encoder_decoder"
33
    """Attention between dec. Q and enc. K/V for encoder-decoder."""
34
35


36
37
class AttentionBackend(ABC):
    """Abstract class for attention backends."""
38
39
40
41
    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False
42

43
44
45
46
47
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

48
49
50
51
52
53
54
    @staticmethod
    @abstractmethod
    def get_impl_cls() -> Type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
55
    def get_metadata_cls() -> Type["AttentionMetadata"]:
56
57
        raise NotImplementedError

58
59
60
61
62
    @staticmethod
    @abstractmethod
    def get_state_cls() -> Type["AttentionState"]:
        raise NotImplementedError

63
64
65
66
    @classmethod
    def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
        return cls.get_metadata_cls()(*args, **kwargs)

67
68
69
70
71
    @staticmethod
    @abstractmethod
    def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
        raise NotImplementedError

72
73
74
75
76
77
78
79
80
81
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        raise NotImplementedError

82
83
84
85
    @staticmethod
    def get_kv_cache_stride_order() -> Tuple[int, ...]:
        raise NotImplementedError

86
87
88
89
90
    @staticmethod
    @abstractmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
91
        src_to_dst: torch.Tensor,
92
93
94
95
96
97
98
    ) -> None:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
99
        src_to_dists: torch.Tensor,
100
101
102
    ) -> None:
        raise NotImplementedError

103
104
105
106
    @classmethod
    def full_cls_name(cls) -> tuple[str, str]:
        return (cls.__module__, cls.__qualname__)

107
108

@dataclass
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class AttentionMetadata:
    """Attention metadata for prefill and decode batched together."""
    # Total number of prefill requests.
    num_prefills: int
    # Number of prefill tokens.
    num_prefill_tokens: int
    # Number of decode tokens. Note that it is equivalent to the number of
    # decode requests.
    num_decode_tokens: int
    # (num_tokens,). The indices of the token slots that input tokens will be
    # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
    # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
    # in block 0, and 1st slot in block 1, respectively.
    slot_mapping: torch.Tensor

124
125
126
127
128
129
130
131
132
    # The index maps that relate multi-modal embeddings to the corresponding
    # placeholders.
    #
    # N.B. These aren't really related to attention and don't belong on this
    # type -- this is just a temporary solution to make them available to
    # `model_executable`.
    multi_modal_placeholder_index_maps: Optional[Dict[
        str, MultiModalPlaceholderMap.IndexMap]]

133
134
135
136
    # Enable/disable KV scales calculation. This is so that we can disable the
    # calculation until after prefill and cuda graph capture.
    enable_kv_scales_calculation: bool

137
138
139
140
141
142
143
144
145
146
147
148
149
    @property
    @abstractmethod
    def prefill_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run prefill
        attention."""
        pass

    @property
    @abstractmethod
    def decode_metadata(self) -> Optional["AttentionMetadata"]:
        """Return the attention metadata that's required to run decode
        attention."""
        pass
150

151
152
153
    def asdict_zerocopy(self,
                        skip_fields: Optional[Set[str]] = None
                        ) -> Dict[str, Any]:
154
        """Similar to dataclasses.asdict, but avoids deepcopying."""
155
156
        if skip_fields is None:
            skip_fields = set()
157
158
159
160
        # Note that if we add dataclasses as fields, they will need
        # similar handling.
        return {
            field.name: getattr(self, field.name)
161
            for field in fields(self) if field.name not in skip_fields
162
163
164
        }


165
T = TypeVar("T", bound=AttentionMetadata)
166
167


168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
class AttentionState(ABC, Generic[T]):
    """Holds attention backend-specific objects reused during the
    lifetime of the model runner."""

    @abstractmethod
    def __init__(self, runner: "ModelRunnerBase"):
        ...

    @abstractmethod
    @contextmanager
    def graph_capture(self, max_batch_size: int):
        """Context manager used when capturing CUDA graphs."""
        yield

    @abstractmethod
    def graph_clone(self, batch_size: int) -> "AttentionState[T]":
        """Clone attention state to save in CUDA graph metadata."""
        ...

    @abstractmethod
188
189
190
191
    def graph_capture_get_metadata_for_batch(
            self,
            batch_size: int,
            is_encoder_decoder_model: bool = False) -> T:
192
193
194
195
        """Get attention metadata for CUDA graph capture of batch_size."""
        ...

    @abstractmethod
196
197
198
199
    def get_graph_input_buffers(
            self,
            attn_metadata: T,
            is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
200
201
202
203
        """Get attention-specific input buffers for CUDA graph capture."""
        ...

    @abstractmethod
204
205
206
207
208
    def prepare_graph_input_buffers(
            self,
            input_buffers: Dict[str, Any],
            attn_metadata: T,
            is_encoder_decoder_model: bool = False) -> None:
209
210
211
212
213
214
215
216
217
        """In-place modify input buffers dict for CUDA graph replay."""
        ...

    @abstractmethod
    def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
        """Prepare state for forward pass."""
        ...


218
219
220
221
class AttentionMetadataBuilder(ABC, Generic[T]):
    """Abstract class for attention metadata builders."""

    @abstractmethod
222
    def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
223
224
225
226
227
228
        """Create the builder, remember some configuration and parameters."""
        raise NotImplementedError

    @abstractmethod
    def prepare(self) -> None:
        """Prepare for one batch."""
229
230
231
        raise NotImplementedError

    @abstractmethod
232
233
    def build(self, seq_lens: List[int], query_lens: List[int],
              cuda_graph_pad_size: int, batch_size: int) -> T:
234
235
236
237
        """Build attention metadata with on-device tensors."""
        raise NotImplementedError


238
239
class AttentionLayer(Protocol):

240
    _q_scale: torch.Tensor
241
242
    _k_scale: torch.Tensor
    _v_scale: torch.Tensor
243
    _q_scale_float: float
244
245
    _k_scale_float: float
    _v_scale_float: float
246
    _prob_scale: torch.Tensor
247
248
249
250
251
252
253
254
255
256
257
258

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        ...


259
class AttentionImpl(ABC, Generic[T]):
260

261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
    # Whether the attention impl can return the softmax lse for decode.
    # Some features like decode context parallelism require the softmax lse.
    can_return_lse_for_decode: bool = False

    # some attention backends might not always want to return lse
    # even if they can return lse (for efficiency reasons)
    need_to_return_lse_for_decode: bool = False

    dcp_world_size: int
    dcp_rank: int

    def __new__(cls, *args, **kwargs):
        # use __new__ so that all subclasses will call this
        self = super().__new__(cls)
        try:
            from vllm.distributed.parallel_state import get_dcp_group
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
        self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
            and self.can_return_lse_for_decode
        return self

287
288
289
290
291
292
293
294
295
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
296
        kv_cache_dtype: str = "auto",
297
        logits_soft_cap: Optional[float] = None,
298
        attn_type: str = AttentionType.DECODER,
299
        kv_sharing_target_layer_name: Optional[str] = None,
300
301
302
303
304
305
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
306
        layer: AttentionLayer,
307
308
309
310
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
311
        attn_metadata: T,
312
        output: Optional[torch.Tensor] = None,
313
        output_scale: Optional[torch.Tensor] = None,
314
        output_block_scale: Optional[torch.Tensor] = None,
315
316
    ) -> torch.Tensor:
        raise NotImplementedError
317

318
    def fused_output_quant_supported(self, quant_key: QuantKey):
319
320
321
322
323
        """
        Does this attention implementation support fused output quantization.
        This is used by the AttnFusionPass to only fuse output quantization
        onto implementations that support it.

324
        :param quant_key: QuantKey object that describes the quantization op
325
326
327
328
        :return: is fusion supported for this type of quantization
        """
        return False

329
330
331
332
333
334
335
336
337
338
339
340
341

class MLAAttentionImpl(AttentionImpl[T], Generic[T]):

    @abstractmethod
    def forward(
        self,
        layer: AttentionLayer,
        hidden_states_or_cq: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: T,
        output: Optional[torch.Tensor] = None,
342
        output_scale: Optional[torch.Tensor] = None,
343
        output_block_scale: Optional[torch.Tensor] = None,
344
345
    ) -> torch.Tensor:
        raise NotImplementedError
346
347
348
349


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return kv_cache_dtype != "auto"