abstract.py 14.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
6
7
8

import torch

9
10
if TYPE_CHECKING:
    from vllm.config.cache import CacheDType
11
12
    from vllm.model_executor.layers.linear import ColumnParallelLinear
    from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
13
14
15
    from vllm.platforms.interface import DeviceCapability
    from vllm.v1.attention.backends.utils import KVCacheLayoutType

16

17
18
19
20
21
class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """
22

23
    DECODER = "decoder"
24
    """Decoder attention between previous layer Q/K/V."""
25
    ENCODER = "encoder"
26
    """Encoder attention between previous layer Q/K/V for encoder-decoder."""
27
    ENCODER_ONLY = "encoder_only"
28
    """Encoder attention between previous layer Q/K/V."""
29
    ENCODER_DECODER = "encoder_decoder"
30
    """Attention between dec. Q and enc. K/V for encoder-decoder."""
31
32


33
34
35
36
37
38
39
class MultipleOf:
    base: int

    def __init__(self, base: int):
        self.base = base


40
41
class AttentionBackend(ABC):
    """Abstract class for attention backends."""
42

43
44
45
46
    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False
47
48
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
    supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
49

50
51
52
53
    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        return [MultipleOf(1)]

54
55
56
57
58
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

59
60
    @staticmethod
    @abstractmethod
61
    def get_impl_cls() -> type["AttentionImpl"]:
62
63
        raise NotImplementedError

64
65
    @staticmethod
    @abstractmethod
66
    def get_builder_cls():  # -> Type["AttentionMetadataBuilder"]:
67
68
        raise NotImplementedError

69
70
71
72
73
74
75
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
76
        cache_dtype_str: str = "auto",
77
    ) -> tuple[int, ...]:
78
79
        raise NotImplementedError

80
    @staticmethod
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    def get_kv_cache_stride_order(
        include_num_layers_dimension: bool = False,
    ) -> tuple[int, ...]:
        """
        Get the physical (memory layout) ordering of the kv cache dimensions.
        e.g. if the KV cache shape is
        [2, num_blocks, block_size, num_heads, head_size],
        and get_kv_cache_stride_order returns (1, 3, 0, 2, 4) then the physical
        ordering of dimensions is
        [num_blocks, num_heads, 2, block_size, head_size].

        If this function is unimplemented / raises NotImplementedError,
        the physical layout of the KV cache will match the logical shape.

        Args:
            include_num_layers_dimension: if True, includes an additional
                num_layers dimension, which is assumed to be prepended
                to the logical KV cache shape.
                With the above example, a return value (2, 4, 0, 1, 3, 5)
                corresponds to
                [num_blocks, num_heads, num_layers, 2, block_size, head_size].

                If an additional dimension is NOT included in the returned
                tuple, the physical layout will not include a layers dimension.

        Returns:
            A tuple of ints which is a permutation of range(len(shape)).
        """
109
110
        raise NotImplementedError

111
112
113
114
    @classmethod
    def full_cls_name(cls) -> tuple[str, str]:
        return (cls.__module__, cls.__qualname__)

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return []

    @classmethod
    def supports_head_size(cls, head_size: int) -> bool:
        supported_head_sizes = cls.get_supported_head_sizes()
        return (not supported_head_sizes) or head_size in supported_head_sizes

    @classmethod
    def supports_dtype(cls, dtype: torch.dtype) -> bool:
        return dtype in cls.supported_dtypes

    @classmethod
    def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
        if kv_cache_dtype is None:
            return True
        return (not cls.supported_kv_cache_dtypes) or (
            kv_cache_dtype in cls.supported_kv_cache_dtypes
        )

    @classmethod
    def supports_block_size(cls, block_size: int | None) -> bool:
        from vllm.config.cache import BlockSize

        if block_size is None:
            return True

        valid_sizes = get_args(BlockSize)
        if block_size not in valid_sizes:
            return False

147
148
        supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
        if not supported_kernel_block_sizes:
149
150
            return True

151
        for supported_size in supported_kernel_block_sizes:
152
153
154
155
156
157
            if isinstance(supported_size, MultipleOf):
                supported_size = supported_size.base
            # With hybrid_blocks feature, the framework-level block size
            # only needs to be a multiple of the kernel's requirement,
            # even if the kernel requires a fixed block_size.
            if block_size % supported_size == 0:
158
159
160
161
162
163
164
165
166
167
168
                return True
        return False

    @classmethod
    def is_mla(cls) -> bool:
        return False

    @classmethod
    def supports_sink(cls) -> bool:
        return False

169
170
171
172
    @classmethod
    def supports_mm_prefix(cls) -> bool:
        return False

173
174
175
176
    @classmethod
    def is_sparse(cls) -> bool:
        return False

177
178
179
180
181
182
183
184
185
    @classmethod
    def supports_attn_type(cls, attn_type: str) -> bool:
        """Check if backend supports a given attention type.

        By default, only supports decoder attention.
        Backends should override this to support other attention types.
        """
        return attn_type == AttentionType.DECODER

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    @classmethod
    def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
        return True

    @classmethod
    def supports_combination(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: "CacheDType | None",
        block_size: int | None,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
        device_capability: "DeviceCapability",
    ) -> str | None:
        return None

    @classmethod
    def validate_configuration(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: "CacheDType | None",
        block_size: int | None,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
214
        use_mm_prefix: bool,
215
        device_capability: "DeviceCapability",
216
        attn_type: str,
217
218
219
220
221
222
223
224
225
226
    ) -> list[str]:
        invalid_reasons = []
        if not cls.supports_head_size(head_size):
            invalid_reasons.append("head_size not supported")
        if not cls.supports_dtype(dtype):
            invalid_reasons.append("dtype not supported")
        if not cls.supports_kv_cache_dtype(kv_cache_dtype):
            invalid_reasons.append("kv_cache_dtype not supported")
        if not cls.supports_block_size(block_size):
            invalid_reasons.append("block_size not supported")
227
228
229
230
        if use_mm_prefix and not cls.supports_mm_prefix():
            invalid_reasons.append(
                "partial multimodal token full attention not supported"
            )
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        if use_mla != cls.is_mla():
            if use_mla:
                invalid_reasons.append("MLA not supported")
            else:
                invalid_reasons.append("non-MLA not supported")
        if has_sink and not cls.supports_sink():
            invalid_reasons.append("sink setting not supported")
        if use_sparse != cls.is_sparse():
            if use_sparse:
                invalid_reasons.append("sparse not supported")
            else:
                invalid_reasons.append("non-sparse not supported")
        if not cls.supports_compute_capability(device_capability):
            invalid_reasons.append("compute capability not supported")
245
246
        if not cls.supports_attn_type(attn_type):
            invalid_reasons.append(f"attention type {attn_type} not supported")
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        combination_reason = cls.supports_combination(
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla,
            has_sink,
            use_sparse,
            device_capability,
        )
        if combination_reason is not None:
            invalid_reasons.append(combination_reason)
        return invalid_reasons

    @classmethod
    def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
        return None

265

266
class AttentionMetadata:
267
    pass
268
269


270
T = TypeVar("T", bound=AttentionMetadata)
271
272


273
class AttentionLayer(Protocol):
274
    _q_scale: torch.Tensor
275
276
    _k_scale: torch.Tensor
    _v_scale: torch.Tensor
277
    _q_scale_float: float
278
279
    _k_scale_float: float
    _v_scale_float: float
280
    _prob_scale: torch.Tensor
281
282
283
284
285
286
287
288

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
289
    ) -> torch.Tensor: ...
290
291


292
class AttentionImpl(ABC, Generic[T]):
293
294
295
296
297
298
299
300
    # Whether the attention impl can return the softmax lse for decode.
    # Some features like decode context parallelism require the softmax lse.
    can_return_lse_for_decode: bool = False

    # some attention backends might not always want to return lse
    # even if they can return lse (for efficiency reasons)
    need_to_return_lse_for_decode: bool = False

301
302
303
304
305
306
307
308
309
310
    # Whether this attention implementation supports pre-quantized query input.
    # When True, the attention layer will quantize queries before passing them
    # to this backend, allowing torch.compile to fuse the quantization with
    # previous operations. This is typically supported when using FP8 KV cache
    # with compatible attention kernels (e.g., TRT-LLM).
    # Subclasses should set this in __init__.
    # TODO add support to more backends:
    # https://github.com/vllm-project/vllm/issues/25584
    supports_quant_query_input: bool = False

311
312
313
    dcp_world_size: int
    dcp_rank: int

314
315
316
317
318
319
    pcp_world_size: int
    pcp_rank: int

    total_cp_world_size: int
    total_cp_rank: int

320
321
322
323
324
    def __new__(cls, *args, **kwargs):
        # use __new__ so that all subclasses will call this
        self = super().__new__(cls)
        try:
            from vllm.distributed.parallel_state import get_dcp_group
325

326
327
328
329
330
331
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
332
333
334
335
336
337
338
339
340
341
342
        try:
            from vllm.distributed.parallel_state import get_pcp_group

            self.pcp_world_size = get_pcp_group().world_size
            self.pcp_rank = get_pcp_group().rank_in_group
        except AssertionError:
            self.pcp_world_size = 1
            self.pcp_rank = 0
        self.total_cp_world_size = self.pcp_world_size * self.dcp_world_size
        self.total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank

343
344
345
        self.need_to_return_lse_for_decode = (
            self.dcp_world_size > 1 and self.can_return_lse_for_decode
        )
346
347
        return self

348
349
350
351
352
353
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
354
355
356
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        sliding_window: int | None = None,
357
        kv_cache_dtype: str = "auto",
358
        logits_soft_cap: float | None = None,
359
        attn_type: str = AttentionType.DECODER,
360
        kv_sharing_target_layer_name: str | None = None,
361
362
363
364
365
366
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
367
        layer: AttentionLayer,
368
369
370
371
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
372
        attn_metadata: T,
373
374
375
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
376
377
    ) -> torch.Tensor:
        raise NotImplementedError
378

379
    def fused_output_quant_supported(self, quant_key: "QuantKey"):
380
381
382
383
384
        """
        Does this attention implementation support fused output quantization.
        This is used by the AttnFusionPass to only fuse output quantization
        onto implementations that support it.

385
        :param quant_key: QuantKey object that describes the quantization op
386
387
388
389
        :return: is fusion supported for this type of quantization
        """
        return False

390
391
392
    def process_weights_after_loading(self, act_dtype: torch.dtype):
        pass

393
394

class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
395
396
397
398
399
400
401
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
402
403
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
404
        kv_cache_dtype: str,
405
        logits_soft_cap: float | None,
406
        attn_type: str,
407
        kv_sharing_target_layer_name: str | None,
408
        # MLA Specific Arguments
409
        q_lora_rank: int | None,
410
411
412
413
414
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        qk_head_dim: int,
        v_head_dim: int,
415
        kv_b_proj: "ColumnParallelLinear",
416
        indexer: object | None = None,
417
418
419
    ) -> None:
        raise NotImplementedError

420
421
422
423
424
425
426
427
428
    @abstractmethod
    def forward(
        self,
        layer: AttentionLayer,
        hidden_states_or_cq: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: T,
429
430
431
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
432
433
    ) -> torch.Tensor:
        raise NotImplementedError
434
435
436
437


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return kv_cache_dtype != "auto"