layer.py 34.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4

5
6
from collections.abc import Callable
from typing import cast
7
8
9

import torch
import torch.nn as nn
10
import torch.nn.functional as F
11

12
import vllm.envs as envs
13
14
15
16
17
from vllm.attention.backends.abstract import (
    AttentionBackend,
    AttentionType,
    MLAAttentionImpl,
)
18
from vllm.attention.backends.registry import AttentionBackendEnum
19
from vllm.attention.selector import get_attn_backend
20
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
21
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
22
from vllm.config import CacheConfig, get_current_vllm_config
23
from vllm.config.multimodal import MultiModalConfig
24
from vllm.config.vllm import VllmConfig
25
from vllm.forward_context import ForwardContext, get_forward_context
26
from vllm.logger import init_logger
27
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
28
29
30
31
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    UnquantizedLinearMethod,
)
32
from vllm.model_executor.layers.quantization import QuantizationConfig
33
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
34
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
35
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
36
from vllm.model_executor.models.vision import get_vit_attn_backend
37
from vllm.platforms import current_platform
38
from vllm.utils.torch_utils import (
39
40
41
42
43
44
45
46
47
    direct_register_custom_op,
    kv_cache_dtype_str_to_dtype,
)
from vllm.v1.kv_cache_interface import (
    FullAttentionSpec,
    KVCacheSpec,
    MLAAttentionSpec,
    SlidingWindowSpec,
)
48

49
50
51
52
53
54
if current_platform.is_rocm():
    from vllm.platforms.rocm import on_gfx9
else:
    on_gfx9 = lambda *args, **kwargs: False


55
FP8_DTYPE = current_platform.fp8_dtype()
56
57
logger = init_logger(__name__)

58

59
def maybe_get_vit_flash_attn_backend(
60
61
62
    attn_backend: AttentionBackendEnum,
    attn_backend_override: AttentionBackendEnum | None = None,
) -> tuple[AttentionBackendEnum, Callable | None]:
63
64
    if current_platform.is_rocm():
        if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
65
            attn_backend = AttentionBackendEnum.ROCM_AITER_FA
66
        elif (
67
            attn_backend_override is None
68
            and on_gfx9()
69
            and attn_backend == AttentionBackendEnum.FLASH_ATTN
70
        ):
71
            pass
72
        else:
73
            return AttentionBackendEnum.TORCH_SDPA, None
74
    elif current_platform.is_cuda():
75
        pass
76
    elif current_platform.is_xpu():
77
        assert attn_backend == AttentionBackendEnum.FLASH_ATTN, (
78
79
            "XPU platform only supports FLASH_ATTN as vision attention backend."
        )
80
        pass
81
    else:
82
        return AttentionBackendEnum.TORCH_SDPA, None
83

84
85
86
87
88
    if attn_backend in {
        AttentionBackendEnum.FLASH_ATTN,
        AttentionBackendEnum.ROCM_AITER_FA,
    }:
        if attn_backend == AttentionBackendEnum.ROCM_AITER_FA:
89
90
            from aiter import flash_attn_varlen_func
        else:
91
            from vllm.attention.utils.fa_utils import flash_attn_varlen_func
92
93
94
95
96
97
    else:
        flash_attn_varlen_func = None

    return attn_backend, flash_attn_varlen_func


98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def _init_kv_cache_quant(
    layer: nn.Module,
    quant_config: QuantizationConfig | None,
    prefix: str,
    kv_cache_dtype: str,
    calculate_kv_scales: bool,
) -> None:
    """Initializes KV cache scaling factors and quantization method.

    This helper function sets up the KV cache quantization attributes that are
    shared between Attention and MLAAttention layers. It initializes scale
    tensors for query, key, value, and probability, and configures the
    quantization method if applicable.

    Args:
        layer: The attention layer instance to initialize.
        quant_config: Optional quantization configuration.
        prefix: Layer name prefix for quantization method lookup.
        kv_cache_dtype: The KV cache data type string.
        calculate_kv_scales: Whether to calculate KV scales dynamically.
    """
    # The default k/v_scale is set to 1.0. This is ignored
    # when kv-cache is not fp8, and should be used with
    # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
    # expect the pre-quantized k/v_scale to be loaded along
    # with the model weights.
    layer.kv_cache_dtype = kv_cache_dtype
    layer.calculate_kv_scales = calculate_kv_scales
    layer._k_scale = torch.tensor(1.0, dtype=torch.float32)
    layer._v_scale = torch.tensor(1.0, dtype=torch.float32)
    layer._q_scale = torch.tensor(1.0, dtype=torch.float32)
    layer._prob_scale = torch.tensor(1.0, dtype=torch.float32)

    # We also keep q/k/v_scale on host (cpu) memory for attention
    # backends that require the scales to be on host instead of on device.
    # e.g. Flashinfer
    layer._q_scale_float = 1.0
    layer._k_scale_float = 1.0
    layer._v_scale_float = 1.0

    # The output scale on host memory. This should be the input scale of
    # the quant op after this attention layer.
    layer._o_scale_float = None

    quant_method = (
        quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
    )
    if quant_method is not None and not isinstance(
        quant_method, UnquantizedLinearMethod
    ):
        assert isinstance(quant_method, BaseKVCacheMethod)
        # TODO (mgoin): kv cache dtype should be specified in the FP8
        # checkpoint config and become the "auto" behavior
        if kv_cache_dtype == "fp8_e5m2":
            raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
        # If quantization is enabled, we make "k_scale" and "v_scale"
        # parameters so that it can be loaded from the model checkpoint.
        # The k/v_scale will then be converted back to native float32
        # values after weight loading.
        layer.quant_method = quant_method
        layer.quant_method.create_weights(layer)


161
class Attention(nn.Module, AttentionLayerBase):
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
178
179
180
181
182
183
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        logits_soft_cap: float | None = None,
        per_layer_sliding_window: int | None = None,
184
        prefix: str = "",
185
        attn_type: str = AttentionType.DECODER,
186
187
        kv_sharing_target_layer_name: str | None = None,
        attn_backend: type[AttentionBackend] | None = None,
188
        **extra_impl_args,
189
    ) -> None:
190
191
192
193
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
194
        super().__init__()
195
196
197
198
199
200
201
202
203
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

204
        vllm_config = get_current_vllm_config()
205
206
207
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
208
            calculate_kv_scales = cache_config.calculate_kv_scales
209
210
211
        else:
            kv_cache_dtype = "auto"
            block_size = 16
212
            calculate_kv_scales = False
213
214
215
        self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
            kv_cache_dtype, vllm_config.model_config
        )
216
217
        if num_kv_heads is None:
            num_kv_heads = num_heads
218
219
220
        assert num_heads % num_kv_heads == 0, (
            f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
        )
221

222
223
224
225
        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(
            self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
        )
226

227
228
229
230
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
231
        self.has_sink = extra_impl_args.get("sinks") is not None
232

233
234
235
236
        # NOTE: model_config may be None during certain tests
        model_config = vllm_config.model_config
        self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm

237
238
239
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
240
        if attn_backend is None:
241
242
243
244
245
            self.attn_backend = get_attn_backend(
                head_size,
                dtype,
                kv_cache_dtype,
                block_size,
246
                use_mla=False,
247
                has_sink=self.has_sink,
248
                use_mm_prefix=self.use_mm_prefix,
249
                attn_type=attn_type,
250
            )
251
252
253
254
        else:
            self.attn_backend = attn_backend

        impl_cls = self.attn_backend.get_impl_cls()
255
256
257
258
259
260
261
262
263
264
265
266
267
        self.impl = impl_cls(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **extra_impl_args,
        )
268
269
        backend_name = self.attn_backend.get_name()
        self.backend = AttentionBackendEnum.__members__.get(backend_name)
270
        self.dtype = dtype
271

272
273
274
275
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
276
        self.use_direct_call = not current_platform.opaque_attention_op()
277

278
        self.use_output = self.attn_backend.accept_output_buffer
279
        compilation_config = vllm_config.compilation_config
280
281
282
283
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
284
        self.attn_type = attn_type
285
286
287
288
289
290
291
292
293

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

294
295
296
297
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
298
            torch.tensor([])
299
            for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
300
        ]
301

302
303
304
305
        # Initialize q/k/v range constants.
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
306

307
308
        # for attn backends supporting query quantization
        self.query_quant = None
309
310
        if (
            self.kv_cache_dtype.startswith("fp8")
311
            and self.impl.supports_quant_query_input
312
313
        ):
            self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
314

315
316
317
318
319
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
320
321
322
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
323
        output_shape: torch.Size | None = None,
324
    ) -> torch.Tensor:
325
326
327
328
329
330
331
332
333
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
334
        if self.calculate_kv_scales:
335
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
336
337
338
339
340
341
342
343
        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
344
345

            # check if query quantization is supported
346
            if self.impl.supports_quant_query_input:
347
                query, _ = self.query_quant(query, self._q_scale)
348

349
        if self.use_output:
350
            output_shape = output_shape if output_shape is not None else query.shape
351
            output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
352
            hidden_size = output_shape[-1]
353
354
355
356
357
358
359
360
361
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
            output = output.view(-1, self.num_heads, self.head_size)
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
                value = value.view(-1, self.num_kv_heads, self.head_size)
362
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
363
                forward_context: ForwardContext = get_forward_context()
364
                attn_metadata = forward_context.attn_metadata
365
366
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
367
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
368
369
370
                self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata, output=output
                )
371
372
            else:
                torch.ops.vllm.unified_attention_with_output(
373
374
                    query, key, value, output, self.layer_name
                )
375
            return output.view(-1, hidden_size)
376
        else:
377
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
378
                forward_context = get_forward_context()
379
                attn_metadata = forward_context.attn_metadata
380
381
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
382
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
383
384
385
                return self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata
                )
386
387
            else:
                return torch.ops.vllm.unified_attention(
388
389
                    query, key, value, self.layer_name
                )
390

391
392
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
393
394
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
395
        self._q_scale_float = self._q_scale.item()
396
397
398
399
400
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

401
402
403
404
405
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
406
        s += f", backend={self.impl.__class__.__name__}"
407
        return s
408

409
    def process_weights_after_loading(self, act_dtype: torch.dtype):
410
        self.impl.process_weights_after_loading(act_dtype)
411

412
413
414
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER
        if self.sliding_window is not None:
            assert not vllm_config.model_config.use_mla, (
                "MLA is not supported for slidingwindow"
            )
            return SlidingWindowSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
                sliding_window=self.sliding_window,
            )
        else:
            return FullAttentionSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
            )

439

440
441
442
443
444
445
446
447
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
448
        num_kv_heads: int | None = None,
449
450
451
        # This has no effect, it is only here to make it easier to swap
        # between Attention and MultiHeadAttention
        prefix: str = "",
452
        multimodal_config: MultiModalConfig | None = None,
453
    ) -> None:
454
455
456
457
458
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
459
        self.layer_name = prefix
460

461
462
        assert self.num_heads % self.num_kv_heads == 0, (
            f"num_heads ({self.num_heads}) is not "
463
            f"divisible by num_kv_heads ({self.num_kv_heads})"
464
        )
465
466
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

467
468
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
469
        dtype = torch.get_default_dtype()
470
471

        # Determine the attention backend
472
473
474
475
476
477
478
479
        attn_backend_override = None
        if multimodal_config is not None:
            attn_backend_override = multimodal_config.mm_encoder_attn_backend
        backend = get_vit_attn_backend(
            head_size=head_size,
            dtype=dtype,
            attn_backend_override=attn_backend_override,
        )
480

481
482
483
484
        self.attn_backend = (
            backend
            if backend
            in {
485
486
487
488
                AttentionBackendEnum.TORCH_SDPA,
                AttentionBackendEnum.PALLAS,
                AttentionBackendEnum.ROCM_AITER_FA,
                AttentionBackendEnum.FLASH_ATTN,
489
            }
490
            else AttentionBackendEnum.TORCH_SDPA
491
        )
492

493
494
        self.attn_backend, self._flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
495
                self.attn_backend,
496
                attn_backend_override=attn_backend_override,
497
            )
498
        )
499
500

        self.is_flash_attn_backend = self.attn_backend in {
501
502
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
503
504
        }

505
        logger.info_once(
506
            f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
507
        )
508

509
510
511
512
513
514
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
515
        """Input shape:
516
517
518
519
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
520
521
522
523
524
525
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

526
527
528
529
530
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

531
        if self.is_flash_attn_backend:
532
            assert self._flash_attn_varlen_func is not None
533
534
535
536
537
538
            cu_seqlens_q = torch.arange(
                0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
            )
            cu_seqlens_k = torch.arange(
                0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device
            )
539
540
541
542
543
544
545
546
547
548
549

            out = self._flash_attn_varlen_func(
                query.flatten(0, 1),
                key.flatten(0, 1),
                value.flatten(0, 1),
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=q_len,
                max_seqlen_k=kv_len,
                softmax_scale=self.scale,
            )
550
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
551
552
            query, key, value = (x.transpose(1, 2) for x in (query, key, value))
            out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
553
            out = out.transpose(1, 2)
554
        elif self.attn_backend == AttentionBackendEnum.PALLAS:
555
            query, key, value = (x.transpose(1, 2) for x in (query, key, value))
556
            from torch_xla.experimental.custom_kernel import flash_attention
557

558
559
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)
560
561
562
        else:
            # ViT attention hasn't supported this backend yet
            raise NotImplementedError(
563
564
                f"ViT attention hasn't supported {self.attn_backend} backend yet."
            )
565

566
        return out.reshape(bsz, q_len, -1)
567
568


569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
class MLAAttention(nn.Module, AttentionLayerBase):
    """Multi-Head Latent Attention layer.

    This class takes query, and compressed key/value tensors as input.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
587
        q_lora_rank: int | None,
588
589
        kv_lora_rank: int,
        kv_b_proj: ColumnParallelLinear,
590
591
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
592
593
        prefix: str = "",
        use_sparse: bool = False,
594
        indexer: object | None = None,
595
        **extra_impl_args,
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = scale
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.head_size = kv_lora_rank + qk_rope_head_dim
        self.layer_name = prefix

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False
616
617
618
619
620

        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(
            self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
        )
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651

        dtype = torch.get_default_dtype()
        self.attn_backend = get_attn_backend(
            self.head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=True,
            use_sparse=use_sparse,
        )
        impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
        self.impl = impl_cls(
            num_heads=self.num_heads,
            head_size=self.head_size,
            scale=self.scale,
            num_kv_heads=1,
            alibi_slopes=None,
            sliding_window=None,
            kv_cache_dtype=self.kv_cache_dtype,
            logits_soft_cap=None,
            attn_type=AttentionType.DECODER,
            kv_sharing_target_layer_name=None,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=kv_b_proj,
            indexer=indexer,
652
            **extra_impl_args,
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
        )

        self.use_direct_call = not current_platform.opaque_attention_op()

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

        self.kv_cache = [
            torch.tensor([])
            for _ in range(
                get_current_vllm_config().parallel_config.pipeline_parallel_size
            )
        ]

        self.use_sparse = use_sparse

        # Initialize q/k/v range constants.
672
673
674
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
675
676
677
678
679
680

    def forward(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
681
        output_shape: torch.Size | None = None,
682
    ) -> torch.Tensor:
683
684
685
        if self.calculate_kv_scales:
            torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)

686
687
688
689
690
691
692
693
        if self.use_direct_call:
            forward_context: ForwardContext = get_forward_context()
            attn_metadata = forward_context.attn_metadata
            if isinstance(attn_metadata, dict):
                attn_metadata = attn_metadata[self.layer_name]
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]

            if self.attn_backend.accept_output_buffer:
694
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
                self.impl.forward(
                    self,
                    q,
                    kv_c_normed,
                    k_pe,
                    self_kv_cache,
                    attn_metadata,
                    output=output,
                )
                return output
            else:
                return self.impl.forward(
                    self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
                )
        else:
            if self.attn_backend.accept_output_buffer:
711
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
                torch.ops.vllm.unified_mla_attention_with_output(
                    q,
                    kv_c_normed,
                    k_pe,
                    output,
                    self.layer_name,
                )
                return output
            else:
                return torch.ops.vllm.unified_mla_attention(
                    q,
                    kv_c_normed,
                    k_pe,
                    self.layer_name,
                )

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if hasattr(self.impl, "process_weights_after_loading"):
            self.impl.process_weights_after_loading(act_dtype)

    def calc_kv_scales(
        self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
    ) -> None:
        """Optional scale calculation for MLA inputs.

        Mirrors Attention.calc_kv_scales. Not all MLA backends require this
        """
        # Use safe defaults if ranges are not present
        q_range = getattr(self, "q_range", torch.tensor(1.0))
        k_range = getattr(self, "k_range", torch.tensor(1.0))
        v_range = getattr(self, "v_range", torch.tensor(1.0))

        self._q_scale.copy_(torch.abs(q).max() / q_range)
        # kv_c_normed is the compressed KV representation; use it for k/v
        kv_abs_max = torch.abs(kv_c_normed).max()
        self._k_scale.copy_(kv_abs_max / k_range)
        self._v_scale.copy_(kv_abs_max / v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        self.calculate_kv_scales = False

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

757
758
759
760
761
762
763
764
765
766
767
768
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        kv_cache_dtype = kv_cache_dtype_str_to_dtype(
            self.kv_cache_dtype, vllm_config.model_config
        )
        return MLAAttentionSpec(
            block_size=vllm_config.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_size,
            dtype=kv_cache_dtype,
            cache_dtype_str=vllm_config.cache_config.cache_dtype,
        )

769

770
771
772
773
774
775
776
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
777
    self = forward_context.no_compile_layers[layer_name]
778

779
780
781
    # Only calculate if the layer's calculate_kv_scales flag is True
    # This flag gets set to False after the first forward pass
    if not self.calculate_kv_scales:
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
        return

    self.calc_kv_scales(query, key, value)


def maybe_calc_kv_scales_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=["query", "key", "value"],
    fake_impl=maybe_calc_kv_scales_fake,
)


804
def get_attention_context(
805
    layer_name: str,
806
807
808
809
810
811
812
813
) -> tuple[dict | object | None, Attention | MLAAttention, torch.Tensor]:
    """Extract attention context for a given layer.

    This helper function extracts the attention metadata, attention layer
    instance, and KV cache tensor for a specific layer.

    Args:
        layer_name: The name/identifier of the attention layer.
814

815
816
817
818
819
820
821
822
823
824
    Returns:
        A tuple containing:
        - attn_metadata: Attention metadata for this specific layer, or None if
            no metadata available
        - attn_layer: The attention layer instance (Attention or MLAAttention)
        - kv_cache: The KV cache tensor for current virtual engine

        Note: attn_metadata may be None, but attn_layer and kv_cache are always
        extracted from the forward context.
    """
825
    forward_context: ForwardContext = get_forward_context()
826
    attn_metadata = forward_context.attn_metadata
827
828
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
829
830
831
832
833
834
835
836
837
838
839
840
841
    attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
    kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
    return attn_metadata, attn_layer, kv_cache


@maybe_transfer_kv_layer
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
842
    output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
843
844

    return output
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
)
861
862


863
@maybe_transfer_kv_layer
864
865
866
867
868
869
def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
870
871
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
872
) -> None:
873
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
874
875
876
877
878
879
880
881
882
883
884
    self.impl.forward(
        self,
        query,
        key,
        value,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )
885
886
887
888
889
890
891
892


def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
893
894
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
895
896
897
898
899
900
901
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
902
    mutates_args=["output", "output_block_scale"],
903
904
    fake_impl=unified_attention_with_output_fake,
)
905
906


907
@maybe_transfer_kv_layer
908
909
910
911
912
913
def unified_mla_attention(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
914
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
    output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)

    return output


def unified_mla_attention_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(q).contiguous()


direct_register_custom_op(
    op_name="unified_mla_attention",
    op_func=unified_mla_attention,
    mutates_args=[],
    fake_impl=unified_mla_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)


938
@maybe_transfer_kv_layer
939
940
941
942
943
944
def unified_mla_attention_with_output(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
945
946
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
947
) -> None:
948
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
    self.impl.forward(
        self,
        q,
        kv_c_normed,
        k_pe,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )


def unified_mla_attention_with_output_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
968
969
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
970
971
972
973
974
975
976
977
978
979
980
) -> None:
    return


direct_register_custom_op(
    op_name="unified_mla_attention_with_output",
    op_func=unified_mla_attention_with_output,
    mutates_args=["output", "output_block_scale"],
    fake_impl=unified_mla_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)