abstract.py 13.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
6
7
8

import torch

9
from vllm.model_executor.layers.linear import ColumnParallelLinear
10
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
11

12
13
14
15
16
if TYPE_CHECKING:
    from vllm.config.cache import CacheDType
    from vllm.platforms.interface import DeviceCapability
    from vllm.v1.attention.backends.utils import KVCacheLayoutType

17

18
19
20
21
22
class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """
23

24
    DECODER = "decoder"
25
    """Decoder attention between previous layer Q/K/V."""
26
    ENCODER = "encoder"
27
    """Encoder attention between previous layer Q/K/V for encoder-decoder."""
28
    ENCODER_ONLY = "encoder_only"
29
    """Encoder attention between previous layer Q/K/V."""
30
    ENCODER_DECODER = "encoder_decoder"
31
    """Attention between dec. Q and enc. K/V for encoder-decoder."""
32
33


34
35
36
37
38
39
40
class MultipleOf:
    base: int

    def __init__(self, base: int):
        self.base = base


41
42
class AttentionBackend(ABC):
    """Abstract class for attention backends."""
43

44
45
46
47
    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False
48
49
50
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
    supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
    supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
51

52
53
54
55
56
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

57
58
    @staticmethod
    @abstractmethod
59
    def get_impl_cls() -> type["AttentionImpl"]:
60
61
        raise NotImplementedError

62
63
    @staticmethod
    @abstractmethod
64
    def get_builder_cls():  # -> Type["AttentionMetadataBuilder"]:
65
66
        raise NotImplementedError

67
68
69
70
71
72
73
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
74
        cache_dtype_str: str = "auto",
75
    ) -> tuple[int, ...]:
76
77
        raise NotImplementedError

78
    @staticmethod
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    def get_kv_cache_stride_order(
        include_num_layers_dimension: bool = False,
    ) -> tuple[int, ...]:
        """
        Get the physical (memory layout) ordering of the kv cache dimensions.
        e.g. if the KV cache shape is
        [2, num_blocks, block_size, num_heads, head_size],
        and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
        ordering of dimensions is
        [num_blocks, num_heads, 2, block_size, head_size].

        If this function is unimplemented / raises NotImplementedError,
        the physical layout of the KV cache will match the logical shape.

        Args:
            include_num_layers_dimension: if True, includes an additional
                num_layers dimension, which is assumed to be prepended
                to the logical KV cache shape.
                With the above example, a return value (2, 4, 0, 1, 3, 5)
                corresponds to
                [num_blocks, num_heads, num_layers, 2, block_size, head_size].

                If an additional dimension is NOT included in the returned
                tuple, the physical layout will not include a layers dimension.

        Returns:
            A tuple of ints which is a permutation of range(len(shape)).
        """
107
108
        raise NotImplementedError

109
110
111
112
    @classmethod
    def full_cls_name(cls) -> tuple[str, str]:
        return (cls.__module__, cls.__qualname__)

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return []

    @classmethod
    def supports_head_size(cls, head_size: int) -> bool:
        supported_head_sizes = cls.get_supported_head_sizes()
        return (not supported_head_sizes) or head_size in supported_head_sizes

    @classmethod
    def supports_dtype(cls, dtype: torch.dtype) -> bool:
        return dtype in cls.supported_dtypes

    @classmethod
    def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
        if kv_cache_dtype is None:
            return True
        return (not cls.supported_kv_cache_dtypes) or (
            kv_cache_dtype in cls.supported_kv_cache_dtypes
        )

    @classmethod
    def supports_block_size(cls, block_size: int | None) -> bool:
        from vllm.config.cache import BlockSize

        if block_size is None:
            return True

        valid_sizes = get_args(BlockSize)
        if block_size not in valid_sizes:
            return False

        if not cls.supported_kernel_block_sizes:
            return True

        for supported_size in cls.supported_kernel_block_sizes:
149
150
151
152
153
154
            if isinstance(supported_size, MultipleOf):
                supported_size = supported_size.base
            # With hybrid_blocks feature, the framework-level block size
            # only needs to be a multiple of the kernel's requirement,
            # even if the kernel requires a fixed block_size.
            if block_size % supported_size == 0:
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
                return True
        return False

    @classmethod
    def is_mla(cls) -> bool:
        return False

    @classmethod
    def supports_sink(cls) -> bool:
        return False

    @classmethod
    def is_sparse(cls) -> bool:
        return False

170
171
172
173
174
175
176
177
178
179
180
    @classmethod
    def supports_attn_type(cls, attn_type: str) -> bool:
        """Check if backend supports a given attention type.

        By default, only supports decoder attention.
        Backends should override this to support other attention types.
        """
        from vllm.attention import AttentionType

        return attn_type == AttentionType.DECODER

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    @classmethod
    def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
        return True

    @classmethod
    def supports_combination(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: "CacheDType | None",
        block_size: int | None,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
        device_capability: "DeviceCapability",
    ) -> str | None:
        return None

    @classmethod
    def validate_configuration(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: "CacheDType | None",
        block_size: int | None,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
        device_capability: "DeviceCapability",
210
        attn_type: str,
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    ) -> list[str]:
        invalid_reasons = []
        if not cls.supports_head_size(head_size):
            invalid_reasons.append("head_size not supported")
        if not cls.supports_dtype(dtype):
            invalid_reasons.append("dtype not supported")
        if not cls.supports_kv_cache_dtype(kv_cache_dtype):
            invalid_reasons.append("kv_cache_dtype not supported")
        if not cls.supports_block_size(block_size):
            invalid_reasons.append("block_size not supported")
        if use_mla != cls.is_mla():
            if use_mla:
                invalid_reasons.append("MLA not supported")
            else:
                invalid_reasons.append("non-MLA not supported")
        if has_sink and not cls.supports_sink():
            invalid_reasons.append("sink setting not supported")
        if use_sparse != cls.is_sparse():
            if use_sparse:
                invalid_reasons.append("sparse not supported")
            else:
                invalid_reasons.append("non-sparse not supported")
        if not cls.supports_compute_capability(device_capability):
            invalid_reasons.append("compute capability not supported")
235
236
        if not cls.supports_attn_type(attn_type):
            invalid_reasons.append(f"attention type {attn_type} not supported")
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        combination_reason = cls.supports_combination(
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla,
            has_sink,
            use_sparse,
            device_capability,
        )
        if combination_reason is not None:
            invalid_reasons.append(combination_reason)
        return invalid_reasons

    @classmethod
    def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
        return None

255

256
class AttentionMetadata:
257
    pass
258
259


260
T = TypeVar("T", bound=AttentionMetadata)
261
262


263
class AttentionLayer(Protocol):
264
    _q_scale: torch.Tensor
265
266
    _k_scale: torch.Tensor
    _v_scale: torch.Tensor
267
    _q_scale_float: float
268
269
    _k_scale_float: float
    _v_scale_float: float
270
    _prob_scale: torch.Tensor
271
272
273
274
275
276
277
278

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
279
    ) -> torch.Tensor: ...
280
281


282
class AttentionImpl(ABC, Generic[T]):
283
284
285
286
287
288
289
290
291
292
293
    # Whether the attention impl can return the softmax lse for decode.
    # Some features like decode context parallelism require the softmax lse.
    can_return_lse_for_decode: bool = False

    # some attention backends might not always want to return lse
    # even if they can return lse (for efficiency reasons)
    need_to_return_lse_for_decode: bool = False

    dcp_world_size: int
    dcp_rank: int

294
295
296
297
298
299
    pcp_world_size: int
    pcp_rank: int

    total_cp_world_size: int
    total_cp_rank: int

300
301
302
303
304
    def __new__(cls, *args, **kwargs):
        # use __new__ so that all subclasses will call this
        self = super().__new__(cls)
        try:
            from vllm.distributed.parallel_state import get_dcp_group
305

306
307
308
309
310
311
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
312
313
314
315
316
317
318
319
320
321
322
        try:
            from vllm.distributed.parallel_state import get_pcp_group

            self.pcp_world_size = get_pcp_group().world_size
            self.pcp_rank = get_pcp_group().rank_in_group
        except AssertionError:
            self.pcp_world_size = 1
            self.pcp_rank = 0
        self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
        self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank

323
324
325
        self.need_to_return_lse_for_decode = (
            self.dcp_world_size > 1 and self.can_return_lse_for_decode
        )
326
327
        return self

328
329
330
331
332
333
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
334
335
336
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        sliding_window: int | None = None,
337
        kv_cache_dtype: str = "auto",
338
        logits_soft_cap: float | None = None,
339
        attn_type: str = AttentionType.DECODER,
340
        kv_sharing_target_layer_name: str | None = None,
341
342
343
344
345
346
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
347
        layer: AttentionLayer,
348
349
350
351
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
352
        attn_metadata: T,
353
354
355
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
356
357
    ) -> torch.Tensor:
        raise NotImplementedError
358

359
    def fused_output_quant_supported(self, quant_key: QuantKey):
360
361
362
363
364
        """
        Does this attention implementation support fused output quantization.
        This is used by the AttnFusionPass to only fuse output quantization
        onto implementations that support it.

365
        :param quant_key: QuantKey object that describes the quantization op
366
367
368
369
        :return: is fusion supported for this type of quantization
        """
        return False

370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    def supports_quant_query_input(self) -> bool:
        """
        Check if this attention implementation supports pre-quantized query input.

        When True, the attention layer will quantize queries before passing them
        to this backend, allowing torch.compile to fuse the quantization with
        previous operations. This is typically supported when using FP8 KV cache
        with compatible attention kernels (e.g., TRT-LLM).
        TODO add support to more backends:
        https://github.com/vllm-project/vllm/issues/25584

        Returns:
            bool: True if the implementation can accept pre-quantized queries.
        """
        return False

386
387
388
    def process_weights_after_loading(self, act_dtype: torch.dtype):
        pass

389
390

class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
391
392
393
394
395
396
397
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
398
399
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
400
        kv_cache_dtype: str,
401
        logits_soft_cap: float | None,
402
        attn_type: str,
403
        kv_sharing_target_layer_name: str | None,
404
        # MLA Specific Arguments
405
        q_lora_rank: int | None,
406
407
408
409
410
411
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        qk_head_dim: int,
        v_head_dim: int,
        kv_b_proj: ColumnParallelLinear,
412
        indexer: object | None = None,
413
414
415
    ) -> None:
        raise NotImplementedError

416
417
418
419
420
421
422
423
424
    @abstractmethod
    def forward(
        self,
        layer: AttentionLayer,
        hidden_states_or_cq: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: T,
425
426
427
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
428
429
    ) -> torch.Tensor:
        raise NotImplementedError
430
431
432
433


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return kv_cache_dtype != "auto"