layer.py 12.7 KB
Newer Older
1
"""Attention layer."""
2
from typing import Any, Dict, List, Optional
3
4
5

import torch
import torch.nn as nn
6
import torch.nn.functional as F
7

8
import vllm.envs as envs
9
from vllm.attention import AttentionMetadata, AttentionType
10
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
11
from vllm.config import CacheConfig, get_current_vllm_config
12
from vllm.forward_context import ForwardContext, get_forward_context
13
14
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
15
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
16
from vllm.platforms import _Backend, current_platform
17
from vllm.utils import direct_register_custom_op
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38


class Attention(nn.Module):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
39
        cache_config: Optional[CacheConfig] = None,
40
        quant_config: Optional[QuantizationConfig] = None,
41
        blocksparse_params: Optional[Dict[str, Any]] = None,
42
        logits_soft_cap: Optional[float] = None,
43
        per_layer_sliding_window: Optional[int] = None,
44
        prefix: str = "",
45
        attn_type: str = AttentionType.DECODER,
46
47
    ) -> None:
        super().__init__()
48
49
50
51
52
53
54
55
56
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

57
58
59
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
60
            is_attention_free = cache_config.is_attention_free
61
            calculate_kv_scales = cache_config.calculate_kv_scales
62
63
64
        else:
            kv_cache_dtype = "auto"
            block_size = 16
65
            is_attention_free = False
66
            calculate_kv_scales = False
67
68
        if num_kv_heads is None:
            num_kv_heads = num_heads
69

70
        # The default k/v_scale is set to 1.0. This is ignored
71
72
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
73
        # expect the pre-quantized k/v_scale to be loaded along
74
75
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
76
77
78
79
80
81
82
83
84
        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)

        # We also keep the float32 versions of k/v_scale for attention
        # backends that don't support tensors (Flashinfer)
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

85
        quant_method = quant_config.get_quant_method(
86
            self, prefix=prefix) if quant_config else None
87
        if quant_method is not None:
88
            assert isinstance(quant_method, BaseKVCacheMethod)
89
90
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
91
92
93
94
95
96
97
98
99
            if self.kv_cache_dtype == "fp8_e5m2":
                raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                 "fp8 checkpoints.")
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
100

101
102
103
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
104
105
106
        attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
                                        block_size, is_attention_free,
                                        blocksparse_params is not None)
107
        impl_cls = attn_backend.get_impl_cls()
108
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
109
                             alibi_slopes, sliding_window, kv_cache_dtype,
110
                             blocksparse_params, logits_soft_cap, attn_type)
111
112
113
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
114
        self.sliding_window = sliding_window
115
        self.backend = backend_name_to_enum(attn_backend.get_name())
116
        self.dtype = dtype
117

118
119
120
121
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
122
123
124
        self.use_direct_call = not current_platform.is_cuda_alike(
        ) and not current_platform.is_cpu()

125
        self.use_output = attn_backend.accept_output_buffer
126
127
128
129
130
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
131
        self.attn_type = attn_type
132
133
134
135
136
137
138
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([]) for _ in range(get_current_vllm_config(
            ).parallel_config.pipeline_parallel_size)
        ]
139

140
141
142
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

143
144
145
146
147
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
148
149
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
150
    ) -> torch.Tensor:
151
152
153
        if self.calculate_kv_scales and \
            attn_metadata.enable_kv_scales_calculation:
            self.calc_kv_scales(key, value)
154
        if self.use_output:
155
156
157
158
159
160
161
162
163
164
165
            output = torch.empty_like(query)
            hidden_size = query.size(-1)
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
            output = output.view(-1, self.num_heads, self.head_size)
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
                value = value.view(-1, self.num_kv_heads, self.head_size)
166
167
168
169
170
171
            if self.use_direct_call:
                unified_attention_with_output(query, key, value, output,
                                              self.layer_name)
            else:
                torch.ops.vllm.unified_attention_with_output(
                    query, key, value, output, self.layer_name)
172
            return output.view(-1, hidden_size)
173
        else:
174
175
176
177
178
            if self.use_direct_call:
                return unified_attention(query, key, value, self.layer_name)
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name)
179

180
181
182
183
184
185
186
187
    def calc_kv_scales(self, key, value):
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

188
189
190
191
192
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
193
        s += f", backend={self.impl.__class__.__name__}"
194
        return s
195
196


197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

        dtype = torch.get_default_dtype()
        attn_backend = get_attn_backend(head_size,
                                        dtype,
                                        kv_cache_dtype=None,
                                        block_size=16,
                                        is_attention_free=False)
219
        backend = backend_name_to_enum(attn_backend.get_name())
220
221
        if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
            backend = _Backend.XFORMERS
222

223
        self.attn_backend = backend if backend in {
224
225
            _Backend.TORCH_SDPA,
            _Backend.XFORMERS,
226
227
228
229
230
231
232
233
234
        } else _Backend.TORCH_SDPA

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: batch_size x seq_len x hidden_size"""
235
        # TODO(Isotr0py): Use existing backend implementations and support FA3
236
237
238
239
240
241
242
        bsz, q_len, _ = query.size()
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

243
        if self.attn_backend == _Backend.XFORMERS:
244
245
246
247
248
249
250
251
252
253
254
255
256
257
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query,
                                                          key,
                                                          value,
                                                          scale=self.scale)
        elif self.attn_backend == _Backend.TORCH_SDPA:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            out = F.scaled_dot_product_attention(query,
                                                 key,
                                                 value,
                                                 scale=self.scale)
            out = out.transpose(1, 2)
258
        return out.reshape(bsz, q_len, -1)
259
260


261
262
263
264
265
266
267
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    forward_context: ForwardContext = get_forward_context()
268
269
270
    attn_metadata = forward_context.attn_metadata
    self = forward_context.attn_layers[layer_name]
    kv_cache = self.kv_cache[forward_context.virtual_engine]
271
    return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
272
273
274
275
276
277
278
279
280
281
282
283
284
285


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
286
    mutates_args=[],
287
288
289
    fake_impl=unified_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)
290
291
292
293
294
295
296
297
298
299


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
300
301
302
    attn_metadata = forward_context.attn_metadata
    self = forward_context.attn_layers[layer_name]
    kv_cache = self.kv_cache[forward_context.virtual_engine]
303
304
    self.impl.forward(self,
                      query,
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
                      key,
                      value,
                      kv_cache,
                      attn_metadata,
                      output=output)


def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
325
    mutates_args=["output"],
326
327
328
    fake_impl=unified_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)