layer.py 36.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4

5
from typing import Callable, Optional, cast
6
7
8

import torch
import torch.nn as nn
9
import torch.nn.functional as F
10

11
import vllm.envs as envs
12
from vllm.attention import AttentionType
13
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
14
15
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.selector import get_attn_backend
16
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
17
from vllm.config import CacheConfig, get_current_vllm_config
18
19
20
21
22
from vllm.distributed.kv_transfer import (
    get_kv_transfer_group,
    has_kv_transfer_group,
    is_v1_kv_transfer_group,
)
23
from vllm.forward_context import ForwardContext, get_forward_context
24
from vllm.logger import init_logger
25
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
26
27
28
29
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    UnquantizedLinearMethod,
)
30
from vllm.model_executor.layers.quantization import QuantizationConfig
31
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
32
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
33
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
34
from vllm.model_executor.models.vision import get_vit_attn_backend
35
from vllm.platforms import current_platform
36
from vllm.utils import GiB_bytes, direct_register_custom_op
37

38
39
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
40
try:
41
    tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe,)
42
43
except AttributeError:
    tag_cudagraph_unsafe = ()  # type: ignore[assignment]
44
45
46
47
48
49
50


def check_xformers_availability():
    global USE_XFORMERS_OPS
    if USE_XFORMERS_OPS is not None:
        return USE_XFORMERS_OPS

51
    if current_platform.is_cuda() and current_platform.has_device_capability(100):
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        # Xformers FA is not compatible with B200
        USE_XFORMERS_OPS = False
    else:
        try:
            from importlib.util import find_spec

            find_spec("xformers.ops")
            USE_XFORMERS_OPS = True
        except ImportError:
            USE_XFORMERS_OPS = False

    # the warning only needs to be shown once
    if not USE_XFORMERS_OPS:
        logger.warning("Xformers is not available, falling back.")

    return USE_XFORMERS_OPS

69

70
def check_upstream_fa_availability(dtype: torch.dtype):
71
72
73
74
75
    if (
        dtype in (torch.float16, torch.bfloat16)
        and current_platform.is_cuda()
        and current_platform.has_device_capability(80)
    ):
76
        from transformers.utils import is_flash_attn_2_available
77

78
        return is_flash_attn_2_available()
79
80
    if current_platform.is_rocm():
        from importlib.util import find_spec
81

82
        return find_spec("flash_attn") is not None
83
84
85
    return False


86
def maybe_get_vit_flash_attn_backend(
87
88
89
90
91
92
93
    attn_backend: _Backend, use_upstream_fa: bool
) -> tuple[_Backend, Callable]:
    if (
        attn_backend != _Backend.FLASH_ATTN
        and attn_backend != _Backend.ROCM_AITER_FA
        and check_upstream_fa_availability(torch.get_default_dtype())
    ):
94
95
96
        attn_backend = _Backend.FLASH_ATTN
        use_upstream_fa = True

97
    if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN:
98
99
        use_upstream_fa = True

100
    if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
101
102
103
104
105
106
107
108
109
110
111
112
113
        if attn_backend == _Backend.ROCM_AITER_FA:
            from aiter import flash_attn_varlen_func
        else:
            if use_upstream_fa:
                from flash_attn import flash_attn_varlen_func
            else:
                from vllm.vllm_flash_attn import flash_attn_varlen_func
    else:
        flash_attn_varlen_func = None

    return attn_backend, flash_attn_varlen_func


114
class Attention(nn.Module, AttentionLayerBase):
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
132
        alibi_slopes: Optional[list[float]] = None,
133
        cache_config: Optional[CacheConfig] = None,
134
        quant_config: Optional[QuantizationConfig] = None,
135
        logits_soft_cap: Optional[float] = None,
136
        per_layer_sliding_window: Optional[int] = None,
137
        prefix: str = "",
138
        attn_type: str = AttentionType.DECODER,
139
        kv_sharing_target_layer_name: Optional[str] = None,
140
        attn_backend: Optional[type[AttentionBackend]] = None,
141
        **extra_impl_args,
142
    ) -> None:
143
144
145
146
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
147
        super().__init__()
148
149
150
151
152
153
154
155
156
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

157
158
159
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
160
            calculate_kv_scales = cache_config.calculate_kv_scales
161
162
163
        else:
            kv_cache_dtype = "auto"
            block_size = 16
164
            calculate_kv_scales = False
165
166
        if num_kv_heads is None:
            num_kv_heads = num_heads
167
168
169
        assert num_heads % num_kv_heads == 0, (
            f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
        )
170

171
        # The default k/v_scale is set to 1.0. This is ignored
172
173
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
174
        # expect the pre-quantized k/v_scale to be loaded along
175
176
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
177
178
179
        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
180
181
182
        # FlashAttn doesn't support quantizing the kv-cache only
        # but requires q to be quantized as well.
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
183
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
184

185
186
187
188
        # We also keep q/k/v_scale on host (cpu) memory for attention
        # backends that require the scales to be on host instead of on device.
        # e.g. Flashinfer
        self._q_scale_float = 1.0
189
190
191
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

192
193
194
195
        # The output scale on host memory. This should be the input scale of
        # the quant op after this attention layer.
        self._o_scale_float: Optional[float] = None

196
197
198
199
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
200
        self.has_sink = extra_impl_args.get("sinks") is not None
201

202
203
204
        quant_method = (
            quant_config.get_quant_method(self, prefix=prefix) if quant_config else None
        )
205
        if quant_method is not None and not isinstance(
206
207
            quant_method, UnquantizedLinearMethod
        ):
208
            assert isinstance(quant_method, BaseKVCacheMethod)
209
210
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
211
            if self.kv_cache_dtype == "fp8_e5m2":
212
213
214
                raise ValueError(
                    "fp8_e5m2 kv-cache is not supported with fp8 checkpoints."
                )
215
216
217
218
219
220
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
221

222
223
224
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
225
        if attn_backend is None:
226
227
228
229
230
            self.attn_backend = get_attn_backend(
                head_size,
                dtype,
                kv_cache_dtype,
                block_size,
231
                use_mla=False,
232
233
                has_sink=self.has_sink,
            )
234
235
236
237
        else:
            self.attn_backend = attn_backend

        impl_cls = self.attn_backend.get_impl_cls()
238
239
240
241
242
243
244
245
246
247
248
249
250
        self.impl = impl_cls(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **extra_impl_args,
        )
251
        self.backend = backend_name_to_enum(self.attn_backend.get_name())
252
        self.dtype = dtype
253

254
255
256
257
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
258
        self.use_direct_call = not current_platform.opaque_attention_op()
259

260
        self.use_output = self.attn_backend.accept_output_buffer
261
262
263
264
265
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
266
        self.attn_type = attn_type
267
268
269
270
271
272
273
274
275

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

276
277
278
279
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
280
281
282
283
            torch.tensor([])
            for _ in range(
                get_current_vllm_config().parallel_config.pipeline_parallel_size
            )
284
        ]
285

286
        try:
287
288
289
            self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
            self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
            self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
290
        except torch.cuda.OutOfMemoryError as e:
291
            logger.error("Failed to initialize attention q/k/v range constants: %s", e)
292
293
            if torch.cuda.is_available():
                logger.debug("CUDA device: %s", torch.cuda.current_device())
294
295
296
297
298
299
                logger.debug(
                    "Allocated: %.2f GiB", torch.cuda.memory_allocated() / GiB_bytes
                )
                logger.debug(
                    "Reserved: %.2f GiB", torch.cuda.memory_reserved() / GiB_bytes
                )
300
301
302
            raise RuntimeError(
                "Failed to initialize q/k/v range constants. "
                "This may be caused by insufficient memory to allocate "
303
304
                "kv cache."
            ) from e
305

306
307
        # for attn backends supporting query quantization
        self.query_quant = None
308
309
310
311
312
        if (
            self.kv_cache_dtype.startswith("fp8")
            and self.attn_backend.supports_quant_query_input
        ):
            self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
313

314
315
316
317
318
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
319
320
321
322
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: Optional[torch.Size] = None,
323
    ) -> torch.Tensor:
324
325
326
327
328
329
330
331
332
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
333
        if self.calculate_kv_scales:
334
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
335
336
337
338
339
340
341
342
343
344
345

        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
            query, _ = self.query_quant(query, self._q_scale)

346
        if self.use_output:
347
348
            output_shape = output_shape if output_shape is not None else query.shape
            output = torch.zeros(output_shape, dtype=output_dtype, device=query.device)
349
            hidden_size = output_shape[-1]
350
351
352
353
354
355
356
357
358
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
            output = output.view(-1, self.num_heads, self.head_size)
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
                value = value.view(-1, self.num_kv_heads, self.head_size)
359
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
360
                forward_context: ForwardContext = get_forward_context()
361
                attn_metadata = forward_context.attn_metadata
362
363
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
364
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
365
366
367
                self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata, output=output
                )
368
369
            else:
                torch.ops.vllm.unified_attention_with_output(
370
371
                    query, key, value, output, self.layer_name
                )
372
            return output.view(-1, hidden_size)
373
        else:
374
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
375
                forward_context = get_forward_context()
376
                attn_metadata = forward_context.attn_metadata
377
378
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
379
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
380
381
382
                return self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata
                )
383
384
            else:
                return torch.ops.vllm.unified_attention(
385
386
                    query, key, value, self.layer_name
                )
387

388
389
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
390
391
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
392
        self._q_scale_float = self._q_scale.item()
393
394
395
396
397
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

398
399
400
401
402
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
403
        s += f", backend={self.impl.__class__.__name__}"
404
        return s
405

406
    def process_weights_after_loading(self, act_dtype: torch.dtype):
407
        if hasattr(self.impl, "process_weights_after_loading"):
408
            self.impl.process_weights_after_loading(act_dtype)
409

410
        # FlashInfer requires attention sinks to be float32
411
        if self.backend == _Backend.FLASHINFER and hasattr(self.impl, "sinks"):
412
            from vllm.v1.attention.backends.flashinfer import FlashInferImpl
413

414
            assert isinstance(self.impl, FlashInferImpl)
415
            if self.impl.sinks is not None and self.impl.sinks.dtype != torch.float32:
416
417
                self.impl.sinks = self.impl.sinks.to(torch.float32)

418
419
420
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

421

422
423
424
425
426
427
428
429
430
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
431
432
433
434
        # This has no effect, it is only here to make it easier to swap
        # between Attention and MultiHeadAttention
        prefix: str = "",
    ) -> None:
435
436
437
438
439
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
440
        self.layer_name = prefix
441

442
443
        assert self.num_heads % self.num_kv_heads == 0, (
            f"num_heads ({self.num_heads}) is not "
444
            f"divisible by num_kv_heads ({self.num_kv_heads})"
445
        )
446
447
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

448
449
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
450
        dtype = torch.get_default_dtype()
451
452
453
454
455
456
457
458
459

        # Determine the attention backend
        backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)

        # Some auto-selected backends can be upgraded
        # to upstream flash attention if available.
        # If vllm native fa is selected, we use it directly.
        use_upstream_fa = False

460
461
        if current_platform.is_xpu():
            # currently, only torch_sdpa is supported on xpu
462
463
            self.attn_backend = _Backend.TORCH_SDPA
        else:
464
465
466
467
468
469
470
471
472
473
474
475
            self.attn_backend = (
                backend
                if backend
                in {
                    _Backend.TORCH_SDPA,
                    _Backend.XFORMERS,
                    _Backend.PALLAS,
                    _Backend.ROCM_AITER_FA,
                    _Backend.FLASH_ATTN,
                }
                else _Backend.TORCH_SDPA
            )
476

477
478
        self.attn_backend, self._flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
479
480
481
                self.attn_backend,
                use_upstream_fa,
            )
482
        )
483

484
        if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability():
485
486
            self.attn_backend = _Backend.TORCH_SDPA

487
        self.is_flash_attn_backend = self.attn_backend in {
488
489
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
490
491
492
493
        }

        # this condition is just to make sure that the
        # use_upstream_fa in the log is correct
494
        if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
495
            use_upstream_fa = True
496
497
498

        logger.info_once(
            f"MultiHeadAttention attn_backend: {self.attn_backend}, "
499
500
            f"use_upstream_fa: {use_upstream_fa}"
        )
501

502
503
504
505
506
507
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
508
        """Input shape:
509
510
511
512
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
513
514
515
516
517
518
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

519
520
521
522
523
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

524
        if self.is_flash_attn_backend:
525
526
527
528
529
530
            cu_seqlens_q = torch.arange(
                0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
            )
            cu_seqlens_k = torch.arange(
                0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device
            )
531
532
533
534
535
536
537
538
539
540
541
542

            out = self._flash_attn_varlen_func(
                query.flatten(0, 1),
                key.flatten(0, 1),
                value.flatten(0, 1),
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=q_len,
                max_seqlen_k=kv_len,
                softmax_scale=self.scale,
            )
        elif self.attn_backend == _Backend.XFORMERS:
543
544
            from xformers import ops as xops

545
546
547
            out = xops.memory_efficient_attention_forward(
                query, key, value, scale=self.scale
            )
548
        elif self.attn_backend == _Backend.TORCH_SDPA:
549
550
            query, key, value = (x.transpose(1, 2) for x in (query, key, value))
            out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
551
            out = out.transpose(1, 2)
552
        elif self.attn_backend == _Backend.PALLAS:
553
            query, key, value = (x.transpose(1, 2) for x in (query, key, value))
554
            from torch_xla.experimental.custom_kernel import flash_attention
555

556
557
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)
558
559
560
        else:
            # ViT attention hasn't supported this backend yet
            raise NotImplementedError(
561
562
                f"ViT attention hasn't supported {self.attn_backend} backend yet."
            )
563

564
        return out.reshape(bsz, q_len, -1)
565
566


567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
class MLAAttention(nn.Module, AttentionLayerBase):
    """Multi-Head Latent Attention layer.

    This class takes query, and compressed key/value tensors as input.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: Optional[int],
        kv_lora_rank: int,
        kv_b_proj: ColumnParallelLinear,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
        use_sparse: bool = False,
        indexer: Optional[object] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = scale
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.head_size = kv_lora_rank + qk_rope_head_dim
        self.layer_name = prefix

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False
        self.kv_cache_dtype = kv_cache_dtype

        dtype = torch.get_default_dtype()
        self.attn_backend = get_attn_backend(
            self.head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=True,
            use_sparse=use_sparse,
        )
        impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
        self.impl = impl_cls(
            num_heads=self.num_heads,
            head_size=self.head_size,
            scale=self.scale,
            num_kv_heads=1,
            alibi_slopes=None,
            sliding_window=None,
            kv_cache_dtype=self.kv_cache_dtype,
            logits_soft_cap=None,
            attn_type=AttentionType.DECODER,
            kv_sharing_target_layer_name=None,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=kv_b_proj,
            indexer=indexer,
        )

        self.use_direct_call = not current_platform.opaque_attention_op()

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

        self.kv_cache = [
            torch.tensor([])
            for _ in range(
                get_current_vllm_config().parallel_config.pipeline_parallel_size
            )
        ]

        # Align with Attention's scale attributes for MLA backends.

        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)

        # Host-side mirrors used by some attention backends
        self._q_scale_float = 1.0
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0
        self._o_scale_float: Optional[float] = None

        self.use_sparse = use_sparse

        # Initialize q/k/v range constants.
        try:
            self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
            self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
            self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
        except torch.cuda.OutOfMemoryError:
            # Keep defaults if allocation fails; not critical for init.
            pass

    def forward(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
        output_shape: Optional[torch.Size] = None,
    ) -> torch.Tensor:
        if self.use_direct_call:
            forward_context: ForwardContext = get_forward_context()
            attn_metadata = forward_context.attn_metadata
            if isinstance(attn_metadata, dict):
                attn_metadata = attn_metadata[self.layer_name]
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]

            # Mirror Attention.forward scale calculation path
            if self.calculate_kv_scales and getattr(
                attn_metadata, "enable_kv_scales_calculation", False
            ):
                self.calc_kv_scales(q, kv_c_normed, k_pe)

            if self.attn_backend.accept_output_buffer:
                output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
                self.impl.forward(
                    self,
                    q,
                    kv_c_normed,
                    k_pe,
                    self_kv_cache,
                    attn_metadata,
                    output=output,
                )
                return output
            else:
                return self.impl.forward(
                    self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
                )
        else:
            if self.attn_backend.accept_output_buffer:
                output = torch.zeros(output_shape, dtype=q.dtype, device=q.device)
                torch.ops.vllm.unified_mla_attention_with_output(
                    q,
                    kv_c_normed,
                    k_pe,
                    output,
                    self.layer_name,
                )
                return output
            else:
                # We can still access forward context to check calculation flag
                if self.calculate_kv_scales:
                    forward_context = get_forward_context()
                    attn_metadata = forward_context.attn_metadata
                    if isinstance(attn_metadata, dict):
                        attn_metadata = attn_metadata[self.layer_name]
                    if getattr(attn_metadata, "enable_kv_scales_calculation", False):
                        self.calc_kv_scales(q, kv_c_normed, k_pe)
                return torch.ops.vllm.unified_mla_attention(
                    q,
                    kv_c_normed,
                    k_pe,
                    self.layer_name,
                )

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if hasattr(self.impl, "process_weights_after_loading"):
            self.impl.process_weights_after_loading(act_dtype)

    def calc_kv_scales(
        self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
    ) -> None:
        """Optional scale calculation for MLA inputs.

        Mirrors Attention.calc_kv_scales. Not all MLA backends require this
        """
        # Use safe defaults if ranges are not present
        q_range = getattr(self, "q_range", torch.tensor(1.0))
        k_range = getattr(self, "k_range", torch.tensor(1.0))
        v_range = getattr(self, "v_range", torch.tensor(1.0))

        self._q_scale.copy_(torch.abs(q).max() / q_range)
        # kv_c_normed is the compressed KV representation; use it for k/v
        kv_abs_max = torch.abs(kv_c_normed).max()
        self._k_scale.copy_(kv_abs_max / k_range)
        self._v_scale.copy_(kv_abs_max / v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        self.calculate_kv_scales = False

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend


779
780
781
782
783
784
785
786
787
788
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
789
    assert isinstance(attn_metadata, dict)
790
791
792
793
794
    connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
    layer_name: str,
795
    kv_cache_layer: list[torch.Tensor],
796
797
798
799
800
801
802
803
804
805
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
806
    assert isinstance(attn_metadata, dict)
807
    connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
808
809


810
811
812
813
814
815
816
817
818
819
820
821
822
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata

    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]

    if attn_metadata is None or not getattr(
823
824
        attn_metadata, "enable_kv_scales_calculation", False
    ):
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
        return

    self = forward_context.no_compile_layers[layer_name]
    self.calc_kv_scales(query, key, value)


def maybe_calc_kv_scales_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=["query", "key", "value"],
    fake_impl=maybe_calc_kv_scales_fake,
)


848
849
850
851
852
853
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
854
855
    wait_for_kv_layer_from_connector(layer_name)

856
    forward_context: ForwardContext = get_forward_context()
857
    attn_metadata = forward_context.attn_metadata
858
859
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
860
    self = forward_context.no_compile_layers[layer_name]
861
    kv_cache = self.kv_cache[forward_context.virtual_engine]
862
    output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
863
864
865

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    return output
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
881
    tags=tag_cudagraph_unsafe,
882
)
883
884
885
886
887
888
889
890


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
891
    output_scale: Optional[torch.Tensor] = None,
892
    output_block_scale: Optional[torch.Tensor] = None,
893
) -> None:
894
    wait_for_kv_layer_from_connector(layer_name)
895
    forward_context: ForwardContext = get_forward_context()
896
    attn_metadata = forward_context.attn_metadata
897
898
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
899
    self = forward_context.no_compile_layers[layer_name]
900
    kv_cache = self.kv_cache[forward_context.virtual_engine]
901
902
903
904
905
906
907
908
909
910
911
    self.impl.forward(
        self,
        query,
        key,
        value,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )
912

913
914
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)

915
916
917
918
919
920
921

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
922
    output_scale: Optional[torch.Tensor] = None,
923
    output_block_scale: Optional[torch.Tensor] = None,
924
925
926
927
928
929
930
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
931
    mutates_args=["output", "output_block_scale"],
932
    fake_impl=unified_attention_with_output_fake,
933
    tags=tag_cudagraph_unsafe,
934
)
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024


def unified_mla_attention(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    wait_for_kv_layer_from_connector(layer_name)

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
    self: MLAAttention = forward_context.no_compile_layers[layer_name]
    kv_cache = self.kv_cache[forward_context.virtual_engine]
    output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    return output


def unified_mla_attention_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(q).contiguous()


direct_register_custom_op(
    op_name="unified_mla_attention",
    op_func=unified_mla_attention,
    mutates_args=[],
    fake_impl=unified_mla_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)


def unified_mla_attention_with_output(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
    output_scale: Optional[torch.Tensor] = None,
    output_block_scale: Optional[torch.Tensor] = None,
) -> None:
    wait_for_kv_layer_from_connector(layer_name)
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
    self: MLAAttention = forward_context.no_compile_layers[layer_name]
    kv_cache = self.kv_cache[forward_context.virtual_engine]
    self.impl.forward(
        self,
        q,
        kv_c_normed,
        k_pe,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)


def unified_mla_attention_with_output_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
    output_scale: Optional[torch.Tensor] = None,
    output_block_scale: Optional[torch.Tensor] = None,
) -> None:
    return


direct_register_custom_op(
    op_name="unified_mla_attention_with_output",
    op_func=unified_mla_attention_with_output,
    mutates_args=["output", "output_block_scale"],
    fake_impl=unified_mla_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)