layer.py 26.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4
from typing import List, Optional
5
6
7

import torch
import torch.nn as nn
8
import torch.nn.functional as F
9

10
import vllm.envs as envs
11
from vllm.attention import AttentionType
12
from vllm.attention.backends.abstract import AttentionBackend
13
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
14
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
15
from vllm.config import CacheConfig, get_current_vllm_config
16
17
18
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group,
                                          is_v1_kv_transfer_group)
19
from vllm.forward_context import ForwardContext, get_forward_context
20
from vllm.logger import init_logger
21
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
22
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
23
24
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
25
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
26
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
27
28
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    GroupShape)
29
from vllm.model_executor.models.vision import get_vit_attn_backend
30
from vllm.platforms import _Backend, current_platform
31
from vllm.utils import GiB_bytes, direct_register_custom_op
32

33
34
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
35
36
37
38
try:
    tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
except AttributeError:
    tag_cudagraph_unsafe = ()  # type: ignore[assignment]
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63


def check_xformers_availability():
    global USE_XFORMERS_OPS
    if USE_XFORMERS_OPS is not None:
        return USE_XFORMERS_OPS

    if current_platform.is_cuda() and current_platform.has_device_capability(
            100):
        # Xformers FA is not compatible with B200
        USE_XFORMERS_OPS = False
    else:
        try:
            from importlib.util import find_spec

            find_spec("xformers.ops")
            USE_XFORMERS_OPS = True
        except ImportError:
            USE_XFORMERS_OPS = False

    # the warning only needs to be shown once
    if not USE_XFORMERS_OPS:
        logger.warning("Xformers is not available, falling back.")

    return USE_XFORMERS_OPS
64
65


66
67
68
69
70
71
72
73
def check_upstream_fa_availability(dtype: torch.dtype):
    if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
    ) and current_platform.has_device_capability(80):
        from transformers.utils import is_flash_attn_2_available
        return is_flash_attn_2_available()
    return False


74
class Attention(nn.Module, AttentionLayerBase):
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
93
        cache_config: Optional[CacheConfig] = None,
94
        quant_config: Optional[QuantizationConfig] = None,
95
        logits_soft_cap: Optional[float] = None,
96
        per_layer_sliding_window: Optional[int] = None,
97
        use_mla: bool = False,
98
        use_sparse: bool = False,
99
        prefix: str = "",
100
        attn_type: str = AttentionType.DECODER,
101
        kv_sharing_target_layer_name: Optional[str] = None,
102
        attn_backend: Optional[type[AttentionBackend]] = None,
103
        **extra_impl_args,
104
    ) -> None:
105
106
107
108
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
109
        super().__init__()
110
111
112
113
114
115
116
117
118
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

119
120
121
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
122
            calculate_kv_scales = cache_config.calculate_kv_scales
123
124
        else:
            kv_cache_dtype = "auto"
125
            block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA and envs.VLLM_USE_FLASH_MLA else 16
126
            calculate_kv_scales = False
127
128
        if num_kv_heads is None:
            num_kv_heads = num_heads
129
130
131
        assert num_heads % num_kv_heads == 0, \
            f"num_heads ({num_heads}) is not " \
            f"divisible by num_kv_heads ({num_kv_heads})"
132

133
        # The default k/v_scale is set to 1.0. This is ignored
134
135
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
136
        # expect the pre-quantized k/v_scale to be loaded along
137
138
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
139
140
141
        self.calculate_kv_scales = calculate_kv_scales
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
142
143
144
        # FlashAttn doesn't support quantizing the kv-cache only
        # but requires q to be quantized as well.
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
145
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
146

147
148
149
150
        # We also keep q/k/v_scale on host (cpu) memory for attention
        # backends that require the scales to be on host instead of on device.
        # e.g. Flashinfer
        self._q_scale_float = 1.0
151
152
153
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

154
155
156
157
        # The output scale on host memory. This should be the input scale of
        # the quant op after this attention layer.
        self._o_scale_float: Optional[float] = None

158
        self.use_mla = use_mla
159
        self.use_sparse = use_sparse
160
161
162
163
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
164
        self.has_sink = extra_impl_args.get("sinks") is not None
165

166
        quant_method = quant_config.get_quant_method(
167
            self, prefix=prefix) if quant_config else None
168
169
        if quant_method is not None and not isinstance(
                quant_method, UnquantizedLinearMethod):
170
            assert isinstance(quant_method, BaseKVCacheMethod)
171
172
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
zhuwenwen's avatar
zhuwenwen committed
173
174
175
            # if self.kv_cache_dtype == "fp8_e5m2":
            #     raise ValueError("fp8_e5m2 kv-cache is not supported with "
            #                      "fp8 checkpoints.")
176
177
178
179
180
181
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
182

183
184
185
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
186
187
188
189
190
        if attn_backend is None:
            self.attn_backend = get_attn_backend(head_size,
                                                 dtype,
                                                 kv_cache_dtype,
                                                 block_size,
191
                                                 use_mla=use_mla,
192
193
                                                 has_sink=self.has_sink,
                                                 use_sparse=use_sparse)
194
195
196
197
        else:
            self.attn_backend = attn_backend

        impl_cls = self.attn_backend.get_impl_cls()
198
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
199
                             alibi_slopes, sliding_window, kv_cache_dtype,
200
                             logits_soft_cap, attn_type,
201
                             kv_sharing_target_layer_name, **extra_impl_args)
202
        self.backend = backend_name_to_enum(self.attn_backend.get_name())
203
        self.dtype = dtype
204

205
206
207
208
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
209
        self.use_direct_call = not current_platform.opaque_attention_op()
210

211
        self.use_output = self.attn_backend.accept_output_buffer
212
213
214
215
216
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
217
        self.attn_type = attn_type
218
219
220
221
222
223
224
225
226

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

227
228
229
230
231
232
233
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([]) for _ in range(get_current_vllm_config(
            ).parallel_config.pipeline_parallel_size)
        ]
234

235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        try:
            self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT,
                                        dtype=torch.float32)
            self.k_range = torch.tensor(envs.K_SCALE_CONSTANT,
                                        dtype=torch.float32)
            self.v_range = torch.tensor(envs.V_SCALE_CONSTANT,
                                        dtype=torch.float32)
        except torch.cuda.OutOfMemoryError as e:
            logger.error(
                "Failed to initialize attention q/k/v range constants: %s", e)
            if torch.cuda.is_available():
                logger.debug("CUDA device: %s", torch.cuda.current_device())
                logger.debug("Allocated: %.2f GiB",
                             torch.cuda.memory_allocated() / GiB_bytes)
                logger.debug("Reserved: %.2f GiB",
                             torch.cuda.memory_reserved() / GiB_bytes)
            raise RuntimeError(
                "Failed to initialize q/k/v range constants. "
                "This may be caused by insufficient memory to allocate "
                "kv cache.") from e
255

256
257
258
259
260
261
        # for attn backends supporting query quantization
        self.query_quant = None
        if self.kv_cache_dtype.startswith(
                "fp8") and self.attn_backend.supports_quant_query_input:
            self.query_quant = QuantFP8(static=True,
                                        group_shape=GroupShape.PER_TENSOR)
262

263
264
265
266
267
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
268
269
270
271
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: Optional[torch.Size] = None,
272
    ) -> torch.Tensor:
273
274
275
276
277
278
279
280
281
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
Chen Zhang's avatar
Chen Zhang committed
282
        if self.calculate_kv_scales:
283
284
            attn_metadata = get_forward_context().attn_metadata
            if attn_metadata.enable_kv_scales_calculation:
285
                self.calc_kv_scales(query, key, value)
286
287
288
289
290
291
292
293
294
295
296

        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
            query, _ = self.query_quant(query, self._q_scale)

297
        if self.use_output:
298
299
            output_shape = (output_shape
                            if output_shape is not None else query.shape)
300
301
302
303
304
305
306
307
            if envs.VLLM_USE_OPT_ZEROS:
                output = torch.empty(output_shape,
                                    dtype=query.dtype,
                                    device=query.device)
            else:
                output = torch.zeros(output_shape,
                                    dtype=query.dtype,
                                    device=query.device)
308
309
310
311
312
313
314
315
316
317
318
319
320
321
            hidden_size = output_shape[-1]
            # We skip reshaping query, key and value tensors for the MLA
            # backend since these tensors have different semantics and are
            # processed differently.
            if not self.use_mla:
                # Reshape the query, key, and value tensors.
                # NOTE(woosuk): We do this outside the custom op to minimize the
                # CPU overheads from the non-CUDA-graph regions.
                query = query.view(-1, self.num_heads, self.head_size)
                output = output.view(-1, self.num_heads, self.head_size)
                if key is not None:
                    key = key.view(-1, self.num_kv_heads, self.head_size)
                if value is not None:
                    value = value.view(-1, self.num_kv_heads, self.head_size)
322
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
323
                forward_context: ForwardContext = get_forward_context()
324
                attn_metadata = forward_context.attn_metadata
325
326
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
327
328
329
330
331
332
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                self.impl.forward(self,
                                  query,
                                  key,
                                  value,
                                  self_kv_cache,
333
                                  attn_metadata,
Chen Zhang's avatar
Chen Zhang committed
334
                                  output=output)
335
336
337
            else:
                torch.ops.vllm.unified_attention_with_output(
                    query, key, value, output, self.layer_name)
338
            return output.view(-1, hidden_size)
339
        else:
340
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
341
                forward_context = get_forward_context()
342
                attn_metadata = forward_context.attn_metadata
343
344
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
345
346
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                return self.impl.forward(self, query, key, value,
347
                                         self_kv_cache, attn_metadata)
348
349
350
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name)
351

352
353
    def calc_kv_scales(self, query, key, value):
        self._q_scale.copy_(torch.abs(query).max() / self.q_range)
354
355
        self._k_scale.copy_(torch.abs(key).max() / self.k_range)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range)
356
        self._q_scale_float = self._q_scale.item()
357
358
359
360
361
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False

362
363
364
365
366
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
367
        s += f", backend={self.impl.__class__.__name__}"
368
        return s
369

370
    def process_weights_after_loading(self, act_dtype: torch.dtype):
371
        if hasattr(self.impl, "process_weights_after_loading"):
372
            self.impl.process_weights_after_loading(act_dtype)
373

374
        # FlashInfer requires attention sinks to be float32
375
        if (self.backend == _Backend.FLASHINFER
376
377
378
379
380
381
382
                and hasattr(self.impl, 'sinks')):
            from vllm.v1.attention.backends.flashinfer import FlashInferImpl
            assert isinstance(self.impl, FlashInferImpl)
            if (self.impl.sinks is not None
                    and self.impl.sinks.dtype != torch.float32):
                self.impl.sinks = self.impl.sinks.to(torch.float32)

383
384
385
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

386

387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

403
404
405
        assert self.num_heads % self.num_kv_heads == 0, \
            f"num_heads ({self.num_heads}) is not " \
            f"divisible by num_kv_heads ({self.num_kv_heads})"
406
407
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

408
409
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
410
        dtype = torch.get_default_dtype()
411
412
413
414
415
416
417
418
419
420
421
422
423

        # Determine the attention backend
        backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)

        # Some auto-selected backends can be upgraded
        # to upstream flash attention if available.
        # If vllm native fa is selected, we use it directly.
        use_upstream_fa = False
        if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
                dtype):
            backend = _Backend.FLASH_ATTN
            use_upstream_fa = True

424
        if current_platform.is_xpu():
425
            # currently, only torch_sdpa is supported on rocm/xpu
426
            self.attn_backend = _Backend.TORCH_SDPA
427
428
429
430
        elif current_platform.is_rocm():
            self.attn_backend = backend if backend in {
                _Backend.FLASH_ATTN,
            } else _Backend.TORCH_SDPA
431
432
        else:
            self.attn_backend = backend if backend in {
433
434
                _Backend.TORCH_SDPA,
                _Backend.XFORMERS,
435
                _Backend.PALLAS,
436
                _Backend.ROCM_AITER_FA,
437
438
                _Backend.FLASH_ATTN,
            } else _Backend.TORCH_SDPA
439

440
441
442
443
        if (self.attn_backend == _Backend.XFORMERS
                and not check_xformers_availability()):
            self.attn_backend = _Backend.TORCH_SDPA

444
        if self.attn_backend == _Backend.FLASH_ATTN:
445
446
447
448
            if use_upstream_fa:
                from flash_attn import flash_attn_varlen_func
                self._flash_attn_varlen_func = flash_attn_varlen_func
            else:
449
450
451
452
453
454
                if current_platform.is_rocm():
                    from flash_attn import flash_attn_varlen_func
                    self._flash_attn_varlen_func = flash_attn_varlen_func
                else:
                    from vllm.vllm_flash_attn import flash_attn_varlen_func
                    self._flash_attn_varlen_func = flash_attn_varlen_func
455
456
457
458
459

        logger.info_once(
            f"MultiHeadAttention attn_backend: {self.attn_backend}, "
            f"use_upstream_fa: {use_upstream_fa}")

460
461
462
463
464
465
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
466
467
468
469
470
        """Input shape: 
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
471
472
473
474
475
476
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

477
478
479
480
481
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

482
        if self.attn_backend == _Backend.FLASH_ATTN:
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
            cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
                                        step=q_len,
                                        dtype=torch.int32,
                                        device=query.device)
            cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
                                        step=kv_len,
                                        dtype=torch.int32,
                                        device=key.device)

            out = self._flash_attn_varlen_func(
                query.flatten(0, 1),
                key.flatten(0, 1),
                value.flatten(0, 1),
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=q_len,
                max_seqlen_k=kv_len,
                softmax_scale=self.scale,
            )
        elif self.attn_backend == _Backend.XFORMERS:
503
504
505
506
507
508
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query,
                                                          key,
                                                          value,
                                                          scale=self.scale)
509
        elif self.attn_backend == _Backend.TORCH_SDPA:
510
511
512
513
514
515
516
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            out = F.scaled_dot_product_attention(query,
                                                 key,
                                                 value,
                                                 scale=self.scale)
            out = out.transpose(1, 2)
517
        elif self.attn_backend == _Backend.PALLAS:
518
519
520
521
522
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            from torch_xla.experimental.custom_kernel import flash_attention
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)
523
524
525
526
527
528
529
530
531
532
533
534
535
        elif self.attn_backend == _Backend.ROCM_AITER_FA:
            from aiter import flash_attn_varlen_func

            # ROCm Flash Attention expects (batch, seq, heads, head_dim)
            out = flash_attn_varlen_func(query,
                                         key,
                                         value,
                                         softmax_scale=self.scale)
        else:
            # ViT attention hasn't supported this backend yet
            raise NotImplementedError(
                f"ViT attention hasn't supported {self.attn_backend} "
                f"backend yet.")
536

537
        return out.reshape(bsz, q_len, -1)
538
539


540
541
542
543
544
545
546
547
548
549
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
550
    assert isinstance(attn_metadata, dict)
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
    layer_name: str,
    kv_cache_layer: List[torch.Tensor],
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
567
568
569
    assert isinstance(attn_metadata, dict)
    connector.save_kv_layer(layer_name, kv_cache_layer,
                            attn_metadata[layer_name])
570
571


572
573
574
575
576
577
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
578
579
    wait_for_kv_layer_from_connector(layer_name)

580
    forward_context: ForwardContext = get_forward_context()
581
    attn_metadata = forward_context.attn_metadata
582
583
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
584
    self = forward_context.no_compile_layers[layer_name]
585
    kv_cache = self.kv_cache[forward_context.virtual_engine]
586
587
    output = self.impl.forward(self, query, key, value, kv_cache,
                               attn_metadata)
zhuwenwen's avatar
zhuwenwen committed
588
589
    
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
590
    return output
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
606
    tags=tag_cudagraph_unsafe,
607
)
608
609
610
611
612
613
614
615


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
616
    output_scale: Optional[torch.Tensor] = None,
617
    output_block_scale: Optional[torch.Tensor] = None,
618
) -> None:
619
    wait_for_kv_layer_from_connector(layer_name)
620
    forward_context: ForwardContext = get_forward_context()
621
    attn_metadata = forward_context.attn_metadata
622
623
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
624
    self = forward_context.no_compile_layers[layer_name]
625
    kv_cache = self.kv_cache[forward_context.virtual_engine]
626
627
    self.impl.forward(self,
                      query,
628
629
630
631
                      key,
                      value,
                      kv_cache,
                      attn_metadata,
632
                      output=output,
633
634
                      output_scale=output_scale,
                      output_block_scale=output_block_scale)
635

zhuwenwen's avatar
zhuwenwen committed
636
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
637

638
639
640
641
642
643
644

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
645
    output_scale: Optional[torch.Tensor] = None,
646
    output_block_scale: Optional[torch.Tensor] = None,
647
648
649
650
651
652
653
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
654
    mutates_args=["output", "output_block_scale"],
655
    fake_impl=unified_attention_with_output_fake,
656
    tags=tag_cudagraph_unsafe,
657
)