layer.py 33 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4

5
from typing import cast
6
7
8
9

import torch
import torch.nn as nn

10
import vllm.envs as envs
11
12
13
14
15
from vllm.attention.backends.abstract import (
    AttentionBackend,
    AttentionType,
    MLAAttentionImpl,
)
16
from vllm.attention.backends.registry import AttentionBackendEnum
17
from vllm.attention.selector import get_attn_backend
18
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
19
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
20
from vllm.config import CacheConfig, get_current_vllm_config
21
from vllm.config.vllm import VllmConfig
22
from vllm.forward_context import ForwardContext, get_forward_context
23
from vllm.logger import init_logger
24
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
25
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
26
27
28
29
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    UnquantizedLinearMethod,
)
30
from vllm.model_executor.layers.quantization import QuantizationConfig
31
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
32
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
33
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
34
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
35
from vllm.platforms import current_platform
36
from vllm.utils.torch_utils import (
37
38
39
40
41
42
43
44
45
    direct_register_custom_op,
    kv_cache_dtype_str_to_dtype,
)
from vllm.v1.kv_cache_interface import (
    FullAttentionSpec,
    KVCacheSpec,
    MLAAttentionSpec,
    SlidingWindowSpec,
)
46

47
48
logger = init_logger(__name__)

49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def should_load_quant_weights(quant_method: QuantizeMethodBase | None) -> bool:
    """Returns whether the quantization method should load quantized weights."""
    return quant_method is not None and not isinstance(
        quant_method, UnquantizedLinearMethod
    )


def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) -> None:
    """Sets default quantization scales for the layer."""
    if register_buffer:
        layer.register_buffer("_k_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_v_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_q_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_prob_scale", torch.tensor(1.0, dtype=torch.float32))
    else:
        layer._k_scale.fill_(1.0)
        layer._v_scale.fill_(1.0)
        layer._q_scale.fill_(1.0)
        layer._prob_scale.fill_(1.0)

    # We also keep q/k/v_scale on host (cpu) memory for attention
    # backends that require the scales to be on host instead of on device.
    # e.g. Flashinfer
    layer._q_scale_float = 1.0
    layer._k_scale_float = 1.0
    layer._v_scale_float = 1.0
    layer._prob_scale_float = 1.0


79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
def _init_kv_cache_quant(
    layer: nn.Module,
    quant_config: QuantizationConfig | None,
    prefix: str,
    kv_cache_dtype: str,
    calculate_kv_scales: bool,
) -> None:
    """Initializes KV cache scaling factors and quantization method.

    This helper function sets up the KV cache quantization attributes that are
    shared between Attention and MLAAttention layers. It initializes scale
    tensors for query, key, value, and probability, and configures the
    quantization method if applicable.

    Args:
        layer: The attention layer instance to initialize.
        quant_config: Optional quantization configuration.
        prefix: Layer name prefix for quantization method lookup.
        kv_cache_dtype: The KV cache data type string.
        calculate_kv_scales: Whether to calculate KV scales dynamically.
    """
    # The default k/v_scale is set to 1.0. This is ignored
    # when kv-cache is not fp8, and should be used with
    # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
    # expect the pre-quantized k/v_scale to be loaded along
    # with the model weights.
    layer.kv_cache_dtype = kv_cache_dtype
    layer.calculate_kv_scales = calculate_kv_scales

108
109
110
111
112
113
114
115
116
117
118
119
120
121
    # Note [Register q/k/v/prob scales in state dict]
    # When calling model.to(device), only parameters/buffers in state dict are
    # moved. If not registering q/k/v/prob scales in state dict, there would
    # be an IMA error when a cuda kernel (e.g., quant_fp8) accesses the tensor
    # on cpu.
    # Registering in state dict means it interacts with weight loading. One edge
    # case is when quant_method is None, or quant_method is UnquantizedLinearMethod
    # (i.e., should_load_quant_weights(quant_method) == False).
    # In this case, the checkpoint does not have the scales. We need to
    # initialize the scales to 1.0 and update the scales after weight loading.
    # This is espectially important when we load dummy weights first (providing
    # wrong scales) and then load real weights (which misses scales and keeps the
    # wrong scales from dummy load).
    set_default_quant_scales(layer, register_buffer=True)
122
123
124
125
126
127
128
129

    # The output scale on host memory. This should be the input scale of
    # the quant op after this attention layer.
    layer._o_scale_float = None

    quant_method = (
        quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
    )
130
131
132

    # See [Note: Register q/k/v/prob scales in state dict]
    if should_load_quant_weights(quant_method):
133
134
135
136
137
138
139
140
141
142
143
144
145
        assert isinstance(quant_method, BaseKVCacheMethod)
        # TODO (mgoin): kv cache dtype should be specified in the FP8
        # checkpoint config and become the "auto" behavior
        if kv_cache_dtype == "fp8_e5m2":
            raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
        # If quantization is enabled, we make "k_scale" and "v_scale"
        # parameters so that it can be loaded from the model checkpoint.
        # The k/v_scale will then be converted back to native float32
        # values after weight loading.
        layer.quant_method = quant_method
        layer.quant_method.create_weights(layer)


146
class Attention(nn.Module, AttentionLayerBase):
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
163
164
165
166
167
168
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        logits_soft_cap: float | None = None,
        per_layer_sliding_window: int | None = None,
169
        prefix: str = "",
170
        attn_type: str = AttentionType.DECODER,
171
172
        kv_sharing_target_layer_name: str | None = None,
        attn_backend: type[AttentionBackend] | None = None,
173
        head_size_v: int | None = None,
174
        **extra_impl_args,
175
    ) -> None:
176
177
178
179
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
180
        super().__init__()
181
182
183
184
185
186
187
188
189
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

190
        vllm_config = get_current_vllm_config()
191
192
193
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
194
            calculate_kv_scales = cache_config.calculate_kv_scales
195
196
197
        else:
            kv_cache_dtype = "auto"
            block_size = 16
198
            calculate_kv_scales = False
199
200
201
        self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
            kv_cache_dtype, vllm_config.model_config
        )
202
203
        if num_kv_heads is None:
            num_kv_heads = num_heads
204
205
206
        assert num_heads % num_kv_heads == 0, (
            f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
        )
207
208
        self.quant_config = quant_config
        self.layer_name = prefix
209

210
211
        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(
212
213
214
215
216
            self,
            self.quant_config,
            self.layer_name,
            kv_cache_dtype,
            calculate_kv_scales,
217
        )
218

219
220
        self.num_heads = num_heads
        self.head_size = head_size
221
        self.head_size_v = self.head_size if head_size_v is None else head_size_v
222
223
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
224
        self.has_sink = extra_impl_args.get("sinks") is not None
225

226
227
228
229
        # NOTE: model_config may be None during certain tests
        model_config = vllm_config.model_config
        self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm

230
231
232
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
233
        if attn_backend is None:
234
235
236
237
238
            self.attn_backend = get_attn_backend(
                head_size,
                dtype,
                kv_cache_dtype,
                block_size,
239
                use_mla=False,
240
                has_sink=self.has_sink,
241
                use_mm_prefix=self.use_mm_prefix,
242
                attn_type=attn_type,
243
            )
244
245
246
        else:
            self.attn_backend = attn_backend

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        # prefix caching + batch invariance is currently not supported for
        # FLASHINFER and TRITON_MLA.
        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
            and vllm_is_batch_invariant()
            and (
                self.attn_backend.get_name() == "FLASHINFER"
                or self.attn_backend.get_name() == "TRITON_MLA"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for FLASHINFER/TRITON_MLA "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

265
        impl_cls = self.attn_backend.get_impl_cls()
266
267
268
269
270
271
272
273
274
275
276
277
278
        self.impl = impl_cls(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **extra_impl_args,
        )
279
        self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
280
        self.dtype = dtype
281

282
283
284
285
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
286
        self.use_direct_call = not current_platform.opaque_attention_op()
287

288
        self.use_output = self.attn_backend.accept_output_buffer
289
        compilation_config = vllm_config.compilation_config
290
291
292
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
293
        self.attn_type = attn_type
294
295
296
297
298
299
300
301
302

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

303
304
305
306
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
307
            torch.tensor([])
308
            for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
309
        ]
310

311
312
313
314
        # Initialize q/k/v range constants.
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
315

316
317
        # for attn backends supporting query quantization
        self.query_quant = None
318
319
        if (
            self.kv_cache_dtype.startswith("fp8")
320
            and self.impl.supports_quant_query_input
321
322
        ):
            self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
323

324
325
326
327
328
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
329
330
331
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
332
        output_shape: torch.Size | None = None,
333
    ) -> torch.Tensor:
334
335
336
337
338
339
340
341
342
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
343
        if self.calculate_kv_scales:
344
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
345
346
347
348
349
350
351
352
        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
353
354

            # check if query quantization is supported
355
            if self.impl.supports_quant_query_input:
356
                query, _ = self.query_quant(query, self._q_scale)
357

358
        if self.use_output:
359
            if output_shape is None:
360
361
362
                # Handle both 2D [num_tokens, hidden] and
                # 3D [num_tokens, heads, head_dim] query
                num_tokens = query.shape[0]
363
                output_shape = torch.Size(
364
                    (num_tokens, self.num_heads * self.head_size_v)
365
                )
366
            output_shape = output_shape if output_shape is not None else query.shape
367
            output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
368
            hidden_size = output_shape[-1]
369
370
371
372
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
373
            output = output.view(-1, self.num_heads, self.head_size_v)
374
375
376
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
377
                value = value.view(-1, self.num_kv_heads, self.head_size_v)
378
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
379
                forward_context: ForwardContext = get_forward_context()
380
                attn_metadata = forward_context.attn_metadata
381
382
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
383
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
384
385
386
                self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata, output=output
                )
387
388
            else:
                torch.ops.vllm.unified_attention_with_output(
389
390
                    query, key, value, output, self.layer_name
                )
391
            return output.view(-1, hidden_size)
392
        else:
393
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
394
                forward_context = get_forward_context()
395
                attn_metadata = forward_context.attn_metadata
396
397
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
398
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
399
400
401
                return self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata
                )
402
403
            else:
                return torch.ops.vllm.unified_attention(
404
405
                    query, key, value, self.layer_name
                )
406

407
408
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
409
410
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
411
        self._q_scale_float = self._q_scale.item()
412
413
414
415
416
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

417
418
419
420
421
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
422
        s += f", backend={self.impl.__class__.__name__}"
423
        return s
424

425
    def process_weights_after_loading(self, act_dtype: torch.dtype):
426
        self.impl.process_weights_after_loading(act_dtype)
427

428
429
430
431
432
433
434
435
436
437
438
        # If we should not load quant weights, we initialize the scales to 1.0
        # as the default value. See [Note: Register q/k/v/prob scales in state dict]
        # for more details.
        quant_method = (
            self.quant_config.get_quant_method(self, prefix=self.layer_name)
            if self.quant_config
            else None
        )
        if not should_load_quant_weights(quant_method):
            set_default_quant_scales(self, register_buffer=False)

439
440
441
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER
        if self.sliding_window is not None:
            assert not vllm_config.model_config.use_mla, (
                "MLA is not supported for slidingwindow"
            )
            return SlidingWindowSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
                sliding_window=self.sliding_window,
            )
        else:
            return FullAttentionSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
463
                head_size_v=self.head_size_v,
464
465
466
                dtype=self.kv_cache_torch_dtype,
            )

467

468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
class MLAAttention(nn.Module, AttentionLayerBase):
    """Multi-Head Latent Attention layer.

    This class takes query, and compressed key/value tensors as input.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
486
        q_lora_rank: int | None,
487
488
        kv_lora_rank: int,
        kv_b_proj: ColumnParallelLinear,
489
490
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
491
492
        prefix: str = "",
        use_sparse: bool = False,
493
        indexer: object | None = None,
494
        **extra_impl_args,
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = scale
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.head_size = kv_lora_rank + qk_rope_head_dim
        self.layer_name = prefix

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False
515
        self.quant_config = quant_config
516
517
518

        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(
519
520
521
522
523
            self,
            self.quant_config,
            self.layer_name,
            kv_cache_dtype,
            calculate_kv_scales,
524
        )
525
526
527
528
529
530
531
532
533
534

        dtype = torch.get_default_dtype()
        self.attn_backend = get_attn_backend(
            self.head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=True,
            use_sparse=use_sparse,
        )
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551

        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
            and vllm_is_batch_invariant()
            and (
                self.attn_backend.get_name() == "TRITON_MLA"
                or self.attn_backend.get_name() == "FLASHINFER"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for TRITON_MLA / FLASHINFER "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
        impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
        self.impl = impl_cls(
            num_heads=self.num_heads,
            head_size=self.head_size,
            scale=self.scale,
            num_kv_heads=1,
            alibi_slopes=None,
            sliding_window=None,
            kv_cache_dtype=self.kv_cache_dtype,
            logits_soft_cap=None,
            attn_type=AttentionType.DECODER,
            kv_sharing_target_layer_name=None,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=kv_b_proj,
            indexer=indexer,
573
            **extra_impl_args,
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        )

        self.use_direct_call = not current_platform.opaque_attention_op()

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

        self.kv_cache = [
            torch.tensor([])
            for _ in range(
                get_current_vllm_config().parallel_config.pipeline_parallel_size
            )
        ]

        self.use_sparse = use_sparse

        # Initialize q/k/v range constants.
593
594
595
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
596
597
598
599
600
601

    def forward(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
602
        output_shape: torch.Size | None = None,
603
    ) -> torch.Tensor:
604
605
606
        if self.calculate_kv_scales:
            torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)

607
608
609
610
611
612
613
614
        if self.use_direct_call:
            forward_context: ForwardContext = get_forward_context()
            attn_metadata = forward_context.attn_metadata
            if isinstance(attn_metadata, dict):
                attn_metadata = attn_metadata[self.layer_name]
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]

            if self.attn_backend.accept_output_buffer:
615
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
                self.impl.forward(
                    self,
                    q,
                    kv_c_normed,
                    k_pe,
                    self_kv_cache,
                    attn_metadata,
                    output=output,
                )
                return output
            else:
                return self.impl.forward(
                    self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
                )
        else:
            if self.attn_backend.accept_output_buffer:
632
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
                torch.ops.vllm.unified_mla_attention_with_output(
                    q,
                    kv_c_normed,
                    k_pe,
                    output,
                    self.layer_name,
                )
                return output
            else:
                return torch.ops.vllm.unified_mla_attention(
                    q,
                    kv_c_normed,
                    k_pe,
                    self.layer_name,
                )

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if hasattr(self.impl, "process_weights_after_loading"):
            self.impl.process_weights_after_loading(act_dtype)

653
654
655
656
657
658
659
660
661
662
663
        # If we should not load quant weights, we initialize the scales to 1.0
        # as the default value. See [Note: Register q/k/v/prob scales in state dict]
        # for more details.
        quant_method = (
            self.quant_config.get_quant_method(self, prefix=self.layer_name)
            if self.quant_config
            else None
        )
        if not should_load_quant_weights(quant_method):
            set_default_quant_scales(self, register_buffer=False)

664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
    def calc_kv_scales(
        self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
    ) -> None:
        """Optional scale calculation for MLA inputs.

        Mirrors Attention.calc_kv_scales. Not all MLA backends require this
        """
        # Use safe defaults if ranges are not present
        q_range = getattr(self, "q_range", torch.tensor(1.0))
        k_range = getattr(self, "k_range", torch.tensor(1.0))
        v_range = getattr(self, "v_range", torch.tensor(1.0))

        self._q_scale.copy_(torch.abs(q).max() / q_range)
        # kv_c_normed is the compressed KV representation; use it for k/v
        kv_abs_max = torch.abs(kv_c_normed).max()
        self._k_scale.copy_(kv_abs_max / k_range)
        self._v_scale.copy_(kv_abs_max / v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        self.calculate_kv_scales = False

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

689
690
691
692
693
694
695
696
697
698
699
700
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        kv_cache_dtype = kv_cache_dtype_str_to_dtype(
            self.kv_cache_dtype, vllm_config.model_config
        )
        return MLAAttentionSpec(
            block_size=vllm_config.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_size,
            dtype=kv_cache_dtype,
            cache_dtype_str=vllm_config.cache_config.cache_dtype,
        )

701

702
703
704
705
706
707
708
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
709
    self = forward_context.no_compile_layers[layer_name]
710

711
712
713
    # Only calculate if the layer's calculate_kv_scales flag is True
    # This flag gets set to False after the first forward pass
    if not self.calculate_kv_scales:
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
        return

    self.calc_kv_scales(query, key, value)


def maybe_calc_kv_scales_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=["query", "key", "value"],
    fake_impl=maybe_calc_kv_scales_fake,
)


736
def get_attention_context(
737
    layer_name: str,
738
739
740
741
742
743
744
745
) -> tuple[dict | object | None, Attention | MLAAttention, torch.Tensor]:
    """Extract attention context for a given layer.

    This helper function extracts the attention metadata, attention layer
    instance, and KV cache tensor for a specific layer.

    Args:
        layer_name: The name/identifier of the attention layer.
746

747
748
749
750
751
752
753
754
755
756
    Returns:
        A tuple containing:
        - attn_metadata: Attention metadata for this specific layer, or None if
            no metadata available
        - attn_layer: The attention layer instance (Attention or MLAAttention)
        - kv_cache: The KV cache tensor for current virtual engine

        Note: attn_metadata may be None, but attn_layer and kv_cache are always
        extracted from the forward context.
    """
757
    forward_context: ForwardContext = get_forward_context()
758
    attn_metadata = forward_context.attn_metadata
759
760
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
761
762
763
764
765
766
767
768
769
770
771
772
773
    attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
    kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
    return attn_metadata, attn_layer, kv_cache


@maybe_transfer_kv_layer
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
774
    output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
775
776

    return output
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
)
793
794


795
@maybe_transfer_kv_layer
796
797
798
799
800
801
def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
802
803
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
804
) -> None:
805
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
806

807
808
809
810
811
812
813
814
815
816
817
    self.impl.forward(
        self,
        query,
        key,
        value,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )
818
819
820
821
822
823
824
825


def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
826
827
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
828
829
830
831
832
833
834
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
835
    mutates_args=["output", "output_block_scale"],
836
837
    fake_impl=unified_attention_with_output_fake,
)
838
839


840
@maybe_transfer_kv_layer
841
842
843
844
845
846
def unified_mla_attention(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
847
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
    output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)

    return output


def unified_mla_attention_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(q).contiguous()


direct_register_custom_op(
    op_name="unified_mla_attention",
    op_func=unified_mla_attention,
    mutates_args=[],
    fake_impl=unified_mla_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)


871
@maybe_transfer_kv_layer
872
873
874
875
876
877
def unified_mla_attention_with_output(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
878
879
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
880
) -> None:
881
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
    self.impl.forward(
        self,
        q,
        kv_c_normed,
        k_pe,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )


def unified_mla_attention_with_output_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
901
902
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
903
904
905
906
907
908
909
910
911
912
913
) -> None:
    return


direct_register_custom_op(
    op_name="unified_mla_attention_with_output",
    op_func=unified_mla_attention_with_output,
    mutates_args=["output", "output_block_scale"],
    fake_impl=unified_mla_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)