layer.py 35.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4

5
6
from collections.abc import Callable
from typing import cast
7
8
9

import torch
import torch.nn as nn
10
import torch.nn.functional as F
11

12
import vllm.envs as envs
13
from vllm.attention import AttentionType
14
from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl
15
16
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.selector import get_attn_backend
17
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
18
from vllm.config import CacheConfig, get_current_vllm_config
19
20
21
22
23
from vllm.distributed.kv_transfer import (
    get_kv_transfer_group,
    has_kv_transfer_group,
    is_v1_kv_transfer_group,
)
24
from vllm.forward_context import ForwardContext, get_forward_context
25
from vllm.logger import init_logger
26
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
27
28
29
30
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    UnquantizedLinearMethod,
)
31
from vllm.model_executor.layers.quantization import QuantizationConfig
32
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
33
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
34
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
35
from vllm.model_executor.models.vision import get_vit_attn_backend
36
from vllm.platforms import current_platform
37
from vllm.utils import direct_register_custom_op
38

39
FP8_DTYPE = current_platform.fp8_dtype()
40
41
42
43
44
45
46
47
48
logger = init_logger(__name__)
USE_XFORMERS_OPS = None


def check_xformers_availability():
    global USE_XFORMERS_OPS
    if USE_XFORMERS_OPS is not None:
        return USE_XFORMERS_OPS

49
    if current_platform.is_cuda() and current_platform.has_device_capability(100):
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        # Xformers FA is not compatible with B200
        USE_XFORMERS_OPS = False
    else:
        try:
            from importlib.util import find_spec

            find_spec("xformers.ops")
            USE_XFORMERS_OPS = True
        except ImportError:
            USE_XFORMERS_OPS = False

    # the warning only needs to be shown once
    if not USE_XFORMERS_OPS:
        logger.warning("Xformers is not available, falling back.")

    return USE_XFORMERS_OPS

67

68
def check_upstream_fa_availability(dtype: torch.dtype):
69
70
71
72
73
    if (
        dtype in (torch.float16, torch.bfloat16)
        and current_platform.is_cuda()
        and current_platform.has_device_capability(80)
    ):
74
        from transformers.utils import is_flash_attn_2_available
75

76
        return is_flash_attn_2_available()
77
78
    if current_platform.is_rocm():
        from importlib.util import find_spec
79

80
        return find_spec("flash_attn") is not None
81
82
83
    return False


84
def maybe_get_vit_flash_attn_backend(
85
86
87
88
89
90
91
    attn_backend: _Backend, use_upstream_fa: bool
) -> tuple[_Backend, Callable]:
    if (
        attn_backend != _Backend.FLASH_ATTN
        and attn_backend != _Backend.ROCM_AITER_FA
        and check_upstream_fa_availability(torch.get_default_dtype())
    ):
92
93
94
        attn_backend = _Backend.FLASH_ATTN
        use_upstream_fa = True

95
    if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN:
96
97
        use_upstream_fa = True

98
    if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
99
100
101
102
103
104
105
106
107
108
109
110
111
        if attn_backend == _Backend.ROCM_AITER_FA:
            from aiter import flash_attn_varlen_func
        else:
            if use_upstream_fa:
                from flash_attn import flash_attn_varlen_func
            else:
                from vllm.vllm_flash_attn import flash_attn_varlen_func
    else:
        flash_attn_varlen_func = None

    return attn_backend, flash_attn_varlen_func


112
class Attention(nn.Module, AttentionLayerBase):
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
129
130
131
132
133
134
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        logits_soft_cap: float | None = None,
        per_layer_sliding_window: int | None = None,
135
        prefix: str = "",
136
        attn_type: str = AttentionType.DECODER,
137
138
        kv_sharing_target_layer_name: str | None = None,
        attn_backend: type[AttentionBackend] | None = None,
139
        **extra_impl_args,
140
    ) -> None:
141
142
143
144
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
145
        super().__init__()
146
147
148
149
150
151
152
153
154
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

155
156
157
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
158
            calculate_kv_scales = cache_config.calculate_kv_scales
159
160
161
        else:
            kv_cache_dtype = "auto"
            block_size = 16
162
            calculate_kv_scales = False
163
164
        if num_kv_heads is None:
            num_kv_heads = num_heads
165
166
167
        assert num_heads % num_kv_heads == 0, (
            f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
        )
168

169
        # The default k/v_scale is set to 1.0. This is ignored
170
171
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
172
        # expect the pre-quantized k/v_scale to be loaded along
173
174
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
175
176
177
        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
178
179
180
        # FlashAttn doesn't support quantizing the kv-cache only
        # but requires q to be quantized as well.
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
181
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
182

183
184
185
186
        # We also keep q/k/v_scale on host (cpu) memory for attention
        # backends that require the scales to be on host instead of on device.
        # e.g. Flashinfer
        self._q_scale_float = 1.0
187
188
189
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

190
191
        # The output scale on host memory. This should be the input scale of
        # the quant op after this attention layer.
192
        self._o_scale_float: float | None = None
193

194
195
196
197
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
198
        self.has_sink = extra_impl_args.get("sinks") is not None
199

200
201
202
        quant_method = (
            quant_config.get_quant_method(self, prefix=prefix) if quant_config else None
        )
203
        if quant_method is not None and not isinstance(
204
205
            quant_method, UnquantizedLinearMethod
        ):
206
            assert isinstance(quant_method, BaseKVCacheMethod)
207
208
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
209
            if self.kv_cache_dtype == "fp8_e5m2":
210
211
212
                raise ValueError(
                    "fp8_e5m2 kv-cache is not supported with fp8 checkpoints."
                )
213
214
215
216
217
218
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
219

220
221
222
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
223
        if attn_backend is None:
224
225
226
227
228
            self.attn_backend = get_attn_backend(
                head_size,
                dtype,
                kv_cache_dtype,
                block_size,
229
                use_mla=False,
230
231
                has_sink=self.has_sink,
            )
232
233
234
235
        else:
            self.attn_backend = attn_backend

        impl_cls = self.attn_backend.get_impl_cls()
236
237
238
239
240
241
242
243
244
245
246
247
248
        self.impl = impl_cls(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **extra_impl_args,
        )
249
        self.backend = backend_name_to_enum(self.attn_backend.get_name())
250
        self.dtype = dtype
251

252
253
254
255
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
256
        self.use_direct_call = not current_platform.opaque_attention_op()
257

258
        self.use_output = self.attn_backend.accept_output_buffer
259
260
261
262
263
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
264
        self.attn_type = attn_type
265
266
267
268
269
270
271
272
273

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

274
275
276
277
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
278
279
280
281
            torch.tensor([])
            for _ in range(
                get_current_vllm_config().parallel_config.pipeline_parallel_size
            )
282
        ]
283

284
285
286
287
        # Initialize q/k/v range constants.
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
288

289
290
        # for attn backends supporting query quantization
        self.query_quant = None
291
292
        if (
            self.kv_cache_dtype.startswith("fp8")
293
            and self.impl.supports_quant_query_input()
294
295
        ):
            self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
296

297
298
299
300
301
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
302
303
304
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
305
        output_shape: torch.Size | None = None,
306
    ) -> torch.Tensor:
307
308
309
310
311
312
313
314
315
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
316
        if self.calculate_kv_scales:
317
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
318
319
320
321
322
323
324
325
        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
326
327
328
329

            # check if query quantization is supported
            if self.impl.supports_quant_query_input():
                query, _ = self.query_quant(query, self._q_scale)
330

331
        if self.use_output:
332
            output_shape = output_shape if output_shape is not None else query.shape
333
            output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
334
            hidden_size = output_shape[-1]
335
336
337
338
339
340
341
342
343
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
            output = output.view(-1, self.num_heads, self.head_size)
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
                value = value.view(-1, self.num_kv_heads, self.head_size)
344
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
345
                forward_context: ForwardContext = get_forward_context()
346
                attn_metadata = forward_context.attn_metadata
347
348
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
349
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
350
351
352
                self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata, output=output
                )
353
354
            else:
                torch.ops.vllm.unified_attention_with_output(
355
356
                    query, key, value, output, self.layer_name
                )
357
            return output.view(-1, hidden_size)
358
        else:
359
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
360
                forward_context = get_forward_context()
361
                attn_metadata = forward_context.attn_metadata
362
363
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
364
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
365
366
367
                return self.impl.forward(
                    self, query, key, value, self_kv_cache, attn_metadata
                )
368
369
            else:
                return torch.ops.vllm.unified_attention(
370
371
                    query, key, value, self.layer_name
                )
372

373
374
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
375
376
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
377
        self._q_scale_float = self._q_scale.item()
378
379
380
381
382
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

383
384
385
386
387
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
388
        s += f", backend={self.impl.__class__.__name__}"
389
        return s
390

391
    def process_weights_after_loading(self, act_dtype: torch.dtype):
392
        self.impl.process_weights_after_loading(act_dtype)
393

394
395
396
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

397

398
399
400
401
402
403
404
405
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
406
        num_kv_heads: int | None = None,
407
408
409
410
        # This has no effect, it is only here to make it easier to swap
        # between Attention and MultiHeadAttention
        prefix: str = "",
    ) -> None:
411
412
413
414
415
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
416
        self.layer_name = prefix
417

418
419
        assert self.num_heads % self.num_kv_heads == 0, (
            f"num_heads ({self.num_heads}) is not "
420
            f"divisible by num_kv_heads ({self.num_kv_heads})"
421
        )
422
423
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

424
425
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
426
        dtype = torch.get_default_dtype()
427
428
429
430
431
432
433
434
435

        # Determine the attention backend
        backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)

        # Some auto-selected backends can be upgraded
        # to upstream flash attention if available.
        # If vllm native fa is selected, we use it directly.
        use_upstream_fa = False

436
437
        if current_platform.is_xpu():
            # currently, only torch_sdpa is supported on xpu
438
439
            self.attn_backend = _Backend.TORCH_SDPA
        else:
440
441
442
443
444
445
446
447
448
449
450
451
            self.attn_backend = (
                backend
                if backend
                in {
                    _Backend.TORCH_SDPA,
                    _Backend.XFORMERS,
                    _Backend.PALLAS,
                    _Backend.ROCM_AITER_FA,
                    _Backend.FLASH_ATTN,
                }
                else _Backend.TORCH_SDPA
            )
452

453
454
        self.attn_backend, self._flash_attn_varlen_func = (
            maybe_get_vit_flash_attn_backend(
455
456
457
                self.attn_backend,
                use_upstream_fa,
            )
458
        )
459

460
        if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability():
461
462
            self.attn_backend = _Backend.TORCH_SDPA

463
        self.is_flash_attn_backend = self.attn_backend in {
464
465
            _Backend.FLASH_ATTN,
            _Backend.ROCM_AITER_FA,
466
467
468
469
        }

        # this condition is just to make sure that the
        # use_upstream_fa in the log is correct
470
        if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
471
            use_upstream_fa = True
472
473
474

        logger.info_once(
            f"MultiHeadAttention attn_backend: {self.attn_backend}, "
475
476
            f"use_upstream_fa: {use_upstream_fa}"
        )
477

478
479
480
481
482
483
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
484
        """Input shape:
485
486
487
488
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
489
490
491
492
493
494
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

495
496
497
498
499
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

500
        if self.is_flash_attn_backend:
501
502
503
504
505
506
            cu_seqlens_q = torch.arange(
                0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device
            )
            cu_seqlens_k = torch.arange(
                0, (bsz + 1) * kv_len, step=kv_len, dtype=torch.int32, device=key.device
            )
507
508
509
510
511
512
513
514
515
516
517
518

            out = self._flash_attn_varlen_func(
                query.flatten(0, 1),
                key.flatten(0, 1),
                value.flatten(0, 1),
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=q_len,
                max_seqlen_k=kv_len,
                softmax_scale=self.scale,
            )
        elif self.attn_backend == _Backend.XFORMERS:
519
520
            from xformers import ops as xops

521
522
523
            out = xops.memory_efficient_attention_forward(
                query, key, value, scale=self.scale
            )
524
        elif self.attn_backend == _Backend.TORCH_SDPA:
525
526
            query, key, value = (x.transpose(1, 2) for x in (query, key, value))
            out = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
527
            out = out.transpose(1, 2)
528
        elif self.attn_backend == _Backend.PALLAS:
529
            query, key, value = (x.transpose(1, 2) for x in (query, key, value))
530
            from torch_xla.experimental.custom_kernel import flash_attention
531

532
533
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)
534
535
536
        else:
            # ViT attention hasn't supported this backend yet
            raise NotImplementedError(
537
538
                f"ViT attention hasn't supported {self.attn_backend} backend yet."
            )
539

540
        return out.reshape(bsz, q_len, -1)
541
542


543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
class MLAAttention(nn.Module, AttentionLayerBase):
    """Multi-Head Latent Attention layer.

    This class takes query, and compressed key/value tensors as input.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
561
        q_lora_rank: int | None,
562
563
        kv_lora_rank: int,
        kv_b_proj: ColumnParallelLinear,
564
565
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
566
567
        prefix: str = "",
        use_sparse: bool = False,
568
        indexer: object | None = None,
569
        **extra_impl_args,
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = scale
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.head_size = kv_lora_rank + qk_rope_head_dim
        self.layer_name = prefix

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False
        self.kv_cache_dtype = kv_cache_dtype

        dtype = torch.get_default_dtype()
        self.attn_backend = get_attn_backend(
            self.head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=True,
            use_sparse=use_sparse,
        )
        impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
        self.impl = impl_cls(
            num_heads=self.num_heads,
            head_size=self.head_size,
            scale=self.scale,
            num_kv_heads=1,
            alibi_slopes=None,
            sliding_window=None,
            kv_cache_dtype=self.kv_cache_dtype,
            logits_soft_cap=None,
            attn_type=AttentionType.DECODER,
            kv_sharing_target_layer_name=None,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=kv_b_proj,
            indexer=indexer,
622
            **extra_impl_args,
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        )

        self.use_direct_call = not current_platform.opaque_attention_op()

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

        self.kv_cache = [
            torch.tensor([])
            for _ in range(
                get_current_vllm_config().parallel_config.pipeline_parallel_size
            )
        ]

        # Align with Attention's scale attributes for MLA backends.

        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)

        # Host-side mirrors used by some attention backends
        self._q_scale_float = 1.0
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0
651
        self._o_scale_float: float | None = None
652
653
654
655

        self.use_sparse = use_sparse

        # Initialize q/k/v range constants.
656
657
658
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
659
660
661
662
663
664

    def forward(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
665
        output_shape: torch.Size | None = None,
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
    ) -> torch.Tensor:
        if self.use_direct_call:
            forward_context: ForwardContext = get_forward_context()
            attn_metadata = forward_context.attn_metadata
            if isinstance(attn_metadata, dict):
                attn_metadata = attn_metadata[self.layer_name]
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]

            # Mirror Attention.forward scale calculation path
            if self.calculate_kv_scales and getattr(
                attn_metadata, "enable_kv_scales_calculation", False
            ):
                self.calc_kv_scales(q, kv_c_normed, k_pe)

            if self.attn_backend.accept_output_buffer:
681
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
                self.impl.forward(
                    self,
                    q,
                    kv_c_normed,
                    k_pe,
                    self_kv_cache,
                    attn_metadata,
                    output=output,
                )
                return output
            else:
                return self.impl.forward(
                    self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
                )
        else:
            if self.attn_backend.accept_output_buffer:
698
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
                torch.ops.vllm.unified_mla_attention_with_output(
                    q,
                    kv_c_normed,
                    k_pe,
                    output,
                    self.layer_name,
                )
                return output
            else:
                # We can still access forward context to check calculation flag
                if self.calculate_kv_scales:
                    forward_context = get_forward_context()
                    attn_metadata = forward_context.attn_metadata
                    if isinstance(attn_metadata, dict):
                        attn_metadata = attn_metadata[self.layer_name]
                    if getattr(attn_metadata, "enable_kv_scales_calculation", False):
                        self.calc_kv_scales(q, kv_c_normed, k_pe)
                return torch.ops.vllm.unified_mla_attention(
                    q,
                    kv_c_normed,
                    k_pe,
                    self.layer_name,
                )

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if hasattr(self.impl, "process_weights_after_loading"):
            self.impl.process_weights_after_loading(act_dtype)

    def calc_kv_scales(
        self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
    ) -> None:
        """Optional scale calculation for MLA inputs.

        Mirrors Attention.calc_kv_scales. Not all MLA backends require this
        """
        # Use safe defaults if ranges are not present
        q_range = getattr(self, "q_range", torch.tensor(1.0))
        k_range = getattr(self, "k_range", torch.tensor(1.0))
        v_range = getattr(self, "v_range", torch.tensor(1.0))

        self._q_scale.copy_(torch.abs(q).max() / q_range)
        # kv_c_normed is the compressed KV representation; use it for k/v
        kv_abs_max = torch.abs(kv_c_normed).max()
        self._k_scale.copy_(kv_abs_max / k_range)
        self._v_scale.copy_(kv_abs_max / v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        self.calculate_kv_scales = False

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend


753
754
755
756
757
758
759
760
761
762
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
763
    assert isinstance(attn_metadata, dict)
764
765
766
767
768
    connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
    layer_name: str,
769
    kv_cache_layer: list[torch.Tensor],
770
771
772
773
774
775
776
777
778
779
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
780
    assert isinstance(attn_metadata, dict)
781
    connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name])
782
783


784
785
786
787
788
789
790
791
792
793
794
795
796
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata

    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]

    if attn_metadata is None or not getattr(
797
798
        attn_metadata, "enable_kv_scales_calculation", False
    ):
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
        return

    self = forward_context.no_compile_layers[layer_name]
    self.calc_kv_scales(query, key, value)


def maybe_calc_kv_scales_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=["query", "key", "value"],
    fake_impl=maybe_calc_kv_scales_fake,
)


822
823
824
825
826
827
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
828
829
    wait_for_kv_layer_from_connector(layer_name)

830
    forward_context: ForwardContext = get_forward_context()
831
    attn_metadata = forward_context.attn_metadata
832
833
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
834
    self = forward_context.no_compile_layers[layer_name]
835
    kv_cache = self.kv_cache[forward_context.virtual_engine]
836
    output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
837
838
839

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    return output
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
)
856
857
858
859
860
861
862
863


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
864
865
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
866
) -> None:
867
    wait_for_kv_layer_from_connector(layer_name)
868
    forward_context: ForwardContext = get_forward_context()
869
    attn_metadata = forward_context.attn_metadata
870
871
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
872
    self = forward_context.no_compile_layers[layer_name]
873
    kv_cache = self.kv_cache[forward_context.virtual_engine]
874
875
876
877
878
879
880
881
882
883
884
    self.impl.forward(
        self,
        query,
        key,
        value,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )
885

886
887
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)

888
889
890
891
892
893
894

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
895
896
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
897
898
899
900
901
902
903
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
904
    mutates_args=["output", "output_block_scale"],
905
906
    fake_impl=unified_attention_with_output_fake,
)
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952


def unified_mla_attention(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    wait_for_kv_layer_from_connector(layer_name)

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
    self: MLAAttention = forward_context.no_compile_layers[layer_name]
    kv_cache = self.kv_cache[forward_context.virtual_engine]
    output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    return output


def unified_mla_attention_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(q).contiguous()


direct_register_custom_op(
    op_name="unified_mla_attention",
    op_func=unified_mla_attention,
    mutates_args=[],
    fake_impl=unified_mla_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)


def unified_mla_attention_with_output(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
953
954
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
) -> None:
    wait_for_kv_layer_from_connector(layer_name)
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
    self: MLAAttention = forward_context.no_compile_layers[layer_name]
    kv_cache = self.kv_cache[forward_context.virtual_engine]
    self.impl.forward(
        self,
        q,
        kv_c_normed,
        k_pe,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)


def unified_mla_attention_with_output_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
984
985
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
986
987
988
989
990
991
992
993
994
995
996
) -> None:
    return


direct_register_custom_op(
    op_name="unified_mla_attention_with_output",
    op_func=unified_mla_attention_with_output,
    mutates_args=["output", "output_block_scale"],
    fake_impl=unified_mla_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)