layer.py 20.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4
from typing import Any, Dict, List, Optional
5
6
7

import torch
import torch.nn as nn
8
import torch.nn.functional as F
9

10
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_maybe_save_kv_layer_to_connector
yangshj1's avatar
yangshj1 committed
11
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
12
import vllm.envs as envs
13
from vllm.attention import AttentionType
14
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
15
from vllm.config import CacheConfig, get_current_vllm_config
16
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
yangshj1's avatar
yangshj1 committed
17
                                          get_lmcache_connector,
18
19
                                          has_kv_transfer_group,
                                          is_v1_kv_transfer_group)
20
from vllm.forward_context import ForwardContext, get_forward_context
21
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
22
23
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
24
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
25
from vllm.platforms import _Backend, current_platform
26
from vllm.utils import direct_register_custom_op
27
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
yangshj1's avatar
yangshj1 committed
28
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import LMCacheConnectorV1
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

class Attention(nn.Module):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
49
        cache_config: Optional[CacheConfig] = None,
50
        quant_config: Optional[QuantizationConfig] = None,
51
        blocksparse_params: Optional[Dict[str, Any]] = None,
52
        logits_soft_cap: Optional[float] = None,
53
        per_layer_sliding_window: Optional[int] = None,
54
        use_mla: bool = False,
55
        prefix: str = "",
56
        attn_type: str = AttentionType.DECODER,
57
        kv_sharing_target_layer_name: Optional[str] = None,
58
        **extra_impl_args,
59
    ) -> None:
60
61
62
63
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
64
        super().__init__()
65
66
67
68
69
70
71
72
73
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

74
75
76
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
77
            is_attention_free = cache_config.is_attention_free
78
            calculate_kv_scales = cache_config.calculate_kv_scales
79
80
        else:
            kv_cache_dtype = "auto"
zhuwenwen's avatar
zhuwenwen committed
81
            block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16
82
            is_attention_free = False
83
            calculate_kv_scales = False
84
85
        if num_kv_heads is None:
            num_kv_heads = num_heads
86
87
88
        assert num_heads % num_kv_heads == 0, \
            f"num_heads ({num_heads}) is not " \
            f"divisible by num_kv_heads ({num_kv_heads})"
89

90
        # The default k/v_scale is set to 1.0. This is ignored
91
92
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
93
        # expect the pre-quantized k/v_scale to be loaded along
94
95
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
96
97
98
        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
99
100
101
        # FlashAttn doesn't support quantizing the kv-cache only
        # but requires q to be quantized as well.
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
102
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
103
104
105
106
107
108

        # We also keep the float32 versions of k/v_scale for attention
        # backends that don't support tensors (Flashinfer)
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

109
        self.use_mla = use_mla
110
111
112
113
114
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window

115
        quant_method = quant_config.get_quant_method(
116
            self, prefix=prefix) if quant_config else None
117
118
        if quant_method is not None and not isinstance(
                quant_method, UnquantizedLinearMethod):
119
            assert isinstance(quant_method, BaseKVCacheMethod)
120
121
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
122
123
124
125
126
127
128
129
130
            if self.kv_cache_dtype == "fp8_e5m2":
                raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                 "fp8 checkpoints.")
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
131

132
133
134
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
135
136
137
138
139
140
141
        attn_backend = get_attn_backend(head_size,
                                        dtype,
                                        kv_cache_dtype,
                                        block_size,
                                        is_attention_free,
                                        blocksparse_params is not None,
                                        use_mla=use_mla)
142
        impl_cls = attn_backend.get_impl_cls()
143
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
144
                             alibi_slopes, sliding_window, kv_cache_dtype,
145
                             blocksparse_params, logits_soft_cap, attn_type,
146
                             kv_sharing_target_layer_name, **extra_impl_args)
147
        self.backend = backend_name_to_enum(attn_backend.get_name())
148
        self.dtype = dtype
149

150
151
152
153
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
154
155
156
        self.use_direct_call = not current_platform.is_cuda_alike(
        ) and not current_platform.is_cpu()

157
        self.use_output = attn_backend.accept_output_buffer
158
159
160
161
162
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
163
        self.attn_type = attn_type
164
165
166
167
168
169
170
171
172
173
174
175
176

        if kv_sharing_target_layer_name is not None:
            if not envs.VLLM_USE_V1:
                raise NotImplementedError(
                    "Cross-layer KV sharing is not supported in V0.")

            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

177
178
179
180
181
182
183
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([]) for _ in range(get_current_vllm_config(
            ).parallel_config.pipeline_parallel_size)
        ]
184

185
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
186
187
188
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

189
190
191
192
193
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
194
195
196
197
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: Optional[torch.Size] = None,
198
    ) -> torch.Tensor:
199
200
201
202
203
204
205
206
207
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
208
        if self.calculate_kv_scales:
209
            attn_metadata = get_forward_context().attn_metadata
yangql's avatar
yangql committed
210
211
            if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
            # if key is not None and value is not None:         
212
                self.calc_kv_scales(query, key, value)
213
        if self.use_output:
214
215
            output_shape = (output_shape
                            if output_shape is not None else query.shape)
216
            output = torch.zeros(output_shape,
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
                                 dtype=query.dtype,
                                 device=query.device)
            hidden_size = output_shape[-1]
            # We skip reshaping query, key and value tensors for the MLA
            # backend since these tensors have different semantics and are
            # processed differently.
            if not self.use_mla:
                # Reshape the query, key, and value tensors.
                # NOTE(woosuk): We do this outside the custom op to minimize the
                # CPU overheads from the non-CUDA-graph regions.
                query = query.view(-1, self.num_heads, self.head_size)
                output = output.view(-1, self.num_heads, self.head_size)
                if key is not None:
                    key = key.view(-1, self.num_kv_heads, self.head_size)
                if value is not None:
                    value = value.view(-1, self.num_kv_heads, self.head_size)
233
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
234
                forward_context: ForwardContext = get_forward_context()
235
                attn_metadata = forward_context.attn_metadata
236
237
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
238
239
240
241
242
243
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                self.impl.forward(self,
                                  query,
                                  key,
                                  value,
                                  self_kv_cache,
244
                                  attn_metadata,
Chen Zhang's avatar
Chen Zhang committed
245
                                  output=output)
246
247
248
            else:
                torch.ops.vllm.unified_attention_with_output(
                    query, key, value, output, self.layer_name)
249
            return output.view(-1, hidden_size)
250
        else:
251
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
252
                forward_context = get_forward_context()
253
                attn_metadata = forward_context.attn_metadata
254
255
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
256
257
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                return self.impl.forward(self, query, key, value,
258
                                         self_kv_cache, attn_metadata)
259
260
261
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name)
262

263
264
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
265
266
267
268
269
270
271
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

272
273
274
275
276
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
277
        s += f", backend={self.impl.__class__.__name__}"
278
        return s
279

280
    def process_weights_after_loading(self, act_dtype: torch.dtype):
281
        if hasattr(self.impl, "process_weights_after_loading"):
282
            self.impl.process_weights_after_loading(act_dtype)
283

284

285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

301
302
303
        assert self.num_heads % self.num_kv_heads == 0, \
            f"num_heads ({self.num_heads}) is not " \
            f"divisible by num_kv_heads ({self.num_kv_heads})"
304
305
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

306
307
308
309
        dtype = torch.get_default_dtype()
        attn_backend = get_attn_backend(head_size,
                                        dtype,
                                        kv_cache_dtype=None,
zhuwenwen's avatar
zhuwenwen committed
310
                                        block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16,
311
                                        is_attention_free=False)
312
        backend = backend_name_to_enum(attn_backend.get_name())
313
314
315
316
        if current_platform.is_rocm():
            # currently, only torch_sdpa is supported on rocm
            self.attn_backend = _Backend.TORCH_SDPA
        else:
317
318
            if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
                           _Backend.FLEX_ATTENTION):
319
                backend = _Backend.XFORMERS
320

321
322
323
            self.attn_backend = backend if backend in {
                _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
            } else _Backend.TORCH_SDPA
324
325
326
327
328
329
330
331

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: batch_size x seq_len x hidden_size"""
332
        # TODO(Isotr0py): Use existing backend implementations and support FA3
333
334
335
336
337
338
339
        bsz, q_len, _ = query.size()
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

340
341
342
343
344
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

345
        if self.attn_backend == _Backend.XFORMERS:
346
347
348
349
350
351
352
353
354
355
356
357
358
359
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query,
                                                          key,
                                                          value,
                                                          scale=self.scale)
        elif self.attn_backend == _Backend.TORCH_SDPA:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            out = F.scaled_dot_product_attention(query,
                                                 key,
                                                 value,
                                                 scale=self.scale)
            out = out.transpose(1, 2)
360
361
362
363
364
365
366
        elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            from torch_xla.experimental.custom_kernel import flash_attention
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)

367
        return out.reshape(bsz, q_len, -1)
368
369


370
371
372
373
374
375
376
377
378
379
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
380
    assert isinstance(attn_metadata, dict)
381
382
    connector.wait_for_layer_load(layer_name)

yangshj1's avatar
add env  
yangshj1 committed
383
384
385
    lmcache_connector = get_lmcache_connector()
    if lmcache_connector is not None:
        lmcache_connector.wait_for_layer_load(layer_name)
386
387
388
389
390
391
392
393
394
395
396
397
398
399

def maybe_save_kv_layer_to_connector(
    layer_name: str,
    kv_cache_layer: List[torch.Tensor],
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
400
401
402
    assert isinstance(attn_metadata, dict)
    connector.save_kv_layer(layer_name, kv_cache_layer,
                            attn_metadata[layer_name])
yangshj1's avatar
add env  
yangshj1 committed
403
404
405
406
    
    lmcache_connector = get_lmcache_connector()
    if lmcache_connector is not None:
        lmcache_connector.save_kv_layer(layer_name, kv_cache_layer,
yangshj1's avatar
yangshj1 committed
407
                            attn_metadata[layer_name])
408

409
410
411
412
413
414
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
415
416
    wait_for_kv_layer_from_connector(layer_name)

417
    forward_context: ForwardContext = get_forward_context()
418
    attn_metadata = forward_context.attn_metadata
419
420
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
421
    self = forward_context.no_compile_layers[layer_name]
422
    kv_cache = self.kv_cache[forward_context.virtual_engine]
423
424
425
    output = self.impl.forward(self, query, key, value, kv_cache,
                               attn_metadata)

426
427
428
429
    if envs.VLLM_ENABLE_TBO:
        tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    else:
        maybe_save_kv_layer_to_connector(layer_name, kv_cache)
430
    return output
431
432
433
434
435
436
437
438
439
440
441
442
443
444


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
445
    mutates_args=[],
446
447
448
    fake_impl=unified_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)
449
450
451
452
453
454
455
456


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
457
    output_scale: Optional[torch.Tensor] = None,
458
) -> None:
459
    wait_for_kv_layer_from_connector(layer_name)
460
    forward_context: ForwardContext = get_forward_context()
461
    attn_metadata = forward_context.attn_metadata
462
463
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
464
    self = forward_context.no_compile_layers[layer_name]
465
    kv_cache = self.kv_cache[forward_context.virtual_engine]
466
467
    self.impl.forward(self,
                      query,
468
469
470
471
                      key,
                      value,
                      kv_cache,
                      attn_metadata,
472
473
                      output=output,
                      output_scale=output_scale)
474
475
476
477
    if envs.VLLM_ENABLE_TBO:
        tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    else:
        maybe_save_kv_layer_to_connector(layer_name, kv_cache)
478

479
480
481
482
483
484
485

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
486
    output_scale: Optional[torch.Tensor] = None,
487
488
489
490
491
492
493
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
494
    mutates_args=["output"],
495
496
497
    fake_impl=unified_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)