abstract.py 7.26 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from typing import Generic, Protocol, TypeVar
6
7
8

import torch

9
from vllm.model_executor.layers.linear import ColumnParallelLinear
10
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
11

12

13
14
15
16
17
class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """
18

19
    DECODER = "decoder"
20
    """Decoder attention between previous layer Q/K/V."""
21
    ENCODER = "encoder"
22
    """Encoder attention between previous layer Q/K/V for encoder-decoder."""
23
    ENCODER_ONLY = "encoder_only"
24
    """Encoder attention between previous layer Q/K/V."""
25
    ENCODER_DECODER = "encoder_decoder"
26
    """Attention between dec. Q and enc. K/V for encoder-decoder."""
27
28


29
30
31
32
33
34
35
class MultipleOf:
    base: int

    def __init__(self, base: int):
        self.base = base


36
37
class AttentionBackend(ABC):
    """Abstract class for attention backends."""
38

39
40
41
42
    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False
43

44
45
46
47
48
49
50
51
    # Whether this backend supports receiving pre-quantized query input.
    # If True, the attention layer will handle query quantization instead
    # of the backend, allowing torch.compile to fuse quantization with
    # previous operations.
    # Needs to be worked through for all backends
    # https://github.com/vllm-project/vllm/issues/25584
    supports_quant_query_input: bool = False

52
53
54
55
56
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

57
58
    @staticmethod
    @abstractmethod
59
    def get_impl_cls() -> type["AttentionImpl"]:
60
61
62
63
        raise NotImplementedError

    @staticmethod
    @abstractmethod
64
    def get_metadata_cls() -> type["AttentionMetadata"]:
65
66
        raise NotImplementedError

67
    @classmethod
68
    def get_supported_kernel_block_size(cls) -> list[int | MultipleOf]:
69
70
        return cls.get_impl_cls().get_supported_kernel_block_size()

71
72
73
74
    @classmethod
    def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
        return cls.get_metadata_cls()(*args, **kwargs)

75
76
    @staticmethod
    @abstractmethod
77
    def get_builder_cls():  # -> Type["AttentionMetadataBuilder"]:
78
79
        raise NotImplementedError

80
81
82
83
84
85
86
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
87
        cache_dtype_str: str = "auto",
88
    ) -> tuple[int, ...]:
89
90
        raise NotImplementedError

91
    @staticmethod
92
    def get_kv_cache_stride_order() -> tuple[int, ...]:
93
94
        raise NotImplementedError

95
96
97
98
    @classmethod
    def full_cls_name(cls) -> tuple[str, str]:
        return (cls.__module__, cls.__qualname__)

99

100
class AttentionMetadata:
101
    pass
102
103


104
T = TypeVar("T", bound=AttentionMetadata)
105
106


107
class AttentionLayer(Protocol):
108
    _q_scale: torch.Tensor
109
110
    _k_scale: torch.Tensor
    _v_scale: torch.Tensor
111
    _q_scale_float: float
112
113
    _k_scale_float: float
    _v_scale_float: float
114
    _prob_scale: torch.Tensor
115
116
117
118
119
120
121
122

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
123
    ) -> torch.Tensor: ...
124
125


126
class AttentionImpl(ABC, Generic[T]):
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    # Whether the attention impl can return the softmax lse for decode.
    # Some features like decode context parallelism require the softmax lse.
    can_return_lse_for_decode: bool = False

    # some attention backends might not always want to return lse
    # even if they can return lse (for efficiency reasons)
    need_to_return_lse_for_decode: bool = False

    dcp_world_size: int
    dcp_rank: int

    def __new__(cls, *args, **kwargs):
        # use __new__ so that all subclasses will call this
        self = super().__new__(cls)
        try:
            from vllm.distributed.parallel_state import get_dcp_group
143

144
145
146
147
148
149
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
150
151
152
        self.need_to_return_lse_for_decode = (
            self.dcp_world_size > 1 and self.can_return_lse_for_decode
        )
153
154
        return self

155
156
157
158
159
160
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
161
162
163
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        sliding_window: int | None = None,
164
        kv_cache_dtype: str = "auto",
165
        logits_soft_cap: float | None = None,
166
        attn_type: str = AttentionType.DECODER,
167
        kv_sharing_target_layer_name: str | None = None,
168
169
170
    ) -> None:
        raise NotImplementedError

171
    @staticmethod
172
    def get_supported_kernel_block_size() -> list[int | MultipleOf]:
173
174
175
        # TODO: implement this function for all backends.
        return [MultipleOf(1)]

176
177
178
    @abstractmethod
    def forward(
        self,
179
        layer: AttentionLayer,
180
181
182
183
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
184
        attn_metadata: T,
185
186
187
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
188
189
    ) -> torch.Tensor:
        raise NotImplementedError
190

191
    def fused_output_quant_supported(self, quant_key: QuantKey):
192
193
194
195
196
        """
        Does this attention implementation support fused output quantization.
        This is used by the AttnFusionPass to only fuse output quantization
        onto implementations that support it.

197
        :param quant_key: QuantKey object that describes the quantization op
198
199
200
201
        :return: is fusion supported for this type of quantization
        """
        return False

202
203

class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
204
205
206
207
208
209
210
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
211
212
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
213
        kv_cache_dtype: str,
214
        logits_soft_cap: float | None,
215
        attn_type: str,
216
        kv_sharing_target_layer_name: str | None,
217
        # MLA Specific Arguments
218
        q_lora_rank: int | None,
219
220
221
222
223
224
        kv_lora_rank: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        qk_head_dim: int,
        v_head_dim: int,
        kv_b_proj: ColumnParallelLinear,
225
        indexer: object | None = None,
226
227
228
    ) -> None:
        raise NotImplementedError

229
230
231
232
233
234
235
236
237
    @abstractmethod
    def forward(
        self,
        layer: AttentionLayer,
        hidden_states_or_cq: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: T,
238
239
240
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
241
242
    ) -> torch.Tensor:
        raise NotImplementedError
243
244
245
246


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return kv_cache_dtype != "auto"