abstract.py 5.72 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import ABC, abstractmethod
5
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
6
7
8

import torch

9
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
10

11

12
13
14
15
16
17
class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """
    DECODER = "decoder"
18
    """Decoder attention between previous layer Q/K/V."""
19
    ENCODER = "encoder"
20
    """Encoder attention between previous layer Q/K/V for encoder-decoder."""
21
    ENCODER_ONLY = "encoder_only"
22
    """Encoder attention between previous layer Q/K/V."""
23
    ENCODER_DECODER = "encoder_decoder"
24
    """Attention between dec. Q and enc. K/V for encoder-decoder."""
25
26


27
28
class AttentionBackend(ABC):
    """Abstract class for attention backends."""
29
30
31
32
    # For some attention backends, we allocate an output tensor before
    # calling the custom op. When piecewise cudagraph is enabled, this
    # makes sure the output tensor is allocated inside the cudagraph.
    accept_output_buffer: bool = False
33

34
35
36
37
38
    @staticmethod
    @abstractmethod
    def get_name() -> str:
        raise NotImplementedError

39
40
41
42
43
44
45
    @staticmethod
    @abstractmethod
    def get_impl_cls() -> Type["AttentionImpl"]:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
46
    def get_metadata_cls() -> Type["AttentionMetadata"]:
47
48
        raise NotImplementedError

49
50
51
52
    @classmethod
    def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
        return cls.get_metadata_cls()(*args, **kwargs)

53
54
    @staticmethod
    @abstractmethod
55
    def get_builder_cls():  # -> Type["AttentionMetadataBuilder"]:
56
57
        raise NotImplementedError

58
59
60
61
62
63
64
65
66
67
    @staticmethod
    @abstractmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        raise NotImplementedError

68
69
70
71
    @staticmethod
    def get_kv_cache_stride_order() -> Tuple[int, ...]:
        raise NotImplementedError

72
73
74
75
    @classmethod
    def full_cls_name(cls) -> tuple[str, str]:
        return (cls.__module__, cls.__qualname__)

76

77
class AttentionMetadata:
78
    pass
79
80


81
T = TypeVar("T", bound=AttentionMetadata)
82
83


84
85
class AttentionLayer(Protocol):

86
    _q_scale: torch.Tensor
87
88
    _k_scale: torch.Tensor
    _v_scale: torch.Tensor
89
    _q_scale_float: float
90
91
    _k_scale_float: float
    _v_scale_float: float
92
    _prob_scale: torch.Tensor
93
94
95
96
97
98
99
100
101
102
103
104

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        ...


105
class AttentionImpl(ABC, Generic[T]):
106

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    # Whether the attention impl can return the softmax lse for decode.
    # Some features like decode context parallelism require the softmax lse.
    can_return_lse_for_decode: bool = False

    # some attention backends might not always want to return lse
    # even if they can return lse (for efficiency reasons)
    need_to_return_lse_for_decode: bool = False

    dcp_world_size: int
    dcp_rank: int

    def __new__(cls, *args, **kwargs):
        # use __new__ so that all subclasses will call this
        self = super().__new__(cls)
        try:
            from vllm.distributed.parallel_state import get_dcp_group
            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0
        self.need_to_return_lse_for_decode = self.dcp_world_size > 1 \
            and self.can_return_lse_for_decode
        return self

133
134
135
136
137
138
139
140
141
    @abstractmethod
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
142
        kv_cache_dtype: str = "auto",
143
        logits_soft_cap: Optional[float] = None,
144
        attn_type: str = AttentionType.DECODER,
145
        kv_sharing_target_layer_name: Optional[str] = None,
146
147
148
149
150
151
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def forward(
        self,
152
        layer: AttentionLayer,
153
154
155
156
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
157
        attn_metadata: T,
158
        output: Optional[torch.Tensor] = None,
159
        output_scale: Optional[torch.Tensor] = None,
160
        output_block_scale: Optional[torch.Tensor] = None,
161
162
    ) -> torch.Tensor:
        raise NotImplementedError
163

164
    def fused_output_quant_supported(self, quant_key: QuantKey):
165
166
167
168
169
        """
        Does this attention implementation support fused output quantization.
        This is used by the AttnFusionPass to only fuse output quantization
        onto implementations that support it.

170
        :param quant_key: QuantKey object that describes the quantization op
171
172
173
174
        :return: is fusion supported for this type of quantization
        """
        return False

175
176
177
178
179
180
181
182
183
184
185
186
187

class MLAAttentionImpl(AttentionImpl[T], Generic[T]):

    @abstractmethod
    def forward(
        self,
        layer: AttentionLayer,
        hidden_states_or_cq: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: T,
        output: Optional[torch.Tensor] = None,
188
        output_scale: Optional[torch.Tensor] = None,
189
        output_block_scale: Optional[torch.Tensor] = None,
190
191
    ) -> torch.Tensor:
        raise NotImplementedError
192
193
194
195


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
    return kv_cache_dtype != "auto"