layer.py 44.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4

5
from typing import cast, Optional
6
7
8
9

import torch
import torch.nn as nn

10
import vllm.envs as envs
11
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
12
from vllm.attention.utils.kv_transfer_utils import maybe_transfer_kv_layer
13
from vllm.config import CacheConfig, get_current_vllm_config
14
from vllm.config.vllm import VllmConfig
15
from vllm.forward_context import ForwardContext, get_forward_context
16
from vllm.logger import init_logger
17
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
18
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
19
20
21
22
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    UnquantizedLinearMethod,
)
23
from vllm.model_executor.layers.quantization import QuantizationConfig
24
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
25
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
26
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
27
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
28
from vllm.platforms import current_platform
29
from vllm.utils.torch_utils import (
30
31
32
    direct_register_custom_op,
    kv_cache_dtype_str_to_dtype,
)
33
34
35
36
37
38
39
from vllm.v1.attention.backend import (
    AttentionBackend,
    AttentionType,
    MLAAttentionImpl,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import get_attn_backend
40
41
42
43
44
45
from vllm.v1.kv_cache_interface import (
    FullAttentionSpec,
    KVCacheSpec,
    MLAAttentionSpec,
    SlidingWindowSpec,
)
46

47
48
49
logger = init_logger(__name__)


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def should_load_quant_weights(quant_method: QuantizeMethodBase | None) -> bool:
    """Returns whether the quantization method should load quantized weights."""
    return quant_method is not None and not isinstance(
        quant_method, UnquantizedLinearMethod
    )


def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) -> None:
    """Sets default quantization scales for the layer."""
    if register_buffer:
        layer.register_buffer("_k_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_v_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_q_scale", torch.tensor(1.0, dtype=torch.float32))
        layer.register_buffer("_prob_scale", torch.tensor(1.0, dtype=torch.float32))
    else:
        layer._k_scale.fill_(1.0)
        layer._v_scale.fill_(1.0)
        layer._q_scale.fill_(1.0)
        layer._prob_scale.fill_(1.0)

    # We also keep q/k/v_scale on host (cpu) memory for attention
    # backends that require the scales to be on host instead of on device.
    # e.g. Flashinfer
    layer._q_scale_float = 1.0
    layer._k_scale_float = 1.0
    layer._v_scale_float = 1.0
    layer._prob_scale_float = 1.0

78
79
80
81
82
    # Initialize q/k/v range constants used by calc_kv_scales
    layer.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
    layer.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
    layer.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

83

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def _init_kv_cache_quant(
    layer: nn.Module,
    quant_config: QuantizationConfig | None,
    prefix: str,
) -> None:
    """Initializes KV cache scaling factors and quantization method.

    This helper function sets up the KV cache quantization attributes that are
    shared between Attention and MLAAttention layers. It initializes scale
    tensors for query, key, value, and probability, and configures the
    quantization method if applicable.

    Args:
        layer: The attention layer instance to initialize.
        quant_config: Optional quantization configuration.
        prefix: Layer name prefix for quantization method lookup.
    """
101
102
103
    quant_method = (
        quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
    )
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
    # Note [Register q/k/v/prob scales in state dict]
    # When calling model.to(device), only parameters/buffers in state dict are
    # moved. If not registering q/k/v/prob scales in state dict, there would
    # be an IMA error when a cuda kernel (e.g., quant_fp8) accesses the tensor
    # on cpu.
    # Registering in state dict means it interacts with weight loading. One edge
    # case is when quant_method is None, or quant_method is UnquantizedLinearMethod
    # (i.e., should_load_quant_weights(quant_method) == False).
    # In this case, the checkpoint does not have the scales. We need to
    # initialize the scales to 1.0 and update the scales after weight loading.
    # This is espectially important when we load dummy weights first (providing
    # wrong scales) and then load real weights (which misses scales and keeps the
    # wrong scales from dummy load).
    set_default_quant_scales(layer, register_buffer=True)
119
120
121
122
123
124
125
126

    # The output scale on host memory. This should be the input scale of
    # the quant op after this attention layer.
    layer._o_scale_float = None

    quant_method = (
        quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
    )
127
128
129

    # See [Note: Register q/k/v/prob scales in state dict]
    if should_load_quant_weights(quant_method):
130
131
132
        assert isinstance(quant_method, BaseKVCacheMethod)
        # TODO (mgoin): kv cache dtype should be specified in the FP8
        # checkpoint config and become the "auto" behavior
zhuwenwen's avatar
zhuwenwen committed
133
134
        # if layer.kv_cache_dtype == "fp8_e5m2":
        #     raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
135
136
137
138
139
140
        # If quantization is enabled, we make "k_scale" and "v_scale"
        # parameters so that it can be loaded from the model checkpoint.
        # The k/v_scale will then be converted back to native float32
        # values after weight loading.
        layer.quant_method = quant_method
        layer.quant_method.create_weights(layer)
141
142


143
class Attention(nn.Module, AttentionLayerBase):
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
160
161
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
Li Xie's avatar
Li Xie committed
162
        use_alibi_sqrt: bool | None = None,
163
164
165
166
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        logits_soft_cap: float | None = None,
        per_layer_sliding_window: int | None = None,
167
        prefix: str = "",
168
        attn_type: str = AttentionType.DECODER,
169
170
        kv_sharing_target_layer_name: str | None = None,
        attn_backend: type[AttentionBackend] | None = None,
171
        head_size_v: int | None = None,
172
        **extra_impl_args,
173
    ) -> None:
174
175
176
177
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
178
        super().__init__()
179
180
181
182
183
184
185
186
187
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

188
        vllm_config = get_current_vllm_config()
189
190
191
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
192
            calculate_kv_scales = cache_config.calculate_kv_scales
193
194
        else:
            kv_cache_dtype = "auto"
195
            block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA and envs.VLLM_USE_FLASH_MLA else 16
196
            calculate_kv_scales = False
197

198
199
        self.block_size = block_size

200
201
202
203
204
205
206
207
        # llm-compressor mdls need to set cache_dtype to "fp8" manually.
        if getattr(quant_config, "kv_cache_scheme", None) is not None:
            kv_cache_dtype = "fp8"
            calculate_kv_scales = False
            if cache_config is not None:
                cache_config.cache_dtype = "fp8"
                cache_config.calculate_kv_scales = False

208
209
210
        self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
            kv_cache_dtype, vllm_config.model_config
        )
211
212
        self.kv_cache_dtype = kv_cache_dtype
        self.calculate_kv_scales = calculate_kv_scales
213
214
        if num_kv_heads is None:
            num_kv_heads = num_heads
215
216
217
        assert num_heads % num_kv_heads == 0, (
            f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
        )
218
219
        self.quant_config = quant_config
        self.layer_name = prefix
220

221
222
        self.num_heads = num_heads
        self.head_size = head_size
223
        self.head_size_v = self.head_size if head_size_v is None else head_size_v
224
225
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
226
        self.has_sink = extra_impl_args.get("sinks") is not None
227

228
229
230
231
        # NOTE: model_config may be None during certain tests
        model_config = vllm_config.model_config
        self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm

232
233
234
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
235
        if attn_backend is None:
236
237
238
239
240
            self.attn_backend = get_attn_backend(
                head_size,
                dtype,
                kv_cache_dtype,
                block_size,
241
                use_mla=False,
242
                has_sink=self.has_sink,
243
                use_mm_prefix=self.use_mm_prefix,
244
                attn_type=attn_type,
245
            )
246
247
        else:
            self.attn_backend = attn_backend
Li Xie's avatar
Li Xie committed
248
249
250
251
252
253
254
255
256
257
        backend_supports_alibi_sqrt = self.attn_backend.supports_alibi_sqrt()
        use_alibi_sqrt = use_alibi_sqrt if use_alibi_sqrt else False
        if use_alibi_sqrt and not backend_supports_alibi_sqrt:
            raise ValueError(
                f"use_alibi_sqrt is not supported by backend "
                f"{self.attn_backend.get_name()}."
            )
        self.use_alibi_sqrt = bool(use_alibi_sqrt)
        if backend_supports_alibi_sqrt:
            extra_impl_args["use_alibi_sqrt"] = self.use_alibi_sqrt
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        # prefix caching + batch invariance is currently not supported for
        # FLASHINFER and TRITON_MLA.
        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
            and vllm_is_batch_invariant()
            and (
                self.attn_backend.get_name() == "FLASHINFER"
                or self.attn_backend.get_name() == "TRITON_MLA"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for FLASHINFER/TRITON_MLA "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

276
        impl_cls = self.attn_backend.get_impl_cls()
277
278
279
280
281
282
283
284
285
286
287
288
289
        self.impl = impl_cls(
            num_heads,
            head_size,
            scale,
            num_kv_heads,
            alibi_slopes,
            sliding_window,
            kv_cache_dtype,
            logits_soft_cap,
            attn_type,
            kv_sharing_target_layer_name,
            **extra_impl_args,
        )
290
        self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
291
        self.dtype = dtype
292

293
294
295
296
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
297
        self.use_direct_call = not current_platform.opaque_attention_op()
298

299
        self.use_output = self.attn_backend.accept_output_buffer
300
        compilation_config = vllm_config.compilation_config
301
302
303
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
304
        self.attn_type = attn_type
305
306
307
308
309
310
311
312
313

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

314
315
316
317
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
318
            torch.tensor([])
319
            for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
320
        ]
321

322
323
        # Initialize KV cache quantization attributes
        _init_kv_cache_quant(self, quant_config, prefix)
324

325
326
        # for attn backends supporting query quantization
        self.query_quant = None
327
328
        # @TODO
        if envs.VLLM_USE_QUERY_QUANT:
329
330
            if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
                "fp8"
331
            ):
332
333
334
335
336
337
338
339
340
341
                is_per_head = (
                    hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
                )
                block_size = self.head_size * self.num_heads // self.num_kv_heads
                self.query_quant = QuantFP8(
                    static=True,
                    group_shape=GroupShape(-1, block_size)
                    if is_per_head
                    else GroupShape.PER_TENSOR,
                )
342

343
344
345
346
347
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
348
349
350
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
351
        output_shape: torch.Size | None = None,
352
    ) -> torch.Tensor:
353
354
355
356
357
358
359
360
361
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
362
        if self.calculate_kv_scales:
363
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
364
365
366
367
368
369
370
371
        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
372
373

            # check if query quantization is supported
374
            if self.impl.supports_quant_query_input:
375
                query, _ = self.query_quant(query, self._q_scale)
376

377
        if self.use_output:
378
            if output_shape is None:
379
380
381
                # Handle both 2D [num_tokens, hidden] and
                # 3D [num_tokens, heads, head_dim] query
                num_tokens = query.shape[0]
382
                output_shape = torch.Size(
383
                    (num_tokens, self.num_heads * self.head_size_v)
384
                )
385
            output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
386
            hidden_size = output_shape[-1]
387
388
389
390
            # Reshape the query, key, and value tensors.
            # NOTE(woosuk): We do this outside the custom op to minimize the
            # CPU overheads from the non-CUDA-graph regions.
            query = query.view(-1, self.num_heads, self.head_size)
391
            output = output.view(-1, self.num_heads, self.head_size_v)
392
393
394
            if key is not None:
                key = key.view(-1, self.num_kv_heads, self.head_size)
            if value is not None:
395
                value = value.view(-1, self.num_kv_heads, self.head_size_v)
396
            if self.use_direct_call:
397
398
399
400
401
402
403
404
405
406
407
408
                kv_cache_dummy_dep = None
                if not self.attn_backend.forward_includes_kv_cache_update:
                    kv_cache_dummy_dep = unified_kv_cache_update(
                        key, value, self.layer_name
                    )
                unified_attention_with_output(
                    query,
                    key,
                    value,
                    output,
                    self.layer_name,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
409
                )
410
            else:
411
412
413
414
415
416
417
418
                kv_cache_dummy_dep = None
                if not self.attn_backend.forward_includes_kv_cache_update and (
                    # torch can only dispatch custom op if a tensor is passed
                    key is not None or value is not None
                ):
                    kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
                        key, value, self.layer_name
                    )
419
                torch.ops.vllm.unified_attention_with_output(
420
421
422
423
424
425
                    query,
                    key,
                    value,
                    output,
                    self.layer_name,
                    kv_cache_dummy_dep=kv_cache_dummy_dep,
426
                )
427
            return output.view(-1, hidden_size)
428
        else:
429
430
431
            assert self.attn_backend.forward_includes_kv_cache_update, (
                "Split KV cache update not supported when output tensor not provided."
            )
432
            if self.use_direct_call:
433
                return unified_attention(query, key, value, self.layer_name)
434
435
            else:
                return torch.ops.vllm.unified_attention(
436
437
                    query, key, value, self.layer_name
                )
438

439
440
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
441
442
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
443
        self._q_scale_float = self._q_scale.item()
444
445
446
447
448
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

449
450
451
452
453
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
454
        s += f", backend={self.impl.__class__.__name__}"
455
        return s
456

457
    def process_weights_after_loading(self, act_dtype: torch.dtype):
458
        self.impl.process_weights_after_loading(act_dtype)
459

460
461
462
463
464
465
466
467
468
469
470
        # If we should not load quant weights, we initialize the scales to 1.0
        # as the default value. See [Note: Register q/k/v/prob scales in state dict]
        # for more details.
        quant_method = (
            self.quant_config.get_quant_method(self, prefix=self.layer_name)
            if self.quant_config
            else None
        )
        if not should_load_quant_weights(quant_method):
            set_default_quant_scales(self, register_buffer=False)

471
472
473
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        # Block size may get updated after model loading, refresh it
        block_size = vllm_config.cache_config.block_size
        # Should not be called for enc-dec or encoder-only attention.
        assert self.attn_type == AttentionType.DECODER
        if self.sliding_window is not None:
            assert not vllm_config.model_config.use_mla, (
                "MLA is not supported for slidingwindow"
            )
            return SlidingWindowSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
                dtype=self.kv_cache_torch_dtype,
                sliding_window=self.sliding_window,
            )
        else:
            return FullAttentionSpec(
                block_size=block_size,
                num_kv_heads=self.num_kv_heads,
                head_size=self.head_size,
495
                head_size_v=self.head_size_v,
496
497
498
                dtype=self.kv_cache_torch_dtype,
            )

499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
class FusedQkvSplitRmsNormRopeAttention(Attention):
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        use_alibi_sqrt: bool | None = None,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        logits_soft_cap: float | None = None,
        per_layer_sliding_window: int | None = None,
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
        kv_sharing_target_layer_name: str | None = None,
        attn_backend: type[AttentionBackend] | None = None,
        head_size_v: int | None = None,
        **extra_impl_args,
    ) -> None:
        super().__init__(num_heads, head_size, scale,
                       num_kv_heads, alibi_slopes,
                       use_alibi_sqrt, cache_config,
                       quant_config, logits_soft_cap,
                       per_layer_sliding_window,
                       prefix, attn_type,
                       kv_sharing_target_layer_name,
                       attn_backend,
                       head_size_v,
                       **extra_impl_args)
        
    def forward(
        self,
        qkv: torch.Tensor,
        positions: torch.Tensor,
        cos_sin_cache: torch.Tensor,
        weight_q_norm: torch.Tensor,
        weight_k_norm: torch.Tensor,
        epsilon: float,
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: torch.Size | None = None,
        is_neox: bool = False,
    ) -> torch.Tensor:
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
        output_dtype = qkv.dtype
        num_tokens = qkv.shape[0]

        if output_shape is None:
            # Handle both 2D [num_tokens, hidden] and
            # 3D [num_tokens, heads, head_dim] query
            output_shape = torch.Size(
                (num_tokens, self.num_heads * self.head_size_v)
            )
        output = torch.empty(output_shape, dtype=output_dtype, device=qkv.device)
        output = output.view(-1, self.num_heads, self.head_size_v)
        hidden_size = output_shape[-1]

        q_size = self.num_heads * self.head_size
        kv_size = self.num_kv_heads * self.head_size
        query, key, value = torch.ops.vllm.fused_qkv_split_rmsnorm_rope_kv_store(qkv=qkv,
                                                 positions=positions,
                                                 layer_name=self.layer_name,
                                                 kv_cache_dtype=self.kv_cache_dtype,
                                                 cos_sin_cache=cos_sin_cache,
                                                 weight_q_norm=weight_q_norm,
                                                 weight_k_norm=weight_k_norm,
                                                 epsilon=epsilon,
                                                 head_size=self.head_size,
                                                 head_size_v=self.head_size_v,
                                                 q_size=q_size,
                                                 kv_size=kv_size,
                                                 block_size=self.block_size,
                                                 is_neox=is_neox)

        kv_cache_dummy_dep = None
        torch.ops.vllm.unified_attention_with_output(
            query,
            key,
            value,
            output,
            self.layer_name,
            kv_cache_dummy_dep=kv_cache_dummy_dep,
        )

        return output.view(-1, hidden_size)
594

595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
class MLAAttention(nn.Module, AttentionLayerBase):
    """Multi-Head Latent Attention layer.

    This class takes query, and compressed key/value tensors as input.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        scale: float,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
613
        q_lora_rank: int | None,
614
615
        kv_lora_rank: int,
        kv_b_proj: ColumnParallelLinear,
616
617
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
618
619
        prefix: str = "",
        use_sparse: bool = False,
620
        indexer: object | None = None,
621
        **extra_impl_args,
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
    ):
        super().__init__()
        self.num_heads = num_heads
        self.scale = scale
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.head_size = kv_lora_rank + qk_rope_head_dim
        self.layer_name = prefix

        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
            calculate_kv_scales = cache_config.calculate_kv_scales
        else:
            kv_cache_dtype = "auto"
            block_size = 16
            calculate_kv_scales = False
642
        self.quant_config = quant_config
643
644

        # Initialize KV cache quantization attributes
645
646
647
        self.kv_cache_dtype = kv_cache_dtype
        self.calculate_kv_scales = calculate_kv_scales
        _init_kv_cache_quant(self, quant_config, prefix)
648
649
650
651
652
653
654
655
656
657

        dtype = torch.get_default_dtype()
        self.attn_backend = get_attn_backend(
            self.head_size,
            dtype,
            kv_cache_dtype,
            block_size,
            use_mla=True,
            use_sparse=use_sparse,
        )
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674

        if (
            cache_config is not None
            and cache_config.enable_prefix_caching
            and vllm_is_batch_invariant()
            and (
                self.attn_backend.get_name() == "TRITON_MLA"
                or self.attn_backend.get_name() == "FLASHINFER"
            )
        ):
            logger.warning_once(
                "Disabling prefix caching for TRITON_MLA / FLASHINFER "
                "with batch invariance, as it is not yet supported.",
                scope="local",
            )
            cache_config.enable_prefix_caching = False

675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
        impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
        self.impl = impl_cls(
            num_heads=self.num_heads,
            head_size=self.head_size,
            scale=self.scale,
            num_kv_heads=1,
            alibi_slopes=None,
            sliding_window=None,
            kv_cache_dtype=self.kv_cache_dtype,
            logits_soft_cap=None,
            attn_type=AttentionType.DECODER,
            kv_sharing_target_layer_name=None,
            # MLA Args
            q_lora_rank=self.q_lora_rank,
            kv_lora_rank=self.kv_lora_rank,
            qk_nope_head_dim=self.qk_nope_head_dim,
            qk_rope_head_dim=self.qk_rope_head_dim,
            qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,
            v_head_dim=self.v_head_dim,
            kv_b_proj=kv_b_proj,
            indexer=indexer,
696
            **extra_impl_args,
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
        )

        self.use_direct_call = not current_platform.opaque_attention_op()

        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

        self.kv_cache = [
            torch.tensor([])
            for _ in range(
                get_current_vllm_config().parallel_config.pipeline_parallel_size
            )
        ]

        self.use_sparse = use_sparse
714

715
        # Initialize q/k/v range constants.
716
717
718
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
719

720
721
722
723
724
    def forward(
        self,
        q: torch.Tensor,
        kv_c_normed: torch.Tensor,
        k_pe: torch.Tensor,
725
        output_shape: torch.Size | None = None,
726
    ) -> torch.Tensor:
727
728
729
        if self.calculate_kv_scales:
            torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)

730
731
732
733
734
735
736
737
        if self.use_direct_call:
            forward_context: ForwardContext = get_forward_context()
            attn_metadata = forward_context.attn_metadata
            if isinstance(attn_metadata, dict):
                attn_metadata = attn_metadata[self.layer_name]
            self_kv_cache = self.kv_cache[forward_context.virtual_engine]

            if self.attn_backend.accept_output_buffer:
738
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
                self.impl.forward(
                    self,
                    q,
                    kv_c_normed,
                    k_pe,
                    self_kv_cache,
                    attn_metadata,
                    output=output,
                )
                return output
            else:
                return self.impl.forward(
                    self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
                )
        else:
            if self.attn_backend.accept_output_buffer:
755
                output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
                torch.ops.vllm.unified_mla_attention_with_output(
                    q,
                    kv_c_normed,
                    k_pe,
                    output,
                    self.layer_name,
                )
                return output
            else:
                return torch.ops.vllm.unified_mla_attention(
                    q,
                    kv_c_normed,
                    k_pe,
                    self.layer_name,
                )

    def process_weights_after_loading(self, act_dtype: torch.dtype):
        if hasattr(self.impl, "process_weights_after_loading"):
            self.impl.process_weights_after_loading(act_dtype)

776
777
778
779
780
781
782
783
784
785
786
        # If we should not load quant weights, we initialize the scales to 1.0
        # as the default value. See [Note: Register q/k/v/prob scales in state dict]
        # for more details.
        quant_method = (
            self.quant_config.get_quant_method(self, prefix=self.layer_name)
            if self.quant_config
            else None
        )
        if not should_load_quant_weights(quant_method):
            set_default_quant_scales(self, register_buffer=False)

787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
    def calc_kv_scales(
        self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor
    ) -> None:
        """Optional scale calculation for MLA inputs.

        Mirrors Attention.calc_kv_scales. Not all MLA backends require this
        """
        # Use safe defaults if ranges are not present
        q_range = getattr(self, "q_range", torch.tensor(1.0))
        k_range = getattr(self, "k_range", torch.tensor(1.0))
        v_range = getattr(self, "v_range", torch.tensor(1.0))

        self._q_scale.copy_(torch.abs(q).max() / q_range)
        # kv_c_normed is the compressed KV representation; use it for k/v
        kv_abs_max = torch.abs(kv_c_normed).max()
        self._k_scale.copy_(kv_abs_max / k_range)
        self._v_scale.copy_(kv_abs_max / v_range)
        self._q_scale_float = self._q_scale.item()
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        self.calculate_kv_scales = False

    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

812
813
814
815
816
817
818
819
820
821
822
823
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
        kv_cache_dtype = kv_cache_dtype_str_to_dtype(
            self.kv_cache_dtype, vllm_config.model_config
        )
        return MLAAttentionSpec(
            block_size=vllm_config.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_size,
            dtype=kv_cache_dtype,
            cache_dtype_str=vllm_config.cache_config.cache_dtype,
        )

824

825
826
827
828
829
830
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
831
    forward_context: ForwardContext = get_forward_context()
832
    self = forward_context.no_compile_layers[layer_name]
833

834
835
836
    # Only calculate if the layer's calculate_kv_scales flag is True
    # This flag gets set to False after the first forward pass
    if not self.calculate_kv_scales:
837
838
        return

839
    self.calc_kv_scales(query, key, value)
840

841
842
843
844
845

def maybe_calc_kv_scales_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
846
    layer_name: str,
847
848
849
850
851
852
853
854
855
856
857
858
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=["query", "key", "value"],
    fake_impl=maybe_calc_kv_scales_fake,
)


859
def get_attention_context(
860
    layer_name: str,
861
862
) -> tuple[dict | object | None, Attention | MLAAttention, torch.Tensor]:
    """Extract attention context for a given layer.
863

864
865
    This helper function extracts the attention metadata, attention layer
    instance, and KV cache tensor for a specific layer.
866

867
868
    Args:
        layer_name: The name/identifier of the attention layer.
869

870
871
872
873
874
875
876
877
878
879
    Returns:
        A tuple containing:
        - attn_metadata: Attention metadata for this specific layer, or None if
            no metadata available
        - attn_layer: The attention layer instance (Attention or MLAAttention)
        - kv_cache: The KV cache tensor for current virtual engine

        Note: attn_metadata may be None, but attn_layer and kv_cache are always
        extracted from the forward context.
    """
880
881
    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
882
883
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
884
885
886
    attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
    kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
    return attn_metadata, attn_layer, kv_cache
887
888


889
@maybe_transfer_kv_layer
890
891
892
893
894
895
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
896
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
897
    output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
898
899

    return output
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
)
916
917


918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
def unified_kv_cache_update(
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    """
    Returns a dummy that is passed to unified_attention to signal a side effect and
    the data dependency between them to ensure torch.compile preserves ordering.
    """
    forward_context = get_forward_context()
    attn_layer = forward_context.no_compile_layers[layer_name]
    kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]

    slot_mapping = forward_context.slot_mapping
    assert isinstance(slot_mapping, dict), (
        f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
    )
    layer_slot_mapping = slot_mapping.get(layer_name)
    if layer_slot_mapping is not None:
        assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
            f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
        )
        attn_layer.impl.do_kv_cache_update(
            attn_layer,
            key,
            value,
            kv_cache,
            layer_slot_mapping,
        )

zhuwenwen's avatar
zhuwenwen committed
948
949
950
951
    if current_platform.is_rocm():
        return torch.empty(0, device=key.device, dtype=key.dtype)
    else:
        return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969


def unified_kv_cache_update_fake(
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty(0, device=key.device, dtype=key.dtype)


direct_register_custom_op(
    op_name="unified_kv_cache_update",
    op_func=unified_kv_cache_update,
    fake_impl=unified_kv_cache_update_fake,
    mutates_args=[],
)


970
@maybe_transfer_kv_layer
971
972
973
974
975
976
def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
977
978
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
979
    kv_cache_dummy_dep: torch.Tensor | None = None,
980
) -> None:
981
982
983
984
    # kv_cache_dummy_dep is not used but accepting it creates a data dependency
    # that ensures torch.compile preserves ordering between KV cache update and
    # attention forward.
    del kv_cache_dummy_dep
985
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
986

987
988
989
990
991
992
993
994
995
996
997
    self.impl.forward(
        self,
        query,
        key,
        value,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )
998

999
1000
1001
1002
1003
1004
1005

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
1006
1007
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
1008
    kv_cache_dummy_dep: torch.Tensor | None = None,
1009
1010
1011
1012
1013
1014
1015
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
1016
    mutates_args=["output", "output_block_scale"],
1017
1018
    fake_impl=unified_attention_with_output_fake,
)
1019
1020


1021
@maybe_transfer_kv_layer
1022
1023
1024
1025
1026
1027
def unified_mla_attention(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
1028
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
    output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)

    return output


def unified_mla_attention_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(q).contiguous()


direct_register_custom_op(
    op_name="unified_mla_attention",
    op_func=unified_mla_attention,
    mutates_args=[],
    fake_impl=unified_mla_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)


1052
@maybe_transfer_kv_layer
1053
1054
1055
1056
1057
1058
def unified_mla_attention_with_output(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
1059
1060
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
1061
) -> None:
1062
    attn_metadata, self, kv_cache = get_attention_context(layer_name)
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
    self.impl.forward(
        self,
        q,
        kv_c_normed,
        k_pe,
        kv_cache,
        attn_metadata,
        output=output,
        output_scale=output_scale,
        output_block_scale=output_block_scale,
    )


def unified_mla_attention_with_output_fake(
    q: torch.Tensor,
    kv_c_normed: torch.Tensor,
    k_pe: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
1082
1083
    output_scale: torch.Tensor | None = None,
    output_block_scale: torch.Tensor | None = None,
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
) -> None:
    return


direct_register_custom_op(
    op_name="unified_mla_attention_with_output",
    op_func=unified_mla_attention_with_output,
    mutates_args=["output", "output_block_scale"],
    fake_impl=unified_mla_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
1094
)
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198

def fused_qkv_split_rmsnorm_rope_kv_store_impl(
    qkv: torch.Tensor,
    positions: torch.Tensor,
    layer_name: str,
    kv_cache_dtype: str,
    cos_sin_cache: torch.Tensor,
    weight_q_norm: torch.Tensor,
    weight_k_norm: torch.Tensor,
    epsilon: float,
    head_size: int,
    head_size_v: int,
    q_size: int,
    kv_size: int,
    block_size: int,
    is_neox: bool = False)-> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    num_tokens = qkv.shape[0]
    forward_context = get_forward_context()
    
    slot_mapping = forward_context.slot_mapping
    layer_slot_mapping = slot_mapping.get(layer_name)
    assert isinstance(slot_mapping, dict), (
        f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
    )
    
    attn_layer = forward_context.no_compile_layers[layer_name]
    kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]

    if layer_slot_mapping is not None:
        if current_platform.is_rocm():
            key_cache, value_cache = kv_cache
        else:
            key_cache, value_cache = kv_cache.unbind(0)

        if kv_cache_dtype.startswith("fp8"):
            # queries are quantized in the attention layer
            from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
            kv_cache_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
                kv_cache_dtype
            )
            key_cache = key_cache.view(kv_cache_dtype)
            value_cache = value_cache.view(kv_cache_dtype)
    else:
        key_cache = torch.empty([0], device=qkv.device, dtype=qkv.dtype)
        value_cache = torch.empty([0], device=qkv.device, dtype=qkv.dtype)

    from lightop import split_qkv_rms_rotary_embedding_fuse_with_kv_store_quant
    q, k, v = split_qkv_rms_rotary_embedding_fuse_with_kv_store_quant(positions,
                                                        qkv.contiguous(),
                                                        q_size,
                                                        kv_size,
                                                        cos_sin_cache,
                                                        head_dim=head_size,
                                                        page_size=block_size,
                                                        k_buffer=key_cache,
                                                        v_buffer=value_cache,
                                                        kv_cache_loc=layer_slot_mapping,
                                                        is_neox=is_neox,
                                                        weight_q=weight_q_norm,
                                                        weight_k=weight_k_norm,
                                                        output_dtype=qkv.dtype,
                                                        kv_cache_dtype=kv_cache_dtype,
                                                        epsilon=epsilon,
                                                        residual_q=None,
                                                        residual_k=None,
                                                        k_scale=None,
                                                        v_scale=None,
                                                        )
    q = q.contiguous().view(num_tokens, q_size//head_size, head_size)
    k = k.contiguous().view(num_tokens, kv_size//head_size_v, head_size_v)
    v = v.contiguous().view(num_tokens, kv_size//head_size_v, head_size_v)
    return q, k ,v


def fused_qkv_split_rmsnorm_rope_kv_store_fake(
    qkv: torch.Tensor,
    positions: torch.Tensor,
    layer_name: str,
    kv_cache_dtype: str,
    cos_sin_cache: torch.Tensor,
    weight_q_norm: torch.Tensor,
    weight_k_norm: torch.Tensor,
    epsilon: float,
    head_size: int,
    head_size_v: int,
    q_size: int,
    kv_size: int,
    block_size: int,
    is_neox: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    num_token = qkv.shape[0]
    q = torch.empty((num_token, q_size//head_size, head_size), device=qkv.device, dtype=qkv.dtype)
    k = torch.empty((num_token, kv_size//head_size_v, head_size_v), device=qkv.device, dtype=qkv.dtype)
    v = torch.empty((num_token, kv_size//head_size_v, head_size_v), device=qkv.device, dtype=qkv.dtype)
    return q, k, v

direct_register_custom_op(
    op_name="fused_qkv_split_rmsnorm_rope_kv_store",
    op_func=fused_qkv_split_rmsnorm_rope_kv_store_impl,
    mutates_args=["qkv", "positions"],
    fake_impl=fused_qkv_split_rmsnorm_rope_kv_store_fake,
    tags=(torch.Tag.needs_fixed_stride_order,),
)