abstract.py 7.61 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from typing import Generic, Protocol, TypeVar
6
7
8

import torch

9
from vllm.model_executor.layers.linear import ColumnParallelLinear
10
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
11

12

13
14
15
16
17
class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """
18

19
    DECODER = "decoder"
20
    """Decoder attention between previous layer Q/K/V."""
21
    ENCODER = "encoder"
22
    """Encoder attention between previous layer Q/K/V for encoder-decoder."""
23
    ENCODER_ONLY = "encoder_only"
24
    """Encoder attention between previous layer Q/K/V."""
25
    ENCODER_DECODER = "encoder_decoder"
26
    """Attention between dec. Q and enc. K/V for encoder-decoder."""
27
28


29
30
31
32
33
34
35
class MultipleOf:
    base: int

    def __init__(self, base: int):
        self.base = base


36
37
class AttentionBackend(ABC):
    """Abstract class for attention backends."""
38

39
40
41
42
    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False
43

44
45
46
47
48
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

49
50
    @staticmethod
    @abstractmethod
51
    def get_impl_cls() -> type["AttentionImpl"]:
52
53
54
55
        raise NotImplementedError

    @staticmethod
    @abstractmethod
56
    def get_metadata_cls() -> type["AttentionMetadata"]:
57
58
        raise NotImplementedError

59
    @classmethod
60
    def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
61
62
        return cls.get_impl_cls().get_supported_kernel_block_size()

63
64
65
66
    @classmethod
    def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
        return cls.get_metadata_cls()(*args, **kwargs)

67
68
    @staticmethod
    @abstractmethod
69
    def get_builder_cls():  # -> Type["AttentionMetadataBuilder"]:
70
71
        raise NotImplementedError

72
73
74
75
76
77
78
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
79
        cache_dtype_str: str = "auto",
80
    ) -> tuple[int, ...]:
81
82
        raise NotImplementedError

83
    @staticmethod
84
    def get_kv_cache_stride_order() -> tuple[int, ...]:
85
86
        raise NotImplementedError

87
88
89
90
    @classmethod
    def full_cls_name(cls) -> tuple[str, str]:
        return (cls.__module__, cls.__qualname__)

91

92
class AttentionMetadata:
93
    pass
94
95


96
T = TypeVar("T", bound=AttentionMetadata)
97
98


99
class AttentionLayer(Protocol):
100
    _q_scale: torch.Tensor
101
102
    _k_scale: torch.Tensor
    _v_scale: torch.Tensor
103
    _q_scale_float: float
104
105
    _k_scale_float: float
    _v_scale_float: float
106
    _prob_scale: torch.Tensor
107
108
109
110
111
112
113
114

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
115
    ) -> torch.Tensor: ...
116
117


118
class AttentionImpl(ABC, Generic[T]):
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    # Whether the attention impl can return the softmax lse for decode.
    # Some features like decode context parallelism require the softmax lse.
    can_return_lse_for_decode: bool = False

    # some attention backends might not always want to return lse
    # even if they can return lse (for efficiency reasons)
    need_to_return_lse_for_decode: bool = False

    dcp_world_size: int
    dcp_rank: int

    def __new__(cls, *args, **kwargs):
        # use __new__ so that all subclasses will call this
        self = super().__new__(cls)
        try:
            from vllm.distributed.parallel_state import get_dcp_group
135

136
137
138
139
140
141
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
142
143
144
        self.need_to_return_lse_for_decode = (
            self.dcp_world_size > 1 and self.can_return_lse_for_decode
        )
145
146
        return self

147
148
149
150
151
152
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
153
154
155
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        sliding_window: int | None = None,
156
        kv_cache_dtype: str = "auto",
157
        logits_soft_cap: float | None = None,
158
        attn_type: str = AttentionType.DECODER,
159
        kv_sharing_target_layer_name: str | None = None,
160
161
162
    ) -> None:
        raise NotImplementedError

163
    @staticmethod
164
    def get_supported_kernel_block_size() -> list[int | MultipleOf]:
165
166
167
        # TODO: implement this function for all backends.
        return [MultipleOf(1)]

168
169
170
    @abstractmethod
    def forward(
        self,
171
        layer: AttentionLayer,
172
173
174
175
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
176
        attn_metadata: T,
177
178
179
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
180
181
    ) -> torch.Tensor:
        raise NotImplementedError
182

183
    def fused_output_quant_supported(self, quant_key: QuantKey):
184
185
186
187
188
        """
        Does this attention implementation support fused output quantization.
        This is used by the AttnFusionPass to only fuse output quantization
        onto implementations that support it.

189
        :param quant_key: QuantKey object that describes the quantization op
190
191
192
193
        :return: is fusion supported for this type of quantization
        """
        return False

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    def supports_quant_query_input(self) -> bool:
        """
        Check if this attention implementation supports pre-quantized query input.

        When True, the attention layer will quantize queries before passing them
        to this backend, allowing torch.compile to fuse the quantization with
        previous operations. This is typically supported when using FP8 KV cache
        with compatible attention kernels (e.g., TRT-LLM).
        TODO add support to more backends:
        https://github.com/vllm-project/vllm/issues/25584

        Returns:
            bool: True if the implementation can accept pre-quantized queries.
        """
        return False

210
211
212
    def process_weights_after_loading(self, act_dtype: torch.dtype):
        pass

213
214

class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
215
216
217
218
219
220
221
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
222
223
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
224
        kv_cache_dtype: str,
225
        logits_soft_cap: float | None,
226
        attn_type: str,
227
        kv_sharing_target_layer_name: str | None,
228
        # MLA Specific Arguments
229
        q_lora_rank: int | None,
230
231
232
233
234
235
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        qk_head_dim: int,
        v_head_dim: int,
        kv_b_proj: ColumnParallelLinear,
236
        indexer: object | None = None,
237
238
239
    ) -> None:
        raise NotImplementedError

240
241
242
243
244
245
246
247
248
    @abstractmethod
    def forward(
        self,
        layer: AttentionLayer,
        hidden_states_or_cq: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: T,
249
250
251
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
252
253
    ) -> torch.Tensor:
        raise NotImplementedError
254
255
256
257


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return kv_cache_dtype != "auto"