"tests/vscode:/vscode.git/clone" did not exist on "9f14c9224d3d6664e2f5a2e7fecd012fd048fcb1"
layer.py 17.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Attention layer."""
3
from typing import Any, Dict, List, Optional
4
5
6

import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
import vllm.envs as envs
10
from vllm.attention import AttentionType
11
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
12
from vllm.config import CacheConfig, get_current_vllm_config
13
14
15
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group,
                                          is_v1_kv_transfer_group)
16
from vllm.forward_context import ForwardContext, get_forward_context
17
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
18
19
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
20
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
21
from vllm.platforms import _Backend, current_platform
22
from vllm.utils import direct_register_custom_op
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43


class Attention(nn.Module):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
44
        cache_config: Optional[CacheConfig] = None,
45
        quant_config: Optional[QuantizationConfig] = None,
46
        blocksparse_params: Optional[Dict[str, Any]] = None,
47
        logits_soft_cap: Optional[float] = None,
48
        per_layer_sliding_window: Optional[int] = None,
49
        use_mla: bool = False,
50
        prefix: str = "",
51
        attn_type: str = AttentionType.DECODER,
52
        **extra_impl_args,
53
    ) -> None:
54
55
56
57
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
58
        super().__init__()
59
60
61
62
63
64
65
66
67
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

68
69
70
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
71
            is_attention_free = cache_config.is_attention_free
72
            calculate_kv_scales = cache_config.calculate_kv_scales
73
74
75
        else:
            kv_cache_dtype = "auto"
            block_size = 16
76
            is_attention_free = False
77
            calculate_kv_scales = False
78
79
        if num_kv_heads is None:
            num_kv_heads = num_heads
80

81
        # The default k/v_scale is set to 1.0. This is ignored
82
83
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
84
        # expect the pre-quantized k/v_scale to be loaded along
85
86
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
87
88
89
        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
90
91
92
        # FlashAttn doesn't support quantizing the kv-cache only
        # but requires q to be quantized as well.
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
93
94
95
96
97
98

        # We also keep the float32 versions of k/v_scale for attention
        # backends that don't support tensors (Flashinfer)
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

99
        self.use_mla = use_mla
100
101
102
103
104
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window

105
        quant_method = quant_config.get_quant_method(
106
            self, prefix=prefix) if quant_config else None
107
108
        if quant_method is not None and not isinstance(
                quant_method, UnquantizedLinearMethod):
109
            assert isinstance(quant_method, BaseKVCacheMethod)
110
111
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
112
113
114
115
116
117
118
119
120
            if self.kv_cache_dtype == "fp8_e5m2":
                raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                 "fp8 checkpoints.")
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
121

122
123
124
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
125
126
127
128
129
130
131
        attn_backend = get_attn_backend(head_size,
                                        dtype,
                                        kv_cache_dtype,
                                        block_size,
                                        is_attention_free,
                                        blocksparse_params is not None,
                                        use_mla=use_mla)
132
        impl_cls = attn_backend.get_impl_cls()
133
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
134
                             alibi_slopes, sliding_window, kv_cache_dtype,
135
136
                             blocksparse_params, logits_soft_cap, attn_type,
                             **extra_impl_args)
137
        self.backend = backend_name_to_enum(attn_backend.get_name())
138
        self.dtype = dtype
139

140
141
142
143
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
144
145
146
        self.use_direct_call = not current_platform.is_cuda_alike(
        ) and not current_platform.is_cpu()

147
        self.use_output = attn_backend.accept_output_buffer
148
149
150
151
152
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
153
        self.attn_type = attn_type
154
155
156
157
158
159
160
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([]) for _ in range(get_current_vllm_config(
            ).parallel_config.pipeline_parallel_size)
        ]
161

162
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
163
164
165
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

166
167
168
169
170
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
171
172
173
174
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: Optional[torch.Size] = None,
175
    ) -> torch.Tensor:
176
177
178
179
180
181
182
183
184
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
185
        if self.calculate_kv_scales:
186
187
            attn_metadata = get_forward_context().attn_metadata
            if attn_metadata.enable_kv_scales_calculation:
188
                self.calc_kv_scales(query, key, value)
189
        if self.use_output:
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
            output_shape = (output_shape
                            if output_shape is not None else query.shape)
            output = torch.empty(output_shape,
                                 dtype=query.dtype,
                                 device=query.device)
            hidden_size = output_shape[-1]
            # We skip reshaping query, key and value tensors for the MLA
            # backend since these tensors have different semantics and are
            # processed differently.
            if not self.use_mla:
                # Reshape the query, key, and value tensors.
                # NOTE(woosuk): We do this outside the custom op to minimize the
                # CPU overheads from the non-CUDA-graph regions.
                query = query.view(-1, self.num_heads, self.head_size)
                output = output.view(-1, self.num_heads, self.head_size)
                if key is not None:
                    key = key.view(-1, self.num_kv_heads, self.head_size)
                if value is not None:
                    value = value.view(-1, self.num_kv_heads, self.head_size)
209
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
210
                forward_context: ForwardContext = get_forward_context()
211
                attn_metadata = forward_context.attn_metadata
Chen Zhang's avatar
Chen Zhang committed
212
213
214
215
216
217
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                self.impl.forward(self,
                                  query,
                                  key,
                                  value,
                                  self_kv_cache,
218
                                  attn_metadata,
Chen Zhang's avatar
Chen Zhang committed
219
                                  output=output)
220
221
222
            else:
                torch.ops.vllm.unified_attention_with_output(
                    query, key, value, output, self.layer_name)
223
            return output.view(-1, hidden_size)
224
        else:
225
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
226
                forward_context = get_forward_context()
227
                attn_metadata = forward_context.attn_metadata
Chen Zhang's avatar
Chen Zhang committed
228
229
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                return self.impl.forward(self, query, key, value,
230
                                         self_kv_cache, attn_metadata)
231
232
233
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name)
234

235
236
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
237
238
239
240
241
242
243
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

244
245
246
247
248
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
249
        s += f", backend={self.impl.__class__.__name__}"
250
        return s
251

252
    def process_weights_after_loading(self, act_dtype: torch.dtype):
253
        if hasattr(self.impl, "process_weights_after_loading"):
254
            self.impl.process_weights_after_loading(act_dtype)
255

256

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

273
274
275
        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

276
277
278
279
280
281
        dtype = torch.get_default_dtype()
        attn_backend = get_attn_backend(head_size,
                                        dtype,
                                        kv_cache_dtype=None,
                                        block_size=16,
                                        is_attention_free=False)
282
        backend = backend_name_to_enum(attn_backend.get_name())
283
284
        if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
            backend = _Backend.XFORMERS
285

286
        self.attn_backend = backend if backend in {
287
            _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
288
289
290
291
292
293
294
295
296
        } else _Backend.TORCH_SDPA

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: batch_size x seq_len x hidden_size"""
297
        # TODO(Isotr0py): Use existing backend implementations and support FA3
298
299
300
301
302
303
304
        bsz, q_len, _ = query.size()
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

305
306
307
308
309
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

310
        if self.attn_backend == _Backend.XFORMERS:
311
312
313
314
315
316
317
318
319
320
321
322
323
324
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query,
                                                          key,
                                                          value,
                                                          scale=self.scale)
        elif self.attn_backend == _Backend.TORCH_SDPA:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            out = F.scaled_dot_product_attention(query,
                                                 key,
                                                 value,
                                                 scale=self.scale)
            out = out.transpose(1, 2)
325
326
327
328
329
330
331
        elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            from torch_xla.experimental.custom_kernel import flash_attention
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)

332
        return out.reshape(bsz, q_len, -1)
333
334


335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return

    connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
    layer_name: str,
    kv_cache_layer: List[torch.Tensor],
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return

    connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)


366
367
368
369
370
371
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
372
373
    wait_for_kv_layer_from_connector(layer_name)

374
    forward_context: ForwardContext = get_forward_context()
375
    attn_metadata = forward_context.attn_metadata
376
    self = forward_context.no_compile_layers[layer_name]
377
    kv_cache = self.kv_cache[forward_context.virtual_engine]
378
379
380
381
382
    output = self.impl.forward(self, query, key, value, kv_cache,
                               attn_metadata)

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    return output
383
384
385
386
387
388
389
390
391
392
393
394
395
396


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
397
    mutates_args=[],
398
399
400
    fake_impl=unified_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)
401
402
403
404
405
406
407
408
409


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
) -> None:
410
    wait_for_kv_layer_from_connector(layer_name)
411
    forward_context: ForwardContext = get_forward_context()
412
    attn_metadata = forward_context.attn_metadata
413
    self = forward_context.no_compile_layers[layer_name]
414
    kv_cache = self.kv_cache[forward_context.virtual_engine]
415
416
    self.impl.forward(self,
                      query,
417
418
419
420
421
422
                      key,
                      value,
                      kv_cache,
                      attn_metadata,
                      output=output)

423
424
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)

425
426
427
428
429
430
431
432
433
434
435
436
437
438

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
439
    mutates_args=["output"],
440
441
442
    fake_impl=unified_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)