layer.py 26.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4
from typing import List, Optional
5
6
7

import torch
import torch.nn as nn
8
import torch.nn.functional as F
9

10
import vllm.envs as envs
11
from vllm.attention import AttentionType
12
from vllm.attention.backends.abstract import AttentionBackend
13
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
14
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
15
from vllm.config import CacheConfig, get_current_vllm_config
16
17
18
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group,
                                          is_v1_kv_transfer_group)
19
from vllm.forward_context import ForwardContext, get_forward_context
20
from vllm.logger import init_logger
21
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
22
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
23
from vllm.model_executor.layers.quantization import QuantizationConfig
24
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
25
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
26
27
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    GroupShape)
28
from vllm.model_executor.models.vision import get_vit_attn_backend
29
from vllm.platforms import _Backend, current_platform
30
from vllm.utils import GiB_bytes, direct_register_custom_op
31

32
33
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
34
35
36
37
try:
    tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
except AttributeError:
    tag_cudagraph_unsafe = ()  # type: ignore[assignment]
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63


def check_xformers_availability():
    global USE_XFORMERS_OPS
    if USE_XFORMERS_OPS is not None:
        return USE_XFORMERS_OPS

    if current_platform.is_cuda() and current_platform.has_device_capability(
            100):
        # Xformers FA is not compatible with B200
        USE_XFORMERS_OPS = False
    else:
        try:
            from importlib.util import find_spec

            find_spec("xformers.ops")
            USE_XFORMERS_OPS = True
        except ImportError:
            USE_XFORMERS_OPS = False

    # the warning only needs to be shown once
    if not USE_XFORMERS_OPS:
        logger.warning("Xformers is not available, falling back.")

    return USE_XFORMERS_OPS

64

65
66
67
68
69
70
71
72
def check_upstream_fa_availability(dtype: torch.dtype):
    if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
    ) and current_platform.has_device_capability(80):
        from transformers.utils import is_flash_attn_2_available
        return is_flash_attn_2_available()
    return False


73
class Attention(nn.Module, AttentionLayerBase):
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
92
        cache_config: Optional[CacheConfig] = None,
93
        quant_config: Optional[QuantizationConfig] = None,
94
        logits_soft_cap: Optional[float] = None,
95
        per_layer_sliding_window: Optional[int] = None,
96
        use_mla: bool = False,
97
        prefix: str = "",
98
        attn_type: str = AttentionType.DECODER,
99
        kv_sharing_target_layer_name: Optional[str] = None,
100
        attn_backend: Optional[type[AttentionBackend]] = None,
101
        **extra_impl_args,
102
    ) -> None:
103
104
105
106
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
107
        super().__init__()
108
109
110
111
112
113
114
115
116
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

117
118
119
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
120
            calculate_kv_scales = cache_config.calculate_kv_scales
121
122
123
        else:
            kv_cache_dtype = "auto"
            block_size = 16
124
            calculate_kv_scales = False
125
126
        if num_kv_heads is None:
            num_kv_heads = num_heads
127
128
129
        assert num_heads % num_kv_heads == 0, \
            f"num_heads ({num_heads}) is not " \
            f"divisible by num_kv_heads ({num_kv_heads})"
130

131
        # The default k/v_scale is set to 1.0. This is ignored
132
133
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
134
        # expect the pre-quantized k/v_scale to be loaded along
135
136
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
137
138
139
        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
140
141
142
        # FlashAttn doesn't support quantizing the kv-cache only
        # but requires q to be quantized as well.
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
143
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
144

145
146
147
148
        # We also keep q/k/v_scale on host (cpu) memory for attention
        # backends that require the scales to be on host instead of on device.
        # e.g. Flashinfer
        self._q_scale_float = 1.0
149
150
151
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

152
153
154
155
        # The output scale on host memory. This should be the input scale of
        # the quant op after this attention layer.
        self._o_scale_float: Optional[float] = None

156
        self.use_mla = use_mla
157
158
159
160
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
161
        self.has_sink = extra_impl_args.get("sinks") is not None
162

163
        quant_method = quant_config.get_quant_method(
164
            self, prefix=prefix) if quant_config else None
165
166
        if quant_method is not None and not isinstance(
                quant_method, UnquantizedLinearMethod):
167
            assert isinstance(quant_method, BaseKVCacheMethod)
168
169
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
170
171
172
173
174
175
176
177
178
            if self.kv_cache_dtype == "fp8_e5m2":
                raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                 "fp8 checkpoints.")
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
179

180
181
182
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
183
184
185
186
187
        if attn_backend is None:
            self.attn_backend = get_attn_backend(head_size,
                                                 dtype,
                                                 kv_cache_dtype,
                                                 block_size,
188
189
                                                 use_mla=use_mla,
                                                 has_sink=self.has_sink)
190
191
192
193
        else:
            self.attn_backend = attn_backend

        impl_cls = self.attn_backend.get_impl_cls()
194
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
195
                             alibi_slopes, sliding_window, kv_cache_dtype,
196
                             logits_soft_cap, attn_type,
197
                             kv_sharing_target_layer_name, **extra_impl_args)
198
        self.backend = backend_name_to_enum(self.attn_backend.get_name())
199
        self.dtype = dtype
200

201
202
203
204
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
205
        self.use_direct_call = not current_platform.opaque_attention_op()
206

207
        self.use_output = self.attn_backend.accept_output_buffer
208
209
210
211
212
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
213
        self.attn_type = attn_type
214
215
216
217
218
219
220
221
222

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

223
224
225
226
227
228
229
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([]) for _ in range(get_current_vllm_config(
            ).parallel_config.pipeline_parallel_size)
        ]
230

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        try:
            self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT,
                                        dtype=torch.float32)
            self.k_range = torch.tensor(envs.K_SCALE_CONSTANT,
                                        dtype=torch.float32)
            self.v_range = torch.tensor(envs.V_SCALE_CONSTANT,
                                        dtype=torch.float32)
        except torch.cuda.OutOfMemoryError as e:
            logger.error(
                "Failed to initialize attention q/k/v range constants: %s", e)
            if torch.cuda.is_available():
                logger.debug("CUDA device: %s", torch.cuda.current_device())
                logger.debug("Allocated: %.2f GiB",
                             torch.cuda.memory_allocated() / GiB_bytes)
                logger.debug("Reserved: %.2f GiB",
                             torch.cuda.memory_reserved() / GiB_bytes)
            raise RuntimeError(
                "Failed to initialize q/k/v range constants. "
                "This may be caused by insufficient memory to allocate "
                "kv cache.") from e
251

252
253
254
255
256
257
258
        # for attn backends supporting query quantization
        self.query_quant = None
        if self.kv_cache_dtype.startswith(
                "fp8") and self.attn_backend.supports_quant_query_input:
            self.query_quant = QuantFP8(static=True,
                                        group_shape=GroupShape.PER_TENSOR)

259
260
261
262
263
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
264
265
266
267
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: Optional[torch.Size] = None,
268
    ) -> torch.Tensor:
269
270
271
272
273
274
275
276
277
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
278
        if self.calculate_kv_scales:
279
280
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
                                                self.layer_name)
281
282
283
284
285
286
287
288
289
290
291

        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
            query, _ = self.query_quant(query, self._q_scale)

292
        if self.use_output:
293
294
            output_shape = (output_shape
                            if output_shape is not None else query.shape)
295
            output = torch.zeros(output_shape,
296
                                 dtype=output_dtype,
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
                                 device=query.device)
            hidden_size = output_shape[-1]
            # We skip reshaping query, key and value tensors for the MLA
            # backend since these tensors have different semantics and are
            # processed differently.
            if not self.use_mla:
                # Reshape the query, key, and value tensors.
                # NOTE(woosuk): We do this outside the custom op to minimize the
                # CPU overheads from the non-CUDA-graph regions.
                query = query.view(-1, self.num_heads, self.head_size)
                output = output.view(-1, self.num_heads, self.head_size)
                if key is not None:
                    key = key.view(-1, self.num_kv_heads, self.head_size)
                if value is not None:
                    value = value.view(-1, self.num_kv_heads, self.head_size)
312
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
313
                forward_context: ForwardContext = get_forward_context()
314
                attn_metadata = forward_context.attn_metadata
315
316
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
317
318
319
320
321
322
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                self.impl.forward(self,
                                  query,
                                  key,
                                  value,
                                  self_kv_cache,
323
                                  attn_metadata,
Chen Zhang's avatar
Chen Zhang committed
324
                                  output=output)
325
326
327
            else:
                torch.ops.vllm.unified_attention_with_output(
                    query, key, value, output, self.layer_name)
328
            return output.view(-1, hidden_size)
329
        else:
330
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
331
                forward_context = get_forward_context()
332
                attn_metadata = forward_context.attn_metadata
333
334
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
335
336
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                return self.impl.forward(self, query, key, value,
337
                                         self_kv_cache, attn_metadata)
338
339
340
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name)
341

342
343
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
344
345
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
346
        self._q_scale_float = self._q_scale.item()
347
348
349
350
351
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

352
353
354
355
356
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
357
        s += f", backend={self.impl.__class__.__name__}"
358
        return s
359

360
    def process_weights_after_loading(self, act_dtype: torch.dtype):
361
        if hasattr(self.impl, "process_weights_after_loading"):
362
            self.impl.process_weights_after_loading(act_dtype)
363

364
        # FlashInfer requires attention sinks to be float32
365
        if (self.backend == _Backend.FLASHINFER
366
367
368
369
370
371
372
                and hasattr(self.impl, 'sinks')):
            from vllm.v1.attention.backends.flashinfer import FlashInferImpl
            assert isinstance(self.impl, FlashInferImpl)
            if (self.impl.sinks is not None
                    and self.impl.sinks.dtype != torch.float32):
                self.impl.sinks = self.impl.sinks.to(torch.float32)

373
374
375
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

376

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

393
394
395
        assert self.num_heads % self.num_kv_heads == 0, \
            f"num_heads ({self.num_heads}) is not " \
            f"divisible by num_kv_heads ({self.num_kv_heads})"
396
397
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

398
399
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
400
        dtype = torch.get_default_dtype()
401
402
403
404
405
406
407
408
409
410
411
412
413

        # Determine the attention backend
        backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)

        # Some auto-selected backends can be upgraded
        # to upstream flash attention if available.
        # If vllm native fa is selected, we use it directly.
        use_upstream_fa = False
        if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
                dtype):
            backend = _Backend.FLASH_ATTN
            use_upstream_fa = True

414
415
        if current_platform.is_rocm() or current_platform.is_xpu():
            # currently, only torch_sdpa is supported on rocm/xpu
416
417
            self.attn_backend = _Backend.TORCH_SDPA
        else:
418

419
            self.attn_backend = backend if backend in {
420
421
                _Backend.TORCH_SDPA,
                _Backend.XFORMERS,
422
                _Backend.PALLAS,
423
                _Backend.ROCM_AITER_FA,
424
425
                _Backend.FLASH_ATTN,
            } else _Backend.TORCH_SDPA
426

427
428
429
430
        if (self.attn_backend == _Backend.XFORMERS
                and not check_xformers_availability()):
            self.attn_backend = _Backend.TORCH_SDPA

431
        if self.attn_backend == _Backend.FLASH_ATTN:
432
433
434
435
436
437
438
439
440
441
442
            if use_upstream_fa:
                from flash_attn import flash_attn_varlen_func
                self._flash_attn_varlen_func = flash_attn_varlen_func
            else:
                from vllm.vllm_flash_attn import flash_attn_varlen_func
                self._flash_attn_varlen_func = flash_attn_varlen_func

        logger.info_once(
            f"MultiHeadAttention attn_backend: {self.attn_backend}, "
            f"use_upstream_fa: {use_upstream_fa}")

443
444
445
446
447
448
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
449
450
451
452
453
        """Input shape: 
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
454
455
456
457
458
459
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

460
461
462
463
464
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

465
        if self.attn_backend == _Backend.FLASH_ATTN:
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
            cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
                                        step=q_len,
                                        dtype=torch.int32,
                                        device=query.device)
            cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
                                        step=kv_len,
                                        dtype=torch.int32,
                                        device=key.device)

            out = self._flash_attn_varlen_func(
                query.flatten(0, 1),
                key.flatten(0, 1),
                value.flatten(0, 1),
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=q_len,
                max_seqlen_k=kv_len,
                softmax_scale=self.scale,
            )
        elif self.attn_backend == _Backend.XFORMERS:
486
487
488
489
490
491
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query,
                                                          key,
                                                          value,
                                                          scale=self.scale)
492
        elif self.attn_backend == _Backend.TORCH_SDPA:
493
494
495
496
497
498
499
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            out = F.scaled_dot_product_attention(query,
                                                 key,
                                                 value,
                                                 scale=self.scale)
            out = out.transpose(1, 2)
500
        elif self.attn_backend == _Backend.PALLAS:
501
502
503
504
505
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            from torch_xla.experimental.custom_kernel import flash_attention
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)
506
507
508
509
510
511
512
513
514
515
516
517
518
        elif self.attn_backend == _Backend.ROCM_AITER_FA:
            from aiter import flash_attn_varlen_func

            # ROCm Flash Attention expects (batch, seq, heads, head_dim)
            out = flash_attn_varlen_func(query,
                                         key,
                                         value,
                                         softmax_scale=self.scale)
        else:
            # ViT attention hasn't supported this backend yet
            raise NotImplementedError(
                f"ViT attention hasn't supported {self.attn_backend} "
                f"backend yet.")
519

520
        return out.reshape(bsz, q_len, -1)
521
522


523
524
525
526
527
528
529
530
531
532
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
533
    assert isinstance(attn_metadata, dict)
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
    connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
    layer_name: str,
    kv_cache_layer: List[torch.Tensor],
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
550
551
552
    assert isinstance(attn_metadata, dict)
    connector.save_kv_layer(layer_name, kv_cache_layer,
                            attn_metadata[layer_name])
553
554


555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata

    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]

    if attn_metadata is None or not getattr(
            attn_metadata, 'enable_kv_scales_calculation', False):
        return

    self = forward_context.no_compile_layers[layer_name]
    self.calc_kv_scales(query, key, value)


def maybe_calc_kv_scales_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=["query", "key", "value"],
    fake_impl=maybe_calc_kv_scales_fake,
)


593
594
595
596
597
598
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
599
600
    wait_for_kv_layer_from_connector(layer_name)

601
    forward_context: ForwardContext = get_forward_context()
602
    attn_metadata = forward_context.attn_metadata
603
604
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
605
    self = forward_context.no_compile_layers[layer_name]
606
    kv_cache = self.kv_cache[forward_context.virtual_engine]
607
608
609
610
611
    output = self.impl.forward(self, query, key, value, kv_cache,
                               attn_metadata)

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    return output
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
627
    tags=tag_cudagraph_unsafe,
628
)
629
630
631
632
633
634
635
636


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
637
    output_scale: Optional[torch.Tensor] = None,
638
    output_block_scale: Optional[torch.Tensor] = None,
639
) -> None:
640
    wait_for_kv_layer_from_connector(layer_name)
641
    forward_context: ForwardContext = get_forward_context()
642
    attn_metadata = forward_context.attn_metadata
643
644
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
645
    self = forward_context.no_compile_layers[layer_name]
646
    kv_cache = self.kv_cache[forward_context.virtual_engine]
647
648
    self.impl.forward(self,
                      query,
649
650
651
652
                      key,
                      value,
                      kv_cache,
                      attn_metadata,
653
                      output=output,
654
655
                      output_scale=output_scale,
                      output_block_scale=output_block_scale)
656

657
658
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)

659
660
661
662
663
664
665

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
666
    output_scale: Optional[torch.Tensor] = None,
667
    output_block_scale: Optional[torch.Tensor] = None,
668
669
670
671
672
673
674
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
675
    mutates_args=["output", "output_block_scale"],
676
    fake_impl=unified_attention_with_output_fake,
677
    tags=tag_cudagraph_unsafe,
678
)