layer.py 33.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4

5
from typing import cast
6
7
8
9

import torch
import torch.nn as nn

10
import vllm.envs as envs
11
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
12
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
13
from vllm.config import CacheConfig, get_current_vllm_config
14
from vllm.config.vllm import VllmConfig
15
from vllm.forward_context import ForwardContext, get_forward_context
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
18
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
19
20
21
22
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    UnquantizedLinearMethod,
)
23
from vllm.model_executor.layers.quantization import QuantizationConfig
24
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
25
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
26
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
27
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
28
from vllm.platforms import current_platform
29
from vllm.utils.torch_utils import (
30
31
32
    direct_register_custom_op,
    kv_cache_dtype_str_to_dtype,
)
33
34
35
36
37
38
39
from vllm.v1.attention.backend import (
    AttentionBackend,
    AttentionType,
    MLAAttentionImpl,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import get_attn_backend
40
41
42
43
44
45
from vllm.v1.kv_cache_interface import (
    FullAttentionSpec,
    KVCacheSpec,
    MLAAttentionSpec,
    SlidingWindowSpec,
)
46

47
48
logger = init_logger(__name__)

49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def should_load_quant_weights(quant_method: QuantizeMethodBase | None) -> bool:
    """Returns whether the quantization method should load quantized weights."""
    return quant_method is not None and not isinstance(
        quant_method, UnquantizedLinearMethod
    )


def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) -> None:
    """Sets default quantization scales for the layer."""
    if register_buffer:
        layer.register_buffer("_k_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_v_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_q_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_prob_scale", torch.tensor(1.0, dtype=torch.float32))
    else:
        layer._k_scale.fill_(1.0)
        layer._v_scale.fill_(1.0)
        layer._q_scale.fill_(1.0)
        layer._prob_scale.fill_(1.0)

    # We also keep q/k/v_scale on host (cpu) memory for attention
    # backends that require the scales to be on host instead of on device.
    # e.g. Flashinfer
    layer._q_scale_float = 1.0
    layer._k_scale_float = 1.0
    layer._v_scale_float = 1.0
    layer._prob_scale_float = 1.0


79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
def _init_kv_cache_quant(
    layer: nn.Module,
    quant_config: QuantizationConfig | None,
    prefix: str,
    kv_cache_dtype: str,
    calculate_kv_scales: bool,
) -> None:
    """Initializes KV cache scaling factors and quantization method.

    This helper function sets up the KV cache quantization attributes that are
    shared between Attention and MLAAttention layers. It initializes scale
    tensors for query, key, value, and probability, and configures the
    quantization method if applicable.

    Args:
        layer: The attention layer instance to initialize.
        quant_config: Optional quantization configuration.
        prefix: Layer name prefix for quantization method lookup.
        kv_cache_dtype: The KV cache data type string.
        calculate_kv_scales: Whether to calculate KV scales dynamically.
    """
    # The default k/v_scale is set to 1.0. This is ignored
    # when kv-cache is not fp8, and should be used with
    # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
    # expect the pre-quantized k/v_scale to be loaded along
    # with the model weights.
    layer.kv_cache_dtype = kv_cache_dtype
    layer.calculate_kv_scales = calculate_kv_scales

108
109
110
111
112
113
114
115
116
117
118
119
120
121
    # Note [Register q/k/v/prob scales in state dict]
    # When calling model.to(device), only parameters/buffers in state dict are
    # moved. If not registering q/k/v/prob scales in state dict, there would
    # be an IMA error when a cuda kernel (e.g., quant_fp8) accesses the tensor
    # on cpu.
    # Registering in state dict means it interacts with weight loading. One edge
    # case is when quant_method is None, or quant_method is UnquantizedLinearMethod
    # (i.e., should_load_quant_weights(quant_method) == False).
    # In this case, the checkpoint does not have the scales. We need to
    # initialize the scales to 1.0 and update the scales after weight loading.
    # This is espectially important when we load dummy weights first (providing
    # wrong scales) and then load real weights (which misses scales and keeps the
    # wrong scales from dummy load).
    set_default_quant_scales(layer, register_buffer=True)
122
123
124
125
126
127
128
129

    # The output scale on host memory. This should be the input scale of
    # the quant op after this attention layer.
    layer._o_scale_float = None

    quant_method = (
        quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
    )
130
131
132

    # See [Note: Register q/k/v/prob scales in state dict]
    if should_load_quant_weights(quant_method):
133
134
135
136
137
138
139
140
141
142
143
144
145
        assert isinstance(quant_method, BaseKVCacheMethod)
        # TODO (mgoin): kv cache dtype should be specified in the FP8
        # checkpoint config and become the "auto" behavior
        if kv_cache_dtype == "fp8_e5m2":
            raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
        # If quantization is enabled, we make "k_scale" and "v_scale"
        # parameters so that it can be loaded from the model checkpoint.
        # The k/v_scale will then be converted back to native float32
        # values after weight loading.
        layer.quant_method = quant_method
        layer.quant_method.create_weights(layer)


146
class Attention(nn.Module, AttentionLayerBase):
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
163
164
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
Li Xie's avatar
Li Xie committed
165
        use_alibi_sqrt: bool | None = None,
166
167
168
169
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        logits_soft_cap: float | None = None,
        per_layer_sliding_window: int | None = None,
170
        prefix: str = "",
171
        attn_type: str = AttentionType.DECODER,
172
173
        kv_sharing_target_layer_name: str | None = None,
        attn_backend: type[AttentionBackend] | None = None,
174
        head_size_v: int | None = None,
175
        **extra_impl_args,
176
    ) -> None:
177
178
179
180
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
181
        super().__init__()
182
183
184
185
186
187
188
189
190
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

191
        vllm_config = get_current_vllm_config()
192
193
194
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
195
            calculate_kv_scales = cache_config.calculate_kv_scales
196
197
198
        else:
            kv_cache_dtype = "auto"
            block_size = 16
199
            calculate_kv_scales = False
200
201
202
        self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
            kv_cache_dtype, vllm_config.model_config
        )
203
204
        if num_kv_heads is None:
            num_kv_heads = num_heads
205
206
207
        assert num_heads % num_kv_heads == 0, (
            f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
        )
208
209
        self.quant_config = quant_config
        self.layer_name = prefix
210

211
212
        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(
213
214
215
216
217
            self,
            self.quant_config,
            self.layer_name,
            kv_cache_dtype,
            calculate_kv_scales,
218
        )
219

220
221
        self.num_heads = num_heads
        self.head_size = head_size
222
        self.head_size_v = self.head_size if head_size_v is None else head_size_v
223
224
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
225
        self.has_sink = extra_impl_args.get("sinks") is not None
226

227
228
229
230
        # NOTE: model_config may be None during certain tests
        model_config = vllm_config.model_config
        self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm

231
232
233
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
234
        if attn_backend is None:
235
236
237
238
239
            self.attn_backend = get_attn_backend(
                head_size,
                dtype,
                kv_cache_dtype,
                block_size,
240
                use_mla=False,
241
                has_sink=self.has_sink,
242
                use_mm_prefix=self.use_mm_prefix,
243
                attn_type=attn_type,
244
            )
245
246
        else:
            self.attn_backend = attn_backend
Li Xie's avatar
Li Xie committed
247
248
249
250
251
252
253
254
255
256
        backend_supports_alibi_sqrt = self.attn_backend.supports_alibi_sqrt()
        use_alibi_sqrt = use_alibi_sqrt if use_alibi_sqrt else False
        if use_alibi_sqrt and not backend_supports_alibi_sqrt:
            raise ValueError(
                f"use_alibi_sqrt is not supported by backend "
                f"{self.attn_backend.get_name()}."
            )
        self.use_alibi_sqrt = bool(use_alibi_sqrt)
        if backend_supports_alibi_sqrt:
            extra_impl_args["use_alibi_sqrt"] = self.use_alibi_sqrt
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        # prefix caching + batch invariance is currently not supported for
        # FLASHINFER and TRITON_MLA.
        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
            and vllm_is_batch_invariant()
            and (
                self.attn_backend.get_name() == "FLASHINFER"
                or self.attn_backend.get_name() == "TRITON_MLA"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for FLASHINFER/TRITON_MLA "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

275
        impl_cls = self.attn_backend.get_impl_cls()
276
277
278
279
280
281
282
283
284
285
286
287
288
        self.impl = impl_cls(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **extra_impl_args,
        )
289
        self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
290
        self.dtype = dtype
291

292
293
294
295
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
296
        self.use_direct_call = not current_platform.opaque_attention_op()
297

298
        self.use_output = self.attn_backend.accept_output_buffer
299
        compilation_config = vllm_config.compilation_config
300
301
302
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
303
        self.attn_type = attn_type
304
305
306
307
308
309
310
311
312

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

313
314
315
316
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
317
            torch.tensor([])
318
            for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
319
        ]
320

321
322
323
324
        # Initialize q/k/v range constants.
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
325

326
327
        # for attn backends supporting query quantization
        self.query_quant = None
328
329
        if (
            self.kv_cache_dtype.startswith("fp8")
330
            and self.impl.supports_quant_query_input
331
332
        ):
            self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
333

334
335
336
337
338
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
339
340
341
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
342
        output_shape: torch.Size | None = None,
343
    ) -> torch.Tensor:
344
345
346
347
348
349
350
351
352
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
353
        if self.calculate_kv_scales:
354
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
355
356
357
358
359
360
361
362
        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
363
364

            # check if query quantization is supported
365
            if self.impl.supports_quant_query_input:
366
                query, _ = self.query_quant(query, self._q_scale)
367

368
        if self.use_output:
369
            if output_shape is None:
370
371
372
                # Handle both 2D [num_tokens, hidden] and
                # 3D [num_tokens, heads, head_dim] query
                num_tokens = query.shape[0]
373
                output_shape = torch.Size(
374
                    (num_tokens, self.num_heads * self.head_size_v)
375
                )
376
            output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
377
            hidden_size = output_shape[-1]
378
379
380
381
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
382
            output = output.view(-1, self.num_heads, self.head_size_v)
383
384
385
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
386
                value = value.view(-1, self.num_kv_heads, self.head_size_v)
387
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
388
                forward_context: ForwardContext = get_forward_context()
389
                attn_metadata = forward_context.attn_metadata
390
391
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
392
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
393
394
395
                self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata, output=output
                )
396
397
            else:
                torch.ops.vllm.unified_attention_with_output(
398
399
                    query, key, value, output, self.layer_name
                )
400
            return output.view(-1, hidden_size)
401
        else:
402
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
403
                forward_context = get_forward_context()
404
                attn_metadata = forward_context.attn_metadata
405
406
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
407
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
408
409
410
                return self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata
                )
411
412
            else:
                return torch.ops.vllm.unified_attention(
413
414
                    query, key, value, self.layer_name
                )
415

416
417
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
418
419
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
420
        self._q_scale_float = self._q_scale.item()
421
422
423
424
425
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

426
427
428
429
430
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
431
        s += f", backend={self.impl.__class__.__name__}"
432
        return s
433

434
    def process_weights_after_loading(self, act_dtype: torch.dtype):
435
        self.impl.process_weights_after_loading(act_dtype)
436

437
438
439
440
441
442
443
444
445
446
447
        # If we should not load quant weights, we initialize the scales to 1.0
        # as the default value. See [Note: Register q/k/v/prob scales in state dict]
        # for more details.
        quant_method = (
            self.quant_config.get_quant_method(self, prefix=self.layer_name)
            if self.quant_config
            else None
        )
        if not should_load_quant_weights(quant_method):
            set_default_quant_scales(self, register_buffer=False)

448
449
450
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER
        if self.sliding_window is not None:
            assert not vllm_config.model_config.use_mla, (
                "MLA is not supported for slidingwindow"
            )
            return SlidingWindowSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
                sliding_window=self.sliding_window,
            )
        else:
            return FullAttentionSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
472
                head_size_v=self.head_size_v,
473
474
475
                dtype=self.kv_cache_torch_dtype,
            )

476

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
class MLAAttention(nn.Module, AttentionLayerBase):
    """Multi-Head Latent Attention layer.

    This class takes query, and compressed key/value tensors as input.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
495
        q_lora_rank: int | None,
496
497
        kv_lora_rank: int,
        kv_b_proj: ColumnParallelLinear,
498
499
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
500
501
        prefix: str = "",
        use_sparse: bool = False,
502
        indexer: object | None = None,
503
        **extra_impl_args,
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = scale
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.head_size = kv_lora_rank + qk_rope_head_dim
        self.layer_name = prefix

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False
524
        self.quant_config = quant_config
525
526
527

        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(
528
529
530
531
532
            self,
            self.quant_config,
            self.layer_name,
            kv_cache_dtype,
            calculate_kv_scales,
533
        )
534
535
536
537
538
539
540
541
542
543

        dtype = torch.get_default_dtype()
        self.attn_backend = get_attn_backend(
            self.head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=True,
            use_sparse=use_sparse,
        )
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560

        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
            and vllm_is_batch_invariant()
            and (
                self.attn_backend.get_name() == "TRITON_MLA"
                or self.attn_backend.get_name() == "FLASHINFER"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for TRITON_MLA / FLASHINFER "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
        impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
        self.impl = impl_cls(
            num_heads=self.num_heads,
            head_size=self.head_size,
            scale=self.scale,
            num_kv_heads=1,
            alibi_slopes=None,
            sliding_window=None,
            kv_cache_dtype=self.kv_cache_dtype,
            logits_soft_cap=None,
            attn_type=AttentionType.DECODER,
            kv_sharing_target_layer_name=None,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=kv_b_proj,
            indexer=indexer,
582
            **extra_impl_args,
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
        )

        self.use_direct_call = not current_platform.opaque_attention_op()

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

        self.kv_cache = [
            torch.tensor([])
            for _ in range(
                get_current_vllm_config().parallel_config.pipeline_parallel_size
            )
        ]

        self.use_sparse = use_sparse

        # Initialize q/k/v range constants.
602
603
604
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
605
606
607
608
609
610

    def forward(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
611
        output_shape: torch.Size | None = None,
612
    ) -> torch.Tensor:
613
614
615
        if self.calculate_kv_scales:
            torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)

616
617
618
619
620
621
622
623
        if self.use_direct_call:
            forward_context: ForwardContext = get_forward_context()
            attn_metadata = forward_context.attn_metadata
            if isinstance(attn_metadata, dict):
                attn_metadata = attn_metadata[self.layer_name]
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]

            if self.attn_backend.accept_output_buffer:
624
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
                self.impl.forward(
                    self,
                    q,
                    kv_c_normed,
                    k_pe,
                    self_kv_cache,
                    attn_metadata,
                    output=output,
                )
                return output
            else:
                return self.impl.forward(
                    self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
                )
        else:
            if self.attn_backend.accept_output_buffer:
641
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
                torch.ops.vllm.unified_mla_attention_with_output(
                    q,
                    kv_c_normed,
                    k_pe,
                    output,
                    self.layer_name,
                )
                return output
            else:
                return torch.ops.vllm.unified_mla_attention(
                    q,
                    kv_c_normed,
                    k_pe,
                    self.layer_name,
                )

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if hasattr(self.impl, "process_weights_after_loading"):
            self.impl.process_weights_after_loading(act_dtype)

662
663
664
665
666
667
668
669
670
671
672
        # If we should not load quant weights, we initialize the scales to 1.0
        # as the default value. See [Note: Register q/k/v/prob scales in state dict]
        # for more details.
        quant_method = (
            self.quant_config.get_quant_method(self, prefix=self.layer_name)
            if self.quant_config
            else None
        )
        if not should_load_quant_weights(quant_method):
            set_default_quant_scales(self, register_buffer=False)

673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
    def calc_kv_scales(
        self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
    ) -> None:
        """Optional scale calculation for MLA inputs.

        Mirrors Attention.calc_kv_scales. Not all MLA backends require this
        """
        # Use safe defaults if ranges are not present
        q_range = getattr(self, "q_range", torch.tensor(1.0))
        k_range = getattr(self, "k_range", torch.tensor(1.0))
        v_range = getattr(self, "v_range", torch.tensor(1.0))

        self._q_scale.copy_(torch.abs(q).max() / q_range)
        # kv_c_normed is the compressed KV representation; use it for k/v
        kv_abs_max = torch.abs(kv_c_normed).max()
        self._k_scale.copy_(kv_abs_max / k_range)
        self._v_scale.copy_(kv_abs_max / v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        self.calculate_kv_scales = False

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

698
699
700
701
702
703
704
705
706
707
708
709
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        kv_cache_dtype = kv_cache_dtype_str_to_dtype(
            self.kv_cache_dtype, vllm_config.model_config
        )
        return MLAAttentionSpec(
            block_size=vllm_config.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_size,
            dtype=kv_cache_dtype,
            cache_dtype_str=vllm_config.cache_config.cache_dtype,
        )

710

711
712
713
714
715
716
717
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
718
    self = forward_context.no_compile_layers[layer_name]
719

720
721
722
    # Only calculate if the layer's calculate_kv_scales flag is True
    # This flag gets set to False after the first forward pass
    if not self.calculate_kv_scales:
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
        return

    self.calc_kv_scales(query, key, value)


def maybe_calc_kv_scales_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=["query", "key", "value"],
    fake_impl=maybe_calc_kv_scales_fake,
)


745
def get_attention_context(
746
    layer_name: str,
747
748
749
750
751
752
753
754
) -> tuple[dict | object | None, Attention | MLAAttention, torch.Tensor]:
    """Extract attention context for a given layer.

    This helper function extracts the attention metadata, attention layer
    instance, and KV cache tensor for a specific layer.

    Args:
        layer_name: The name/identifier of the attention layer.
755

756
757
758
759
760
761
762
763
764
765
    Returns:
        A tuple containing:
        - attn_metadata: Attention metadata for this specific layer, or None if
            no metadata available
        - attn_layer: The attention layer instance (Attention or MLAAttention)
        - kv_cache: The KV cache tensor for current virtual engine

        Note: attn_metadata may be None, but attn_layer and kv_cache are always
        extracted from the forward context.
    """
766
    forward_context: ForwardContext = get_forward_context()
767
    attn_metadata = forward_context.attn_metadata
768
769
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
770
771
772
773
774
775
776
777
778
779
780
781
782
    attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
    kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
    return attn_metadata, attn_layer, kv_cache


@maybe_transfer_kv_layer
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
783
    output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
784
785

    return output
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
)
802
803


804
@maybe_transfer_kv_layer
805
806
807
808
809
810
def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
811
812
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
813
) -> None:
814
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
815

816
817
818
819
820
821
822
823
824
825
826
    self.impl.forward(
        self,
        query,
        key,
        value,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )
827
828
829
830
831
832
833
834


def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
835
836
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
837
838
839
840
841
842
843
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
844
    mutates_args=["output", "output_block_scale"],
845
846
    fake_impl=unified_attention_with_output_fake,
)
847
848


849
@maybe_transfer_kv_layer
850
851
852
853
854
855
def unified_mla_attention(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
856
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
    output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)

    return output


def unified_mla_attention_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(q).contiguous()


direct_register_custom_op(
    op_name="unified_mla_attention",
    op_func=unified_mla_attention,
    mutates_args=[],
    fake_impl=unified_mla_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)


880
@maybe_transfer_kv_layer
881
882
883
884
885
886
def unified_mla_attention_with_output(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
887
888
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
889
) -> None:
890
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
    self.impl.forward(
        self,
        q,
        kv_c_normed,
        k_pe,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )


def unified_mla_attention_with_output_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
910
911
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
912
913
914
915
916
917
918
919
920
921
922
) -> None:
    return


direct_register_custom_op(
    op_name="unified_mla_attention_with_output",
    op_func=unified_mla_attention_with_output,
    mutates_args=["output", "output_block_scale"],
    fake_impl=unified_mla_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)