layer.py 37.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4

5
6
from collections.abc import Callable
from typing import cast
7
8
9

import torch
import torch.nn as nn
10
import torch.nn.functional as F
11

12
import vllm.envs as envs
13
from vllm.attention import AttentionType
14
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
15
16
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.selector import get_attn_backend
17
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
18
from vllm.config import CacheConfig, get_current_vllm_config
19
from vllm.config.multimodal import MultiModalConfig
20
from vllm.config.vllm import VllmConfig
21
22
23
24
25
from vllm.distributed.kv_transfer import (
    get_kv_transfer_group,
    has_kv_transfer_group,
    is_v1_kv_transfer_group,
)
26
from vllm.forward_context import ForwardContext, get_forward_context
27
from vllm.logger import init_logger
28
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
29
30
31
32
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    UnquantizedLinearMethod,
)
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
35
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
36
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
37
from vllm.model_executor.models.vision import get_vit_attn_backend
38
from vllm.platforms import current_platform
39
from vllm.utils.torch_utils import (
40
41
42
43
44
45
46
47
48
    direct_register_custom_op,
    kv_cache_dtype_str_to_dtype,
)
from vllm.v1.kv_cache_interface import (
    FullAttentionSpec,
    KVCacheSpec,
    MLAAttentionSpec,
    SlidingWindowSpec,
)
49

50
FP8_DTYPE = current_platform.fp8_dtype()
51
52
53
54
55
56
57
58
59
logger = init_logger(__name__)
USE_XFORMERS_OPS = None


def check_xformers_availability():
    global USE_XFORMERS_OPS
    if USE_XFORMERS_OPS is not None:
        return USE_XFORMERS_OPS

60
    if current_platform.is_cuda() and current_platform.has_device_capability(100):
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        # Xformers FA is not compatible with B200
        USE_XFORMERS_OPS = False
    else:
        try:
            from importlib.util import find_spec

            find_spec("xformers.ops")
            USE_XFORMERS_OPS = True
        except ImportError:
            USE_XFORMERS_OPS = False

    # the warning only needs to be shown once
    if not USE_XFORMERS_OPS:
        logger.warning("Xformers is not available, falling back.")

    return USE_XFORMERS_OPS

78

79
def check_upstream_fa_availability(dtype: torch.dtype):
80
81
82
83
84
    if (
        dtype in (torch.float16, torch.bfloat16)
        and current_platform.is_cuda()
        and current_platform.has_device_capability(80)
    ):
85
        from transformers.utils import is_flash_attn_2_available
86

87
        return is_flash_attn_2_available()
88
89
    if current_platform.is_rocm():
        from importlib.util import find_spec
90

91
        return find_spec("flash_attn") is not None
92
93
94
    return False


95
def maybe_get_vit_flash_attn_backend(
96
97
98
    attn_backend: _Backend,
    use_upstream_fa: bool,
    attn_backend_override: _Backend | None = None,
99
100
101
102
103
) -> tuple[_Backend, Callable]:
    if (
        attn_backend != _Backend.FLASH_ATTN
        and attn_backend != _Backend.ROCM_AITER_FA
        and check_upstream_fa_availability(torch.get_default_dtype())
104
        and attn_backend_override is None
105
    ):
106
107
108
        attn_backend = _Backend.FLASH_ATTN
        use_upstream_fa = True

109
    if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN:
110
111
        use_upstream_fa = True

112
    if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
113
114
115
116
117
118
119
120
121
122
123
124
125
        if attn_backend == _Backend.ROCM_AITER_FA:
            from aiter import flash_attn_varlen_func
        else:
            if use_upstream_fa:
                from flash_attn import flash_attn_varlen_func
            else:
                from vllm.vllm_flash_attn import flash_attn_varlen_func
    else:
        flash_attn_varlen_func = None

    return attn_backend, flash_attn_varlen_func


126
class Attention(nn.Module, AttentionLayerBase):
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
143
144
145
146
147
148
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        logits_soft_cap: float | None = None,
        per_layer_sliding_window: int | None = None,
149
        prefix: str = "",
150
        attn_type: str = AttentionType.DECODER,
151
152
        kv_sharing_target_layer_name: str | None = None,
        attn_backend: type[AttentionBackend] | None = None,
153
        **extra_impl_args,
154
    ) -> None:
155
156
157
158
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
159
        super().__init__()
160
161
162
163
164
165
166
167
168
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

169
        vllm_config = get_current_vllm_config()
170
171
172
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
173
            calculate_kv_scales = cache_config.calculate_kv_scales
174
175
176
        else:
            kv_cache_dtype = "auto"
            block_size = 16
177
            calculate_kv_scales = False
178
179
180
        self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
            kv_cache_dtype, vllm_config.model_config
        )
181
182
        if num_kv_heads is None:
            num_kv_heads = num_heads
183
184
185
        assert num_heads % num_kv_heads == 0, (
            f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
        )
186

187
        # The default k/v_scale is set to 1.0. This is ignored
188
189
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
190
        # expect the pre-quantized k/v_scale to be loaded along
191
192
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
193
194
195
        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
196
197
198
        # FlashAttn doesn't support quantizing the kv-cache only
        # but requires q to be quantized as well.
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
199
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
200

201
202
203
204
        # We also keep q/k/v_scale on host (cpu) memory for attention
        # backends that require the scales to be on host instead of on device.
        # e.g. Flashinfer
        self._q_scale_float = 1.0
205
206
207
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

208
209
        # The output scale on host memory. This should be the input scale of
        # the quant op after this attention layer.
210
        self._o_scale_float: float | None = None
211

212
213
214
215
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
216
        self.has_sink = extra_impl_args.get("sinks") is not None
217

218
219
220
        quant_method = (
            quant_config.get_quant_method(self, prefix=prefix) if quant_config else None
        )
221
        if quant_method is not None and not isinstance(
222
223
            quant_method, UnquantizedLinearMethod
        ):
224
            assert isinstance(quant_method, BaseKVCacheMethod)
225
226
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
227
            if self.kv_cache_dtype == "fp8_e5m2":
228
229
230
                raise ValueError(
                    "fp8_e5m2 kv-cache is not supported with fp8 checkpoints."
                )
231
232
233
234
235
236
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
237

238
239
240
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
241
        if attn_backend is None:
242
243
244
245
246
            self.attn_backend = get_attn_backend(
                head_size,
                dtype,
                kv_cache_dtype,
                block_size,
247
                use_mla=False,
248
249
                has_sink=self.has_sink,
            )
250
251
252
253
        else:
            self.attn_backend = attn_backend

        impl_cls = self.attn_backend.get_impl_cls()
254
255
256
257
258
259
260
261
262
263
264
265
266
        self.impl = impl_cls(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **extra_impl_args,
        )
267
        self.backend = backend_name_to_enum(self.attn_backend.get_name())
268
        self.dtype = dtype
269

270
271
272
273
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
274
        self.use_direct_call = not current_platform.opaque_attention_op()
275

276
        self.use_output = self.attn_backend.accept_output_buffer
277
        compilation_config = vllm_config.compilation_config
278
279
280
281
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
282
        self.attn_type = attn_type
283
284
285
286
287
288
289
290
291

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

292
293
294
295
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
296
            torch.tensor([])
297
            for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
298
        ]
299

300
301
302
303
        # Initialize q/k/v range constants.
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
304

305
306
        # for attn backends supporting query quantization
        self.query_quant = None
307
308
        if (
            self.kv_cache_dtype.startswith("fp8")
309
            and self.impl.supports_quant_query_input()
310
311
        ):
            self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
312

313
314
315
316
317
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
318
319
320
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
321
        output_shape: torch.Size | None = None,
322
    ) -> torch.Tensor:
323
324
325
326
327
328
329
330
331
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
332
        if self.calculate_kv_scales:
333
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
334
335
336
337
338
339
340
341
        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
342
343
344
345

            # check if query quantization is supported
            if self.impl.supports_quant_query_input():
                query, _ = self.query_quant(query, self._q_scale)
346

347
        if self.use_output:
348
            output_shape = output_shape if output_shape is not None else query.shape
349
            output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
350
            hidden_size = output_shape[-1]
351
352
353
354
355
356
357
358
359
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
            output = output.view(-1, self.num_heads, self.head_size)
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
                value = value.view(-1, self.num_kv_heads, self.head_size)
360
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
361
                forward_context: ForwardContext = get_forward_context()
362
                attn_metadata = forward_context.attn_metadata
363
364
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
365
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
366
367
368
                self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata, output=output
                )
369
370
            else:
                torch.ops.vllm.unified_attention_with_output(
371
372
                    query, key, value, output, self.layer_name
                )
373
            return output.view(-1, hidden_size)
374
        else:
375
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
376
                forward_context = get_forward_context()
377
                attn_metadata = forward_context.attn_metadata
378
379
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
380
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
381
382
383
                return self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata
                )
384
385
            else:
                return torch.ops.vllm.unified_attention(
386
387
                    query, key, value, self.layer_name
                )
388

389
390
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
391
392
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
393
        self._q_scale_float = self._q_scale.item()
394
395
396
397
398
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

399
400
401
402
403
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
404
        s += f", backend={self.impl.__class__.__name__}"
405
        return s
406

407
    def process_weights_after_loading(self, act_dtype: torch.dtype):
408
        self.impl.process_weights_after_loading(act_dtype)
409

410
411
412
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER
        if self.sliding_window is not None:
            assert not vllm_config.model_config.use_mla, (
                "MLA is not supported for slidingwindow"
            )
            return SlidingWindowSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
                sliding_window=self.sliding_window,
            )
        else:
            return FullAttentionSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
            )

437

438
439
440
441
442
443
444
445
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
446
        num_kv_heads: int | None = None,
447
448
449
        # This has no effect, it is only here to make it easier to swap
        # between Attention and MultiHeadAttention
        prefix: str = "",
450
        multimodal_config: MultiModalConfig | None = None,
451
    ) -> None:
452
453
454
455
456
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
457
        self.layer_name = prefix
458

459
460
        assert self.num_heads % self.num_kv_heads == 0, (
            f"num_heads ({self.num_heads}) is not "
461
            f"divisible by num_kv_heads ({self.num_kv_heads})"
462
        )
463
464
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

465
466
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
467
        dtype = torch.get_default_dtype()
468
469

        # Determine the attention backend
470
471
472
473
474
475
476
477
        attn_backend_override = None
        if multimodal_config is not None:
            attn_backend_override = multimodal_config.mm_encoder_attn_backend
        backend = get_vit_attn_backend(
            head_size=head_size,
            dtype=dtype,
            attn_backend_override=attn_backend_override,
        )
478
479
480
481
482
483

        # Some auto-selected backends can be upgraded
        # to upstream flash attention if available.
        # If vllm native fa is selected, we use it directly.
        use_upstream_fa = False

484
485
        if current_platform.is_xpu():
            # currently, only torch_sdpa is supported on xpu
486
487
            self.attn_backend = _Backend.TORCH_SDPA
        else:
488
489
490
491
492
493
494
495
496
497
498
499
            self.attn_backend = (
                backend
                if backend
                in {
                    _Backend.TORCH_SDPA,
                    _Backend.XFORMERS,
                    _Backend.PALLAS,
                    _Backend.ROCM_AITER_FA,
                    _Backend.FLASH_ATTN,
                }
                else _Backend.TORCH_SDPA
            )
500

501
502
        self.attn_backend, self._flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
503
504
                self.attn_backend,
                use_upstream_fa,
505
                attn_backend_override=attn_backend_override,
506
            )
507
        )
508

509
        if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability():
510
511
            self.attn_backend = _Backend.TORCH_SDPA

512
        self.is_flash_attn_backend = self.attn_backend in {
513
514
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
515
516
517
518
        }

        # this condition is just to make sure that the
        # use_upstream_fa in the log is correct
519
        if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
520
            use_upstream_fa = True
521
522
523

        logger.info_once(
            f"MultiHeadAttention attn_backend: {self.attn_backend}, "
524
525
            f"use_upstream_fa: {use_upstream_fa}"
        )
526

527
528
529
530
531
532
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
533
        """Input shape:
534
535
536
537
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
538
539
540
541
542
543
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

544
545
546
547
548
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

549
        if self.is_flash_attn_backend:
550
551
552
553
554
555
            cu_seqlens_q = torch.arange(
                0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
            )
            cu_seqlens_k = torch.arange(
                0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device
            )
556
557
558
559
560
561
562
563
564
565
566
567

            out = self._flash_attn_varlen_func(
                query.flatten(0, 1),
                key.flatten(0, 1),
                value.flatten(0, 1),
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=q_len,
                max_seqlen_k=kv_len,
                softmax_scale=self.scale,
            )
        elif self.attn_backend == _Backend.XFORMERS:
568
569
            from xformers import ops as xops

570
571
572
            out = xops.memory_efficient_attention_forward(
                query, key, value, scale=self.scale
            )
573
        elif self.attn_backend == _Backend.TORCH_SDPA:
574
575
            query, key, value = (x.transpose(1, 2) for x in (query, key, value))
            out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
576
            out = out.transpose(1, 2)
577
        elif self.attn_backend == _Backend.PALLAS:
578
            query, key, value = (x.transpose(1, 2) for x in (query, key, value))
579
            from torch_xla.experimental.custom_kernel import flash_attention
580

581
582
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)
583
584
585
        else:
            # ViT attention hasn't supported this backend yet
            raise NotImplementedError(
586
587
                f"ViT attention hasn't supported {self.attn_backend} backend yet."
            )
588

589
        return out.reshape(bsz, q_len, -1)
590
591


592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
class MLAAttention(nn.Module, AttentionLayerBase):
    """Multi-Head Latent Attention layer.

    This class takes query, and compressed key/value tensors as input.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
610
        q_lora_rank: int | None,
611
612
        kv_lora_rank: int,
        kv_b_proj: ColumnParallelLinear,
613
614
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
615
616
        prefix: str = "",
        use_sparse: bool = False,
617
        indexer: object | None = None,
618
        **extra_impl_args,
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = scale
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.head_size = kv_lora_rank + qk_rope_head_dim
        self.layer_name = prefix

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False
        self.kv_cache_dtype = kv_cache_dtype

        dtype = torch.get_default_dtype()
        self.attn_backend = get_attn_backend(
            self.head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=True,
            use_sparse=use_sparse,
        )
        impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
        self.impl = impl_cls(
            num_heads=self.num_heads,
            head_size=self.head_size,
            scale=self.scale,
            num_kv_heads=1,
            alibi_slopes=None,
            sliding_window=None,
            kv_cache_dtype=self.kv_cache_dtype,
            logits_soft_cap=None,
            attn_type=AttentionType.DECODER,
            kv_sharing_target_layer_name=None,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=kv_b_proj,
            indexer=indexer,
671
            **extra_impl_args,
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
        )

        self.use_direct_call = not current_platform.opaque_attention_op()

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

        self.kv_cache = [
            torch.tensor([])
            for _ in range(
                get_current_vllm_config().parallel_config.pipeline_parallel_size
            )
        ]

        # Align with Attention's scale attributes for MLA backends.

        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)

        # Host-side mirrors used by some attention backends
        self._q_scale_float = 1.0
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0
700
        self._o_scale_float: float | None = None
701
702
703
704

        self.use_sparse = use_sparse

        # Initialize q/k/v range constants.
705
706
707
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
708
709
710
711
712
713

    def forward(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
714
        output_shape: torch.Size | None = None,
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
    ) -> torch.Tensor:
        if self.use_direct_call:
            forward_context: ForwardContext = get_forward_context()
            attn_metadata = forward_context.attn_metadata
            if isinstance(attn_metadata, dict):
                attn_metadata = attn_metadata[self.layer_name]
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]

            # Mirror Attention.forward scale calculation path
            if self.calculate_kv_scales and getattr(
                attn_metadata, "enable_kv_scales_calculation", False
            ):
                self.calc_kv_scales(q, kv_c_normed, k_pe)

            if self.attn_backend.accept_output_buffer:
730
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
                self.impl.forward(
                    self,
                    q,
                    kv_c_normed,
                    k_pe,
                    self_kv_cache,
                    attn_metadata,
                    output=output,
                )
                return output
            else:
                return self.impl.forward(
                    self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
                )
        else:
            if self.attn_backend.accept_output_buffer:
747
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
                torch.ops.vllm.unified_mla_attention_with_output(
                    q,
                    kv_c_normed,
                    k_pe,
                    output,
                    self.layer_name,
                )
                return output
            else:
                # We can still access forward context to check calculation flag
                if self.calculate_kv_scales:
                    forward_context = get_forward_context()
                    attn_metadata = forward_context.attn_metadata
                    if isinstance(attn_metadata, dict):
                        attn_metadata = attn_metadata[self.layer_name]
                    if getattr(attn_metadata, "enable_kv_scales_calculation", False):
                        self.calc_kv_scales(q, kv_c_normed, k_pe)
                return torch.ops.vllm.unified_mla_attention(
                    q,
                    kv_c_normed,
                    k_pe,
                    self.layer_name,
                )

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if hasattr(self.impl, "process_weights_after_loading"):
            self.impl.process_weights_after_loading(act_dtype)

    def calc_kv_scales(
        self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
    ) -> None:
        """Optional scale calculation for MLA inputs.

        Mirrors Attention.calc_kv_scales. Not all MLA backends require this
        """
        # Use safe defaults if ranges are not present
        q_range = getattr(self, "q_range", torch.tensor(1.0))
        k_range = getattr(self, "k_range", torch.tensor(1.0))
        v_range = getattr(self, "v_range", torch.tensor(1.0))

        self._q_scale.copy_(torch.abs(q).max() / q_range)
        # kv_c_normed is the compressed KV representation; use it for k/v
        kv_abs_max = torch.abs(kv_c_normed).max()
        self._k_scale.copy_(kv_abs_max / k_range)
        self._v_scale.copy_(kv_abs_max / v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        self.calculate_kv_scales = False

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

801
802
803
804
805
806
807
808
809
810
811
812
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        kv_cache_dtype = kv_cache_dtype_str_to_dtype(
            self.kv_cache_dtype, vllm_config.model_config
        )
        return MLAAttentionSpec(
            block_size=vllm_config.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_size,
            dtype=kv_cache_dtype,
            cache_dtype_str=vllm_config.cache_config.cache_dtype,
        )

813

814
815
816
817
818
819
820
821
822
823
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
824
    assert isinstance(attn_metadata, dict)
825
826
827
828
829
    connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
    layer_name: str,
830
    kv_cache_layer: list[torch.Tensor],
831
832
833
834
835
836
837
838
839
840
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
841
    assert isinstance(attn_metadata, dict)
842
    connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
843
844


845
846
847
848
849
850
851
852
853
854
855
856
857
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata

    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]

    if attn_metadata is None or not getattr(
858
859
        attn_metadata, "enable_kv_scales_calculation", False
    ):
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
        return

    self = forward_context.no_compile_layers[layer_name]
    self.calc_kv_scales(query, key, value)


def maybe_calc_kv_scales_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=["query", "key", "value"],
    fake_impl=maybe_calc_kv_scales_fake,
)


883
884
885
886
887
888
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
889
890
    wait_for_kv_layer_from_connector(layer_name)

891
    forward_context: ForwardContext = get_forward_context()
892
    attn_metadata = forward_context.attn_metadata
893
894
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
895
    self = forward_context.no_compile_layers[layer_name]
896
    kv_cache = self.kv_cache[forward_context.virtual_engine]
897
    output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
898
899
900

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    return output
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
)
917
918
919
920
921
922
923
924


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
925
926
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
927
) -> None:
928
    wait_for_kv_layer_from_connector(layer_name)
929
    forward_context: ForwardContext = get_forward_context()
930
    attn_metadata = forward_context.attn_metadata
931
932
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
933
    self = forward_context.no_compile_layers[layer_name]
934
    kv_cache = self.kv_cache[forward_context.virtual_engine]
935
936
937
938
939
940
941
942
943
944
945
    self.impl.forward(
        self,
        query,
        key,
        value,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )
946

947
948
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)

949
950
951
952
953
954
955

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
956
957
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
958
959
960
961
962
963
964
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
965
    mutates_args=["output", "output_block_scale"],
966
967
    fake_impl=unified_attention_with_output_fake,
)
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013


def unified_mla_attention(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    wait_for_kv_layer_from_connector(layer_name)

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
    self: MLAAttention = forward_context.no_compile_layers[layer_name]
    kv_cache = self.kv_cache[forward_context.virtual_engine]
    output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    return output


def unified_mla_attention_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(q).contiguous()


direct_register_custom_op(
    op_name="unified_mla_attention",
    op_func=unified_mla_attention,
    mutates_args=[],
    fake_impl=unified_mla_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)


def unified_mla_attention_with_output(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
1014
1015
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
) -> None:
    wait_for_kv_layer_from_connector(layer_name)
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
    self: MLAAttention = forward_context.no_compile_layers[layer_name]
    kv_cache = self.kv_cache[forward_context.virtual_engine]
    self.impl.forward(
        self,
        q,
        kv_c_normed,
        k_pe,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)


def unified_mla_attention_with_output_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
1045
1046
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
) -> None:
    return


direct_register_custom_op(
    op_name="unified_mla_attention_with_output",
    op_func=unified_mla_attention_with_output,
    mutates_args=["output", "output_block_scale"],
    fake_impl=unified_mla_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)