layer.py 26.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4
from typing import List, Optional
5
6
7

import torch
import torch.nn as nn
8
import torch.nn.functional as F
9

10
import vllm.envs as envs
11
from vllm.attention import AttentionType
12
from vllm.attention.backends.abstract import AttentionBackend
13
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
14
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
15
from vllm.config import CacheConfig, get_current_vllm_config
16
17
18
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group,
                                          is_v1_kv_transfer_group)
19
from vllm.forward_context import ForwardContext, get_forward_context
20
from vllm.logger import init_logger
21
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
22
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
23
24
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
25
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
26
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
27
28
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    GroupShape)
29
from vllm.model_executor.models.vision import get_vit_attn_backend
30
from vllm.platforms import _Backend, current_platform
31
from vllm.utils import GiB_bytes, direct_register_custom_op
32

33
34
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
35
36
37
38
try:
    tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
except AttributeError:
    tag_cudagraph_unsafe = ()  # type: ignore[assignment]
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64


def check_xformers_availability():
    global USE_XFORMERS_OPS
    if USE_XFORMERS_OPS is not None:
        return USE_XFORMERS_OPS

    if current_platform.is_cuda() and current_platform.has_device_capability(
            100):
        # Xformers FA is not compatible with B200
        USE_XFORMERS_OPS = False
    else:
        try:
            from importlib.util import find_spec

            find_spec("xformers.ops")
            USE_XFORMERS_OPS = True
        except ImportError:
            USE_XFORMERS_OPS = False

    # the warning only needs to be shown once
    if not USE_XFORMERS_OPS:
        logger.warning("Xformers is not available, falling back.")

    return USE_XFORMERS_OPS

65

66
67
68
69
70
71
72
73
def check_upstream_fa_availability(dtype: torch.dtype):
    if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
    ) and current_platform.has_device_capability(80):
        from transformers.utils import is_flash_attn_2_available
        return is_flash_attn_2_available()
    return False


74
class Attention(nn.Module, AttentionLayerBase):
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
93
        cache_config: Optional[CacheConfig] = None,
94
        quant_config: Optional[QuantizationConfig] = None,
95
        logits_soft_cap: Optional[float] = None,
96
        per_layer_sliding_window: Optional[int] = None,
97
        use_mla: bool = False,
98
        prefix: str = "",
99
        attn_type: str = AttentionType.DECODER,
100
        kv_sharing_target_layer_name: Optional[str] = None,
101
        attn_backend: Optional[type[AttentionBackend]] = None,
102
        **extra_impl_args,
103
    ) -> None:
104
105
106
107
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
108
        super().__init__()
109
110
111
112
113
114
115
116
117
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

118
119
120
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
121
            calculate_kv_scales = cache_config.calculate_kv_scales
122
123
124
        else:
            kv_cache_dtype = "auto"
            block_size = 16
125
            calculate_kv_scales = False
126
127
        if num_kv_heads is None:
            num_kv_heads = num_heads
128
129
130
        assert num_heads % num_kv_heads == 0, \
            f"num_heads ({num_heads}) is not " \
            f"divisible by num_kv_heads ({num_kv_heads})"
131

132
        # The default k/v_scale is set to 1.0. This is ignored
133
134
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
135
        # expect the pre-quantized k/v_scale to be loaded along
136
137
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
138
139
140
        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
141
142
143
        # FlashAttn doesn't support quantizing the kv-cache only
        # but requires q to be quantized as well.
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
144
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
145

146
147
148
149
        # We also keep q/k/v_scale on host (cpu) memory for attention
        # backends that require the scales to be on host instead of on device.
        # e.g. Flashinfer
        self._q_scale_float = 1.0
150
151
152
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

153
154
155
156
        # The output scale on host memory. This should be the input scale of
        # the quant op after this attention layer.
        self._o_scale_float: Optional[float] = None

157
        self.use_mla = use_mla
158
159
160
161
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
162
        self.has_sink = extra_impl_args.get("sinks") is not None
163

164
        quant_method = quant_config.get_quant_method(
165
            self, prefix=prefix) if quant_config else None
166
167
        if quant_method is not None and not isinstance(
                quant_method, UnquantizedLinearMethod):
168
            assert isinstance(quant_method, BaseKVCacheMethod)
169
170
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
171
172
173
174
175
176
177
178
179
            if self.kv_cache_dtype == "fp8_e5m2":
                raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                 "fp8 checkpoints.")
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
180

181
182
183
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
184
185
186
187
188
        if attn_backend is None:
            self.attn_backend = get_attn_backend(head_size,
                                                 dtype,
                                                 kv_cache_dtype,
                                                 block_size,
189
190
                                                 use_mla=use_mla,
                                                 has_sink=self.has_sink)
191
192
193
194
        else:
            self.attn_backend = attn_backend

        impl_cls = self.attn_backend.get_impl_cls()
195
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
196
                             alibi_slopes, sliding_window, kv_cache_dtype,
197
                             logits_soft_cap, attn_type,
198
                             kv_sharing_target_layer_name, **extra_impl_args)
199
        self.backend = backend_name_to_enum(self.attn_backend.get_name())
200
        self.dtype = dtype
201

202
203
204
205
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
206
        self.use_direct_call = not current_platform.opaque_attention_op()
207

208
        self.use_output = self.attn_backend.accept_output_buffer
209
210
211
212
213
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
214
        self.attn_type = attn_type
215
216
217
218
219
220
221
222
223

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

224
225
226
227
228
229
230
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([]) for _ in range(get_current_vllm_config(
            ).parallel_config.pipeline_parallel_size)
        ]
231

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        try:
            self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT,
                                        dtype=torch.float32)
            self.k_range = torch.tensor(envs.K_SCALE_CONSTANT,
                                        dtype=torch.float32)
            self.v_range = torch.tensor(envs.V_SCALE_CONSTANT,
                                        dtype=torch.float32)
        except torch.cuda.OutOfMemoryError as e:
            logger.error(
                "Failed to initialize attention q/k/v range constants: %s", e)
            if torch.cuda.is_available():
                logger.debug("CUDA device: %s", torch.cuda.current_device())
                logger.debug("Allocated: %.2f GiB",
                             torch.cuda.memory_allocated() / GiB_bytes)
                logger.debug("Reserved: %.2f GiB",
                             torch.cuda.memory_reserved() / GiB_bytes)
            raise RuntimeError(
                "Failed to initialize q/k/v range constants. "
                "This may be caused by insufficient memory to allocate "
                "kv cache.") from e
252

253
254
255
256
257
258
259
        # for attn backends supporting query quantization
        self.query_quant = None
        if self.kv_cache_dtype.startswith(
                "fp8") and self.attn_backend.supports_quant_query_input:
            self.query_quant = QuantFP8(static=True,
                                        group_shape=GroupShape.PER_TENSOR)

260
261
262
263
264
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
265
266
267
268
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: Optional[torch.Size] = None,
269
    ) -> torch.Tensor:
270
271
272
273
274
275
276
277
278
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
279
        if self.calculate_kv_scales:
280
281
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
                                                self.layer_name)
282
283
284
285
286
287
288
289
290
291
292

        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
            query, _ = self.query_quant(query, self._q_scale)

293
        if self.use_output:
294
295
            output_shape = (output_shape
                            if output_shape is not None else query.shape)
296
            output = torch.zeros(output_shape,
297
                                 dtype=output_dtype,
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
                                 device=query.device)
            hidden_size = output_shape[-1]
            # We skip reshaping query, key and value tensors for the MLA
            # backend since these tensors have different semantics and are
            # processed differently.
            if not self.use_mla:
                # Reshape the query, key, and value tensors.
                # NOTE(woosuk): We do this outside the custom op to minimize the
                # CPU overheads from the non-CUDA-graph regions.
                query = query.view(-1, self.num_heads, self.head_size)
                output = output.view(-1, self.num_heads, self.head_size)
                if key is not None:
                    key = key.view(-1, self.num_kv_heads, self.head_size)
                if value is not None:
                    value = value.view(-1, self.num_kv_heads, self.head_size)
313
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
314
                forward_context: ForwardContext = get_forward_context()
315
                attn_metadata = forward_context.attn_metadata
316
317
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
318
319
320
321
322
323
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                self.impl.forward(self,
                                  query,
                                  key,
                                  value,
                                  self_kv_cache,
324
                                  attn_metadata,
Chen Zhang's avatar
Chen Zhang committed
325
                                  output=output)
326
327
328
            else:
                torch.ops.vllm.unified_attention_with_output(
                    query, key, value, output, self.layer_name)
329
            return output.view(-1, hidden_size)
330
        else:
331
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
332
                forward_context = get_forward_context()
333
                attn_metadata = forward_context.attn_metadata
334
335
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
336
337
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                return self.impl.forward(self, query, key, value,
338
                                         self_kv_cache, attn_metadata)
339
340
341
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name)
342

343
344
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
345
346
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
347
        self._q_scale_float = self._q_scale.item()
348
349
350
351
352
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

353
354
355
356
357
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
358
        s += f", backend={self.impl.__class__.__name__}"
359
        return s
360

361
    def process_weights_after_loading(self, act_dtype: torch.dtype):
362
        if hasattr(self.impl, "process_weights_after_loading"):
363
            self.impl.process_weights_after_loading(act_dtype)
364

365
        # FlashInfer requires attention sinks to be float32
366
        if (self.backend == _Backend.FLASHINFER
367
368
369
370
371
372
373
                and hasattr(self.impl, 'sinks')):
            from vllm.v1.attention.backends.flashinfer import FlashInferImpl
            assert isinstance(self.impl, FlashInferImpl)
            if (self.impl.sinks is not None
                    and self.impl.sinks.dtype != torch.float32):
                self.impl.sinks = self.impl.sinks.to(torch.float32)

374
375
376
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

377

378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

394
395
396
        assert self.num_heads % self.num_kv_heads == 0, \
            f"num_heads ({self.num_heads}) is not " \
            f"divisible by num_kv_heads ({self.num_kv_heads})"
397
398
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

399
400
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
401
        dtype = torch.get_default_dtype()
402
403
404
405
406
407
408
409
410
411
412
413
414

        # Determine the attention backend
        backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)

        # Some auto-selected backends can be upgraded
        # to upstream flash attention if available.
        # If vllm native fa is selected, we use it directly.
        use_upstream_fa = False
        if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
                dtype):
            backend = _Backend.FLASH_ATTN
            use_upstream_fa = True

415
416
        if current_platform.is_rocm() or current_platform.is_xpu():
            # currently, only torch_sdpa is supported on rocm/xpu
417
418
            self.attn_backend = _Backend.TORCH_SDPA
        else:
419

420
            self.attn_backend = backend if backend in {
421
422
                _Backend.TORCH_SDPA,
                _Backend.XFORMERS,
423
                _Backend.PALLAS,
424
                _Backend.ROCM_AITER_FA,
425
426
                _Backend.FLASH_ATTN,
            } else _Backend.TORCH_SDPA
427

428
429
430
431
        if (self.attn_backend == _Backend.XFORMERS
                and not check_xformers_availability()):
            self.attn_backend = _Backend.TORCH_SDPA

432
        if self.attn_backend == _Backend.FLASH_ATTN:
433
434
435
436
437
438
439
440
441
442
443
            if use_upstream_fa:
                from flash_attn import flash_attn_varlen_func
                self._flash_attn_varlen_func = flash_attn_varlen_func
            else:
                from vllm.vllm_flash_attn import flash_attn_varlen_func
                self._flash_attn_varlen_func = flash_attn_varlen_func

        logger.info_once(
            f"MultiHeadAttention attn_backend: {self.attn_backend}, "
            f"use_upstream_fa: {use_upstream_fa}")

444
445
446
447
448
449
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
450
451
452
453
454
        """Input shape: 
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
455
456
457
458
459
460
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

461
462
463
464
465
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

466
        if self.attn_backend == _Backend.FLASH_ATTN:
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
            cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
                                        step=q_len,
                                        dtype=torch.int32,
                                        device=query.device)
            cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
                                        step=kv_len,
                                        dtype=torch.int32,
                                        device=key.device)

            out = self._flash_attn_varlen_func(
                query.flatten(0, 1),
                key.flatten(0, 1),
                value.flatten(0, 1),
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=q_len,
                max_seqlen_k=kv_len,
                softmax_scale=self.scale,
            )
        elif self.attn_backend == _Backend.XFORMERS:
487
488
489
490
491
492
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query,
                                                          key,
                                                          value,
                                                          scale=self.scale)
493
        elif self.attn_backend == _Backend.TORCH_SDPA:
494
495
496
497
498
499
500
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            out = F.scaled_dot_product_attention(query,
                                                 key,
                                                 value,
                                                 scale=self.scale)
            out = out.transpose(1, 2)
501
        elif self.attn_backend == _Backend.PALLAS:
502
503
504
505
506
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            from torch_xla.experimental.custom_kernel import flash_attention
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)
507
508
509
510
511
512
513
514
515
516
517
518
519
        elif self.attn_backend == _Backend.ROCM_AITER_FA:
            from aiter import flash_attn_varlen_func

            # ROCm Flash Attention expects (batch, seq, heads, head_dim)
            out = flash_attn_varlen_func(query,
                                         key,
                                         value,
                                         softmax_scale=self.scale)
        else:
            # ViT attention hasn't supported this backend yet
            raise NotImplementedError(
                f"ViT attention hasn't supported {self.attn_backend} "
                f"backend yet.")
520

521
        return out.reshape(bsz, q_len, -1)
522
523


524
525
526
527
528
529
530
531
532
533
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
534
    assert isinstance(attn_metadata, dict)
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
    layer_name: str,
    kv_cache_layer: List[torch.Tensor],
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
551
552
553
    assert isinstance(attn_metadata, dict)
    connector.save_kv_layer(layer_name, kv_cache_layer,
                            attn_metadata[layer_name])
554
555


556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata

    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]

    if attn_metadata is None or not getattr(
            attn_metadata, 'enable_kv_scales_calculation', False):
        return

    self = forward_context.no_compile_layers[layer_name]
    self.calc_kv_scales(query, key, value)


def maybe_calc_kv_scales_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=["query", "key", "value"],
    fake_impl=maybe_calc_kv_scales_fake,
)


594
595
596
597
598
599
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
600
601
    wait_for_kv_layer_from_connector(layer_name)

602
    forward_context: ForwardContext = get_forward_context()
603
    attn_metadata = forward_context.attn_metadata
604
605
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
606
    self = forward_context.no_compile_layers[layer_name]
607
    kv_cache = self.kv_cache[forward_context.virtual_engine]
608
609
610
611
612
    output = self.impl.forward(self, query, key, value, kv_cache,
                               attn_metadata)

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    return output
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
628
    tags=tag_cudagraph_unsafe,
629
)
630
631
632
633
634
635
636
637


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
638
    output_scale: Optional[torch.Tensor] = None,
639
    output_block_scale: Optional[torch.Tensor] = None,
640
) -> None:
641
    wait_for_kv_layer_from_connector(layer_name)
642
    forward_context: ForwardContext = get_forward_context()
643
    attn_metadata = forward_context.attn_metadata
644
645
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
646
    self = forward_context.no_compile_layers[layer_name]
647
    kv_cache = self.kv_cache[forward_context.virtual_engine]
648
649
    self.impl.forward(self,
                      query,
650
651
652
653
                      key,
                      value,
                      kv_cache,
                      attn_metadata,
654
                      output=output,
655
656
                      output_scale=output_scale,
                      output_block_scale=output_block_scale)
657

658
659
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)

660
661
662
663
664
665
666

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
667
    output_scale: Optional[torch.Tensor] = None,
668
    output_block_scale: Optional[torch.Tensor] = None,
669
670
671
672
673
674
675
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
676
    mutates_args=["output", "output_block_scale"],
677
    fake_impl=unified_attention_with_output_fake,
678
    tags=tag_cudagraph_unsafe,
679
)