layer.py 24.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4
from typing import List, Optional
5
6
7

import torch
import torch.nn as nn
8
import torch.nn.functional as F
9

10
import vllm.envs as envs
11
from vllm.attention import AttentionType
12
from vllm.attention.backends.abstract import AttentionBackend
13
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
14
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
15
from vllm.config import CacheConfig, get_current_vllm_config
16
17
18
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group,
                                          is_v1_kv_transfer_group)
19
from vllm.forward_context import ForwardContext, get_forward_context
20
from vllm.logger import init_logger
21
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
22
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
23
24
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
25
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
26
from vllm.model_executor.models.vision import get_vit_attn_backend
27
from vllm.platforms import _Backend, current_platform
28
from vllm.utils import direct_register_custom_op
29

30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
logger = init_logger(__name__)
USE_XFORMERS_OPS = None


def check_xformers_availability():
    global USE_XFORMERS_OPS
    if USE_XFORMERS_OPS is not None:
        return USE_XFORMERS_OPS

    if current_platform.is_cuda() and current_platform.has_device_capability(
            100):
        # Xformers FA is not compatible with B200
        USE_XFORMERS_OPS = False
    else:
        try:
            from importlib.util import find_spec

            find_spec("xformers.ops")
            USE_XFORMERS_OPS = True
        except ImportError:
            USE_XFORMERS_OPS = False

    # the warning only needs to be shown once
    if not USE_XFORMERS_OPS:
        logger.warning("Xformers is not available, falling back.")

    return USE_XFORMERS_OPS

58

59
60
61
62
63
64
65
66
def check_upstream_fa_availability(dtype: torch.dtype):
    if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
    ) and current_platform.has_device_capability(80):
        from transformers.utils import is_flash_attn_2_available
        return is_flash_attn_2_available()
    return False


67
class Attention(nn.Module, AttentionLayerBase):
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
86
        cache_config: Optional[CacheConfig] = None,
87
        quant_config: Optional[QuantizationConfig] = None,
88
        logits_soft_cap: Optional[float] = None,
89
        per_layer_sliding_window: Optional[int] = None,
90
        use_mla: bool = False,
91
        prefix: str = "",
92
        attn_type: str = AttentionType.DECODER,
93
        kv_sharing_target_layer_name: Optional[str] = None,
94
        attn_backend: Optional[type[AttentionBackend]] = None,
95
        **extra_impl_args,
96
    ) -> None:
97
98
99
100
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
101
        super().__init__()
102
103
104
105
106
107
108
109
110
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

111
112
113
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
114
            is_attention_free = cache_config.is_attention_free
115
            calculate_kv_scales = cache_config.calculate_kv_scales
116
117
118
        else:
            kv_cache_dtype = "auto"
            block_size = 16
119
            is_attention_free = False
120
            calculate_kv_scales = False
121
122
        if num_kv_heads is None:
            num_kv_heads = num_heads
123
124
125
        assert num_heads % num_kv_heads == 0, \
            f"num_heads ({num_heads}) is not " \
            f"divisible by num_kv_heads ({num_kv_heads})"
126

127
        # The default k/v_scale is set to 1.0. This is ignored
128
129
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
130
        # expect the pre-quantized k/v_scale to be loaded along
131
132
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
133
134
135
        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
136
137
138
        # FlashAttn doesn't support quantizing the kv-cache only
        # but requires q to be quantized as well.
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
139
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
140

141
142
143
144
        # We also keep q/k/v_scale on host (cpu) memory for attention
        # backends that require the scales to be on host instead of on device.
        # e.g. Flashinfer
        self._q_scale_float = 1.0
145
146
147
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

148
149
150
151
        # The output scale on host memory. This should be the input scale of
        # the quant op after this attention layer.
        self._o_scale_float: Optional[float] = None

152
        self.use_mla = use_mla
153
154
155
156
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
157
        self.has_sink = extra_impl_args.get("sinks") is not None
158

159
        quant_method = quant_config.get_quant_method(
160
            self, prefix=prefix) if quant_config else None
161
162
        if quant_method is not None and not isinstance(
                quant_method, UnquantizedLinearMethod):
163
            assert isinstance(quant_method, BaseKVCacheMethod)
164
165
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
166
167
168
169
170
171
172
173
174
            if self.kv_cache_dtype == "fp8_e5m2":
                raise ValueError("fp8_e5m2 kv-cache is not supported with "
                                 "fp8 checkpoints.")
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
175

176
177
178
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
179
180
181
182
183
184
        if attn_backend is None:
            self.attn_backend = get_attn_backend(head_size,
                                                 dtype,
                                                 kv_cache_dtype,
                                                 block_size,
                                                 is_attention_free,
185
186
                                                 use_mla=use_mla,
                                                 has_sink=self.has_sink)
187
188
189
190
        else:
            self.attn_backend = attn_backend

        impl_cls = self.attn_backend.get_impl_cls()
191
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
192
                             alibi_slopes, sliding_window, kv_cache_dtype,
193
                             logits_soft_cap, attn_type,
194
                             kv_sharing_target_layer_name, **extra_impl_args)
195
        self.backend = backend_name_to_enum(self.attn_backend.get_name())
196
        self.dtype = dtype
197

198
199
200
201
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
202
        self.use_direct_call = not current_platform.opaque_attention_op()
203

204
        self.use_output = self.attn_backend.accept_output_buffer
205
206
207
208
209
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
210
        self.attn_type = attn_type
211
212
213
214
215
216
217
218
219

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

220
221
222
223
224
225
226
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([]) for _ in range(get_current_vllm_config(
            ).parallel_config.pipeline_parallel_size)
        ]
227

228
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
229
230
231
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

232
233
234
235
236
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
237
238
239
240
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: Optional[torch.Size] = None,
241
    ) -> torch.Tensor:
242
243
244
245
246
247
248
249
250
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
251
        if self.calculate_kv_scales:
252
253
            attn_metadata = get_forward_context().attn_metadata
            if attn_metadata.enable_kv_scales_calculation:
254
                self.calc_kv_scales(query, key, value)
255
        if self.use_output:
256
257
            output_shape = (output_shape
                            if output_shape is not None else query.shape)
258
            output = torch.zeros(output_shape,
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
                                 dtype=query.dtype,
                                 device=query.device)
            hidden_size = output_shape[-1]
            # We skip reshaping query, key and value tensors for the MLA
            # backend since these tensors have different semantics and are
            # processed differently.
            if not self.use_mla:
                # Reshape the query, key, and value tensors.
                # NOTE(woosuk): We do this outside the custom op to minimize the
                # CPU overheads from the non-CUDA-graph regions.
                query = query.view(-1, self.num_heads, self.head_size)
                output = output.view(-1, self.num_heads, self.head_size)
                if key is not None:
                    key = key.view(-1, self.num_kv_heads, self.head_size)
                if value is not None:
                    value = value.view(-1, self.num_kv_heads, self.head_size)
275
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
276
                forward_context: ForwardContext = get_forward_context()
277
                attn_metadata = forward_context.attn_metadata
278
279
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
280
281
282
283
284
285
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                self.impl.forward(self,
                                  query,
                                  key,
                                  value,
                                  self_kv_cache,
286
                                  attn_metadata,
Chen Zhang's avatar
Chen Zhang committed
287
                                  output=output)
288
289
290
            else:
                torch.ops.vllm.unified_attention_with_output(
                    query, key, value, output, self.layer_name)
291
            return output.view(-1, hidden_size)
292
        else:
293
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
294
                forward_context = get_forward_context()
295
                attn_metadata = forward_context.attn_metadata
296
297
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
298
299
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                return self.impl.forward(self, query, key, value,
300
                                         self_kv_cache, attn_metadata)
301
302
303
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name)
304

305
306
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
307
308
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
309
        self._q_scale_float = self._q_scale.item()
310
311
312
313
314
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

315
316
317
318
319
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
320
        s += f", backend={self.impl.__class__.__name__}"
321
        return s
322

323
    def process_weights_after_loading(self, act_dtype: torch.dtype):
324
        if hasattr(self.impl, "process_weights_after_loading"):
325
            self.impl.process_weights_after_loading(act_dtype)
326

327
328
329
330
331
332
333
334
335
        # FlashInfer requires attention sinks to be float32
        if (self.backend == _Backend.FLASHINFER_VLLM_V1
                and hasattr(self.impl, 'sinks')):
            from vllm.v1.attention.backends.flashinfer import FlashInferImpl
            assert isinstance(self.impl, FlashInferImpl)
            if (self.impl.sinks is not None
                    and self.impl.sinks.dtype != torch.float32):
                self.impl.sinks = self.impl.sinks.to(torch.float32)

336
337
338
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

339

340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

356
357
358
        assert self.num_heads % self.num_kv_heads == 0, \
            f"num_heads ({self.num_heads}) is not " \
            f"divisible by num_kv_heads ({self.num_kv_heads})"
359
360
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

361
362
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
363
        dtype = torch.get_default_dtype()
364
365
366
367
368
369
370
371
372
373
374
375
376

        # Determine the attention backend
        backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)

        # Some auto-selected backends can be upgraded
        # to upstream flash attention if available.
        # If vllm native fa is selected, we use it directly.
        use_upstream_fa = False
        if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
                dtype):
            backend = _Backend.FLASH_ATTN
            use_upstream_fa = True

377
378
379
380
        if current_platform.is_rocm():
            # currently, only torch_sdpa is supported on rocm
            self.attn_backend = _Backend.TORCH_SDPA
        else:
381

382
            self.attn_backend = backend if backend in {
383
384
385
386
387
                _Backend.TORCH_SDPA,
                _Backend.TORCH_SDPA_VLLM_V1,
                _Backend.XFORMERS,
                _Backend.PALLAS_VLLM_V1,
                _Backend.ROCM_AITER_FA,
388
389
390
                _Backend.FLASH_ATTN,
                _Backend.FLASH_ATTN_VLLM_V1,
            } else _Backend.TORCH_SDPA
391

392
393
394
395
        if (self.attn_backend == _Backend.XFORMERS
                and not check_xformers_availability()):
            self.attn_backend = _Backend.TORCH_SDPA

396
397
398
399
400
401
402
403
404
405
406
407
408
409
        if self.attn_backend in {
                _Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1
        }:
            if use_upstream_fa:
                from flash_attn import flash_attn_varlen_func
                self._flash_attn_varlen_func = flash_attn_varlen_func
            else:
                from vllm.vllm_flash_attn import flash_attn_varlen_func
                self._flash_attn_varlen_func = flash_attn_varlen_func

        logger.info_once(
            f"MultiHeadAttention attn_backend: {self.attn_backend}, "
            f"use_upstream_fa: {use_upstream_fa}")

410
411
412
413
414
415
416
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: batch_size x seq_len x hidden_size"""
417
        # TODO(Isotr0py): Use existing backend implementations and support FA3
418
419
420
421
422
423
424
        bsz, q_len, _ = query.size()
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

425
426
427
428
429
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
        if self.attn_backend in {
                _Backend.FLASH_ATTN,
                _Backend.FLASH_ATTN_VLLM_V1,
        }:

            cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
                                        step=q_len,
                                        dtype=torch.int32,
                                        device=query.device)
            cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
                                        step=kv_len,
                                        dtype=torch.int32,
                                        device=key.device)

            out = self._flash_attn_varlen_func(
                query.flatten(0, 1),
                key.flatten(0, 1),
                value.flatten(0, 1),
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=q_len,
                max_seqlen_k=kv_len,
                softmax_scale=self.scale,
            )
        elif self.attn_backend == _Backend.XFORMERS:
455
456
457
458
459
460
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query,
                                                          key,
                                                          value,
                                                          scale=self.scale)
461
462
        elif (self.attn_backend == _Backend.TORCH_SDPA
              or self.attn_backend == _Backend.TORCH_SDPA_VLLM_V1):
463
464
465
466
467
468
469
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            out = F.scaled_dot_product_attention(query,
                                                 key,
                                                 value,
                                                 scale=self.scale)
            out = out.transpose(1, 2)
470
471
472
473
474
475
        elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            from torch_xla.experimental.custom_kernel import flash_attention
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)
476
477
478
479
480
481
482
483
484
485
486
487
488
        elif self.attn_backend == _Backend.ROCM_AITER_FA:
            from aiter import flash_attn_varlen_func

            # ROCm Flash Attention expects (batch, seq, heads, head_dim)
            out = flash_attn_varlen_func(query,
                                         key,
                                         value,
                                         softmax_scale=self.scale)
        else:
            # ViT attention hasn't supported this backend yet
            raise NotImplementedError(
                f"ViT attention hasn't supported {self.attn_backend} "
                f"backend yet.")
489

490
        return out.reshape(bsz, q_len, -1)
491
492


493
494
495
496
497
498
499
500
501
502
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
503
    assert isinstance(attn_metadata, dict)
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
    connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
    layer_name: str,
    kv_cache_layer: List[torch.Tensor],
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
520
521
522
    assert isinstance(attn_metadata, dict)
    connector.save_kv_layer(layer_name, kv_cache_layer,
                            attn_metadata[layer_name])
523
524


525
526
527
528
529
530
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
531
532
    wait_for_kv_layer_from_connector(layer_name)

533
    forward_context: ForwardContext = get_forward_context()
534
    attn_metadata = forward_context.attn_metadata
535
536
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
537
    self = forward_context.no_compile_layers[layer_name]
538
    kv_cache = self.kv_cache[forward_context.virtual_engine]
539
540
541
542
543
    output = self.impl.forward(self, query, key, value, kv_cache,
                               attn_metadata)

    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    return output
544
545
546
547
548
549
550
551
552
553
554
555
556
557


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
558
    mutates_args=[],
559
560
561
    fake_impl=unified_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)
562
563
564
565
566
567
568
569


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
570
    output_scale: Optional[torch.Tensor] = None,
571
    output_block_scale: Optional[torch.Tensor] = None,
572
) -> None:
573
    wait_for_kv_layer_from_connector(layer_name)
574
    forward_context: ForwardContext = get_forward_context()
575
    attn_metadata = forward_context.attn_metadata
576
577
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
578
    self = forward_context.no_compile_layers[layer_name]
579
    kv_cache = self.kv_cache[forward_context.virtual_engine]
580
581
    self.impl.forward(self,
                      query,
582
583
584
585
                      key,
                      value,
                      kv_cache,
                      attn_metadata,
586
                      output=output,
587
588
                      output_scale=output_scale,
                      output_block_scale=output_block_scale)
589

590
591
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)

592
593
594
595
596
597
598

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
599
    output_scale: Optional[torch.Tensor] = None,
600
    output_block_scale: Optional[torch.Tensor] = None,
601
602
603
604
605
606
607
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
608
    mutates_args=["output", "output_block_scale"],
609
610
611
    fake_impl=unified_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
)