attention.py 28.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import TYPE_CHECKING, Any
5
6
7
8

import torch
import torch.nn as nn

9
import vllm.envs as envs
10
from vllm.config import CacheConfig, get_current_vllm_config
11
from vllm.config.vllm import VllmConfig
12
from vllm.forward_context import ForwardContext, get_forward_context
13
from vllm.logger import init_logger
14
15
16
from vllm.model_executor.layers.attention.kv_transfer_utils import (
    maybe_transfer_kv_layer,
)
17
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
18
19
20
from vllm.model_executor.layers.linear import (
    UnquantizedLinearMethod,
)
21
from vllm.model_executor.layers.quantization import QuantizationConfig
22
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
23
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
24
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
25
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
26
from vllm.platforms import current_platform
27
from vllm.utils.torch_utils import (
28
29
30
    direct_register_custom_op,
    kv_cache_dtype_str_to_dtype,
)
31
32
33
34
35
36
from vllm.v1.attention.backend import (
    AttentionBackend,
    AttentionType,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import get_attn_backend
37
38
39
40
from vllm.v1.kv_cache_interface import (
    FullAttentionSpec,
    KVCacheSpec,
    SlidingWindowSpec,
41
    get_kv_quant_mode,
42
)
43

44
45
46
if TYPE_CHECKING:
    from vllm.model_executor.layers.attention import MLAAttention

47
48
logger = init_logger(__name__)

49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def validate_kv_sharing_target(
    current_layer_name, target_layer_name, static_forward_context
):
    error_msg = (
        f"Specified KV sharing target layer for {current_layer_name} "
        f"is not valid: target layer {target_layer_name} "
    )

    if current_layer_name == target_layer_name:
        raise ValueError(error_msg + "cannot be the same as the current layer.")

    if target_layer_name not in static_forward_context:
        from vllm.model_executor.models.utils import extract_layer_index

        # If target layer name is not in the static fwd context, it means either
        # a) the target layer does not come BEFORE the current layer, or
        # b) the target layer is not an Attention layer that exists in the model
        current_layer_idx = extract_layer_index(current_layer_name)
        target_layer_idx = extract_layer_index(target_layer_name)
        if current_layer_idx <= target_layer_idx:
            raise ValueError(error_msg + "must come before the current layer.")
        else:
            raise ValueError(error_msg + "is not a valid Attention layer in the model.")

    # Currently KV sharing is only supported between layers of the same type
    target_layer_attn_type = static_forward_context[target_layer_name].attn_type
    expected = static_forward_context[current_layer_name].attn_type
    if target_layer_attn_type != expected:
        raise ValueError(
            error_msg + f"must be the same type as the current layer ({expected})."
        )


83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
def should_load_quant_weights(quant_method: QuantizeMethodBase | None) -> bool:
    """Returns whether the quantization method should load quantized weights."""
    return quant_method is not None and not isinstance(
        quant_method, UnquantizedLinearMethod
    )


def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) -> None:
    """Sets default quantization scales for the layer."""
    if register_buffer:
        layer.register_buffer("_k_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_v_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_q_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_prob_scale", torch.tensor(1.0, dtype=torch.float32))
    else:
        layer._k_scale.fill_(1.0)
        layer._v_scale.fill_(1.0)
        layer._q_scale.fill_(1.0)
        layer._prob_scale.fill_(1.0)

    # We also keep q/k/v_scale on host (cpu) memory for attention
    # backends that require the scales to be on host instead of on device.
    # e.g. Flashinfer
    layer._q_scale_float = 1.0
    layer._k_scale_float = 1.0
    layer._v_scale_float = 1.0
    layer._prob_scale_float = 1.0

111
112
113
114
115
    # Initialize q/k/v range constants used by calc_kv_scales
    layer.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
    layer.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
    layer.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

116

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def _init_kv_cache_quant(
    layer: nn.Module,
    quant_config: QuantizationConfig | None,
    prefix: str,
) -> None:
    """Initializes KV cache scaling factors and quantization method.

    This helper function sets up the KV cache quantization attributes that are
    shared between Attention and MLAAttention layers. It initializes scale
    tensors for query, key, value, and probability, and configures the
    quantization method if applicable.

    Args:
        layer: The attention layer instance to initialize.
        quant_config: Optional quantization configuration.
        prefix: Layer name prefix for quantization method lookup.
    """

135
136
137
138
139
140
141
142
143
144
145
146
147
148
    # Note [Register q/k/v/prob scales in state dict]
    # When calling model.to(device), only parameters/buffers in state dict are
    # moved. If not registering q/k/v/prob scales in state dict, there would
    # be an IMA error when a cuda kernel (e.g., quant_fp8) accesses the tensor
    # on cpu.
    # Registering in state dict means it interacts with weight loading. One edge
    # case is when quant_method is None, or quant_method is UnquantizedLinearMethod
    # (i.e., should_load_quant_weights(quant_method) == False).
    # In this case, the checkpoint does not have the scales. We need to
    # initialize the scales to 1.0 and update the scales after weight loading.
    # This is espectially important when we load dummy weights first (providing
    # wrong scales) and then load real weights (which misses scales and keeps the
    # wrong scales from dummy load).
    set_default_quant_scales(layer, register_buffer=True)
149
150
151
152
153
154
155
156

    # The output scale on host memory. This should be the input scale of
    # the quant op after this attention layer.
    layer._o_scale_float = None

    quant_method = (
        quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
    )
157
158
159

    # See [Note: Register q/k/v/prob scales in state dict]
    if should_load_quant_weights(quant_method):
160
161
162
        assert isinstance(quant_method, BaseKVCacheMethod)
        # TODO (mgoin): kv cache dtype should be specified in the FP8
        # checkpoint config and become the "auto" behavior
163
        if layer.kv_cache_dtype == "fp8_e5m2":
164
165
166
167
168
169
170
171
172
            raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
        # If quantization is enabled, we make "k_scale" and "v_scale"
        # parameters so that it can be loaded from the model checkpoint.
        # The k/v_scale will then be converted back to native float32
        # values after weight loading.
        layer.quant_method = quant_method
        layer.quant_method.create_weights(layer)


173
class Attention(nn.Module, AttentionLayerBase):
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
190
191
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
Li Xie's avatar
Li Xie committed
192
        use_alibi_sqrt: bool | None = None,
193
194
195
196
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        logits_soft_cap: float | None = None,
        per_layer_sliding_window: int | None = None,
197
        prefix: str = "",
198
        attn_type: str = AttentionType.DECODER,
199
200
        kv_sharing_target_layer_name: str | None = None,
        attn_backend: type[AttentionBackend] | None = None,
201
        head_size_v: int | None = None,
202
        **extra_impl_args,
203
    ) -> None:
204
205
206
207
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
208
        super().__init__()
209
210
211
212
213
214
215
216
217
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

218
        vllm_config = get_current_vllm_config()
219
220
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
221
            calculate_kv_scales = cache_config.calculate_kv_scales
222
223
        else:
            kv_cache_dtype = "auto"
224
            calculate_kv_scales = False
225
226

        # llm-compressor mdls need to set cache_dtype to "fp8" manually.
227
228
        kv_cache_scheme = getattr(quant_config, "kv_cache_scheme", None)
        if kv_cache_scheme is not None:
229
230
231
232
233
234
            kv_cache_dtype = "fp8"
            calculate_kv_scales = False
            if cache_config is not None:
                cache_config.cache_dtype = "fp8"
                cache_config.calculate_kv_scales = False

235
236
237
238
239
240
        # Check if per-head quant scales are required based on kv_cache_scheme
        use_per_head_quant_scales = (
            kv_cache_scheme is not None
            and kv_cache_scheme.get("strategy") == "attn_head"
        )

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
        # Skip quantization for specified layers
        if cache_config is not None and cache_config.kv_cache_dtype_skip_layers:
            from vllm.model_executor.models.utils import extract_layer_index

            skip = False
            # Check attention type
            if (
                sliding_window is not None
                and "sliding_window" in cache_config.kv_cache_dtype_skip_layers
            ):
                skip = True
            # Check layer index
            layer_idx = extract_layer_index(prefix)
            if str(layer_idx) in cache_config.kv_cache_dtype_skip_layers:
                skip = True
            if skip:
                kv_cache_dtype = "auto"
                calculate_kv_scales = False
            logger.info(
                "Layer %s: kv_cache_dtype=%s, sliding_window=%s",
                prefix,
                kv_cache_dtype,
                sliding_window,
            )

266
267
268
        self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
            kv_cache_dtype, vllm_config.model_config
        )
269
270
        self.kv_cache_dtype = kv_cache_dtype
        self.calculate_kv_scales = calculate_kv_scales
271
272
        if num_kv_heads is None:
            num_kv_heads = num_heads
273
274
275
        assert num_heads % num_kv_heads == 0, (
            f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
        )
276
277
        self.quant_config = quant_config
        self.layer_name = prefix
278

279
280
        self.num_heads = num_heads
        self.head_size = head_size
281
        self.head_size_v = self.head_size if head_size_v is None else head_size_v
282
283
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
284
        self.has_sink = extra_impl_args.get("sinks") is not None
285

286
287
288
289
        # NOTE: model_config may be None during certain tests
        model_config = vllm_config.model_config
        self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm

290
291
292
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
293
        if attn_backend is None:
294
295
296
297
            self.attn_backend = get_attn_backend(
                head_size,
                dtype,
                kv_cache_dtype,
298
                use_mla=False,
299
                has_sink=self.has_sink,
300
                use_mm_prefix=self.use_mm_prefix,
301
                use_per_head_quant_scales=use_per_head_quant_scales,
302
                attn_type=attn_type,
303
            )
304
305
        else:
            self.attn_backend = attn_backend
Li Xie's avatar
Li Xie committed
306
307
308
309
310
311
312
313
314
315
        backend_supports_alibi_sqrt = self.attn_backend.supports_alibi_sqrt()
        use_alibi_sqrt = use_alibi_sqrt if use_alibi_sqrt else False
        if use_alibi_sqrt and not backend_supports_alibi_sqrt:
            raise ValueError(
                f"use_alibi_sqrt is not supported by backend "
                f"{self.attn_backend.get_name()}."
            )
        self.use_alibi_sqrt = bool(use_alibi_sqrt)
        if backend_supports_alibi_sqrt:
            extra_impl_args["use_alibi_sqrt"] = self.use_alibi_sqrt
316
317
318
319
320
        # prefix caching + batch invariance is currently not supported for
        # FLASHINFER and TRITON_MLA.
        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
321
            and envs.VLLM_BATCH_INVARIANT
322
323
324
325
326
327
328
329
330
331
332
333
            and (
                self.attn_backend.get_name() == "FLASHINFER"
                or self.attn_backend.get_name() == "TRITON_MLA"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for FLASHINFER/TRITON_MLA "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

334
        impl_cls = self.attn_backend.get_impl_cls()
335
336
337
338
339
340
341
342
343
344
345
346
347
        self.impl = impl_cls(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **extra_impl_args,
        )
348
        self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
349
        self.dtype = dtype
350

351
352
353
354
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
355
        self.use_direct_call = not current_platform.opaque_attention_op()
356

357
        self.use_output = self.attn_backend.accept_output_buffer
358
        compilation_config = vllm_config.compilation_config
359
360
361
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
362
        self.attn_type = attn_type
363
364
365
366
367
368
369
370
371

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

372
373
374
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
375
        self.kv_cache = torch.tensor([])
376

377
378
        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(self, quant_config, prefix)
379

380
381
        # for attn backends supporting query quantization
        self.query_quant = None
382
383
384
385
        if (
            self.impl.supports_quant_query_input
            and self.kv_cache_dtype.startswith("fp8")
            and not self.kv_cache_dtype.endswith("per_token_head")
386
        ):
387
388
389
390
391
392
393
394
395
396
            is_per_head = (
                hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
            )
            block_size = self.head_size * self.num_heads // self.num_kv_heads
            self.query_quant = QuantFP8(
                static=True,
                group_shape=GroupShape(-1, block_size)
                if is_per_head
                else GroupShape.PER_TENSOR,
            )
397

398
399
400
401
402
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
403
404
405
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
406
        output_shape: torch.Size | None = None,
407
    ) -> torch.Tensor:
408
409
410
411
412
413
414
415
416
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
417
        if self.calculate_kv_scales:
418
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
419
420
421
422
423
424
425
426
        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
427
428

            # check if query quantization is supported
429
            if self.impl.supports_quant_query_input:
430
                query, _ = self.query_quant(query, self._q_scale)
431

432
        if self.use_output:
433
            if output_shape is None:
434
435
436
                # Handle both 2D [num_tokens, hidden] and
                # 3D [num_tokens, heads, head_dim] query
                num_tokens = query.shape[0]
437
                output_shape = torch.Size(
438
                    (num_tokens, self.num_heads * self.head_size_v)
439
                )
440
            output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
441
            hidden_size = output_shape[-1]
442
443
444
445
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
446
            output = output.view(-1, self.num_heads, self.head_size_v)
447
448
449
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
450
                value = value.view(-1, self.num_kv_heads, self.head_size_v)
451
            kv_cache_dummy_dep = None
452
            if self.use_direct_call:
453
454
455
456
457
458
459
                # Skip this if sharing KV cache with an earlier attention layer.
                if (
                    not self.attn_backend.forward_includes_kv_cache_update
                    and self.kv_sharing_target_layer_name is None
                    and key is not None
                    and value is not None
                ):
460
461
462
463
464
465
466
467
468
469
                    kv_cache_dummy_dep = unified_kv_cache_update(
                        key, value, self.layer_name
                    )
                unified_attention_with_output(
                    query,
                    key,
                    value,
                    output,
                    self.layer_name,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
470
                )
471
            else:
472
473
474
475
476
477
                # Skip this if sharing KV cache with an earlier attention layer.
                if (
                    not self.attn_backend.forward_includes_kv_cache_update
                    and self.kv_sharing_target_layer_name is None
                    and key is not None
                    and value is not None
478
479
480
481
                ):
                    kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
                        key, value, self.layer_name
                    )
482
                torch.ops.vllm.unified_attention_with_output(
483
484
485
486
487
488
                    query,
                    key,
                    value,
                    output,
                    self.layer_name,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
489
                )
490
            return output.view(-1, hidden_size)
491
        else:
492
493
494
            assert self.attn_backend.forward_includes_kv_cache_update, (
                "Split KV cache update not supported when output tensor not provided."
            )
495
            if self.use_direct_call:
496
                return unified_attention(query, key, value, self.layer_name)
497
498
            else:
                return torch.ops.vllm.unified_attention(
499
500
                    query, key, value, self.layer_name
                )
501

502
503
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
504
505
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
506
        self._q_scale_float = self._q_scale.item()
507
508
509
510
511
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

512
513
514
515
516
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
517
        s += f", backend={self.impl.__class__.__name__}"
518
        return s
519

520
    def process_weights_after_loading(self, act_dtype: torch.dtype):
521
        self.impl.process_weights_after_loading(act_dtype)
522

523
524
525
526
527
528
529
530
531
532
533
        # If we should not load quant weights, we initialize the scales to 1.0
        # as the default value. See [Note: Register q/k/v/prob scales in state dict]
        # for more details.
        quant_method = (
            self.quant_config.get_quant_method(self, prefix=self.layer_name)
            if self.quant_config
            else None
        )
        if not should_load_quant_weights(quant_method):
            set_default_quant_scales(self, register_buffer=False)

534
535
536
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

537
538
539
540
541
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER
542
        quant_mode = get_kv_quant_mode(self.kv_cache_dtype)
543
544
545
546
547
548
549
550
551
        if self.sliding_window is not None:
            assert not vllm_config.model_config.use_mla, (
                "MLA is not supported for slidingwindow"
            )
            return SlidingWindowSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
552
                kv_quant_mode=quant_mode,
553
554
555
556
557
558
559
                sliding_window=self.sliding_window,
            )
        else:
            return FullAttentionSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
560
                head_size_v=self.head_size_v,
561
                dtype=self.kv_cache_torch_dtype,
562
                kv_quant_mode=quant_mode,
563
564
            )

565

566
567
568
569
570
571
572
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
573
    self = forward_context.no_compile_layers[layer_name]
574

575
576
577
    # Only calculate if the layer's calculate_kv_scales flag is True
    # This flag gets set to False after the first forward pass
    if not self.calculate_kv_scales:
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
        return

    self.calc_kv_scales(query, key, value)


def maybe_calc_kv_scales_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=["query", "key", "value"],
    fake_impl=maybe_calc_kv_scales_fake,
)


600
def get_attention_context(
601
    layer_name: str,
602
) -> tuple[Any, "Attention | MLAAttention", torch.Tensor, torch.Tensor]:
603
604
605
    """Extract attention context for a given layer.

    This helper function extracts the attention metadata, attention layer
606
    instance, KV cache tensor, and slot mapping for a specific layer.
607
608
609

    Args:
        layer_name: The name/identifier of the attention layer.
610

611
612
613
614
615
    Returns:
        A tuple containing:
        - attn_metadata: Attention metadata for this specific layer, or None if
            no metadata available
        - attn_layer: The attention layer instance (Attention or MLAAttention)
616
        - kv_cache: The KV cache tensor for current forward pass
617
        - slot_mapping: The slot mapping for this specific layer
618
619
620
621

        Note: attn_metadata may be None, but attn_layer and kv_cache are always
        extracted from the forward context.
    """
622
    forward_context: ForwardContext = get_forward_context()
623
    attn_metadata = forward_context.attn_metadata
624
625
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
626
    attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
627
    kv_cache = attn_layer.kv_cache
628
629
630
631
632
633
    slot_mapping = forward_context.slot_mapping
    assert isinstance(slot_mapping, dict), (
        f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
    )
    layer_slot_mapping = slot_mapping.get(layer_name)
    return attn_metadata, attn_layer, kv_cache, layer_slot_mapping
634
635
636
637
638
639
640
641
642


@maybe_transfer_kv_layer
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
643
    attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
644
    output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
645
646

    return output
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
)
663
664


665
666
667
668
669
670
671
672
673
def unified_kv_cache_update(
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    """
    Returns a dummy that is passed to unified_attention to signal a side effect and
    the data dependency between them to ensure torch.compile preserves ordering.
    """
674
    _, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
    if layer_slot_mapping is not None:
        assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
            f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
        )
        attn_layer.impl.do_kv_cache_update(
            attn_layer,
            key,
            value,
            kv_cache,
            layer_slot_mapping,
        )

    return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)


def unified_kv_cache_update_fake(
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty(0, device=key.device, dtype=key.dtype)


direct_register_custom_op(
    op_name="unified_kv_cache_update",
    op_func=unified_kv_cache_update,
    fake_impl=unified_kv_cache_update_fake,
    mutates_args=[],
)


706
@maybe_transfer_kv_layer
707
708
709
710
711
712
def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
713
714
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
715
    kv_cache_dummy_dep: torch.Tensor | None = None,
716
) -> None:
717
718
719
720
    # kv_cache_dummy_dep is not used but accepting it creates a data dependency
    # that ensures torch.compile preserves ordering between KV cache update and
    # attention forward.
    del kv_cache_dummy_dep
721
    attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
722

723
724
725
726
727
728
729
730
731
732
733
    self.impl.forward(
        self,
        query,
        key,
        value,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )
734
735
736
737
738
739
740
741


def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
742
743
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
744
    kv_cache_dummy_dep: torch.Tensor | None = None,
745
746
747
748
749
750
751
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
752
    mutates_args=["output", "output_block_scale"],
753
754
    fake_impl=unified_attention_with_output_fake,
)