layer.py 12.2 KB
Newer Older
1
"""Attention layer."""
2
from typing import Any, Dict, List, Optional
3
4
5

import torch
import torch.nn as nn
6
import torch.nn.functional as F
7

8
from vllm.attention import AttentionMetadata, AttentionType
9
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
10
from vllm.config import CacheConfig, get_current_vllm_config
11
from vllm.forward_context import ForwardContext, get_forward_context
12
13
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
14
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
15
from vllm.platforms import _Backend, current_platform
16
from vllm.utils import direct_register_custom_op
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37


class Attention(nn.Module):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
38
        cache_config: Optional[CacheConfig] = None,
39
        quant_config: Optional[QuantizationConfig] = None,
40
        blocksparse_params: Optional[Dict[str, Any]] = None,
41
        logits_soft_cap: Optional[float] = None,
42
        per_layer_sliding_window: Optional[int] = None,
43
        prefix: str = "",
44
45
    ) -> None:
        super().__init__()
46
47
48
49
50
51
52
53
54
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

55
56
57
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
58
            is_attention_free = cache_config.is_attention_free
59
60
61
        else:
            kv_cache_dtype = "auto"
            block_size = 16
62
            is_attention_free = False
63
64
        if num_kv_heads is None:
            num_kv_heads = num_heads
65

66
        # The default k/v_scale is set to 1.0. This is ignored
67
68
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
69
        # expect the pre-quantized k/v_scale to be loaded along
70
71
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
72
73
        self._k_scale = 1.0
        self._v_scale = 1.0
74
        quant_method = quant_config.get_quant_method(
75
            self, prefix=prefix) if quant_config else None
76
        if quant_method is not None:
77
            assert isinstance(quant_method, BaseKVCacheMethod)
78
79
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
80
81
82
83
84
85
86
87
88
            if self.kv_cache_dtype == "fp8_e5m2":
                raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                 "fp8 checkpoints.")
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
89

90
91
92
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
93
94
95
        attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
                                        block_size, is_attention_free,
                                        blocksparse_params is not None)
96
        impl_cls = attn_backend.get_impl_cls()
97
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
98
                             alibi_slopes, sliding_window, kv_cache_dtype,
99
                             blocksparse_params, logits_soft_cap)
100
101
102
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
103
        self.backend = backend_name_to_enum(attn_backend.get_name())
104

105
106
107
108
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
109
110
111
112
113
114
115
116
        self.use_direct_call = not current_platform.is_cuda_alike(
        ) and not current_platform.is_cpu()

        # For some attention backends, we allocate an output tensor before
        # calling the custom op. When piecewise cudagraph is enabled, this
        # makes sure the output tensor is allocated inside the cudagraph.
        self.use_output = self.backend == _Backend.FLASH_ATTN or \
            self.backend == _Backend.FLASH_ATTN_VLLM_V1
117
118
119
120
121
122
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix

123
124
125
126
127
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
128
        kv_cache: torch.Tensor,
129
        attn_metadata: AttentionMetadata,
130
        attn_type: str = AttentionType.DECODER,
131
    ) -> torch.Tensor:
132

133
134
135
136
137
138
139
140
141
        if self.use_direct_call:
            return self.impl.forward(query,
                                     key,
                                     value,
                                     kv_cache,
                                     attn_metadata,
                                     self._k_scale,
                                     self._v_scale,
                                     attn_type=attn_type)
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        elif self.use_output:
            output = torch.empty_like(query)
            hidden_size = query.size(-1)
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
            output = output.view(-1, self.num_heads, self.head_size)
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
                value = value.view(-1, self.num_kv_heads, self.head_size)
            torch.ops.vllm.unified_attention_with_output(
                query, key, value, output, kv_cache, attn_type,
                self.layer_name)
            return output.view(-1, hidden_size)
158
159
160
161
        else:
            return torch.ops.vllm.unified_attention(query, key, value,
                                                    kv_cache, attn_type,
                                                    self.layer_name)
162
163
164
165
166
167

    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
168
        s += f", backend={self.impl.__class__.__name__}"
169
        return s
170
171


172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

        dtype = torch.get_default_dtype()
        attn_backend = get_attn_backend(head_size,
                                        dtype,
                                        kv_cache_dtype=None,
                                        block_size=16,
                                        is_attention_free=False)
194
        attn_backend = backend_name_to_enum(attn_backend.get_name())
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
            attn_backend = _Backend.XFORMERS

        self.attn_backend = attn_backend if attn_backend in {
            _Backend.TORCH_SDPA, _Backend.XFORMERS
        } else _Backend.TORCH_SDPA

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: batch_size x seq_len x hidden_size"""
        # TODO(Isotr0py): Use existing backend implementations and support FA2
        bsz, q_len, _ = query.size()
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

        if self.attn_backend == _Backend.XFORMERS:
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query,
                                                          key,
                                                          value,
                                                          scale=self.scale)
        elif self.attn_backend == _Backend.TORCH_SDPA:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            out = F.scaled_dot_product_attention(query,
                                                 key,
                                                 value,
                                                 scale=self.scale)
            out = out.transpose(1, 2)
        return out.view(bsz, q_len, -1)


235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_type: str,
    layer_name: str,
) -> torch.Tensor:
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.dynamic_forward_context
    self = forward_context.static_forward_context[layer_name]
    return self.impl.forward(query,
                             key,
                             value,
                             kv_cache,
                             attn_metadata,
                             self._k_scale,
                             self._v_scale,
                             attn_type=attn_type)


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_type: str,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    mutates_args=["kv_cache"],
    fake_impl=unified_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_type: str,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.dynamic_forward_context
    self = forward_context.static_forward_context[layer_name]
    self.impl.forward(query,
                      key,
                      value,
                      kv_cache,
                      attn_metadata,
                      self._k_scale,
                      self._v_scale,
                      attn_type=attn_type,
                      output=output)


def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_type: str,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
    mutates_args=["kv_cache", "output"],
    fake_impl=unified_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)