layer.py 23.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4
from typing import Any, Dict, List, Optional
5
6
7

import torch
import torch.nn as nn
8
import torch.nn.functional as F
9

10
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_maybe_save_kv_layer_to_connector
11
import vllm.envs as envs
12
from vllm.attention import AttentionType
13
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
14
from vllm.config import CacheConfig, get_current_vllm_config
15
16
17
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group,
                                          is_v1_kv_transfer_group)
18
from vllm.forward_context import ForwardContext, get_forward_context
19
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
20
21
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
22
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
23
from vllm.platforms import _Backend, current_platform
24
from vllm.utils import direct_register_custom_op
25
from vllm.v1.attention.backends.utils import validate_kv_sharing_target
26

zhuwenwen's avatar
zhuwenwen committed
27
28
29
30
31
USE_XFORMERS_OPS = None
try:
    tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
except AttributeError:
    tag_cudagraph_unsafe = ()  # type: ignore[assignment]
32
    
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

class Attention(nn.Module):
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
53
        cache_config: Optional[CacheConfig] = None,
54
        quant_config: Optional[QuantizationConfig] = None,
55
        blocksparse_params: Optional[Dict[str, Any]] = None,
56
        logits_soft_cap: Optional[float] = None,
57
        per_layer_sliding_window: Optional[int] = None,
58
        use_mla: bool = False,
59
        prefix: str = "",
60
        attn_type: str = AttentionType.DECODER,
61
        kv_sharing_target_layer_name: Optional[str] = None,
62
        **extra_impl_args,
63
    ) -> None:
64
65
66
67
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
68
        super().__init__()
69
70
71
72
73
74
75
76
77
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

78
79
80
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
81
            is_attention_free = cache_config.is_attention_free
82
            calculate_kv_scales = cache_config.calculate_kv_scales
83
84
        else:
            kv_cache_dtype = "auto"
85
            block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA and envs.VLLM_USE_FLASH_MLA else 16
86
            is_attention_free = False
87
            calculate_kv_scales = False
88
89
        if num_kv_heads is None:
            num_kv_heads = num_heads
90
91
92
        assert num_heads % num_kv_heads == 0, \
            f"num_heads ({num_heads}) is not " \
            f"divisible by num_kv_heads ({num_kv_heads})"
93

94
        # The default k/v_scale is set to 1.0. This is ignored
95
96
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
97
        # expect the pre-quantized k/v_scale to be loaded along
98
99
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
100
101
102
        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
103
104
105
        # FlashAttn doesn't support quantizing the kv-cache only
        # but requires q to be quantized as well.
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
106
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
107
108
109
110
111
112

        # We also keep the float32 versions of k/v_scale for attention
        # backends that don't support tensors (Flashinfer)
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

113
        self.use_mla = use_mla
114
115
116
117
118
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window

119
        quant_method = quant_config.get_quant_method(
120
            self, prefix=prefix) if quant_config else None
121
122
        if quant_method is not None and not isinstance(
                quant_method, UnquantizedLinearMethod):
123
            assert isinstance(quant_method, BaseKVCacheMethod)
124
125
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
zhuwenwen's avatar
zhuwenwen committed
126
127
128
            # if self.kv_cache_dtype == "fp8_e5m2":
            #     raise ValueError("fp8_e5m2 kv-cache is not supported with "
            #                      "fp8 checkpoints.")
129
130
131
132
133
134
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
135

136
137
138
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
139
140
141
142
143
144
145
        attn_backend = get_attn_backend(head_size,
                                        dtype,
                                        kv_cache_dtype,
                                        block_size,
                                        is_attention_free,
                                        blocksparse_params is not None,
                                        use_mla=use_mla)
146
        impl_cls = attn_backend.get_impl_cls()
147
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
148
                             alibi_slopes, sliding_window, kv_cache_dtype,
149
                             blocksparse_params, logits_soft_cap, attn_type,
150
                             kv_sharing_target_layer_name, **extra_impl_args)
151
        self.backend = backend_name_to_enum(attn_backend.get_name())
152
        self.dtype = dtype
153

154
155
156
157
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
158
159
160
        self.use_direct_call = not current_platform.is_cuda_alike(
        ) and not current_platform.is_cpu()

161
        self.use_output = attn_backend.accept_output_buffer
162
163
164
165
166
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
167
        self.attn_type = attn_type
168
169
170
171
172
173
174
175
176
177
178
179
180

        if kv_sharing_target_layer_name is not None:
            if not envs.VLLM_USE_V1:
                raise NotImplementedError(
                    "Cross-layer KV sharing is not supported in V0.")

            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

181
182
183
184
185
186
187
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([]) for _ in range(get_current_vllm_config(
            ).parallel_config.pipeline_parallel_size)
        ]
188

189
        self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
190
191
192
        self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
        self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

193
194
195
196
197
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
198
199
200
201
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: Optional[torch.Size] = None,
202
203
204
205
206
        q_ori: Optional[torch.Tensor] = None,
        key_normed: Optional[torch.Tensor] = None,
        positions: Optional[torch.Tensor] = None,
        weight: Optional[torch.Tensor] = None,
        cos_sin_cache: Optional[torch.Tensor] = None,
207
    ) -> torch.Tensor:
208
209
210
211
212
213
214
215
216
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
217
        if self.calculate_kv_scales:
zhuwenwen's avatar
zhuwenwen committed
218
219
            # attn_metadata = get_forward_context().attn_metadata
            # #if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
yangql's avatar
yangql committed
220
            # if key is not None and value is not None:         
zhuwenwen's avatar
zhuwenwen committed
221
222
223
                # self.calc_kv_scales(query, key, value)
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
                                                self.layer_name)     
224
        if self.use_output:
225
226
            output_shape = (output_shape
                            if output_shape is not None else query.shape)
227
228
229
230
231
232
233
234
            if envs.VLLM_USE_OPT_ZEROS:
                output = torch.empty(output_shape,
                                    dtype=query.dtype,
                                    device=query.device)
            else:
                output = torch.zeros(output_shape,
                                    dtype=query.dtype,
                                    device=query.device)
235
236
237
238
239
240
241
242
243
244
245
246
247
248
            hidden_size = output_shape[-1]
            # We skip reshaping query, key and value tensors for the MLA
            # backend since these tensors have different semantics and are
            # processed differently.
            if not self.use_mla:
                # Reshape the query, key, and value tensors.
                # NOTE(woosuk): We do this outside the custom op to minimize the
                # CPU overheads from the non-CUDA-graph regions.
                query = query.view(-1, self.num_heads, self.head_size)
                output = output.view(-1, self.num_heads, self.head_size)
                if key is not None:
                    key = key.view(-1, self.num_kv_heads, self.head_size)
                if value is not None:
                    value = value.view(-1, self.num_kv_heads, self.head_size)
249
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
250
                forward_context: ForwardContext = get_forward_context()
251
                attn_metadata = forward_context.attn_metadata
252
253
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
254
255
256
257
258
259
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                self.impl.forward(self,
                                  query,
                                  key,
                                  value,
                                  self_kv_cache,
260
                                  attn_metadata,
Chen Zhang's avatar
Chen Zhang committed
261
                                  output=output)
262
            else:
263
264
265
266
267
268
                if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
                    torch.ops.vllm.unified_attention_with_output(
                        query, key, value, output, self.layer_name)
                else:
                    torch.ops.vllm.unified_attention_with_output(
                        query, key, value, output, self.layer_name, None, q_ori, key_normed, positions, weight, cos_sin_cache)
269
            return output.view(-1, hidden_size)
270
        else:
271
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
272
                forward_context = get_forward_context()
273
                attn_metadata = forward_context.attn_metadata
274
275
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
276
277
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                return self.impl.forward(self, query, key, value,
278
                                         self_kv_cache, attn_metadata)
279
280
281
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name)
282

283
284
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
285
286
287
288
289
290
291
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

292
293
294
295
296
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
297
        s += f", backend={self.impl.__class__.__name__}"
298
        return s
299

300
    def process_weights_after_loading(self, act_dtype: torch.dtype):
301
        if hasattr(self.impl, "process_weights_after_loading"):
302
            self.impl.process_weights_after_loading(act_dtype)
303

304

305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

321
322
323
        assert self.num_heads % self.num_kv_heads == 0, \
            f"num_heads ({self.num_heads}) is not " \
            f"divisible by num_kv_heads ({self.num_kv_heads})"
324
325
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

326
327
328
329
        dtype = torch.get_default_dtype()
        attn_backend = get_attn_backend(head_size,
                                        dtype,
                                        kv_cache_dtype=None,
330
                                        block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA and envs.VLLM_USE_FLASH_MLA else 16,
331
                                        is_attention_free=False)
332
        backend = backend_name_to_enum(attn_backend.get_name())
333
334
335
336
        if current_platform.is_rocm():
            # currently, only torch_sdpa is supported on rocm
            self.attn_backend = _Backend.TORCH_SDPA
        else:
337
338
            if backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1,
                           _Backend.FLEX_ATTENTION):
339
                backend = _Backend.XFORMERS
340

341
342
343
            self.attn_backend = backend if backend in {
                _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1
            } else _Backend.TORCH_SDPA
344
345
346
347
348
349
350
351

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: batch_size x seq_len x hidden_size"""
352
        # TODO(Isotr0py): Use existing backend implementations and support FA3
353
354
355
356
357
358
359
        bsz, q_len, _ = query.size()
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

360
361
362
363
364
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

365
        if self.attn_backend == _Backend.XFORMERS:
366
367
368
369
370
371
372
373
374
375
376
377
378
379
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query,
                                                          key,
                                                          value,
                                                          scale=self.scale)
        elif self.attn_backend == _Backend.TORCH_SDPA:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            out = F.scaled_dot_product_attention(query,
                                                 key,
                                                 value,
                                                 scale=self.scale)
            out = out.transpose(1, 2)
380
381
382
383
384
385
386
        elif self.attn_backend == _Backend.PALLAS_VLLM_V1:
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            from torch_xla.experimental.custom_kernel import flash_attention
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)

387
        return out.reshape(bsz, q_len, -1)
388
389


390
391
392
393
394
395
396
397
398
399
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
400
    assert isinstance(attn_metadata, dict)
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
    connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
    layer_name: str,
    kv_cache_layer: List[torch.Tensor],
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
417
418
419
    assert isinstance(attn_metadata, dict)
    connector.save_kv_layer(layer_name, kv_cache_layer,
                            attn_metadata[layer_name])
420

421

zhuwenwen's avatar
zhuwenwen committed
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata

    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]

    # if attn_metadata is None or not getattr(
    #         attn_metadata, 'enable_kv_scales_calculation', False):
    #     return

    self = forward_context.no_compile_layers[layer_name]
    self.calc_kv_scales(query, key, value)


def maybe_calc_kv_scales_fake(    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=[],
    fake_impl=maybe_calc_kv_scales_fake,
    dispatch_key=current_platform.dispatch_key,
    tags=tag_cudagraph_unsafe,)


460
461
462
463
464
465
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
466
467
    wait_for_kv_layer_from_connector(layer_name)

468
    forward_context: ForwardContext = get_forward_context()
469
    attn_metadata = forward_context.attn_metadata
470
471
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
472
    self = forward_context.no_compile_layers[layer_name]
473
    kv_cache = self.kv_cache[forward_context.virtual_engine]
474
475
476
    output = self.impl.forward(self, query, key, value, kv_cache,
                               attn_metadata)

477
478
479
480
    if envs.VLLM_ENABLE_TBO:
        tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    else:
        maybe_save_kv_layer_to_connector(layer_name, kv_cache)
481
    return output
482
483
484
485
486
487
488
489
490
491
492
493
494
495


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
496
    mutates_args=[],
497
498
499
    fake_impl=unified_attention_fake,
    dispatch_key=current_platform.dispatch_key,
)
500
501
502
503
504
505
506
507


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
508
    output_scale: Optional[torch.Tensor] = None,
509
510
511
512
513
    q_ori: Optional[torch.Tensor] = None,
    key_normed: Optional[torch.Tensor] = None,
    positions: Optional[torch.Tensor] = None,
    weight: Optional[torch.Tensor] = None,
    cos_sin_cache: Optional[torch.Tensor] = None,
514
) -> None:
515
    wait_for_kv_layer_from_connector(layer_name)
516
    forward_context: ForwardContext = get_forward_context()
517
    attn_metadata = forward_context.attn_metadata
518
519
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
520
    self = forward_context.no_compile_layers[layer_name]
521
    kv_cache = self.kv_cache[forward_context.virtual_engine]
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
    if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
        self.impl.forward(self,
                        query,
                        key,
                        value,
                        kv_cache,
                        attn_metadata,
                        output=output,
                        output_scale=output_scale)
    else:
        self.impl.forward(self,
                        query,
                        key,
                        value,
                        kv_cache,
                        attn_metadata,
                        output=output,
                        output_scale=output_scale,
                        q_ori=q_ori,
                        key_normed=key_normed,
                        positions=positions,
                        weight=weight,
                        cos_sin_cache=cos_sin_cache)
545
546
547
548
    if envs.VLLM_ENABLE_TBO:
        tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
    else:
        maybe_save_kv_layer_to_connector(layer_name, kv_cache)
549

550

551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
    def unified_attention_with_output_fake(
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output: torch.Tensor,
        layer_name: str,
        output_scale: Optional[torch.Tensor] = None,
    ) -> None:
        return
else:
    def unified_attention_with_output_fake(
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output: torch.Tensor,
        layer_name: str,
        output_scale: Optional[torch.Tensor] = None,
        q_ori: Optional[torch.Tensor] = None,
        key_normed: Optional[torch.Tensor] = None,
        positions: Optional[torch.Tensor] = None,
        weight: Optional[torch.Tensor] = None,
        cos_sin_cache: Optional[torch.Tensor] = None,
        ) -> None:
        return
576
577
578
579
580


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
581
    mutates_args=["output"],
582
583
    fake_impl=unified_attention_with_output_fake,
    dispatch_key=current_platform.dispatch_key,
584
)