layer.py 28.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer."""
4
from typing import List, Optional
5
6
7

import torch
import torch.nn as nn
8
import torch.nn.functional as F
9

10
import vllm.envs as envs
11
from vllm.attention import AttentionType
12
from vllm.attention.backends.abstract import AttentionBackend
13
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
14
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
15
from vllm.config import CacheConfig, get_current_vllm_config
16
17
18
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group,
                                          is_v1_kv_transfer_group)
19
from vllm.forward_context import ForwardContext, get_forward_context
20
from vllm.logger import init_logger
21
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
22
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
23
24
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
25
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
26
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
27
28
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    GroupShape)
29
from vllm.model_executor.models.vision import get_vit_attn_backend
30
from vllm.platforms import _Backend, current_platform
31
from vllm.utils import GiB_bytes, direct_register_custom_op
32

33
34
logger = init_logger(__name__)
USE_XFORMERS_OPS = None
35
36
37
38
try:
    tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
except AttributeError:
    tag_cudagraph_unsafe = ()  # type: ignore[assignment]
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63


def check_xformers_availability():
    global USE_XFORMERS_OPS
    if USE_XFORMERS_OPS is not None:
        return USE_XFORMERS_OPS

    if current_platform.is_cuda() and current_platform.has_device_capability(
            100):
        # Xformers FA is not compatible with B200
        USE_XFORMERS_OPS = False
    else:
        try:
            from importlib.util import find_spec

            find_spec("xformers.ops")
            USE_XFORMERS_OPS = True
        except ImportError:
            USE_XFORMERS_OPS = False

    # the warning only needs to be shown once
    if not USE_XFORMERS_OPS:
        logger.warning("Xformers is not available, falling back.")

    return USE_XFORMERS_OPS
64
65


66
67
68
69
70
71
72
73
def check_upstream_fa_availability(dtype: torch.dtype):
    if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
    ) and current_platform.has_device_capability(80):
        from transformers.utils import is_flash_attn_2_available
        return is_flash_attn_2_available()
    return False


74
class Attention(nn.Module, AttentionLayerBase):
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    """Attention layer.

    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens.
    The class does the following:

    1. Store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention.
    3. Return the output tensor.
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
93
        cache_config: Optional[CacheConfig] = None,
94
        quant_config: Optional[QuantizationConfig] = None,
95
        logits_soft_cap: Optional[float] = None,
96
        per_layer_sliding_window: Optional[int] = None,
97
        use_mla: bool = False,
98
        use_sparse: bool = False,
99
        prefix: str = "",
100
        attn_type: str = AttentionType.DECODER,
101
        kv_sharing_target_layer_name: Optional[str] = None,
102
        attn_backend: Optional[type[AttentionBackend]] = None,
103
        **extra_impl_args,
104
    ) -> None:
105
106
107
108
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.
        """
109
        super().__init__()
110
111
112
113
114
115
116
117
118
        if per_layer_sliding_window is not None:
            # per-layer sliding window
            sliding_window = per_layer_sliding_window
        elif cache_config is not None:
            # model-level sliding window
            sliding_window = cache_config.sliding_window
        else:
            sliding_window = None

119
120
121
        if cache_config is not None:
            kv_cache_dtype = cache_config.cache_dtype
            block_size = cache_config.block_size
122
            calculate_kv_scales = cache_config.calculate_kv_scales
123
124
        else:
            kv_cache_dtype = "auto"
125
            block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA and envs.VLLM_USE_FLASH_MLA else 16
126
            calculate_kv_scales = False
127
128
        if num_kv_heads is None:
            num_kv_heads = num_heads
129
130
131
        assert num_heads % num_kv_heads == 0, \
            f"num_heads ({num_heads}) is not " \
            f"divisible by num_kv_heads ({num_kv_heads})"
132

133
        # The default k/v_scale is set to 1.0. This is ignored
134
135
        # when kv-cache is not fp8, and should be used with
        # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
136
        # expect the pre-quantized k/v_scale to be loaded along
137
138
        # with the model weights.
        self.kv_cache_dtype = kv_cache_dtype
139
        self.calculate_kv_scales = calculate_kv_scales
140
        if self.kv_cache_dtype in {"fp8", "fp8_e4m3","fp8_e5m2"} :
zhuwenwen's avatar
zhuwenwen committed
141
142
143
            self.check_fp8_overflow = True
        else:
            self.check_fp8_overflow = False
144
145
        self._k_scale = torch.tensor(1.0, dtype=torch.float32)
        self._v_scale = torch.tensor(1.0, dtype=torch.float32)
146
147
148
        # FlashAttn doesn't support quantizing the kv-cache only
        # but requires q to be quantized as well.
        self._q_scale = torch.tensor(1.0, dtype=torch.float32)
149
        self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
150

151
152
153
154
        # We also keep q/k/v_scale on host (cpu) memory for attention
        # backends that require the scales to be on host instead of on device.
        # e.g. Flashinfer
        self._q_scale_float = 1.0
155
156
157
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0

158
159
160
161
        # The output scale on host memory. This should be the input scale of
        # the quant op after this attention layer.
        self._o_scale_float: Optional[float] = None

162
        self.use_mla = use_mla
163
        self.use_sparse = use_sparse
164
165
166
167
        self.num_heads = num_heads
        self.head_size = head_size
        self.num_kv_heads = num_kv_heads
        self.sliding_window = sliding_window
168
        self.has_sink = extra_impl_args.get("sinks") is not None
169

170
        quant_method = quant_config.get_quant_method(
171
            self, prefix=prefix) if quant_config else None
172
173
        if quant_method is not None and not isinstance(
                quant_method, UnquantizedLinearMethod):
174
            assert isinstance(quant_method, BaseKVCacheMethod)
175
176
            # TODO (mgoin): kv cache dtype should be specified in the FP8
            # checkpoint config and become the "auto" behavior
zhuwenwen's avatar
zhuwenwen committed
177
178
179
            # if self.kv_cache_dtype == "fp8_e5m2":
            #     raise ValueError("fp8_e5m2 kv-cache is not supported with "
            #                      "fp8 checkpoints.")
180
181
182
183
184
185
            # If quantization is enabled, we make "k_scale" and "v_scale"
            # parameters so that it can be loaded from the model checkpoint.
            # The k/v_scale will then be converted back to native float32
            # values after weight loading.
            self.quant_method = quant_method
            self.quant_method.create_weights(self)
186

187
188
189
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
        dtype = torch.get_default_dtype()
190
191
192
193
194
        if attn_backend is None:
            self.attn_backend = get_attn_backend(head_size,
                                                 dtype,
                                                 kv_cache_dtype,
                                                 block_size,
195
                                                 use_mla=use_mla,
196
197
                                                 has_sink=self.has_sink,
                                                 use_sparse=use_sparse)
198
199
200
201
        else:
            self.attn_backend = attn_backend

        impl_cls = self.attn_backend.get_impl_cls()
202
        self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
203
                             alibi_slopes, sliding_window, kv_cache_dtype,
204
                             logits_soft_cap, attn_type,
205
                             kv_sharing_target_layer_name, **extra_impl_args)
206
        self.backend = backend_name_to_enum(self.attn_backend.get_name())
207
        self.dtype = dtype
208

209
210
211
212
        # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
        # torch.compile works by registering the attention as one giant
        # opaque custom op. For other platforms, we directly call them
        # and let torch.compile handle them.
213
        self.use_direct_call = not current_platform.opaque_attention_op()
214

215
        self.use_output = self.attn_backend.accept_output_buffer
216
217
218
219
220
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self
        self.layer_name = prefix
221
        self.attn_type = attn_type
222
223
224
225
226
227
228
229
230

        if kv_sharing_target_layer_name is not None:
            validate_kv_sharing_target(
                prefix,
                kv_sharing_target_layer_name,
                compilation_config.static_forward_context,
            )
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

231
232
233
234
235
236
237
        # use a placeholder kv cache tensor during init, which will be replaced
        # by bind_kv_cache
        # this variable will not be accessed if use_direct_call is True
        self.kv_cache = [
            torch.tensor([]) for _ in range(get_current_vllm_config(
            ).parallel_config.pipeline_parallel_size)
        ]
238

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        try:
            self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT,
                                        dtype=torch.float32)
            self.k_range = torch.tensor(envs.K_SCALE_CONSTANT,
                                        dtype=torch.float32)
            self.v_range = torch.tensor(envs.V_SCALE_CONSTANT,
                                        dtype=torch.float32)
        except torch.cuda.OutOfMemoryError as e:
            logger.error(
                "Failed to initialize attention q/k/v range constants: %s", e)
            if torch.cuda.is_available():
                logger.debug("CUDA device: %s", torch.cuda.current_device())
                logger.debug("Allocated: %.2f GiB",
                             torch.cuda.memory_allocated() / GiB_bytes)
                logger.debug("Reserved: %.2f GiB",
                             torch.cuda.memory_reserved() / GiB_bytes)
            raise RuntimeError(
                "Failed to initialize q/k/v range constants. "
                "This may be caused by insufficient memory to allocate "
                "kv cache.") from e
259

260
261
        # for attn backends supporting query quantization
        self.query_quant = None
zhuwenwen's avatar
zhuwenwen committed
262
263
264
        # @TODO
        if envs.VLLM_USE_QUERY_QUANT:
            if self.kv_cache_dtype.startswith(
xiabo's avatar
xiabo committed
265
                        "fp8") and self.attn_backend.supports_quant_query_input:
zhuwenwen's avatar
zhuwenwen committed
266
267
                self.query_quant = QuantFP8(static=True,
                                            group_shape=GroupShape.PER_TENSOR)
268

269
270
271
272
273
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
274
275
276
277
        # For some alternate attention backends like MLA the attention output
        # shape does not match the query shape, so we optionally let the model
        # definition specify the output tensor shape.
        output_shape: Optional[torch.Size] = None,
278
    ) -> torch.Tensor:
279
280
281
282
283
284
285
286
287
        """
        The KV cache is stored inside this class and is accessed via
        `self.kv_cache`.

        Attention metadata (`attn_metadata`) is set using a context manager in
        the model runner's `execute_model` method. It is accessed via forward
        context using
        `vllm.forward_context.get_forward_context().attn_metadata`.
        """
zhuwenwen's avatar
zhuwenwen committed
288
        if self.calculate_kv_scales or self.check_fp8_overflow:
xiabo's avatar
xiabo committed
289
290
291
292
293
            # attn_metadata = get_forward_context().attn_metadata
            # if attn_metadata.enable_kv_scales_calculation:
            #     self.calc_kv_scales(query, key, value)
            torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
                                                self.layer_name)  
294
295
296
297
298
299
300
301
302
303
304

        output_dtype = query.dtype
        if self.query_quant is not None:
            # quantizing with a simple torch operation enables
            # torch.compile to fuse this into previous ops
            # which reduces overheads during decoding.
            # Otherwise queries are quantized using custom ops
            # which causes decoding overheads
            assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
            query, _ = self.query_quant(query, self._q_scale)

305
        if self.use_output:
306
307
            output_shape = (output_shape
                            if output_shape is not None else query.shape)
308
309
            if envs.VLLM_USE_OPT_ZEROS:
                output = torch.empty(output_shape,
xiabo's avatar
xiabo committed
310
                                    dtype=output_dtype,
311
312
313
                                    device=query.device)
            else:
                output = torch.zeros(output_shape,
xiabo's avatar
xiabo committed
314
                                    dtype=output_dtype,
315
                                    device=query.device)
316
317
318
319
320
321
322
323
324
325
326
327
328
329
            hidden_size = output_shape[-1]
            # We skip reshaping query, key and value tensors for the MLA
            # backend since these tensors have different semantics and are
            # processed differently.
            if not self.use_mla:
                # Reshape the query, key, and value tensors.
                # NOTE(woosuk): We do this outside the custom op to minimize the
                # CPU overheads from the non-CUDA-graph regions.
                query = query.view(-1, self.num_heads, self.head_size)
                output = output.view(-1, self.num_heads, self.head_size)
                if key is not None:
                    key = key.view(-1, self.num_kv_heads, self.head_size)
                if value is not None:
                    value = value.view(-1, self.num_kv_heads, self.head_size)
330
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
331
                forward_context: ForwardContext = get_forward_context()
332
                attn_metadata = forward_context.attn_metadata
333
334
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
335
336
337
338
339
340
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                self.impl.forward(self,
                                  query,
                                  key,
                                  value,
                                  self_kv_cache,
341
                                  attn_metadata,
Chen Zhang's avatar
Chen Zhang committed
342
                                  output=output)
343
344
345
            else:
                torch.ops.vllm.unified_attention_with_output(
                    query, key, value, output, self.layer_name)
346
            return output.view(-1, hidden_size)
347
        else:
348
            if self.use_direct_call:
Chen Zhang's avatar
Chen Zhang committed
349
                forward_context = get_forward_context()
350
                attn_metadata = forward_context.attn_metadata
351
352
                if isinstance(attn_metadata, dict):
                    attn_metadata = attn_metadata[self.layer_name]
Chen Zhang's avatar
Chen Zhang committed
353
354
                self_kv_cache = self.kv_cache[forward_context.virtual_engine]
                return self.impl.forward(self, query, key, value,
355
                                         self_kv_cache, attn_metadata)
356
357
358
            else:
                return torch.ops.vllm.unified_attention(
                    query, key, value, self.layer_name)
359

360
    def calc_kv_scales(self, query, key, value):
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        if self.calculate_kv_scales == False :
            if self.kv_cache_dtype in {"fp8", "fp8_e4m3"} and torch.abs(query).max().item()<=200 : #check fp8  overflow
                return
            if  torch.abs(query).max().item()>=0.01 : #check fp8 too small
                return
        bias=0.0 # add bias to avoid q values are too small(or zeros) and scales are not correct
        if torch.abs(query).max().item() < 0.01:
            if self.kv_cache_dtype in {"fp8_e5m2"}:
                bias = 0.1
            else :
                bias = 1.0
        self._q_scale.copy_(torch.abs(query).max() / self.q_range+bias)
        self._k_scale.copy_(torch.abs(key).max() / self.k_range+bias)
        self._v_scale.copy_(torch.abs(value).max() / self.v_range+bias)
        
376
        self._q_scale_float = self._q_scale.item()
377
378
379
380
        self._k_scale_float = self._k_scale.item()
        self._v_scale_float = self._v_scale.item()
        # We only calculate the scales once
        self.calculate_kv_scales = False
381
        self.check_fp8_overflow = False
382

383
384
385
386
387
    def extra_repr(self) -> str:
        s = f"head_size={self.impl.head_size}"  # type: ignore
        s += f", num_heads={self.impl.num_heads}"  # type: ignore
        s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
        s += f", scale={self.impl.scale}"  # type: ignore
388
        s += f", backend={self.impl.__class__.__name__}"
389
        return s
390

391
    def process_weights_after_loading(self, act_dtype: torch.dtype):
392
        if hasattr(self.impl, "process_weights_after_loading"):
393
            self.impl.process_weights_after_loading(act_dtype)
394

395
        # FlashInfer requires attention sinks to be float32
396
        if (self.backend == _Backend.FLASHINFER
397
398
399
400
401
402
403
                and hasattr(self.impl, 'sinks')):
            from vllm.v1.attention.backends.flashinfer import FlashInferImpl
            assert isinstance(self.impl, FlashInferImpl)
            if (self.impl.sinks is not None
                    and self.impl.sinks.dtype != torch.float32):
                self.impl.sinks = self.impl.sinks.to(torch.float32)

404
405
406
    def get_attn_backend(self) -> type[AttentionBackend]:
        return self.attn_backend

407

408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
class MultiHeadAttention(nn.Module):
    """Multi-headed attention without any cache, used for ViT."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

424
425
426
        assert self.num_heads % self.num_kv_heads == 0, \
            f"num_heads ({self.num_heads}) is not " \
            f"divisible by num_kv_heads ({self.num_kv_heads})"
427
428
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

429
430
        # During model initialization, the default dtype is set as the model
        # weight and activation dtype.
431
        dtype = torch.get_default_dtype()
432
433
434
435
436
437
438
439
440
441
442
443
444

        # Determine the attention backend
        backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)

        # Some auto-selected backends can be upgraded
        # to upstream flash attention if available.
        # If vllm native fa is selected, we use it directly.
        use_upstream_fa = False
        if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
                dtype):
            backend = _Backend.FLASH_ATTN
            use_upstream_fa = True

445
        if current_platform.is_xpu():
446
            # currently, only torch_sdpa is supported on rocm/xpu
447
            self.attn_backend = _Backend.TORCH_SDPA
448
449
450
451
        elif current_platform.is_rocm():
            self.attn_backend = backend if backend in {
                _Backend.FLASH_ATTN,
            } else _Backend.TORCH_SDPA
452
453
        else:
            self.attn_backend = backend if backend in {
454
455
                _Backend.TORCH_SDPA,
                _Backend.XFORMERS,
456
                _Backend.PALLAS,
457
                _Backend.ROCM_AITER_FA,
458
459
                _Backend.FLASH_ATTN,
            } else _Backend.TORCH_SDPA
460

461
462
463
464
        if (self.attn_backend == _Backend.XFORMERS
                and not check_xformers_availability()):
            self.attn_backend = _Backend.TORCH_SDPA

465
        if self.attn_backend == _Backend.FLASH_ATTN:
466
467
468
469
            if use_upstream_fa:
                from flash_attn import flash_attn_varlen_func
                self._flash_attn_varlen_func = flash_attn_varlen_func
            else:
470
471
472
473
474
475
                if current_platform.is_rocm():
                    from flash_attn import flash_attn_varlen_func
                    self._flash_attn_varlen_func = flash_attn_varlen_func
                else:
                    from vllm.vllm_flash_attn import flash_attn_varlen_func
                    self._flash_attn_varlen_func = flash_attn_varlen_func
476
477
478
479
480

        logger.info_once(
            f"MultiHeadAttention attn_backend: {self.attn_backend}, "
            f"use_upstream_fa: {use_upstream_fa}")

481
482
483
484
485
486
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
487
488
489
490
491
        """Input shape: 
        (batch_size x seq_len x hidden_size) or
        (batch_size x seq_len x num_heads x head_size)
        """
        bsz, q_len = query.size()[:2]
492
493
494
495
496
497
        kv_len = key.size(1)

        query = query.view(bsz, q_len, self.num_heads, self.head_size)
        key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
        value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)

498
499
500
501
502
        if (num_repeat := self.num_queries_per_kv) > 1:
            # Handle MQA and GQA
            key = torch.repeat_interleave(key, num_repeat, dim=2)
            value = torch.repeat_interleave(value, num_repeat, dim=2)

503
        if self.attn_backend == _Backend.FLASH_ATTN:
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
            cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
                                        step=q_len,
                                        dtype=torch.int32,
                                        device=query.device)
            cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
                                        step=kv_len,
                                        dtype=torch.int32,
                                        device=key.device)

            out = self._flash_attn_varlen_func(
                query.flatten(0, 1),
                key.flatten(0, 1),
                value.flatten(0, 1),
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=q_len,
                max_seqlen_k=kv_len,
                softmax_scale=self.scale,
            )
        elif self.attn_backend == _Backend.XFORMERS:
524
525
526
527
528
529
            from xformers import ops as xops

            out = xops.memory_efficient_attention_forward(query,
                                                          key,
                                                          value,
                                                          scale=self.scale)
530
        elif self.attn_backend == _Backend.TORCH_SDPA:
531
532
533
534
535
536
537
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            out = F.scaled_dot_product_attention(query,
                                                 key,
                                                 value,
                                                 scale=self.scale)
            out = out.transpose(1, 2)
538
        elif self.attn_backend == _Backend.PALLAS:
539
540
541
542
543
            query, key, value = (x.transpose(1, 2)
                                 for x in (query, key, value))
            from torch_xla.experimental.custom_kernel import flash_attention
            out = flash_attention(query, key, value, sm_scale=self.scale)
            out = out.transpose(1, 2)
544
545
546
547
548
549
550
551
552
553
554
555
556
        elif self.attn_backend == _Backend.ROCM_AITER_FA:
            from aiter import flash_attn_varlen_func

            # ROCm Flash Attention expects (batch, seq, heads, head_dim)
            out = flash_attn_varlen_func(query,
                                         key,
                                         value,
                                         softmax_scale=self.scale)
        else:
            # ViT attention hasn't supported this backend yet
            raise NotImplementedError(
                f"ViT attention hasn't supported {self.attn_backend} "
                f"backend yet.")
557

558
        return out.reshape(bsz, q_len, -1)
559
560


561
562
563
564
565
566
567
568
569
570
def wait_for_kv_layer_from_connector(layer_name: str):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
571
    assert isinstance(attn_metadata, dict)
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    connector.wait_for_layer_load(layer_name)


def maybe_save_kv_layer_to_connector(
    layer_name: str,
    kv_cache_layer: List[torch.Tensor],
):
    if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
        return

    connector = get_kv_transfer_group()

    forward_context: ForwardContext = get_forward_context()
    attn_metadata = forward_context.attn_metadata
    if attn_metadata is None:
        return
588
589
590
    assert isinstance(attn_metadata, dict)
    connector.save_kv_layer(layer_name, kv_cache_layer,
                            attn_metadata[layer_name])
591

xiabo's avatar
xiabo committed
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
def maybe_calc_kv_scales(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    forward_context: ForwardContext = get_forward_context()
    self = forward_context.no_compile_layers[layer_name]

    # Only calculate if the layer's calculate_kv_scales flag is True
    # This flag gets set to False after the first forward pass

    self.calc_kv_scales(query, key, value)


def maybe_calc_kv_scales_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> None:
    return


direct_register_custom_op(
    op_name="maybe_calc_kv_scales",
    op_func=maybe_calc_kv_scales,
    mutates_args=["query", "key", "value"],
    fake_impl=maybe_calc_kv_scales_fake,
)
622

623
624
625
626
627
628
def unified_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
629
630
    wait_for_kv_layer_from_connector(layer_name)

631
    forward_context: ForwardContext = get_forward_context()
632
    attn_metadata = forward_context.attn_metadata
633
634
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
635
    self = forward_context.no_compile_layers[layer_name]
636
    kv_cache = self.kv_cache[forward_context.virtual_engine]
637
638
    output = self.impl.forward(self, query, key, value, kv_cache,
                               attn_metadata)
zhuwenwen's avatar
zhuwenwen committed
639
640
    
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
641
    return output
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656


def unified_attention_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    layer_name: str,
) -> torch.Tensor:
    return torch.empty_like(query).contiguous()


direct_register_custom_op(
    op_name="unified_attention",
    op_func=unified_attention,
    fake_impl=unified_attention_fake,
657
    tags=tag_cudagraph_unsafe,
658
)
659
660
661
662
663
664
665
666


def unified_attention_with_output(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
667
    output_scale: Optional[torch.Tensor] = None,
668
    output_block_scale: Optional[torch.Tensor] = None,
669
) -> None:
670
    wait_for_kv_layer_from_connector(layer_name)
671
    forward_context: ForwardContext = get_forward_context()
672
    attn_metadata = forward_context.attn_metadata
673
674
    if isinstance(attn_metadata, dict):
        attn_metadata = attn_metadata[layer_name]
675
    self = forward_context.no_compile_layers[layer_name]
676
    kv_cache = self.kv_cache[forward_context.virtual_engine]
677
678
    self.impl.forward(self,
                      query,
679
680
681
682
                      key,
                      value,
                      kv_cache,
                      attn_metadata,
683
                      output=output,
684
685
                      output_scale=output_scale,
                      output_block_scale=output_block_scale)
686

zhuwenwen's avatar
zhuwenwen committed
687
    maybe_save_kv_layer_to_connector(layer_name, kv_cache)
688

689
690
691
692
693
694
695

def unified_attention_with_output_fake(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    output: torch.Tensor,
    layer_name: str,
696
    output_scale: Optional[torch.Tensor] = None,
697
    output_block_scale: Optional[torch.Tensor] = None,
698
699
700
701
702
703
704
) -> None:
    return


direct_register_custom_op(
    op_name="unified_attention_with_output",
    op_func=unified_attention_with_output,
705
    mutates_args=["output", "output_block_scale"],
706
    fake_impl=unified_attention_with_output_fake,
707
    tags=tag_cudagraph_unsafe,
708
)