abstract.py 12.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args
6
7
8

import torch

9
from vllm.model_executor.layers.linear import ColumnParallelLinear
10
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
11

12
13
14
15
16
if TYPE_CHECKING:
    from vllm.config.cache import CacheDType
    from vllm.platforms.interface import DeviceCapability
    from vllm.v1.attention.backends.utils import KVCacheLayoutType

17

18
19
20
21
22
class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """
23

24
    DECODER = "decoder"
25
    """Decoder attention between previous layer Q/K/V."""
26
    ENCODER = "encoder"
27
    """Encoder attention between previous layer Q/K/V for encoder-decoder."""
28
    ENCODER_ONLY = "encoder_only"
29
    """Encoder attention between previous layer Q/K/V."""
30
    ENCODER_DECODER = "encoder_decoder"
31
    """Attention between dec. Q and enc. K/V for encoder-decoder."""
32
33


34
35
36
37
38
39
40
class MultipleOf:
    base: int

    def __init__(self, base: int):
        self.base = base


41
42
class AttentionBackend(ABC):
    """Abstract class for attention backends."""
43

44
45
46
47
    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False
48
49
50
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
    supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
    supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
51

52
53
54
55
56
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

57
58
    @staticmethod
    @abstractmethod
59
    def get_impl_cls() -> type["AttentionImpl"]:
60
61
        raise NotImplementedError

62
63
    @staticmethod
    @abstractmethod
64
    def get_builder_cls():  # -> Type["AttentionMetadataBuilder"]:
65
66
        raise NotImplementedError

67
68
69
70
71
72
73
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
74
        cache_dtype_str: str = "auto",
75
    ) -> tuple[int, ...]:
76
77
        raise NotImplementedError

78
    @staticmethod
79
    def get_kv_cache_stride_order() -> tuple[int, ...]:
80
81
        raise NotImplementedError

82
83
84
85
    @classmethod
    def full_cls_name(cls) -> tuple[str, str]:
        return (cls.__module__, cls.__qualname__)

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
        return []

    @classmethod
    def supports_head_size(cls, head_size: int) -> bool:
        supported_head_sizes = cls.get_supported_head_sizes()
        return (not supported_head_sizes) or head_size in supported_head_sizes

    @classmethod
    def supports_dtype(cls, dtype: torch.dtype) -> bool:
        return dtype in cls.supported_dtypes

    @classmethod
    def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool:
        if kv_cache_dtype is None:
            return True
        return (not cls.supported_kv_cache_dtypes) or (
            kv_cache_dtype in cls.supported_kv_cache_dtypes
        )

    @classmethod
    def supports_block_size(cls, block_size: int | None) -> bool:
        from vllm.config.cache import BlockSize

        if block_size is None:
            return True

        valid_sizes = get_args(BlockSize)
        if block_size not in valid_sizes:
            return False

        if not cls.supported_kernel_block_sizes:
            return True

        for supported_size in cls.supported_kernel_block_sizes:
            is_multiple_of = (
                isinstance(supported_size, MultipleOf)
                and block_size % supported_size.base == 0
            )
            is_int_equal = (
                isinstance(supported_size, int) and block_size == supported_size
            )
            if is_multiple_of or is_int_equal:
                return True
        return False

    @classmethod
    def is_mla(cls) -> bool:
        return False

    @classmethod
    def supports_sink(cls) -> bool:
        return False

    @classmethod
    def is_sparse(cls) -> bool:
        return False

145
146
147
148
149
150
151
152
153
154
155
    @classmethod
    def supports_attn_type(cls, attn_type: str) -> bool:
        """Check if backend supports a given attention type.

        By default, only supports decoder attention.
        Backends should override this to support other attention types.
        """
        from vllm.attention import AttentionType

        return attn_type == AttentionType.DECODER

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    @classmethod
    def supports_compute_capability(cls, capability: "DeviceCapability") -> bool:
        return True

    @classmethod
    def supports_combination(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: "CacheDType | None",
        block_size: int | None,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
        device_capability: "DeviceCapability",
    ) -> str | None:
        return None

    @classmethod
    def validate_configuration(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: "CacheDType | None",
        block_size: int | None,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
        device_capability: "DeviceCapability",
185
        attn_type: str,
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    ) -> list[str]:
        invalid_reasons = []
        if not cls.supports_head_size(head_size):
            invalid_reasons.append("head_size not supported")
        if not cls.supports_dtype(dtype):
            invalid_reasons.append("dtype not supported")
        if not cls.supports_kv_cache_dtype(kv_cache_dtype):
            invalid_reasons.append("kv_cache_dtype not supported")
        if not cls.supports_block_size(block_size):
            invalid_reasons.append("block_size not supported")
        if use_mla != cls.is_mla():
            if use_mla:
                invalid_reasons.append("MLA not supported")
            else:
                invalid_reasons.append("non-MLA not supported")
        if has_sink and not cls.supports_sink():
            invalid_reasons.append("sink setting not supported")
        if use_sparse != cls.is_sparse():
            if use_sparse:
                invalid_reasons.append("sparse not supported")
            else:
                invalid_reasons.append("non-sparse not supported")
        if not cls.supports_compute_capability(device_capability):
            invalid_reasons.append("compute capability not supported")
210
211
        if not cls.supports_attn_type(attn_type):
            invalid_reasons.append(f"attention type {attn_type} not supported")
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        combination_reason = cls.supports_combination(
            head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla,
            has_sink,
            use_sparse,
            device_capability,
        )
        if combination_reason is not None:
            invalid_reasons.append(combination_reason)
        return invalid_reasons

    @classmethod
    def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
        return None

230

231
class AttentionMetadata:
232
    pass
233
234


235
T = TypeVar("T", bound=AttentionMetadata)
236
237


238
class AttentionLayer(Protocol):
239
    _q_scale: torch.Tensor
240
241
    _k_scale: torch.Tensor
    _v_scale: torch.Tensor
242
    _q_scale_float: float
243
244
    _k_scale_float: float
    _v_scale_float: float
245
    _prob_scale: torch.Tensor
246
247
248
249
250
251
252
253

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
254
    ) -> torch.Tensor: ...
255
256


257
class AttentionImpl(ABC, Generic[T]):
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    # Whether the attention impl can return the softmax lse for decode.
    # Some features like decode context parallelism require the softmax lse.
    can_return_lse_for_decode: bool = False

    # some attention backends might not always want to return lse
    # even if they can return lse (for efficiency reasons)
    need_to_return_lse_for_decode: bool = False

    dcp_world_size: int
    dcp_rank: int

    def __new__(cls, *args, **kwargs):
        # use __new__ so that all subclasses will call this
        self = super().__new__(cls)
        try:
            from vllm.distributed.parallel_state import get_dcp_group
274

275
276
277
278
279
280
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
281
282
283
        self.need_to_return_lse_for_decode = (
            self.dcp_world_size > 1 and self.can_return_lse_for_decode
        )
284
285
        return self

286
287
288
289
290
291
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
292
293
294
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        sliding_window: int | None = None,
295
        kv_cache_dtype: str = "auto",
296
        logits_soft_cap: float | None = None,
297
        attn_type: str = AttentionType.DECODER,
298
        kv_sharing_target_layer_name: str | None = None,
299
300
301
302
303
304
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
305
        layer: AttentionLayer,
306
307
308
309
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
310
        attn_metadata: T,
311
312
313
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
314
315
    ) -> torch.Tensor:
        raise NotImplementedError
316

317
    def fused_output_quant_supported(self, quant_key: QuantKey):
318
319
320
321
322
        """
        Does this attention implementation support fused output quantization.
        This is used by the AttnFusionPass to only fuse output quantization
        onto implementations that support it.

323
        :param quant_key: QuantKey object that describes the quantization op
324
325
326
327
        :return: is fusion supported for this type of quantization
        """
        return False

328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    def supports_quant_query_input(self) -> bool:
        """
        Check if this attention implementation supports pre-quantized query input.

        When True, the attention layer will quantize queries before passing them
        to this backend, allowing torch.compile to fuse the quantization with
        previous operations. This is typically supported when using FP8 KV cache
        with compatible attention kernels (e.g., TRT-LLM).
        TODO add support to more backends:
        https://github.com/vllm-project/vllm/issues/25584

        Returns:
            bool: True if the implementation can accept pre-quantized queries.
        """
        return False

344
345
346
    def process_weights_after_loading(self, act_dtype: torch.dtype):
        pass

347
348

class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
349
350
351
352
353
354
355
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
356
357
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
358
        kv_cache_dtype: str,
359
        logits_soft_cap: float | None,
360
        attn_type: str,
361
        kv_sharing_target_layer_name: str | None,
362
        # MLA Specific Arguments
363
        q_lora_rank: int | None,
364
365
366
367
368
369
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        qk_head_dim: int,
        v_head_dim: int,
        kv_b_proj: ColumnParallelLinear,
370
        indexer: object | None = None,
371
372
373
    ) -> None:
        raise NotImplementedError

374
375
376
377
378
379
380
381
382
    @abstractmethod
    def forward(
        self,
        layer: AttentionLayer,
        hidden_states_or_cq: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: T,
383
384
385
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
386
387
    ) -> torch.Tensor:
        raise NotImplementedError
388
389
390
391


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return kv_cache_dtype != "auto"