layer.py 21.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4
from typing import List, Optional
5
6
7

import torch
import torch.nn as nn
8
import torch.nn.functional as F
9

10
import vllm.envs as envs
11
from vllm.attention import AttentionType
12
from vllm.attention.backends.abstract import AttentionBackend
13
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
14
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
15
from vllm.config import CacheConfig, get_current_vllm_config
16
17
18
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group,
                                          is_v1_kv_transfer_group)
19
from vllm.forward_context import ForwardContext, get_forward_context
20
from vllm.logger import init_logger
21
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
22
23
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
24
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
25
from vllm.platforms import _Backend, current_platform
26
from vllm.utils import direct_register_custom_op
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
logger = init_logger(__name__)
USE_XFORMERS_OPS = None


def check_xformers_availability():
    global USE_XFORMERS_OPS
    if USE_XFORMERS_OPS is not None:
        return USE_XFORMERS_OPS

    if current_platform.is_cuda() and current_platform.has_device_capability(
            100):
        # Xformers FA is not compatible with B200
        USE_XFORMERS_OPS = False
    else:
        try:
            from importlib.util import find_spec

            find_spec("xformers.ops")
            USE_XFORMERS_OPS = True
        except ImportError:
            USE_XFORMERS_OPS = False

    # the warning only needs to be shown once
    if not USE_XFORMERS_OPS:
        logger.warning("Xformers is not available, falling back.")

    return USE_XFORMERS_OPS

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

class Attention(nn.Module):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
76
        cache_config: Optional[CacheConfig] = None,
77
        quant_config: Optional[QuantizationConfig] = None,
78
        logits_soft_cap: Optional[float] = None,
79
        per_layer_sliding_window: Optional[int] = None,
80
        use_mla: bool = False,
81
        prefix: str = "",
82
        attn_type: str = AttentionType.DECODER,
83
        kv_sharing_target_layer_name: Optional[str] = None,
84
        attn_backend: Optional[type[AttentionBackend]] = None,
85
        **extra_impl_args,
86
    ) -> None:
87
88
89
90
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
91
        super().__init__()
92
93
94
95
96
97
98
99
100
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

101
102
103
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
104
            is_attention_free = cache_config.is_attention_free
105
            calculate_kv_scales = cache_config.calculate_kv_scales
106
107
108
        else:
            kv_cache_dtype = "auto"
            block_size = 16
109
            is_attention_free = False
110
            calculate_kv_scales = False
111
112
        if num_kv_heads is None:
            num_kv_heads = num_heads
113
114
115
        assert num_heads % num_kv_heads == 0, \
            f"num_heads ({num_heads}) is not " \
            f"divisible by num_kv_heads ({num_kv_heads})"
116

117
        # The default k/v_scale is set to 1.0. This is ignored
118
119
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
120
        # expect the pre-quantized k/v_scale to be loaded along
121
122
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
123
124
125
        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
126
127
128
        # FlashAttn doesn't support quantizing the kv-cache only
        # but requires q to be quantized as well.
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
129
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
130
131
132
133
134
135

        # We also keep the float32 versions of k/v_scale for attention
        # backends that don't support tensors (Flashinfer)
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

136
        self.use_mla = use_mla
137
138
139
140
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
141
        self.has_sink = extra_impl_args.get("sinks") is not None
142

143
        quant_method = quant_config.get_quant_method(
144
            self, prefix=prefix) if quant_config else None
145
146
        if quant_method is not None and not isinstance(
                quant_method, UnquantizedLinearMethod):
147
            assert isinstance(quant_method, BaseKVCacheMethod)
148
149
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
150
151
152
153
154
155
156
157
158
            if self.kv_cache_dtype == "fp8_e5m2":
                raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                 "fp8 checkpoints.")
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
159

160
161
162
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
163
164
165
166
167
168
        if attn_backend is None:
            self.attn_backend = get_attn_backend(head_size,
                                                 dtype,
                                                 kv_cache_dtype,
                                                 block_size,
                                                 is_attention_free,
169
170
                                                 use_mla=use_mla,
                                                 has_sink=self.has_sink)
171
172
173
174
        else:
            self.attn_backend = attn_backend

        impl_cls = self.attn_backend.get_impl_cls()
175
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
176
                             alibi_slopes, sliding_window, kv_cache_dtype,
177
                             logits_soft_cap, attn_type,
178
                             kv_sharing_target_layer_name, **extra_impl_args)
179
        self.backend = backend_name_to_enum(self.attn_backend.get_name())
180
        self.dtype = dtype
181

182
183
184
185
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
186
187
188
        self.use_direct_call = not current_platform.is_cuda_alike(
        ) and not current_platform.is_cpu()

189
        self.use_output = self.attn_backend.accept_output_buffer
190
191
192
193
194
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
195
        self.attn_type = attn_type
196
197
198
199
200
201
202
203
204

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

205
206
207
208
209
210
211
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([]) for _ in range(get_current_vllm_config(
            ).parallel_config.pipeline_parallel_size)
        ]
212

213
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
214
215
216
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

217
218
219
220
221
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
222
223
224
225
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: Optional[torch.Size] = None,
226
    ) -> torch.Tensor:
227
228
229
230
231
232
233
234
235
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
236
        if self.calculate_kv_scales:
237
238
            attn_metadata = get_forward_context().attn_metadata
            if attn_metadata.enable_kv_scales_calculation:
239
                self.calc_kv_scales(query, key, value)
240
        if self.use_output:
241
242
            output_shape = (output_shape
                            if output_shape is not None else query.shape)
243
            output = torch.zeros(output_shape,
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
                                 dtype=query.dtype,
                                 device=query.device)
            hidden_size = output_shape[-1]
            # We skip reshaping query, key and value tensors for the MLA
            # backend since these tensors have different semantics and are
            # processed differently.
            if not self.use_mla:
                # Reshape the query, key, and value tensors.
                # NOTE(woosuk): We do this outside the custom op to minimize the
                # CPU overheads from the non-CUDA-graph regions.
                query = query.view(-1, self.num_heads, self.head_size)
                output = output.view(-1, self.num_heads, self.head_size)
                if key is not None:
                    key = key.view(-1, self.num_kv_heads, self.head_size)
                if value is not None:
                    value = value.view(-1, self.num_kv_heads, self.head_size)
260
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
261
                forward_context: ForwardContext = get_forward_context()
262
                attn_metadata = forward_context.attn_metadata
263
264
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
265
266
267
268
269
270
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                self.impl.forward(self,
                                  query,
                                  key,
                                  value,
                                  self_kv_cache,
271
                                  attn_metadata,
Chen Zhang's avatar
Chen Zhang committed
272
                                  output=output)
273
274
275
            else:
                torch.ops.vllm.unified_attention_with_output(
                    query, key, value, output, self.layer_name)
276
            return output.view(-1, hidden_size)
277
        else:
278
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
279
                forward_context = get_forward_context()
280
                attn_metadata = forward_context.attn_metadata
281
282
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
283
284
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                return self.impl.forward(self, query, key, value,
285
                                         self_kv_cache, attn_metadata)
286
287
288
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name)
289

290
291
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
292
293
294
295
296
297
298
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

299
300
301
302
303
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
304
        s += f", backend={self.impl.__class__.__name__}"
305
        return s
306

307
    def process_weights_after_loading(self, act_dtype: torch.dtype):
308
        if hasattr(self.impl, "process_weights_after_loading"):
309
            self.impl.process_weights_after_loading(act_dtype)
310

311
312
313
314
315
316
317
318
319
        # FlashInfer requires attention sinks to be float32
        if (self.backend == _Backend.FLASHINFER_VLLM_V1
                and hasattr(self.impl, 'sinks')):
            from vllm.v1.attention.backends.flashinfer import FlashInferImpl
            assert isinstance(self.impl, FlashInferImpl)
            if (self.impl.sinks is not None
                    and self.impl.sinks.dtype != torch.float32):
                self.impl.sinks = self.impl.sinks.to(torch.float32)

320
321
322
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

323

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

340
341
342
        assert self.num_heads % self.num_kv_heads == 0, \
            f"num_heads ({self.num_heads}) is not " \
            f"divisible by num_kv_heads ({self.num_kv_heads})"
343
344
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

345
346
347
348
349
350
        dtype = torch.get_default_dtype()
        attn_backend = get_attn_backend(head_size,
                                        dtype,
                                        kv_cache_dtype=None,
                                        block_size=16,
                                        is_attention_free=False)
351
        backend = backend_name_to_enum(attn_backend.get_name())
352
353
354
355
        if current_platform.is_rocm():
            # currently, only torch_sdpa is supported on rocm
            self.attn_backend = _Backend.TORCH_SDPA
        else:
356
357
            if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
                           _Backend.FLEX_ATTENTION):
358
                backend = _Backend.XFORMERS
359

360
361
362
            self.attn_backend = backend if backend in {
                _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
            } else _Backend.TORCH_SDPA
363

364
365
366
367
        if (self.attn_backend == _Backend.XFORMERS
                and not check_xformers_availability()):
            self.attn_backend = _Backend.TORCH_SDPA

368
369
370
371
372
373
374
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: batch_size x seq_len x hidden_size"""
375
        # TODO(Isotr0py): Use existing backend implementations and support FA3
376
377
378
379
380
381
382
        bsz, q_len, _ = query.size()
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

383
384
385
386
387
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

388
        if self.attn_backend == _Backend.XFORMERS:
389
390
391
392
393
394
395
396
397
398
399
400
401
402
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query,
                                                          key,
                                                          value,
                                                          scale=self.scale)
        elif self.attn_backend == _Backend.TORCH_SDPA:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            out = F.scaled_dot_product_attention(query,
                                                 key,
                                                 value,
                                                 scale=self.scale)
            out = out.transpose(1, 2)
403
404
405
406
407
408
409
        elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            from torch_xla.experimental.custom_kernel import flash_attention
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)

410
        return out.reshape(bsz, q_len, -1)
411
412


413
414
415
416
417
418
419
420
421
422
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
423
    assert isinstance(attn_metadata, dict)
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
    layer_name: str,
    kv_cache_layer: List[torch.Tensor],
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
440
441
442
    assert isinstance(attn_metadata, dict)
    connector.save_kv_layer(layer_name, kv_cache_layer,
                            attn_metadata[layer_name])
443
444


445
446
447
448
449
450
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
451
452
    wait_for_kv_layer_from_connector(layer_name)

453
    forward_context: ForwardContext = get_forward_context()
454
    attn_metadata = forward_context.attn_metadata
455
456
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
457
    self = forward_context.no_compile_layers[layer_name]
458
    kv_cache = self.kv_cache[forward_context.virtual_engine]
459
460
461
462
463
    output = self.impl.forward(self, query, key, value, kv_cache,
                               attn_metadata)

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    return output
464
465
466
467
468
469
470
471
472
473
474
475
476
477


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
478
    mutates_args=[],
479
480
481
    fake_impl=unified_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)
482
483
484
485
486
487
488
489


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
490
    output_scale: Optional[torch.Tensor] = None,
491
) -> None:
492
    wait_for_kv_layer_from_connector(layer_name)
493
    forward_context: ForwardContext = get_forward_context()
494
    attn_metadata = forward_context.attn_metadata
495
496
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
497
    self = forward_context.no_compile_layers[layer_name]
498
    kv_cache = self.kv_cache[forward_context.virtual_engine]
499
500
    self.impl.forward(self,
                      query,
501
502
503
504
                      key,
                      value,
                      kv_cache,
                      attn_metadata,
505
506
                      output=output,
                      output_scale=output_scale)
507

508
509
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)

510
511
512
513
514
515
516

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
517
    output_scale: Optional[torch.Tensor] = None,
518
519
520
521
522
523
524
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
525
    mutates_args=["output"],
526
527
528
    fake_impl=unified_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)