layer.py 20 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4
from typing import Any, Dict, List, Optional
5
6
7

import torch
import torch.nn as nn
8
import torch.nn.functional as F
9

10
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_maybe_save_kv_layer_to_connector
11
import vllm.envs as envs
12
from vllm.attention import AttentionType
13
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
14
from vllm.config import CacheConfig, get_current_vllm_config
15
16
17
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group,
                                          is_v1_kv_transfer_group)
18
from vllm.forward_context import ForwardContext, get_forward_context
19
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
20
21
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
22
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
23
from vllm.platforms import _Backend, current_platform
24
from vllm.utils import direct_register_custom_op
25
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

class Attention(nn.Module):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
47
        cache_config: Optional[CacheConfig] = None,
48
        quant_config: Optional[QuantizationConfig] = None,
49
        blocksparse_params: Optional[Dict[str, Any]] = None,
50
        logits_soft_cap: Optional[float] = None,
51
        per_layer_sliding_window: Optional[int] = None,
52
        use_mla: bool = False,
53
        prefix: str = "",
54
        attn_type: str = AttentionType.DECODER,
55
        kv_sharing_target_layer_name: Optional[str] = None,
56
        **extra_impl_args,
57
    ) -> None:
58
59
60
61
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
62
        super().__init__()
63
64
65
66
67
68
69
70
71
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

72
73
74
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
75
            is_attention_free = cache_config.is_attention_free
76
            calculate_kv_scales = cache_config.calculate_kv_scales
77
78
        else:
            kv_cache_dtype = "auto"
79
            block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA and envs.VLLM_USE_FLASH_MLA else 16
80
            is_attention_free = False
81
            calculate_kv_scales = False
82
83
        if num_kv_heads is None:
            num_kv_heads = num_heads
84
85
86
        assert num_heads % num_kv_heads == 0, \
            f"num_heads ({num_heads}) is not " \
            f"divisible by num_kv_heads ({num_kv_heads})"
87

88
        # The default k/v_scale is set to 1.0. This is ignored
89
90
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
91
        # expect the pre-quantized k/v_scale to be loaded along
92
93
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
94
95
96
        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
97
98
99
        # FlashAttn doesn't support quantizing the kv-cache only
        # but requires q to be quantized as well.
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
100
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
101
102
103
104
105
106

        # We also keep the float32 versions of k/v_scale for attention
        # backends that don't support tensors (Flashinfer)
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

107
        self.use_mla = use_mla
108
109
110
111
112
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window

113
        quant_method = quant_config.get_quant_method(
114
            self, prefix=prefix) if quant_config else None
115
116
        if quant_method is not None and not isinstance(
                quant_method, UnquantizedLinearMethod):
117
            assert isinstance(quant_method, BaseKVCacheMethod)
118
119
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
120
121
122
123
124
125
126
127
128
            if self.kv_cache_dtype == "fp8_e5m2":
                raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                 "fp8 checkpoints.")
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
129

130
131
132
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
133
134
135
136
137
138
139
        attn_backend = get_attn_backend(head_size,
                                        dtype,
                                        kv_cache_dtype,
                                        block_size,
                                        is_attention_free,
                                        blocksparse_params is not None,
                                        use_mla=use_mla)
140
        impl_cls = attn_backend.get_impl_cls()
141
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
142
                             alibi_slopes, sliding_window, kv_cache_dtype,
143
                             blocksparse_params, logits_soft_cap, attn_type,
144
                             kv_sharing_target_layer_name, **extra_impl_args)
145
        self.backend = backend_name_to_enum(attn_backend.get_name())
146
        self.dtype = dtype
147

148
149
150
151
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
152
153
154
        self.use_direct_call = not current_platform.is_cuda_alike(
        ) and not current_platform.is_cpu()

155
        self.use_output = attn_backend.accept_output_buffer
156
157
158
159
160
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
161
        self.attn_type = attn_type
162
163
164
165
166
167
168
169
170
171
172
173
174

        if kv_sharing_target_layer_name is not None:
            if not envs.VLLM_USE_V1:
                raise NotImplementedError(
                    "Cross-layer KV sharing is not supported in V0.")

            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

175
176
177
178
179
180
181
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([]) for _ in range(get_current_vllm_config(
            ).parallel_config.pipeline_parallel_size)
        ]
182

183
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
184
185
186
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

187
188
189
190
191
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
192
193
194
195
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: Optional[torch.Size] = None,
196
    ) -> torch.Tensor:
197
198
199
200
201
202
203
204
205
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
206
        if self.calculate_kv_scales:
207
            attn_metadata = get_forward_context().attn_metadata
yangql's avatar
yangql committed
208
209
            if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
            # if key is not None and value is not None:         
210
                self.calc_kv_scales(query, key, value)
211
        if self.use_output:
212
213
            output_shape = (output_shape
                            if output_shape is not None else query.shape)
214
            output = torch.zeros(output_shape,
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
                                 dtype=query.dtype,
                                 device=query.device)
            hidden_size = output_shape[-1]
            # We skip reshaping query, key and value tensors for the MLA
            # backend since these tensors have different semantics and are
            # processed differently.
            if not self.use_mla:
                # Reshape the query, key, and value tensors.
                # NOTE(woosuk): We do this outside the custom op to minimize the
                # CPU overheads from the non-CUDA-graph regions.
                query = query.view(-1, self.num_heads, self.head_size)
                output = output.view(-1, self.num_heads, self.head_size)
                if key is not None:
                    key = key.view(-1, self.num_kv_heads, self.head_size)
                if value is not None:
                    value = value.view(-1, self.num_kv_heads, self.head_size)
231
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
232
                forward_context: ForwardContext = get_forward_context()
233
                attn_metadata = forward_context.attn_metadata
234
235
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
236
237
238
239
240
241
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                self.impl.forward(self,
                                  query,
                                  key,
                                  value,
                                  self_kv_cache,
242
                                  attn_metadata,
Chen Zhang's avatar
Chen Zhang committed
243
                                  output=output)
244
245
246
            else:
                torch.ops.vllm.unified_attention_with_output(
                    query, key, value, output, self.layer_name)
247
            return output.view(-1, hidden_size)
248
        else:
249
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
250
                forward_context = get_forward_context()
251
                attn_metadata = forward_context.attn_metadata
252
253
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
254
255
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                return self.impl.forward(self, query, key, value,
256
                                         self_kv_cache, attn_metadata)
257
258
259
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name)
260

261
262
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
263
264
265
266
267
268
269
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

270
271
272
273
274
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
275
        s += f", backend={self.impl.__class__.__name__}"
276
        return s
277

278
    def process_weights_after_loading(self, act_dtype: torch.dtype):
279
        if hasattr(self.impl, "process_weights_after_loading"):
280
            self.impl.process_weights_after_loading(act_dtype)
281

282

283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

299
300
301
        assert self.num_heads % self.num_kv_heads == 0, \
            f"num_heads ({self.num_heads}) is not " \
            f"divisible by num_kv_heads ({self.num_kv_heads})"
302
303
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

304
305
306
307
        dtype = torch.get_default_dtype()
        attn_backend = get_attn_backend(head_size,
                                        dtype,
                                        kv_cache_dtype=None,
308
                                        block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA and envs.VLLM_USE_FLASH_MLA else 16,
309
                                        is_attention_free=False)
310
        backend = backend_name_to_enum(attn_backend.get_name())
311
312
313
314
        if current_platform.is_rocm():
            # currently, only torch_sdpa is supported on rocm
            self.attn_backend = _Backend.TORCH_SDPA
        else:
315
316
            if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
                           _Backend.FLEX_ATTENTION):
317
                backend = _Backend.XFORMERS
318

319
320
321
            self.attn_backend = backend if backend in {
                _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
            } else _Backend.TORCH_SDPA
322
323
324
325
326
327
328
329

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: batch_size x seq_len x hidden_size"""
330
        # TODO(Isotr0py): Use existing backend implementations and support FA3
331
332
333
334
335
336
337
        bsz, q_len, _ = query.size()
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

338
339
340
341
342
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

343
        if self.attn_backend == _Backend.XFORMERS:
344
345
346
347
348
349
350
351
352
353
354
355
356
357
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query,
                                                          key,
                                                          value,
                                                          scale=self.scale)
        elif self.attn_backend == _Backend.TORCH_SDPA:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            out = F.scaled_dot_product_attention(query,
                                                 key,
                                                 value,
                                                 scale=self.scale)
            out = out.transpose(1, 2)
358
359
360
361
362
363
364
        elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            from torch_xla.experimental.custom_kernel import flash_attention
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)

365
        return out.reshape(bsz, q_len, -1)
366
367


368
369
370
371
372
373
374
375
376
377
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
378
    assert isinstance(attn_metadata, dict)
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
    connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
    layer_name: str,
    kv_cache_layer: List[torch.Tensor],
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
395
396
397
    assert isinstance(attn_metadata, dict)
    connector.save_kv_layer(layer_name, kv_cache_layer,
                            attn_metadata[layer_name])
398

399

400
401
402
403
404
405
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
406
407
    wait_for_kv_layer_from_connector(layer_name)

408
    forward_context: ForwardContext = get_forward_context()
409
    attn_metadata = forward_context.attn_metadata
410
411
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
412
    self = forward_context.no_compile_layers[layer_name]
413
    kv_cache = self.kv_cache[forward_context.virtual_engine]
414
415
416
    output = self.impl.forward(self, query, key, value, kv_cache,
                               attn_metadata)

maxiao1's avatar
maxiao1 committed
417
418
419
420
    if envs.VLLM_ENABLE_TBO:
        tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    else:
        maybe_save_kv_layer_to_connector(layer_name, kv_cache)
421
    return output
422
423
424
425
426
427
428
429
430
431
432
433
434
435


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
436
    mutates_args=[],
437
438
439
    fake_impl=unified_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)
440
441
442
443
444
445
446
447


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
448
    output_scale: Optional[torch.Tensor] = None,
449
) -> None:
450
    wait_for_kv_layer_from_connector(layer_name)
451
    forward_context: ForwardContext = get_forward_context()
452
    attn_metadata = forward_context.attn_metadata
453
454
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
455
    self = forward_context.no_compile_layers[layer_name]
456
    kv_cache = self.kv_cache[forward_context.virtual_engine]
457
458
    self.impl.forward(self,
                      query,
459
460
461
462
                      key,
                      value,
                      kv_cache,
                      attn_metadata,
463
464
                      output=output,
                      output_scale=output_scale)
maxiao1's avatar
maxiao1 committed
465
466
467
468
    if envs.VLLM_ENABLE_TBO:
        tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    else:
        maybe_save_kv_layer_to_connector(layer_name, kv_cache)
469

470
471
472
473
474
475
476

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
477
    output_scale: Optional[torch.Tensor] = None,
478
479
480
481
482
483
484
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
485
    mutates_args=["output"],
486
487
488
    fake_impl=unified_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)