layer.py 34.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4

5
import functools
6
from typing import cast
7
8
9

import torch
import torch.nn as nn
10
import torch.nn.functional as F
11

12
import vllm.envs as envs
13
14
15
16
17
from vllm.attention.backends.abstract import (
    AttentionBackend,
    AttentionType,
    MLAAttentionImpl,
)
18
from vllm.attention.backends.registry import AttentionBackendEnum
19
from vllm.attention.layers.mm_encoder_attention import maybe_get_vit_flash_attn_backend
20
from vllm.attention.selector import get_attn_backend
21
from vllm.attention.utils.fa_utils import get_flash_attn_version
22
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
23
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
24
from vllm.config import CacheConfig, get_current_vllm_config
25
from vllm.config.multimodal import MultiModalConfig
26
from vllm.config.vllm import VllmConfig
27
from vllm.forward_context import ForwardContext, get_forward_context
28
from vllm.logger import init_logger
29
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
30
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
31
32
33
34
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    UnquantizedLinearMethod,
)
35
from vllm.model_executor.layers.quantization import QuantizationConfig
36
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
37
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
38
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
39
from vllm.model_executor.models.vision import get_vit_attn_backend
40
from vllm.platforms import current_platform
41
from vllm.utils.torch_utils import (
42
43
44
45
46
47
48
49
50
    direct_register_custom_op,
    kv_cache_dtype_str_to_dtype,
)
from vllm.v1.kv_cache_interface import (
    FullAttentionSpec,
    KVCacheSpec,
    MLAAttentionSpec,
    SlidingWindowSpec,
)
51

52
53
54
logger = init_logger(__name__)


55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def _init_kv_cache_quant(
    layer: nn.Module,
    quant_config: QuantizationConfig | None,
    prefix: str,
    kv_cache_dtype: str,
    calculate_kv_scales: bool,
) -> None:
    """Initializes KV cache scaling factors and quantization method.

    This helper function sets up the KV cache quantization attributes that are
    shared between Attention and MLAAttention layers. It initializes scale
    tensors for query, key, value, and probability, and configures the
    quantization method if applicable.

    Args:
        layer: The attention layer instance to initialize.
        quant_config: Optional quantization configuration.
        prefix: Layer name prefix for quantization method lookup.
        kv_cache_dtype: The KV cache data type string.
        calculate_kv_scales: Whether to calculate KV scales dynamically.
    """
    # The default k/v_scale is set to 1.0. This is ignored
    # when kv-cache is not fp8, and should be used with
    # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
    # expect the pre-quantized k/v_scale to be loaded along
    # with the model weights.
    layer.kv_cache_dtype = kv_cache_dtype
    layer.calculate_kv_scales = calculate_kv_scales
    layer._k_scale = torch.tensor(1.0, dtype=torch.float32)
    layer._v_scale = torch.tensor(1.0, dtype=torch.float32)
    layer._q_scale = torch.tensor(1.0, dtype=torch.float32)
    layer._prob_scale = torch.tensor(1.0, dtype=torch.float32)

    # We also keep q/k/v_scale on host (cpu) memory for attention
    # backends that require the scales to be on host instead of on device.
    # e.g. Flashinfer
    layer._q_scale_float = 1.0
    layer._k_scale_float = 1.0
    layer._v_scale_float = 1.0

    # The output scale on host memory. This should be the input scale of
    # the quant op after this attention layer.
    layer._o_scale_float = None

    quant_method = (
        quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
    )
    if quant_method is not None and not isinstance(
        quant_method, UnquantizedLinearMethod
    ):
        assert isinstance(quant_method, BaseKVCacheMethod)
        # TODO (mgoin): kv cache dtype should be specified in the FP8
        # checkpoint config and become the "auto" behavior
        if kv_cache_dtype == "fp8_e5m2":
            raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
        # If quantization is enabled, we make "k_scale" and "v_scale"
        # parameters so that it can be loaded from the model checkpoint.
        # The k/v_scale will then be converted back to native float32
        # values after weight loading.
        layer.quant_method = quant_method
        layer.quant_method.create_weights(layer)
116
117


118
class Attention(nn.Module, AttentionLayerBase):
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
135
136
137
138
139
140
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        logits_soft_cap: float | None = None,
        per_layer_sliding_window: int | None = None,
141
        prefix: str = "",
142
        attn_type: str = AttentionType.DECODER,
143
144
        kv_sharing_target_layer_name: str | None = None,
        attn_backend: type[AttentionBackend] | None = None,
145
        **extra_impl_args,
146
    ) -> None:
147
148
149
150
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
151
        super().__init__()
152
153
154
155
156
157
158
159
160
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

161
        vllm_config = get_current_vllm_config()
162
163
164
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
165
            calculate_kv_scales = cache_config.calculate_kv_scales
166
167
        else:
            kv_cache_dtype = "auto"
168
            block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA and envs.VLLM_USE_FLASH_MLA else 16
169
            calculate_kv_scales = False
170
171
172
        self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
            kv_cache_dtype, vllm_config.model_config
        )
173
174
        if num_kv_heads is None:
            num_kv_heads = num_heads
175
176
177
        assert num_heads % num_kv_heads == 0, (
            f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
        )
178

179
180
181
182
        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(
            self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
        )
183

184
185
186
187
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
188
        self.has_sink = extra_impl_args.get("sinks") is not None
189

190
191
192
193
        # NOTE: model_config may be None during certain tests
        model_config = vllm_config.model_config
        self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm

194
195
196
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
197
        if attn_backend is None:
198
199
200
201
202
            self.attn_backend = get_attn_backend(
                head_size,
                dtype,
                kv_cache_dtype,
                block_size,
203
                use_mla=False,
204
                has_sink=self.has_sink,
205
                use_mm_prefix=self.use_mm_prefix,
206
                attn_type=attn_type,
207
            )
208
209
210
        else:
            self.attn_backend = attn_backend

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        # prefix caching + batch invariance is currently not supported for
        # FLASHINFER and TRITON_MLA.
        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
            and vllm_is_batch_invariant()
            and (
                self.attn_backend.get_name() == "FLASHINFER"
                or self.attn_backend.get_name() == "TRITON_MLA"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for FLASHINFER/TRITON_MLA "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

229
        impl_cls = self.attn_backend.get_impl_cls()
230
231
232
233
234
235
236
237
238
239
240
241
242
        self.impl = impl_cls(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **extra_impl_args,
        )
243
244
        backend_name = self.attn_backend.get_name()
        self.backend = AttentionBackendEnum.__members__.get(backend_name)
245
        self.dtype = dtype
246

247
248
249
250
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
251
        self.use_direct_call = not current_platform.opaque_attention_op()
252

253
        self.use_output = self.attn_backend.accept_output_buffer
254
        compilation_config = vllm_config.compilation_config
255
256
257
258
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
259
        self.attn_type = attn_type
260
261
262
263
264
265
266
267
268

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

269
270
271
272
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
273
            torch.tensor([])
274
            for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
275
        ]
276

277
278
279
280
        # Initialize q/k/v range constants.
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
281

282
283
        # for attn backends supporting query quantization
        self.query_quant = None
284
285
        if (
            self.kv_cache_dtype.startswith("fp8")
286
            and self.impl.supports_quant_query_input
287
288
        ):
            self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
289

290
291
292
293
294
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
295
296
297
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
298
        output_shape: torch.Size | None = None,
299
    ) -> torch.Tensor:
300
301
302
303
304
305
306
307
308
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
309
        if self.calculate_kv_scales:
310
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
311
312
313
314
315
316
317
318
        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
319
320

            # check if query quantization is supported
321
            if self.impl.supports_quant_query_input:
322
                query, _ = self.query_quant(query, self._q_scale)
323

324
        if self.use_output:
325
            output_shape = output_shape if output_shape is not None else query.shape
326
            output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
327
            hidden_size = output_shape[-1]
328
329
330
331
332
333
334
335
336
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
            output = output.view(-1, self.num_heads, self.head_size)
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
                value = value.view(-1, self.num_kv_heads, self.head_size)
337
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
338
                forward_context: ForwardContext = get_forward_context()
339
                attn_metadata = forward_context.attn_metadata
340
341
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
342
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
343
344
345
                self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata, output=output
                )
346
347
            else:
                torch.ops.vllm.unified_attention_with_output(
348
349
                    query, key, value, output, self.layer_name
                )
350
            return output.view(-1, hidden_size)
351
        else:
352
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
353
                forward_context = get_forward_context()
354
                attn_metadata = forward_context.attn_metadata
355
356
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
357
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
358
359
360
                return self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata
                )
361
362
            else:
                return torch.ops.vllm.unified_attention(
363
364
                    query, key, value, self.layer_name
                )
365

366
367
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
368
369
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
370
        self._q_scale_float = self._q_scale.item()
371
372
373
374
375
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

376
377
378
379
380
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
381
        s += f", backend={self.impl.__class__.__name__}"
382
        return s
383

384
    def process_weights_after_loading(self, act_dtype: torch.dtype):
385
        self.impl.process_weights_after_loading(act_dtype)
386

387
388
389
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER
        if self.sliding_window is not None:
            assert not vllm_config.model_config.use_mla, (
                "MLA is not supported for slidingwindow"
            )
            return SlidingWindowSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
                sliding_window=self.sliding_window,
            )
        else:
            return FullAttentionSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
            )

414

415
416
417
418
419
420
421
422
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
423
        num_kv_heads: int | None = None,
424
425
426
        # This has no effect, it is only here to make it easier to swap
        # between Attention and MultiHeadAttention
        prefix: str = "",
427
        multimodal_config: MultiModalConfig | None = None,
428
    ) -> None:
429
430
431
432
433
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
434
        self.layer_name = prefix
435

436
437
        assert self.num_heads % self.num_kv_heads == 0, (
            f"num_heads ({self.num_heads}) is not "
438
            f"divisible by num_kv_heads ({self.num_kv_heads})"
439
        )
440
441
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

442
443
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
444
        dtype = torch.get_default_dtype()
445
446

        # Determine the attention backend
447
448
449
        attn_backend_override = None
        if multimodal_config is not None:
            attn_backend_override = multimodal_config.mm_encoder_attn_backend
450
451

        self.attn_backend = get_vit_attn_backend(
452
453
454
455
            head_size=head_size,
            dtype=dtype,
            attn_backend_override=attn_backend_override,
        )
456

457
458
        self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend(
            self.attn_backend,
459
        )
460
461

        self.is_flash_attn_backend = self.attn_backend in {
462
463
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
464
        }
465

466
        self.fa_version = None
467
468
469
470
        if (
            self.attn_backend == AttentionBackendEnum.FLASH_ATTN
            and current_platform.is_cuda()
        ):
471
472
473
474
475
476
            self.fa_version = get_flash_attn_version()
            assert self._flash_attn_varlen_func is not None
            self._flash_attn_varlen_func = functools.partial(
                self._flash_attn_varlen_func, fa_version=self.fa_version
            )

477
        logger.info_once(
478
            f"Using {self.attn_backend} for MultiHeadAttention in multimodal encoder."
479
        )
480

481
482
483
484
485
486
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
487
        """Input shape:
488
489
490
491
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
492
493
494
495
496
497
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

498
499
500
501
502
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

503
        if self.is_flash_attn_backend:
504
            assert self._flash_attn_varlen_func is not None
505
506
507
508
509
510
            cu_seqlens_q = torch.arange(
                0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
            )
            cu_seqlens_k = torch.arange(
                0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device
            )
511
512
513
514
515
516
517
518
519
520
521

            out = self._flash_attn_varlen_func(
                query.flatten(0, 1),
                key.flatten(0, 1),
                value.flatten(0, 1),
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=q_len,
                max_seqlen_k=kv_len,
                softmax_scale=self.scale,
            )
522
        elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
523
524
            query, key, value = (x.transpose(1, 2) for x in (query, key, value))
            out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
525
            out = out.transpose(1, 2)
526
        elif self.attn_backend == AttentionBackendEnum.PALLAS:
527
            query, key, value = (x.transpose(1, 2) for x in (query, key, value))
528
            from torch_xla.experimental.custom_kernel import flash_attention
529

530
531
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)
532
533
534
        else:
            # ViT attention hasn't supported this backend yet
            raise NotImplementedError(
535
536
                f"ViT attention hasn't supported {self.attn_backend} backend yet."
            )
537

538
        return out.reshape(bsz, q_len, -1)
539
540


541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
class MLAAttention(nn.Module, AttentionLayerBase):
    """Multi-Head Latent Attention layer.

    This class takes query, and compressed key/value tensors as input.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
559
        q_lora_rank: int | None,
560
561
        kv_lora_rank: int,
        kv_b_proj: ColumnParallelLinear,
562
563
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
564
565
        prefix: str = "",
        use_sparse: bool = False,
566
        indexer: object | None = None,
567
        **extra_impl_args,
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = scale
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.head_size = kv_lora_rank + qk_rope_head_dim
        self.layer_name = prefix

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False
588
589
590
591
592

        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(
            self, quant_config, prefix, kv_cache_dtype, calculate_kv_scales
        )
593
594
595
596
597
598
599
600
601
602

        dtype = torch.get_default_dtype()
        self.attn_backend = get_attn_backend(
            self.head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=True,
            use_sparse=use_sparse,
        )
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619

        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
            and vllm_is_batch_invariant()
            and (
                self.attn_backend.get_name() == "TRITON_MLA"
                or self.attn_backend.get_name() == "FLASHINFER"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for TRITON_MLA / FLASHINFER "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
        impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
        self.impl = impl_cls(
            num_heads=self.num_heads,
            head_size=self.head_size,
            scale=self.scale,
            num_kv_heads=1,
            alibi_slopes=None,
            sliding_window=None,
            kv_cache_dtype=self.kv_cache_dtype,
            logits_soft_cap=None,
            attn_type=AttentionType.DECODER,
            kv_sharing_target_layer_name=None,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=kv_b_proj,
            indexer=indexer,
641
            **extra_impl_args,
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        )

        self.use_direct_call = not current_platform.opaque_attention_op()

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

        self.kv_cache = [
            torch.tensor([])
            for _ in range(
                get_current_vllm_config().parallel_config.pipeline_parallel_size
            )
        ]

        self.use_sparse = use_sparse
659

660
        # Initialize q/k/v range constants.
661
662
663
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
664

665
666
667
668
669
    def forward(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
670
        output_shape: torch.Size | None = None,
671
    ) -> torch.Tensor:
672
673
674
        if self.calculate_kv_scales:
            torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)

675
676
677
678
679
680
681
682
        if self.use_direct_call:
            forward_context: ForwardContext = get_forward_context()
            attn_metadata = forward_context.attn_metadata
            if isinstance(attn_metadata, dict):
                attn_metadata = attn_metadata[self.layer_name]
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]

            if self.attn_backend.accept_output_buffer:
683
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
                self.impl.forward(
                    self,
                    q,
                    kv_c_normed,
                    k_pe,
                    self_kv_cache,
                    attn_metadata,
                    output=output,
                )
                return output
            else:
                return self.impl.forward(
                    self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
                )
        else:
            if self.attn_backend.accept_output_buffer:
700
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
                torch.ops.vllm.unified_mla_attention_with_output(
                    q,
                    kv_c_normed,
                    k_pe,
                    output,
                    self.layer_name,
                )
                return output
            else:
                return torch.ops.vllm.unified_mla_attention(
                    q,
                    kv_c_normed,
                    k_pe,
                    self.layer_name,
                )

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if hasattr(self.impl, "process_weights_after_loading"):
            self.impl.process_weights_after_loading(act_dtype)

    def calc_kv_scales(
        self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
    ) -> None:
        """Optional scale calculation for MLA inputs.

        Mirrors Attention.calc_kv_scales. Not all MLA backends require this
        """
        # Use safe defaults if ranges are not present
        q_range = getattr(self, "q_range", torch.tensor(1.0))
        k_range = getattr(self, "k_range", torch.tensor(1.0))
        v_range = getattr(self, "v_range", torch.tensor(1.0))

        self._q_scale.copy_(torch.abs(q).max() / q_range)
        # kv_c_normed is the compressed KV representation; use it for k/v
        kv_abs_max = torch.abs(kv_c_normed).max()
        self._k_scale.copy_(kv_abs_max / k_range)
        self._v_scale.copy_(kv_abs_max / v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        self.calculate_kv_scales = False

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

746
747
748
749
750
751
752
753
754
755
756
757
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        kv_cache_dtype = kv_cache_dtype_str_to_dtype(
            self.kv_cache_dtype, vllm_config.model_config
        )
        return MLAAttentionSpec(
            block_size=vllm_config.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_size,
            dtype=kv_cache_dtype,
            cache_dtype_str=vllm_config.cache_config.cache_dtype,
        )

758

759
760
761
762
763
764
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
765
    forward_context: ForwardContext = get_forward_context()
766
    self = forward_context.no_compile_layers[layer_name]
767

768
769
770
    # Only calculate if the layer's calculate_kv_scales flag is True
    # This flag gets set to False after the first forward pass
    if not self.calculate_kv_scales:
771
772
        return

773
    self.calc_kv_scales(query, key, value)
774

775
776
777
778
779

def maybe_calc_kv_scales_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
780
    layer_name: str,
781
782
783
784
785
786
787
788
789
790
791
792
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=["query", "key", "value"],
    fake_impl=maybe_calc_kv_scales_fake,
)


793
def get_attention_context(
794
    layer_name: str,
795
796
) -> tuple[dict | object | None, Attention | MLAAttention, torch.Tensor]:
    """Extract attention context for a given layer.
797

798
799
    This helper function extracts the attention metadata, attention layer
    instance, and KV cache tensor for a specific layer.
800

801
802
    Args:
        layer_name: The name/identifier of the attention layer.
803

804
805
806
807
808
809
810
811
812
813
    Returns:
        A tuple containing:
        - attn_metadata: Attention metadata for this specific layer, or None if
            no metadata available
        - attn_layer: The attention layer instance (Attention or MLAAttention)
        - kv_cache: The KV cache tensor for current virtual engine

        Note: attn_metadata may be None, but attn_layer and kv_cache are always
        extracted from the forward context.
    """
814
815
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
816
817
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
818
819
820
    attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
    kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
    return attn_metadata, attn_layer, kv_cache
821
822


823
@maybe_transfer_kv_layer
824
825
826
827
828
829
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
830
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
831
    output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
832
833

    return output
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
)
850
851


852
@maybe_transfer_kv_layer
853
854
855
856
857
858
def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
859
860
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
861
) -> None:
862
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
863
864
865
866
867
868
869
870
871
872
873
    self.impl.forward(
        self,
        query,
        key,
        value,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )
874

875
876
877
878
879
880
881

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
882
883
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
884
885
886
887
888
889
890
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
891
    mutates_args=["output", "output_block_scale"],
892
893
    fake_impl=unified_attention_with_output_fake,
)
894
895


896
@maybe_transfer_kv_layer
897
898
899
900
901
902
def unified_mla_attention(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
903
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
    output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)

    return output


def unified_mla_attention_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(q).contiguous()


direct_register_custom_op(
    op_name="unified_mla_attention",
    op_func=unified_mla_attention,
    mutates_args=[],
    fake_impl=unified_mla_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)


927
@maybe_transfer_kv_layer
928
929
930
931
932
933
def unified_mla_attention_with_output(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
934
935
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
936
) -> None:
937
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
    self.impl.forward(
        self,
        q,
        kv_c_normed,
        k_pe,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )


def unified_mla_attention_with_output_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
957
958
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
959
960
961
962
963
964
965
966
967
968
) -> None:
    return


direct_register_custom_op(
    op_name="unified_mla_attention_with_output",
    op_func=unified_mla_attention_with_output,
    mutates_args=["output", "output_block_scale"],
    fake_impl=unified_mla_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
969
)