abstract.py 6.13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from typing import Generic, Optional, Protocol, TypeVar
6
7
8

import torch

9
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
10

11

12
13
14
15
16
class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """
17

18
    DECODER = "decoder"
19
    """Decoder attention between previous layer Q/K/V."""
20
    ENCODER = "encoder"
21
    """Encoder attention between previous layer Q/K/V for encoder-decoder."""
22
    ENCODER_ONLY = "encoder_only"
23
    """Encoder attention between previous layer Q/K/V."""
24
    ENCODER_DECODER = "encoder_decoder"
25
    """Attention between dec. Q and enc. K/V for encoder-decoder."""
26
27


28
29
class AttentionBackend(ABC):
    """Abstract class for attention backends."""
30

31
32
33
34
    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False
35

36
37
38
39
40
41
42
43
    # Whether this backend supports receiving pre-quantized query input.
    # If True, the attention layer will handle query quantization instead
    # of the backend, allowing torch.compile to fuse quantization with
    # previous operations.
    # Needs to be worked through for all backends
    # https://github.com/vllm-project/vllm/issues/25584
    supports_quant_query_input: bool = False

44
45
46
47
48
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

49
50
    @staticmethod
    @abstractmethod
51
    def get_impl_cls() -> type["AttentionImpl"]:
52
53
54
55
        raise NotImplementedError

    @staticmethod
    @abstractmethod
56
    def get_metadata_cls() -> type["AttentionMetadata"]:
57
58
        raise NotImplementedError

59
60
61
62
    @classmethod
    def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
        return cls.get_metadata_cls()(*args, **kwargs)

63
64
    @staticmethod
    @abstractmethod
65
    def get_builder_cls():  # -> Type["AttentionMetadataBuilder"]:
66
67
        raise NotImplementedError

68
69
70
71
72
73
74
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
75
        cache_dtype_str: str = "auto",
76
    ) -> tuple[int, ...]:
77
78
        raise NotImplementedError

79
    @staticmethod
80
    def get_kv_cache_stride_order() -> tuple[int, ...]:
81
82
        raise NotImplementedError

83
84
85
86
    @classmethod
    def full_cls_name(cls) -> tuple[str, str]:
        return (cls.__module__, cls.__qualname__)

87

88
class AttentionMetadata:
89
    pass
90
91


92
T = TypeVar("T", bound=AttentionMetadata)
93
94


95
class AttentionLayer(Protocol):
96
    _q_scale: torch.Tensor
97
98
    _k_scale: torch.Tensor
    _v_scale: torch.Tensor
99
    _q_scale_float: float
100
101
    _k_scale_float: float
    _v_scale_float: float
102
    _prob_scale: torch.Tensor
103
104
105
106
107
108
109
110

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
111
    ) -> torch.Tensor: ...
112
113


114
class AttentionImpl(ABC, Generic[T]):
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    # Whether the attention impl can return the softmax lse for decode.
    # Some features like decode context parallelism require the softmax lse.
    can_return_lse_for_decode: bool = False

    # some attention backends might not always want to return lse
    # even if they can return lse (for efficiency reasons)
    need_to_return_lse_for_decode: bool = False

    dcp_world_size: int
    dcp_rank: int

    def __new__(cls, *args, **kwargs):
        # use __new__ so that all subclasses will call this
        self = super().__new__(cls)
        try:
            from vllm.distributed.parallel_state import get_dcp_group
131

132
133
134
135
136
137
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
138
139
140
        self.need_to_return_lse_for_decode = (
            self.dcp_world_size > 1 and self.can_return_lse_for_decode
        )
141
142
        return self

143
144
145
146
147
148
149
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
150
        alibi_slopes: Optional[list[float]] = None,
151
        sliding_window: Optional[int] = None,
152
        kv_cache_dtype: str = "auto",
153
        logits_soft_cap: Optional[float] = None,
154
        attn_type: str = AttentionType.DECODER,
155
        kv_sharing_target_layer_name: Optional[str] = None,
156
157
158
159
160
161
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
162
        layer: AttentionLayer,
163
164
165
166
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
167
        attn_metadata: T,
168
        output: Optional[torch.Tensor] = None,
169
        output_scale: Optional[torch.Tensor] = None,
170
        output_block_scale: Optional[torch.Tensor] = None,
171
172
    ) -> torch.Tensor:
        raise NotImplementedError
173

174
    def fused_output_quant_supported(self, quant_key: QuantKey):
175
176
177
178
179
        """
        Does this attention implementation support fused output quantization.
        This is used by the AttnFusionPass to only fuse output quantization
        onto implementations that support it.

180
        :param quant_key: QuantKey object that describes the quantization op
181
182
183
184
        :return: is fusion supported for this type of quantization
        """
        return False

185
186
187
188
189
190
191
192
193
194
195
196

class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
    @abstractmethod
    def forward(
        self,
        layer: AttentionLayer,
        hidden_states_or_cq: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: T,
        output: Optional[torch.Tensor] = None,
197
        output_scale: Optional[torch.Tensor] = None,
198
        output_block_scale: Optional[torch.Tensor] = None,
199
200
    ) -> torch.Tensor:
        raise NotImplementedError
201
202
203
204


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return kv_cache_dtype != "auto"