layer.py 33.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4

5
from typing import cast
6
7
8
9

import torch
import torch.nn as nn

10
import vllm.envs as envs
11
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
12
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
13
from vllm.config import CacheConfig, get_current_vllm_config
14
from vllm.config.vllm import VllmConfig
15
from vllm.forward_context import ForwardContext, get_forward_context
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
18
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
19
20
21
22
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    UnquantizedLinearMethod,
)
23
from vllm.model_executor.layers.quantization import QuantizationConfig
24
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
25
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
26
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
27
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
28
from vllm.platforms import current_platform
29
from vllm.utils.torch_utils import (
30
31
32
    direct_register_custom_op,
    kv_cache_dtype_str_to_dtype,
)
33
34
35
36
37
38
39
from vllm.v1.attention.backend import (
    AttentionBackend,
    AttentionType,
    MLAAttentionImpl,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import get_attn_backend
40
41
42
43
44
45
from vllm.v1.kv_cache_interface import (
    FullAttentionSpec,
    KVCacheSpec,
    MLAAttentionSpec,
    SlidingWindowSpec,
)
46

47
48
logger = init_logger(__name__)

49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def should_load_quant_weights(quant_method: QuantizeMethodBase | None) -> bool:
    """Returns whether the quantization method should load quantized weights."""
    return quant_method is not None and not isinstance(
        quant_method, UnquantizedLinearMethod
    )


def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) -> None:
    """Sets default quantization scales for the layer."""
    if register_buffer:
        layer.register_buffer("_k_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_v_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_q_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_prob_scale", torch.tensor(1.0, dtype=torch.float32))
    else:
        layer._k_scale.fill_(1.0)
        layer._v_scale.fill_(1.0)
        layer._q_scale.fill_(1.0)
        layer._prob_scale.fill_(1.0)

    # We also keep q/k/v_scale on host (cpu) memory for attention
    # backends that require the scales to be on host instead of on device.
    # e.g. Flashinfer
    layer._q_scale_float = 1.0
    layer._k_scale_float = 1.0
    layer._v_scale_float = 1.0
    layer._prob_scale_float = 1.0

78
79
80
81
82
    # Initialize q/k/v range constants used by calc_kv_scales
    layer.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
    layer.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
    layer.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

83

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def _init_kv_cache_quant(
    layer: nn.Module,
    quant_config: QuantizationConfig | None,
    prefix: str,
) -> None:
    """Initializes KV cache scaling factors and quantization method.

    This helper function sets up the KV cache quantization attributes that are
    shared between Attention and MLAAttention layers. It initializes scale
    tensors for query, key, value, and probability, and configures the
    quantization method if applicable.

    Args:
        layer: The attention layer instance to initialize.
        quant_config: Optional quantization configuration.
        prefix: Layer name prefix for quantization method lookup.
    """
101
102
103
    quant_method = (
        quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
    )
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
    # Note [Register q/k/v/prob scales in state dict]
    # When calling model.to(device), only parameters/buffers in state dict are
    # moved. If not registering q/k/v/prob scales in state dict, there would
    # be an IMA error when a cuda kernel (e.g., quant_fp8) accesses the tensor
    # on cpu.
    # Registering in state dict means it interacts with weight loading. One edge
    # case is when quant_method is None, or quant_method is UnquantizedLinearMethod
    # (i.e., should_load_quant_weights(quant_method) == False).
    # In this case, the checkpoint does not have the scales. We need to
    # initialize the scales to 1.0 and update the scales after weight loading.
    # This is espectially important when we load dummy weights first (providing
    # wrong scales) and then load real weights (which misses scales and keeps the
    # wrong scales from dummy load).
    set_default_quant_scales(layer, register_buffer=True)
119
120
121
122
123
124
125
126

    # The output scale on host memory. This should be the input scale of
    # the quant op after this attention layer.
    layer._o_scale_float = None

    quant_method = (
        quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
    )
127
128
129

    # See [Note: Register q/k/v/prob scales in state dict]
    if should_load_quant_weights(quant_method):
130
131
132
        assert isinstance(quant_method, BaseKVCacheMethod)
        # TODO (mgoin): kv cache dtype should be specified in the FP8
        # checkpoint config and become the "auto" behavior
133
        if layer.kv_cache_dtype == "fp8_e5m2":
134
135
136
137
138
139
140
141
142
            raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
        # If quantization is enabled, we make "k_scale" and "v_scale"
        # parameters so that it can be loaded from the model checkpoint.
        # The k/v_scale will then be converted back to native float32
        # values after weight loading.
        layer.quant_method = quant_method
        layer.quant_method.create_weights(layer)


143
class Attention(nn.Module, AttentionLayerBase):
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
160
161
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
Li Xie's avatar
Li Xie committed
162
        use_alibi_sqrt: bool | None = None,
163
164
165
166
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        logits_soft_cap: float | None = None,
        per_layer_sliding_window: int | None = None,
167
        prefix: str = "",
168
        attn_type: str = AttentionType.DECODER,
169
170
        kv_sharing_target_layer_name: str | None = None,
        attn_backend: type[AttentionBackend] | None = None,
171
        head_size_v: int | None = None,
172
        **extra_impl_args,
173
    ) -> None:
174
175
176
177
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
178
        super().__init__()
179
180
181
182
183
184
185
186
187
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

188
        vllm_config = get_current_vllm_config()
189
190
191
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
192
            calculate_kv_scales = cache_config.calculate_kv_scales
193
194
195
        else:
            kv_cache_dtype = "auto"
            block_size = 16
196
            calculate_kv_scales = False
197
198
199
200
201
202
203
204
205

        # llm-compressor mdls need to set cache_dtype to "fp8" manually.
        if getattr(quant_config, "kv_cache_scheme", None) is not None:
            kv_cache_dtype = "fp8"
            calculate_kv_scales = False
            if cache_config is not None:
                cache_config.cache_dtype = "fp8"
                cache_config.calculate_kv_scales = False

206
207
208
        self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
            kv_cache_dtype, vllm_config.model_config
        )
209
210
        self.kv_cache_dtype = kv_cache_dtype
        self.calculate_kv_scales = calculate_kv_scales
211
212
        if num_kv_heads is None:
            num_kv_heads = num_heads
213
214
215
        assert num_heads % num_kv_heads == 0, (
            f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
        )
216
217
        self.quant_config = quant_config
        self.layer_name = prefix
218

219
220
        self.num_heads = num_heads
        self.head_size = head_size
221
        self.head_size_v = self.head_size if head_size_v is None else head_size_v
222
223
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
224
        self.has_sink = extra_impl_args.get("sinks") is not None
225

226
227
228
229
        # NOTE: model_config may be None during certain tests
        model_config = vllm_config.model_config
        self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm

230
231
232
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
233
        if attn_backend is None:
234
235
236
237
238
            self.attn_backend = get_attn_backend(
                head_size,
                dtype,
                kv_cache_dtype,
                block_size,
239
                use_mla=False,
240
                has_sink=self.has_sink,
241
                use_mm_prefix=self.use_mm_prefix,
242
                attn_type=attn_type,
243
            )
244
245
        else:
            self.attn_backend = attn_backend
Li Xie's avatar
Li Xie committed
246
247
248
249
250
251
252
253
254
255
        backend_supports_alibi_sqrt = self.attn_backend.supports_alibi_sqrt()
        use_alibi_sqrt = use_alibi_sqrt if use_alibi_sqrt else False
        if use_alibi_sqrt and not backend_supports_alibi_sqrt:
            raise ValueError(
                f"use_alibi_sqrt is not supported by backend "
                f"{self.attn_backend.get_name()}."
            )
        self.use_alibi_sqrt = bool(use_alibi_sqrt)
        if backend_supports_alibi_sqrt:
            extra_impl_args["use_alibi_sqrt"] = self.use_alibi_sqrt
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        # prefix caching + batch invariance is currently not supported for
        # FLASHINFER and TRITON_MLA.
        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
            and vllm_is_batch_invariant()
            and (
                self.attn_backend.get_name() == "FLASHINFER"
                or self.attn_backend.get_name() == "TRITON_MLA"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for FLASHINFER/TRITON_MLA "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

274
        impl_cls = self.attn_backend.get_impl_cls()
275
276
277
278
279
280
281
282
283
284
285
286
287
        self.impl = impl_cls(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **extra_impl_args,
        )
288
        self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
289
        self.dtype = dtype
290

291
292
293
294
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
295
        self.use_direct_call = not current_platform.opaque_attention_op()
296

297
        self.use_output = self.attn_backend.accept_output_buffer
298
        compilation_config = vllm_config.compilation_config
299
300
301
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
302
        self.attn_type = attn_type
303
304
305
306
307
308
309
310
311

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

312
313
314
315
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
316
            torch.tensor([])
317
            for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
318
        ]
319

320
321
        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(self, quant_config, prefix)
322

323
324
        # for attn backends supporting query quantization
        self.query_quant = None
325
326
        if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
            "fp8"
327
        ):
328
329
330
331
332
333
334
335
336
337
            is_per_head = (
                hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
            )
            block_size = self.head_size * self.num_heads // self.num_kv_heads
            self.query_quant = QuantFP8(
                static=True,
                group_shape=GroupShape(-1, block_size)
                if is_per_head
                else GroupShape.PER_TENSOR,
            )
338

339
340
341
342
343
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
344
345
346
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
347
        output_shape: torch.Size | None = None,
348
    ) -> torch.Tensor:
349
350
351
352
353
354
355
356
357
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
358
        if self.calculate_kv_scales:
359
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
360
361
362
363
364
365
366
367
        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
368
369

            # check if query quantization is supported
370
            if self.impl.supports_quant_query_input:
371
                query, _ = self.query_quant(query, self._q_scale)
372

373
        if self.use_output:
374
            if output_shape is None:
375
376
377
                # Handle both 2D [num_tokens, hidden] and
                # 3D [num_tokens, heads, head_dim] query
                num_tokens = query.shape[0]
378
                output_shape = torch.Size(
379
                    (num_tokens, self.num_heads * self.head_size_v)
380
                )
381
            output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
382
            hidden_size = output_shape[-1]
383
384
385
386
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
387
            output = output.view(-1, self.num_heads, self.head_size_v)
388
389
390
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
391
                value = value.view(-1, self.num_kv_heads, self.head_size_v)
392
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
393
                forward_context: ForwardContext = get_forward_context()
394
                attn_metadata = forward_context.attn_metadata
395
396
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
397
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
398
399
400
                self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata, output=output
                )
401
402
            else:
                torch.ops.vllm.unified_attention_with_output(
403
404
                    query, key, value, output, self.layer_name
                )
405
            return output.view(-1, hidden_size)
406
        else:
407
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
408
                forward_context = get_forward_context()
409
                attn_metadata = forward_context.attn_metadata
410
411
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
412
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
413
414
415
                return self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata
                )
416
417
            else:
                return torch.ops.vllm.unified_attention(
418
419
                    query, key, value, self.layer_name
                )
420

421
422
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
423
424
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
425
        self._q_scale_float = self._q_scale.item()
426
427
428
429
430
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

431
432
433
434
435
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
436
        s += f", backend={self.impl.__class__.__name__}"
437
        return s
438

439
    def process_weights_after_loading(self, act_dtype: torch.dtype):
440
        self.impl.process_weights_after_loading(act_dtype)
441

442
443
444
445
446
447
448
449
450
451
452
        # If we should not load quant weights, we initialize the scales to 1.0
        # as the default value. See [Note: Register q/k/v/prob scales in state dict]
        # for more details.
        quant_method = (
            self.quant_config.get_quant_method(self, prefix=self.layer_name)
            if self.quant_config
            else None
        )
        if not should_load_quant_weights(quant_method):
            set_default_quant_scales(self, register_buffer=False)

453
454
455
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER
        if self.sliding_window is not None:
            assert not vllm_config.model_config.use_mla, (
                "MLA is not supported for slidingwindow"
            )
            return SlidingWindowSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
                sliding_window=self.sliding_window,
            )
        else:
            return FullAttentionSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
477
                head_size_v=self.head_size_v,
478
479
480
                dtype=self.kv_cache_torch_dtype,
            )

481

482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
class MLAAttention(nn.Module, AttentionLayerBase):
    """Multi-Head Latent Attention layer.

    This class takes query, and compressed key/value tensors as input.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
500
        q_lora_rank: int | None,
501
502
        kv_lora_rank: int,
        kv_b_proj: ColumnParallelLinear,
503
504
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
505
506
        prefix: str = "",
        use_sparse: bool = False,
507
        indexer: object | None = None,
508
        **extra_impl_args,
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = scale
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.head_size = kv_lora_rank + qk_rope_head_dim
        self.layer_name = prefix

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False
529
        self.quant_config = quant_config
530
531

        # Initialize KV cache quantization attributes
532
533
534
        self.kv_cache_dtype = kv_cache_dtype
        self.calculate_kv_scales = calculate_kv_scales
        _init_kv_cache_quant(self, quant_config, prefix)
535
536
537
538
539
540
541
542
543
544

        dtype = torch.get_default_dtype()
        self.attn_backend = get_attn_backend(
            self.head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=True,
            use_sparse=use_sparse,
        )
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561

        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
            and vllm_is_batch_invariant()
            and (
                self.attn_backend.get_name() == "TRITON_MLA"
                or self.attn_backend.get_name() == "FLASHINFER"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for TRITON_MLA / FLASHINFER "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
        self.impl = impl_cls(
            num_heads=self.num_heads,
            head_size=self.head_size,
            scale=self.scale,
            num_kv_heads=1,
            alibi_slopes=None,
            sliding_window=None,
            kv_cache_dtype=self.kv_cache_dtype,
            logits_soft_cap=None,
            attn_type=AttentionType.DECODER,
            kv_sharing_target_layer_name=None,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=kv_b_proj,
            indexer=indexer,
583
            **extra_impl_args,
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
        )

        self.use_direct_call = not current_platform.opaque_attention_op()

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

        self.kv_cache = [
            torch.tensor([])
            for _ in range(
                get_current_vllm_config().parallel_config.pipeline_parallel_size
            )
        ]

        self.use_sparse = use_sparse

        # Initialize q/k/v range constants.
603
604
605
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
606
607
608
609
610
611

    def forward(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
612
        output_shape: torch.Size | None = None,
613
    ) -> torch.Tensor:
614
615
616
        if self.calculate_kv_scales:
            torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)

617
618
619
620
621
622
623
624
        if self.use_direct_call:
            forward_context: ForwardContext = get_forward_context()
            attn_metadata = forward_context.attn_metadata
            if isinstance(attn_metadata, dict):
                attn_metadata = attn_metadata[self.layer_name]
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]

            if self.attn_backend.accept_output_buffer:
625
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
                self.impl.forward(
                    self,
                    q,
                    kv_c_normed,
                    k_pe,
                    self_kv_cache,
                    attn_metadata,
                    output=output,
                )
                return output
            else:
                return self.impl.forward(
                    self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
                )
        else:
            if self.attn_backend.accept_output_buffer:
642
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
                torch.ops.vllm.unified_mla_attention_with_output(
                    q,
                    kv_c_normed,
                    k_pe,
                    output,
                    self.layer_name,
                )
                return output
            else:
                return torch.ops.vllm.unified_mla_attention(
                    q,
                    kv_c_normed,
                    k_pe,
                    self.layer_name,
                )

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if hasattr(self.impl, "process_weights_after_loading"):
            self.impl.process_weights_after_loading(act_dtype)

663
664
665
666
667
668
669
670
671
672
673
        # If we should not load quant weights, we initialize the scales to 1.0
        # as the default value. See [Note: Register q/k/v/prob scales in state dict]
        # for more details.
        quant_method = (
            self.quant_config.get_quant_method(self, prefix=self.layer_name)
            if self.quant_config
            else None
        )
        if not should_load_quant_weights(quant_method):
            set_default_quant_scales(self, register_buffer=False)

674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
    def calc_kv_scales(
        self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
    ) -> None:
        """Optional scale calculation for MLA inputs.

        Mirrors Attention.calc_kv_scales. Not all MLA backends require this
        """
        # Use safe defaults if ranges are not present
        q_range = getattr(self, "q_range", torch.tensor(1.0))
        k_range = getattr(self, "k_range", torch.tensor(1.0))
        v_range = getattr(self, "v_range", torch.tensor(1.0))

        self._q_scale.copy_(torch.abs(q).max() / q_range)
        # kv_c_normed is the compressed KV representation; use it for k/v
        kv_abs_max = torch.abs(kv_c_normed).max()
        self._k_scale.copy_(kv_abs_max / k_range)
        self._v_scale.copy_(kv_abs_max / v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        self.calculate_kv_scales = False

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

699
700
701
702
703
704
705
706
707
708
709
710
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        kv_cache_dtype = kv_cache_dtype_str_to_dtype(
            self.kv_cache_dtype, vllm_config.model_config
        )
        return MLAAttentionSpec(
            block_size=vllm_config.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_size,
            dtype=kv_cache_dtype,
            cache_dtype_str=vllm_config.cache_config.cache_dtype,
        )

711

712
713
714
715
716
717
718
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
719
    self = forward_context.no_compile_layers[layer_name]
720

721
722
723
    # Only calculate if the layer's calculate_kv_scales flag is True
    # This flag gets set to False after the first forward pass
    if not self.calculate_kv_scales:
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
        return

    self.calc_kv_scales(query, key, value)


def maybe_calc_kv_scales_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=["query", "key", "value"],
    fake_impl=maybe_calc_kv_scales_fake,
)


746
def get_attention_context(
747
    layer_name: str,
748
749
750
751
752
753
754
755
) -> tuple[dict | object | None, Attention | MLAAttention, torch.Tensor]:
    """Extract attention context for a given layer.

    This helper function extracts the attention metadata, attention layer
    instance, and KV cache tensor for a specific layer.

    Args:
        layer_name: The name/identifier of the attention layer.
756

757
758
759
760
761
762
763
764
765
766
    Returns:
        A tuple containing:
        - attn_metadata: Attention metadata for this specific layer, or None if
            no metadata available
        - attn_layer: The attention layer instance (Attention or MLAAttention)
        - kv_cache: The KV cache tensor for current virtual engine

        Note: attn_metadata may be None, but attn_layer and kv_cache are always
        extracted from the forward context.
    """
767
    forward_context: ForwardContext = get_forward_context()
768
    attn_metadata = forward_context.attn_metadata
769
770
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
771
772
773
774
775
776
777
778
779
780
781
782
783
    attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
    kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
    return attn_metadata, attn_layer, kv_cache


@maybe_transfer_kv_layer
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
784
    output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
785
786

    return output
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
)
803
804


805
@maybe_transfer_kv_layer
806
807
808
809
810
811
def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
812
813
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
814
) -> None:
815
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
816

817
818
819
820
821
822
823
824
825
826
827
    self.impl.forward(
        self,
        query,
        key,
        value,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )
828
829
830
831
832
833
834
835


def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
836
837
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
838
839
840
841
842
843
844
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
845
    mutates_args=["output", "output_block_scale"],
846
847
    fake_impl=unified_attention_with_output_fake,
)
848
849


850
@maybe_transfer_kv_layer
851
852
853
854
855
856
def unified_mla_attention(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
857
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
    output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)

    return output


def unified_mla_attention_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(q).contiguous()


direct_register_custom_op(
    op_name="unified_mla_attention",
    op_func=unified_mla_attention,
    mutates_args=[],
    fake_impl=unified_mla_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)


881
@maybe_transfer_kv_layer
882
883
884
885
886
887
def unified_mla_attention_with_output(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
888
889
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
890
) -> None:
891
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
    self.impl.forward(
        self,
        q,
        kv_c_normed,
        k_pe,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )


def unified_mla_attention_with_output_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
911
912
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
913
914
915
916
917
918
919
920
921
922
923
) -> None:
    return


direct_register_custom_op(
    op_name="unified_mla_attention_with_output",
    op_func=unified_mla_attention_with_output,
    mutates_args=["output", "output_block_scale"],
    fake_impl=unified_mla_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)