layer.py 22.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4
from typing import List, Optional
5
6
7

import torch
import torch.nn as nn
8
import torch.nn.functional as F
9

10
import vllm.envs as envs
11
from vllm.attention import AttentionType
12
from vllm.attention.backends.abstract import AttentionBackend
13
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
14
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
15
from vllm.config import CacheConfig, get_current_vllm_config
16
17
18
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group,
                                          is_v1_kv_transfer_group)
19
from vllm.forward_context import ForwardContext, get_forward_context
20
from vllm.logger import init_logger
21
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
22
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
23
24
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
25
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
26
from vllm.platforms import _Backend, current_platform
27
from vllm.utils import direct_register_custom_op
28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
logger = init_logger(__name__)
USE_XFORMERS_OPS = None


def check_xformers_availability():
    global USE_XFORMERS_OPS
    if USE_XFORMERS_OPS is not None:
        return USE_XFORMERS_OPS

    if current_platform.is_cuda() and current_platform.has_device_capability(
            100):
        # Xformers FA is not compatible with B200
        USE_XFORMERS_OPS = False
    else:
        try:
            from importlib.util import find_spec

            find_spec("xformers.ops")
            USE_XFORMERS_OPS = True
        except ImportError:
            USE_XFORMERS_OPS = False

    # the warning only needs to be shown once
    if not USE_XFORMERS_OPS:
        logger.warning("Xformers is not available, falling back.")

    return USE_XFORMERS_OPS

57

58
class Attention(nn.Module, AttentionLayerBase):
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
77
        cache_config: Optional[CacheConfig] = None,
78
        quant_config: Optional[QuantizationConfig] = None,
79
        logits_soft_cap: Optional[float] = None,
80
        per_layer_sliding_window: Optional[int] = None,
81
        use_mla: bool = False,
82
        prefix: str = "",
83
        attn_type: str = AttentionType.DECODER,
84
        kv_sharing_target_layer_name: Optional[str] = None,
85
        attn_backend: Optional[type[AttentionBackend]] = None,
86
        **extra_impl_args,
87
    ) -> None:
88
89
90
91
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
92
        super().__init__()
93
94
95
96
97
98
99
100
101
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

102
103
104
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
105
            is_attention_free = cache_config.is_attention_free
106
            calculate_kv_scales = cache_config.calculate_kv_scales
107
108
109
        else:
            kv_cache_dtype = "auto"
            block_size = 16
110
            is_attention_free = False
111
            calculate_kv_scales = False
112
113
        if num_kv_heads is None:
            num_kv_heads = num_heads
114
115
116
        assert num_heads % num_kv_heads == 0, \
            f"num_heads ({num_heads}) is not " \
            f"divisible by num_kv_heads ({num_kv_heads})"
117

118
        # The default k/v_scale is set to 1.0. This is ignored
119
120
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
121
        # expect the pre-quantized k/v_scale to be loaded along
122
123
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
124
125
126
        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
127
128
129
        # FlashAttn doesn't support quantizing the kv-cache only
        # but requires q to be quantized as well.
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
130
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
131

132
133
134
135
        # We also keep q/k/v_scale on host (cpu) memory for attention
        # backends that require the scales to be on host instead of on device.
        # e.g. Flashinfer
        self._q_scale_float = 1.0
136
137
138
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

139
140
141
142
        # The output scale on host memory. This should be the input scale of
        # the quant op after this attention layer.
        self._o_scale_float: Optional[float] = None

143
        self.use_mla = use_mla
144
145
146
147
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
148
        self.has_sink = extra_impl_args.get("sinks") is not None
149

150
        quant_method = quant_config.get_quant_method(
151
            self, prefix=prefix) if quant_config else None
152
153
        if quant_method is not None and not isinstance(
                quant_method, UnquantizedLinearMethod):
154
            assert isinstance(quant_method, BaseKVCacheMethod)
155
156
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
157
158
159
160
161
162
163
164
165
            if self.kv_cache_dtype == "fp8_e5m2":
                raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                 "fp8 checkpoints.")
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
166

167
168
169
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
170
171
172
173
174
175
        if attn_backend is None:
            self.attn_backend = get_attn_backend(head_size,
                                                 dtype,
                                                 kv_cache_dtype,
                                                 block_size,
                                                 is_attention_free,
176
177
                                                 use_mla=use_mla,
                                                 has_sink=self.has_sink)
178
179
180
181
        else:
            self.attn_backend = attn_backend

        impl_cls = self.attn_backend.get_impl_cls()
182
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
183
                             alibi_slopes, sliding_window, kv_cache_dtype,
184
                             logits_soft_cap, attn_type,
185
                             kv_sharing_target_layer_name, **extra_impl_args)
186
        self.backend = backend_name_to_enum(self.attn_backend.get_name())
187
        self.dtype = dtype
188

189
190
191
192
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
193
        self.use_direct_call = not current_platform.opaque_attention_op()
194

195
        self.use_output = self.attn_backend.accept_output_buffer
196
197
198
199
200
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
201
        self.attn_type = attn_type
202
203
204
205
206
207
208
209
210

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

211
212
213
214
215
216
217
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([]) for _ in range(get_current_vllm_config(
            ).parallel_config.pipeline_parallel_size)
        ]
218

219
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
220
221
222
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

223
224
225
226
227
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
228
229
230
231
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: Optional[torch.Size] = None,
232
    ) -> torch.Tensor:
233
234
235
236
237
238
239
240
241
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
242
        if self.calculate_kv_scales:
243
244
            attn_metadata = get_forward_context().attn_metadata
            if attn_metadata.enable_kv_scales_calculation:
245
                self.calc_kv_scales(query, key, value)
246
        if self.use_output:
247
248
            output_shape = (output_shape
                            if output_shape is not None else query.shape)
249
            output = torch.zeros(output_shape,
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
                                 dtype=query.dtype,
                                 device=query.device)
            hidden_size = output_shape[-1]
            # We skip reshaping query, key and value tensors for the MLA
            # backend since these tensors have different semantics and are
            # processed differently.
            if not self.use_mla:
                # Reshape the query, key, and value tensors.
                # NOTE(woosuk): We do this outside the custom op to minimize the
                # CPU overheads from the non-CUDA-graph regions.
                query = query.view(-1, self.num_heads, self.head_size)
                output = output.view(-1, self.num_heads, self.head_size)
                if key is not None:
                    key = key.view(-1, self.num_kv_heads, self.head_size)
                if value is not None:
                    value = value.view(-1, self.num_kv_heads, self.head_size)
266
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
267
                forward_context: ForwardContext = get_forward_context()
268
                attn_metadata = forward_context.attn_metadata
269
270
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
271
272
273
274
275
276
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                self.impl.forward(self,
                                  query,
                                  key,
                                  value,
                                  self_kv_cache,
277
                                  attn_metadata,
Chen Zhang's avatar
Chen Zhang committed
278
                                  output=output)
279
280
281
            else:
                torch.ops.vllm.unified_attention_with_output(
                    query, key, value, output, self.layer_name)
282
            return output.view(-1, hidden_size)
283
        else:
284
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
285
                forward_context = get_forward_context()
286
                attn_metadata = forward_context.attn_metadata
287
288
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
289
290
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                return self.impl.forward(self, query, key, value,
291
                                         self_kv_cache, attn_metadata)
292
293
294
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name)
295

296
297
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
298
299
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
300
        self._q_scale_float = self._q_scale.item()
301
302
303
304
305
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

306
307
308
309
310
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
311
        s += f", backend={self.impl.__class__.__name__}"
312
        return s
313

314
    def process_weights_after_loading(self, act_dtype: torch.dtype):
315
        if hasattr(self.impl, "process_weights_after_loading"):
316
            self.impl.process_weights_after_loading(act_dtype)
317

318
319
320
321
322
323
324
325
326
        # FlashInfer requires attention sinks to be float32
        if (self.backend == _Backend.FLASHINFER_VLLM_V1
                and hasattr(self.impl, 'sinks')):
            from vllm.v1.attention.backends.flashinfer import FlashInferImpl
            assert isinstance(self.impl, FlashInferImpl)
            if (self.impl.sinks is not None
                    and self.impl.sinks.dtype != torch.float32):
                self.impl.sinks = self.impl.sinks.to(torch.float32)

327
328
329
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

330

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

347
348
349
        assert self.num_heads % self.num_kv_heads == 0, \
            f"num_heads ({self.num_heads}) is not " \
            f"divisible by num_kv_heads ({self.num_kv_heads})"
350
351
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

352
353
354
355
356
357
        dtype = torch.get_default_dtype()
        attn_backend = get_attn_backend(head_size,
                                        dtype,
                                        kv_cache_dtype=None,
                                        block_size=16,
                                        is_attention_free=False)
358
        backend = backend_name_to_enum(attn_backend.get_name())
359
360
361
362
363
        if current_platform.is_rocm():
            # currently, only torch_sdpa is supported on rocm
            self.attn_backend = _Backend.TORCH_SDPA
        else:
            self.attn_backend = backend if backend in {
364
365
366
367
368
369
                _Backend.TORCH_SDPA,
                _Backend.TORCH_SDPA_VLLM_V1,
                _Backend.XFORMERS,
                _Backend.PALLAS_VLLM_V1,
                _Backend.ROCM_AITER_FA,
            } else current_platform.get_vit_attn_backend()
370

371
372
373
374
        if (self.attn_backend == _Backend.XFORMERS
                and not check_xformers_availability()):
            self.attn_backend = _Backend.TORCH_SDPA

375
376
377
378
379
380
381
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: batch_size x seq_len x hidden_size"""
382
        # TODO(Isotr0py): Use existing backend implementations and support FA3
383
384
385
386
387
388
389
        bsz, q_len, _ = query.size()
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

390
391
392
393
394
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

395
        if self.attn_backend == _Backend.XFORMERS:
396
397
398
399
400
401
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query,
                                                          key,
                                                          value,
                                                          scale=self.scale)
402
403
        elif (self.attn_backend == _Backend.TORCH_SDPA
              or self.attn_backend == _Backend.TORCH_SDPA_VLLM_V1):
404
405
406
407
408
409
410
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            out = F.scaled_dot_product_attention(query,
                                                 key,
                                                 value,
                                                 scale=self.scale)
            out = out.transpose(1, 2)
411
412
413
414
415
416
        elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            from torch_xla.experimental.custom_kernel import flash_attention
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)
417
418
419
420
421
422
423
424
425
426
427
428
429
        elif self.attn_backend == _Backend.ROCM_AITER_FA:
            from aiter import flash_attn_varlen_func

            # ROCm Flash Attention expects (batch, seq, heads, head_dim)
            out = flash_attn_varlen_func(query,
                                         key,
                                         value,
                                         softmax_scale=self.scale)
        else:
            # ViT attention hasn't supported this backend yet
            raise NotImplementedError(
                f"ViT attention hasn't supported {self.attn_backend} "
                f"backend yet.")
430

431
        return out.reshape(bsz, q_len, -1)
432
433


434
435
436
437
438
439
440
441
442
443
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
444
    assert isinstance(attn_metadata, dict)
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
    connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
    layer_name: str,
    kv_cache_layer: List[torch.Tensor],
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
461
462
463
    assert isinstance(attn_metadata, dict)
    connector.save_kv_layer(layer_name, kv_cache_layer,
                            attn_metadata[layer_name])
464
465


466
467
468
469
470
471
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
472
473
    wait_for_kv_layer_from_connector(layer_name)

474
    forward_context: ForwardContext = get_forward_context()
475
    attn_metadata = forward_context.attn_metadata
476
477
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
478
    self = forward_context.no_compile_layers[layer_name]
479
    kv_cache = self.kv_cache[forward_context.virtual_engine]
480
481
482
483
484
    output = self.impl.forward(self, query, key, value, kv_cache,
                               attn_metadata)

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    return output
485
486
487
488
489
490
491
492
493
494
495
496
497
498


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
499
    mutates_args=[],
500
501
502
    fake_impl=unified_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)
503
504
505
506
507
508
509
510


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
511
    output_scale: Optional[torch.Tensor] = None,
512
    output_block_scale: Optional[torch.Tensor] = None,
513
) -> None:
514
    wait_for_kv_layer_from_connector(layer_name)
515
    forward_context: ForwardContext = get_forward_context()
516
    attn_metadata = forward_context.attn_metadata
517
518
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
519
    self = forward_context.no_compile_layers[layer_name]
520
    kv_cache = self.kv_cache[forward_context.virtual_engine]
521
522
    self.impl.forward(self,
                      query,
523
524
525
526
                      key,
                      value,
                      kv_cache,
                      attn_metadata,
527
                      output=output,
528
529
                      output_scale=output_scale,
                      output_block_scale=output_block_scale)
530

531
532
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)

533
534
535
536
537
538
539

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
540
    output_scale: Optional[torch.Tensor] = None,
541
    output_block_scale: Optional[torch.Tensor] = None,
542
543
544
545
546
547
548
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
549
    mutates_args=["output", "output_block_scale"],
550
551
552
    fake_impl=unified_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)