"vllm/vscode:/vscode.git/clone" did not exist on "238dfc8ac3f9befb594e1ad2f249616afc68d484"
layer.py 38.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4

5
6
from collections.abc import Callable
from typing import cast
7
8
9

import torch
import torch.nn as nn
10
import torch.nn.functional as F
11

12
import vllm.envs as envs
13
from vllm.attention import AttentionType
14
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
15
16
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.selector import get_attn_backend
17
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
18
from vllm.config import CacheConfig, get_current_vllm_config
19
from vllm.config.multimodal import MultiModalConfig
20
from vllm.config.vllm import VllmConfig
21
22
23
24
25
from vllm.distributed.kv_transfer import (
    get_kv_transfer_group,
    has_kv_transfer_group,
    is_v1_kv_transfer_group,
)
26
from vllm.forward_context import ForwardContext, get_forward_context
27
from vllm.logger import init_logger
28
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
29
30
31
32
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    UnquantizedLinearMethod,
)
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
35
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
36
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
37
from vllm.model_executor.models.vision import get_vit_attn_backend
38
from vllm.platforms import current_platform
39
from vllm.utils.torch_utils import (
40
41
42
43
44
45
46
47
48
    direct_register_custom_op,
    kv_cache_dtype_str_to_dtype,
)
from vllm.v1.kv_cache_interface import (
    FullAttentionSpec,
    KVCacheSpec,
    MLAAttentionSpec,
    SlidingWindowSpec,
)
49

50
51
52
53
54
55
if current_platform.is_rocm():
    from vllm.platforms.rocm import on_gfx9
else:
    on_gfx9 = lambda *args, **kwargs: False


56
FP8_DTYPE = current_platform.fp8_dtype()
57
58
59
60
61
62
63
64
65
logger = init_logger(__name__)
USE_XFORMERS_OPS = None


def check_xformers_availability():
    global USE_XFORMERS_OPS
    if USE_XFORMERS_OPS is not None:
        return USE_XFORMERS_OPS

66
    if current_platform.is_cuda() and current_platform.has_device_capability(100):
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
        # Xformers FA is not compatible with B200
        USE_XFORMERS_OPS = False
    else:
        try:
            from importlib.util import find_spec

            find_spec("xformers.ops")
            USE_XFORMERS_OPS = True
        except ImportError:
            USE_XFORMERS_OPS = False

    # the warning only needs to be shown once
    if not USE_XFORMERS_OPS:
        logger.warning("Xformers is not available, falling back.")

    return USE_XFORMERS_OPS

84

85
def check_upstream_fa_availability(dtype: torch.dtype):
86
87
88
89
90
    if (
        dtype in (torch.float16, torch.bfloat16)
        and current_platform.is_cuda()
        and current_platform.has_device_capability(80)
    ):
91
        from transformers.utils import is_flash_attn_2_available
92

93
        return is_flash_attn_2_available()
94
95
    if current_platform.is_rocm():
        from importlib.util import find_spec
96

97
        return find_spec("flash_attn") is not None
98
99
100
    return False


101
def maybe_get_vit_flash_attn_backend(
102
103
104
    attn_backend: _Backend,
    use_upstream_fa: bool,
    attn_backend_override: _Backend | None = None,
105
106
107
108
109
110
111
112
113
114
115
116
117
118
) -> tuple[_Backend, Callable | None]:
    if current_platform.is_rocm():
        if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
            attn_backend = _Backend.ROCM_AITER_FA

        elif (
            check_upstream_fa_availability(torch.get_default_dtype())
            and on_gfx9()
            and attn_backend_override is None
        ):
            attn_backend = _Backend.FLASH_ATTN
            use_upstream_fa = True
        else:
            return _Backend.TORCH_SDPA, None
119

120
121
122
123
124
125
    elif current_platform.is_cuda():
        if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
            torch.get_default_dtype()
        ):
            attn_backend = _Backend.FLASH_ATTN
            use_upstream_fa = True
126
127
128
129
130
    elif current_platform.is_xpu():
        assert attn_backend == _Backend.FLASH_ATTN, (
            "XPU platform only supports FLASH_ATTN as vision attention backend."
        )
        use_upstream_fa = False
131
132
    else:
        return _Backend.TORCH_SDPA, None
133

134
    if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
135
136
137
138
139
140
        if attn_backend == _Backend.ROCM_AITER_FA:
            from aiter import flash_attn_varlen_func
        else:
            if use_upstream_fa:
                from flash_attn import flash_attn_varlen_func
            else:
141
                from vllm.attention.utils.fa_utils import flash_attn_varlen_func
142
143
144
145
146
147
    else:
        flash_attn_varlen_func = None

    return attn_backend, flash_attn_varlen_func


148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def _init_kv_cache_quant(
    layer: nn.Module,
    quant_config: QuantizationConfig | None,
    prefix: str,
    kv_cache_dtype: str,
    calculate_kv_scales: bool,
) -> None:
    """Initializes KV cache scaling factors and quantization method.

    This helper function sets up the KV cache quantization attributes that are
    shared between Attention and MLAAttention layers. It initializes scale
    tensors for query, key, value, and probability, and configures the
    quantization method if applicable.

    Args:
        layer: The attention layer instance to initialize.
        quant_config: Optional quantization configuration.
        prefix: Layer name prefix for quantization method lookup.
        kv_cache_dtype: The KV cache data type string.
        calculate_kv_scales: Whether to calculate KV scales dynamically.
    """
    # The default k/v_scale is set to 1.0. This is ignored
    # when kv-cache is not fp8, and should be used with
    # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
    # expect the pre-quantized k/v_scale to be loaded along
    # with the model weights.
    layer.kv_cache_dtype = kv_cache_dtype
    layer.calculate_kv_scales = calculate_kv_scales
    layer._k_scale = torch.tensor(1.0, dtype=torch.float32)
    layer._v_scale = torch.tensor(1.0, dtype=torch.float32)
    layer._q_scale = torch.tensor(1.0, dtype=torch.float32)
    layer._prob_scale = torch.tensor(1.0, dtype=torch.float32)

    # We also keep q/k/v_scale on host (cpu) memory for attention
    # backends that require the scales to be on host instead of on device.
    # e.g. Flashinfer
    layer._q_scale_float = 1.0
    layer._k_scale_float = 1.0
    layer._v_scale_float = 1.0

    # The output scale on host memory. This should be the input scale of
    # the quant op after this attention layer.
    layer._o_scale_float = None

    quant_method = (
        quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
    )
    if quant_method is not None and not isinstance(
        quant_method, UnquantizedLinearMethod
    ):
        assert isinstance(quant_method, BaseKVCacheMethod)
        # TODO (mgoin): kv cache dtype should be specified in the FP8
        # checkpoint config and become the "auto" behavior
        if kv_cache_dtype == "fp8_e5m2":
            raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
        # If quantization is enabled, we make "k_scale" and "v_scale"
        # parameters so that it can be loaded from the model checkpoint.
        # The k/v_scale will then be converted back to native float32
        # values after weight loading.
        layer.quant_method = quant_method
        layer.quant_method.create_weights(layer)


211
class Attention(nn.Module, AttentionLayerBase):
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
228
229
230
231
232
233
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        logits_soft_cap: float | None = None,
        per_layer_sliding_window: int | None = None,
234
        prefix: str = "",
235
        attn_type: str = AttentionType.DECODER,
236
237
        kv_sharing_target_layer_name: str | None = None,
        attn_backend: type[AttentionBackend] | None = None,
238
        **extra_impl_args,
239
    ) -> None:
240
241
242
243
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
244
        super().__init__()
245
246
247
248
249
250
251
252
253
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

254
        vllm_config = get_current_vllm_config()
255
256
257
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
258
            calculate_kv_scales = cache_config.calculate_kv_scales
259
260
261
        else:
            kv_cache_dtype = "auto"
            block_size = 16
262
            calculate_kv_scales = False
263
264
265
        self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
            kv_cache_dtype, vllm_config.model_config
        )
266
267
        if num_kv_heads is None:
            num_kv_heads = num_heads
268
269
270
        assert num_heads % num_kv_heads == 0, (
            f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
        )
271

272
273
274
275
        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(
            self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
        )
276

277
278
279
280
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
281
        self.has_sink = extra_impl_args.get("sinks") is not None
282

283
284
285
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
286
        if attn_backend is None:
287
288
289
290
291
            self.attn_backend = get_attn_backend(
                head_size,
                dtype,
                kv_cache_dtype,
                block_size,
292
                use_mla=False,
293
294
                has_sink=self.has_sink,
            )
295
296
297
298
        else:
            self.attn_backend = attn_backend

        impl_cls = self.attn_backend.get_impl_cls()
299
300
301
302
303
304
305
306
307
308
309
310
311
        self.impl = impl_cls(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **extra_impl_args,
        )
312
        self.backend = backend_name_to_enum(self.attn_backend.get_name())
313
        self.dtype = dtype
314

315
316
317
318
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
319
        self.use_direct_call = not current_platform.opaque_attention_op()
320

321
        self.use_output = self.attn_backend.accept_output_buffer
322
        compilation_config = vllm_config.compilation_config
323
324
325
326
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
327
        self.attn_type = attn_type
328
329
330
331
332
333
334
335
336

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

337
338
339
340
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
341
            torch.tensor([])
342
            for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
343
        ]
344

345
346
347
348
        # Initialize q/k/v range constants.
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
349

350
351
        # for attn backends supporting query quantization
        self.query_quant = None
352
353
        if (
            self.kv_cache_dtype.startswith("fp8")
354
            and self.impl.supports_quant_query_input()
355
356
        ):
            self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
357

358
359
360
361
362
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
363
364
365
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
366
        output_shape: torch.Size | None = None,
367
    ) -> torch.Tensor:
368
369
370
371
372
373
374
375
376
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
377
        if self.calculate_kv_scales:
378
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
379
380
381
382
383
384
385
386
        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
387
388
389
390

            # check if query quantization is supported
            if self.impl.supports_quant_query_input():
                query, _ = self.query_quant(query, self._q_scale)
391

392
        if self.use_output:
393
            output_shape = output_shape if output_shape is not None else query.shape
394
            output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
395
            hidden_size = output_shape[-1]
396
397
398
399
400
401
402
403
404
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
            output = output.view(-1, self.num_heads, self.head_size)
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
                value = value.view(-1, self.num_kv_heads, self.head_size)
405
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
406
                forward_context: ForwardContext = get_forward_context()
407
                attn_metadata = forward_context.attn_metadata
408
409
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
410
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
411
412
413
                self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata, output=output
                )
414
415
            else:
                torch.ops.vllm.unified_attention_with_output(
416
417
                    query, key, value, output, self.layer_name
                )
418
            return output.view(-1, hidden_size)
419
        else:
420
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
421
                forward_context = get_forward_context()
422
                attn_metadata = forward_context.attn_metadata
423
424
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
425
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
426
427
428
                return self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata
                )
429
430
            else:
                return torch.ops.vllm.unified_attention(
431
432
                    query, key, value, self.layer_name
                )
433

434
435
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
436
437
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
438
        self._q_scale_float = self._q_scale.item()
439
440
441
442
443
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

444
445
446
447
448
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
449
        s += f", backend={self.impl.__class__.__name__}"
450
        return s
451

452
    def process_weights_after_loading(self, act_dtype: torch.dtype):
453
        self.impl.process_weights_after_loading(act_dtype)
454

455
456
457
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER
        if self.sliding_window is not None:
            assert not vllm_config.model_config.use_mla, (
                "MLA is not supported for slidingwindow"
            )
            return SlidingWindowSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
                sliding_window=self.sliding_window,
            )
        else:
            return FullAttentionSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
            )

482

483
484
485
486
487
488
489
490
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
491
        num_kv_heads: int | None = None,
492
493
494
        # This has no effect, it is only here to make it easier to swap
        # between Attention and MultiHeadAttention
        prefix: str = "",
495
        multimodal_config: MultiModalConfig | None = None,
496
    ) -> None:
497
498
499
500
501
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
502
        self.layer_name = prefix
503

504
505
        assert self.num_heads % self.num_kv_heads == 0, (
            f"num_heads ({self.num_heads}) is not "
506
            f"divisible by num_kv_heads ({self.num_kv_heads})"
507
        )
508
509
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

510
511
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
512
        dtype = torch.get_default_dtype()
513
514

        # Determine the attention backend
515
516
517
518
519
520
521
522
        attn_backend_override = None
        if multimodal_config is not None:
            attn_backend_override = multimodal_config.mm_encoder_attn_backend
        backend = get_vit_attn_backend(
            head_size=head_size,
            dtype=dtype,
            attn_backend_override=attn_backend_override,
        )
523
524
525
526
527
528

        # Some auto-selected backends can be upgraded
        # to upstream flash attention if available.
        # If vllm native fa is selected, we use it directly.
        use_upstream_fa = False

529
530
531
532
533
534
535
536
537
538
539
540
        self.attn_backend = (
            backend
            if backend
            in {
                _Backend.TORCH_SDPA,
                _Backend.XFORMERS,
                _Backend.PALLAS,
                _Backend.ROCM_AITER_FA,
                _Backend.FLASH_ATTN,
            }
            else _Backend.TORCH_SDPA
        )
541

542
543
        self.attn_backend, self._flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
544
545
                self.attn_backend,
                use_upstream_fa,
546
                attn_backend_override=attn_backend_override,
547
            )
548
        )
549

550
        if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability():
551
552
            self.attn_backend = _Backend.TORCH_SDPA

553
        self.is_flash_attn_backend = self.attn_backend in {
554
555
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
556
557
558
559
        }

        # this condition is just to make sure that the
        # use_upstream_fa in the log is correct
560
        if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
561
            use_upstream_fa = True
562
563
564

        logger.info_once(
            f"MultiHeadAttention attn_backend: {self.attn_backend}, "
565
566
            f"use_upstream_fa: {use_upstream_fa}"
        )
567

568
569
570
571
572
573
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
574
        """Input shape:
575
576
577
578
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
579
580
581
582
583
584
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

585
586
587
588
589
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

590
        if self.is_flash_attn_backend:
591
            assert self._flash_attn_varlen_func is not None
592
593
594
595
596
597
            cu_seqlens_q = torch.arange(
                0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
            )
            cu_seqlens_k = torch.arange(
                0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device
            )
598
599
600
601
602
603
604
605
606
607
608
609

            out = self._flash_attn_varlen_func(
                query.flatten(0, 1),
                key.flatten(0, 1),
                value.flatten(0, 1),
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=q_len,
                max_seqlen_k=kv_len,
                softmax_scale=self.scale,
            )
        elif self.attn_backend == _Backend.XFORMERS:
610
611
            from xformers import ops as xops

612
613
614
            out = xops.memory_efficient_attention_forward(
                query, key, value, scale=self.scale
            )
615
        elif self.attn_backend == _Backend.TORCH_SDPA:
616
617
            query, key, value = (x.transpose(1, 2) for x in (query, key, value))
            out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
618
            out = out.transpose(1, 2)
619
        elif self.attn_backend == _Backend.PALLAS:
620
            query, key, value = (x.transpose(1, 2) for x in (query, key, value))
621
            from torch_xla.experimental.custom_kernel import flash_attention
622

623
624
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)
625
626
627
        else:
            # ViT attention hasn't supported this backend yet
            raise NotImplementedError(
628
629
                f"ViT attention hasn't supported {self.attn_backend} backend yet."
            )
630

631
        return out.reshape(bsz, q_len, -1)
632
633


634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
class MLAAttention(nn.Module, AttentionLayerBase):
    """Multi-Head Latent Attention layer.

    This class takes query, and compressed key/value tensors as input.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
652
        q_lora_rank: int | None,
653
654
        kv_lora_rank: int,
        kv_b_proj: ColumnParallelLinear,
655
656
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
657
658
        prefix: str = "",
        use_sparse: bool = False,
659
        indexer: object | None = None,
660
        **extra_impl_args,
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = scale
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.head_size = kv_lora_rank + qk_rope_head_dim
        self.layer_name = prefix

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False
681
682
683
684
685

        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(
            self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
        )
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716

        dtype = torch.get_default_dtype()
        self.attn_backend = get_attn_backend(
            self.head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=True,
            use_sparse=use_sparse,
        )
        impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
        self.impl = impl_cls(
            num_heads=self.num_heads,
            head_size=self.head_size,
            scale=self.scale,
            num_kv_heads=1,
            alibi_slopes=None,
            sliding_window=None,
            kv_cache_dtype=self.kv_cache_dtype,
            logits_soft_cap=None,
            attn_type=AttentionType.DECODER,
            kv_sharing_target_layer_name=None,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=kv_b_proj,
            indexer=indexer,
717
            **extra_impl_args,
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
        )

        self.use_direct_call = not current_platform.opaque_attention_op()

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

        self.kv_cache = [
            torch.tensor([])
            for _ in range(
                get_current_vllm_config().parallel_config.pipeline_parallel_size
            )
        ]

        self.use_sparse = use_sparse

        # Initialize q/k/v range constants.
737
738
739
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
740
741
742
743
744
745

    def forward(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
746
        output_shape: torch.Size | None = None,
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
    ) -> torch.Tensor:
        if self.use_direct_call:
            forward_context: ForwardContext = get_forward_context()
            attn_metadata = forward_context.attn_metadata
            if isinstance(attn_metadata, dict):
                attn_metadata = attn_metadata[self.layer_name]
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]

            # Mirror Attention.forward scale calculation path
            if self.calculate_kv_scales and getattr(
                attn_metadata, "enable_kv_scales_calculation", False
            ):
                self.calc_kv_scales(q, kv_c_normed, k_pe)

            if self.attn_backend.accept_output_buffer:
762
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
                self.impl.forward(
                    self,
                    q,
                    kv_c_normed,
                    k_pe,
                    self_kv_cache,
                    attn_metadata,
                    output=output,
                )
                return output
            else:
                return self.impl.forward(
                    self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
                )
        else:
            if self.attn_backend.accept_output_buffer:
779
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
                torch.ops.vllm.unified_mla_attention_with_output(
                    q,
                    kv_c_normed,
                    k_pe,
                    output,
                    self.layer_name,
                )
                return output
            else:
                # We can still access forward context to check calculation flag
                if self.calculate_kv_scales:
                    forward_context = get_forward_context()
                    attn_metadata = forward_context.attn_metadata
                    if isinstance(attn_metadata, dict):
                        attn_metadata = attn_metadata[self.layer_name]
                    if getattr(attn_metadata, "enable_kv_scales_calculation", False):
                        self.calc_kv_scales(q, kv_c_normed, k_pe)
                return torch.ops.vllm.unified_mla_attention(
                    q,
                    kv_c_normed,
                    k_pe,
                    self.layer_name,
                )

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if hasattr(self.impl, "process_weights_after_loading"):
            self.impl.process_weights_after_loading(act_dtype)

    def calc_kv_scales(
        self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
    ) -> None:
        """Optional scale calculation for MLA inputs.

        Mirrors Attention.calc_kv_scales. Not all MLA backends require this
        """
        # Use safe defaults if ranges are not present
        q_range = getattr(self, "q_range", torch.tensor(1.0))
        k_range = getattr(self, "k_range", torch.tensor(1.0))
        v_range = getattr(self, "v_range", torch.tensor(1.0))

        self._q_scale.copy_(torch.abs(q).max() / q_range)
        # kv_c_normed is the compressed KV representation; use it for k/v
        kv_abs_max = torch.abs(kv_c_normed).max()
        self._k_scale.copy_(kv_abs_max / k_range)
        self._v_scale.copy_(kv_abs_max / v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        self.calculate_kv_scales = False

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

833
834
835
836
837
838
839
840
841
842
843
844
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        kv_cache_dtype = kv_cache_dtype_str_to_dtype(
            self.kv_cache_dtype, vllm_config.model_config
        )
        return MLAAttentionSpec(
            block_size=vllm_config.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_size,
            dtype=kv_cache_dtype,
            cache_dtype_str=vllm_config.cache_config.cache_dtype,
        )

845

846
847
848
849
850
851
852
853
854
855
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
856
    assert isinstance(attn_metadata, dict)
857
858
859
860
861
    connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
    layer_name: str,
862
    kv_cache_layer: list[torch.Tensor],
863
864
865
866
867
868
869
870
871
872
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
873
    assert isinstance(attn_metadata, dict)
874
    connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
875
876


877
878
879
880
881
882
883
884
885
886
887
888
889
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata

    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]

    if attn_metadata is None or not getattr(
890
891
        attn_metadata, "enable_kv_scales_calculation", False
    ):
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
        return

    self = forward_context.no_compile_layers[layer_name]
    self.calc_kv_scales(query, key, value)


def maybe_calc_kv_scales_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=["query", "key", "value"],
    fake_impl=maybe_calc_kv_scales_fake,
)


915
916
917
918
919
920
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
921
922
    wait_for_kv_layer_from_connector(layer_name)

923
    forward_context: ForwardContext = get_forward_context()
924
    attn_metadata = forward_context.attn_metadata
925
926
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
927
    self = forward_context.no_compile_layers[layer_name]
928
    kv_cache = self.kv_cache[forward_context.virtual_engine]
929
    output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
930
931
932

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    return output
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
)
949
950
951
952
953
954
955
956


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
957
958
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
959
) -> None:
960
    wait_for_kv_layer_from_connector(layer_name)
961
    forward_context: ForwardContext = get_forward_context()
962
    attn_metadata = forward_context.attn_metadata
963
964
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
965
    self = forward_context.no_compile_layers[layer_name]
966
    kv_cache = self.kv_cache[forward_context.virtual_engine]
967
968
969
970
971
972
973
974
975
976
977
    self.impl.forward(
        self,
        query,
        key,
        value,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )
978

979
980
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)

981
982
983
984
985
986
987

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
988
989
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
990
991
992
993
994
995
996
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
997
    mutates_args=["output", "output_block_scale"],
998
999
    fake_impl=unified_attention_with_output_fake,
)
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045


def unified_mla_attention(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    wait_for_kv_layer_from_connector(layer_name)

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
    self: MLAAttention = forward_context.no_compile_layers[layer_name]
    kv_cache = self.kv_cache[forward_context.virtual_engine]
    output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    return output


def unified_mla_attention_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(q).contiguous()


direct_register_custom_op(
    op_name="unified_mla_attention",
    op_func=unified_mla_attention,
    mutates_args=[],
    fake_impl=unified_mla_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)


def unified_mla_attention_with_output(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
1046
1047
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
) -> None:
    wait_for_kv_layer_from_connector(layer_name)
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
    self: MLAAttention = forward_context.no_compile_layers[layer_name]
    kv_cache = self.kv_cache[forward_context.virtual_engine]
    self.impl.forward(
        self,
        q,
        kv_c_normed,
        k_pe,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)


def unified_mla_attention_with_output_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
1077
1078
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
) -> None:
    return


direct_register_custom_op(
    op_name="unified_mla_attention_with_output",
    op_func=unified_mla_attention_with_output,
    mutates_args=["output", "output_block_scale"],
    fake_impl=unified_mla_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)