flash_attn.py 33.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
5
from typing import Optional
6

7
import numpy as np
8
9
import torch

10
from vllm import _custom_ops as ops
11
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
12
13
                                              AttentionMetadata, AttentionType,
                                              is_quantized_kv_cache)
14
from vllm.attention.layer import Attention
15
from vllm.attention.ops.merge_attn_states import merge_attn_states
16
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
17
                                           get_flash_attn_version,
18
19
20
21
22
23
24
                                           is_flash_attn_varlen_func_available)

if is_flash_attn_varlen_func_available():
    from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
                                               get_scheduler_metadata,
                                               reshape_and_cache_flash)

25
from vllm.config import VllmConfig, get_layers_from_vllm_config
26
from vllm.logger import init_logger
27
from vllm.utils import cdiv
28
29
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
                                              AttentionMetadataBuilder,
30
31
                                              CommonAttentionMetadata,
                                              get_kv_cache_layout)
32
from vllm.v1.kv_cache_interface import AttentionSpec
33

34
35
logger = init_logger(__name__)

36
37
38
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16

39
40
41

class FlashAttentionBackend(AttentionBackend):

42
43
    accept_output_buffer: bool = True

44
45
46
47
    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.float16, torch.bfloat16]

48
49
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
50
51
        return [32, 64, 96, 128, 160, 192, 224, 256]

52
53
54
55
56
57
58
59
60
61
62
    @classmethod
    def validate_head_size(cls, head_size: int) -> None:
        supported_head_sizes = cls.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            attn_type = cls.__name__.removesuffix("Backend")
            raise ValueError(
                f"Head size {head_size} is not supported by {attn_type}. "
                f"Supported head sizes are: {supported_head_sizes}. "
                "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
                "FlexAttention backend which supports all head sizes.")

63
64
    @staticmethod
    def get_name() -> str:
65
        return "FLASH_ATTN_VLLM_V1"
66
67

    @staticmethod
68
    def get_impl_cls() -> type["FlashAttentionImpl"]:
69
70
71
        return FlashAttentionImpl

    @staticmethod
72
    def get_metadata_cls() -> type["AttentionMetadata"]:
73
74
        return FlashAttentionMetadata

75
    @staticmethod
76
    def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
77
78
        return FlashAttentionMetadataBuilder

79
80
81
82
83
84
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
85
    ) -> tuple[int, ...]:
86
87
88
89
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")
        return (2, num_blocks, block_size, num_kv_heads, head_size)

90
91
    @staticmethod
    def get_kv_cache_stride_order() -> tuple[int, ...]:
92
        # `stride_order` indicates the permutation that gets
93
        # us from `get_kv_cache_shape` to the actual memory layout we want.
94
        cache_layout = get_kv_cache_layout()
95
96
97
98
99
        if cache_layout == "NHD":
            stride_order = (0, 1, 2, 3, 4)
        elif cache_layout == "HND":
            stride_order = (0, 1, 3, 2, 4)
        else:
100
            raise ValueError(f"Unknown cache layout format {cache_layout}.")
101
102
        return stride_order

103
104
105
106
107
108
109
    @staticmethod
    def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

110
111
112
113
114
115
116
117
118
119
120

@dataclass
class FlashAttentionMetadata:
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

121
    num_actual_tokens: int  # Number of tokens excluding padding.
122
123
124
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
125
    seq_lens: torch.Tensor
126
127
    block_table: torch.Tensor
    slot_mapping: torch.Tensor
128
129
130
131
132

    # For cascade attention.
    use_cascade: bool
    common_prefix_len: int
    cu_prefix_query_lens: Optional[torch.Tensor]
133
134
    prefix_kv_lens: Optional[torch.Tensor]
    suffix_kv_lens: Optional[torch.Tensor]
135

136
137
138
    # Optional aot scheduling
    scheduler_metadata: Optional[torch.Tensor] = None
    prefix_scheduler_metadata: Optional[torch.Tensor] = None
139
    max_num_splits: int = 0
140

141
142
    causal: bool = True

143

144
145
146
147
148
149
150
151
152
153
154
def _get_sliding_window_configs(
        vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
    """Get the set of all sliding window configs used in the model."""
    sliding_window_configs: set[Optional[tuple[int, int]]] = set()
    layers = get_layers_from_vllm_config(vllm_config, Attention)
    for layer in layers.values():
        assert isinstance(layer.impl, FlashAttentionImpl)
        sliding_window_configs.add(layer.impl.sliding_window)
    return sliding_window_configs


155
156
class FlashAttentionMetadataBuilder(
        AttentionMetadataBuilder[FlashAttentionMetadata]):
157
158
159
160
161
162
163
164
165
166
167
168
169
    # FA3:
    # Supports full cudagraphs for all cases.
    #
    # FA2:
    # For FA2, a graph is captured with max_query_len=1, (which is what we
    # capture by default for num_tokens <= max_num_seqs when there is no
    # spec-decode) then these graphs will not work for mixed prefill-decode
    # (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
    # in FA2.
    # In summary if we are running with spec decodes the graphs would
    # work for mixed prefill-decode and uniform-decode. But for non-spec decodes
    # the graphs would not work for mixed prefill-decode; sorta the inverse
    # of UNIFORM_SINGLE_TOKEN_DECODE.
co63oc's avatar
co63oc committed
170
    # There's probably a better way to describe this using `AttentionCGSupport`
171
172
173
174
175
176
    # but for now just set it to `UNIFORM_BATCH` to get use to drop down
    # to FULL_AND_PIECEWISE.
    # TODO(luka, lucas): audit FA2 as part of:
    #  https://github.com/vllm-project/vllm/issues/22945
    cudagraph_support = AttentionCGSupport.ALWAYS \
        if get_flash_attn_version() == 3 else AttentionCGSupport.UNIFORM_BATCH
177

178
179
    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
180
181
182
183
184
185
186
187
188
189
190
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.parallel_config = vllm_config.parallel_config
        self.cache_config = vllm_config.cache_config
        self.compilation_config = vllm_config.compilation_config
        self.device = device

        self.num_heads_q = self.model_config.get_num_attention_heads(
            self.parallel_config)
        self.num_heads_kv = self.model_config.get_num_kv_heads(
            self.parallel_config)
191
        self.kv_cache_dtype = kv_cache_spec.dtype
192
        self.headdim = self.model_config.get_head_size()
193
        self.block_size = kv_cache_spec.block_size
194

195
        self.max_num_splits = 0  # No upper bound on the number of splits.
196
        self.aot_schedule = (get_flash_attn_version() == 3)
197
198
199
200
201
202
203

        self.use_full_cuda_graph = \
            self.compilation_config.cudagraph_mode.has_full_cudagraphs()

        if self.use_full_cuda_graph and self.aot_schedule:
            self.max_cudagraph_size = self.compilation_config.max_capture_size

204
205
206
207
208
209
210
211
            if self.max_cudagraph_size > 992:
                # This condition derives from FA3's internal heuristic.
                # TODO(woosuk): Support larger cudagraph sizes.
                raise ValueError(
                    "Capture size larger than 992 is not supported for "
                    "full cuda graph.")

            self.scheduler_metadata = torch.zeros(
212
                vllm_config.scheduler_config.max_num_seqs + 1,
213
                dtype=torch.int32,
214
                device=self.device,
215
216
217
218
219
            )
            # When using cuda graph, we need to set the upper bound of the
            # number of splits so that large enough intermediate buffers are
            # pre-allocated during capture.
            self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
220

221
222
223
224
        # Sliding window size to be used with the AOT scheduler will be
        # populated on first build() call.
        self.aot_sliding_window: Optional[tuple[int, int]] = None

225
226
227
228
229
230
231
232
    def build(self,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata,
              fast_build: bool = False) -> FlashAttentionMetadata:
        """
        fast_build disables AOT scheduling, used when there will be few 
        iterations i.e. spec-decode
        """
233
234
235
        num_reqs = common_attn_metadata.num_reqs
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        max_query_len = common_attn_metadata.max_query_len
236
        max_seq_len = common_attn_metadata.max_seq_len
237
238
        query_start_loc = common_attn_metadata.query_start_loc
        seq_lens = common_attn_metadata.seq_lens
239
240
241
        seq_lens_cpu = common_attn_metadata.seq_lens_cpu
        block_table_tensor = common_attn_metadata.block_table_tensor
        slot_mapping = common_attn_metadata.slot_mapping
242
        causal = common_attn_metadata.causal
243

244
245
        # the overhead of the aot schedule is not worth it for spec-decode
        aot_schedule = self.aot_schedule and not fast_build
246

247
248
249
250
251
252
        if self.aot_sliding_window is None:
            self.aot_sliding_window = (-1, -1)
            # For the AOT scheduler we need the sliding window value to be
            # constant for all layers to. We have to populate this on the first
            # build() call so the layers are constructed (cannot populate)
            # in __init__.
253
            if aot_schedule:
254
                sliding_window_configs = _get_sliding_window_configs(
255
                    self.vllm_config)
256
257
258
259
260
261
                if len(sliding_window_configs) == 1:
                    sliding_window_config = sliding_window_configs.pop()
                    if sliding_window_config is not None:
                        self.aot_sliding_window = sliding_window_config
                elif len(sliding_window_configs) > 1:
                    self.aot_schedule = False
262
                    aot_schedule = False
263

264
265
        def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
                     max_seq_len, causal):
266
267
268
269
270
271
            cache_dtype = self.cache_config.cache_dtype
            if cache_dtype.startswith("fp8"):
                qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
                    cache_dtype)
            else:
                qkv_dtype = self.kv_cache_dtype
272
            if aot_schedule:
273
274
275
276
277
278
279
                return get_scheduler_metadata(
                    batch_size=batch_size,
                    max_seqlen_q=max_query_len,
                    max_seqlen_k=max_seq_len,
                    num_heads_q=self.num_heads_q,
                    num_heads_kv=self.num_heads_kv,
                    headdim=self.headdim,
280
281
                    cache_seqlens=seqlens,
                    qkv_dtype=qkv_dtype,
282
                    cu_seqlens_q=cu_query_lens,
283
                    page_size=self.block_size,
284
                    causal=causal,
285
                    window_size=self.aot_sliding_window,
286
                    num_splits=self.max_num_splits,
287
288
289
                )
            return None

290
        use_cascade = common_prefix_len > 0
291

292
293
294
        if use_cascade:
            cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
                                                dtype=torch.int32,
295
                                                device=self.device)
296
297
            prefix_kv_lens = torch.tensor([common_prefix_len],
                                          dtype=torch.int32,
298
299
300
                                          device=self.device)
            suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
                self.device, non_blocking=True)
301
            prefix_scheduler_metadata = schedule(
302
                batch_size=1,
303
304
305
306
307
                cu_query_lens=cu_prefix_query_lens,
                max_query_len=num_actual_tokens,
                seqlens=prefix_kv_lens,
                max_seq_len=common_prefix_len,
                causal=False)
308
309
            scheduler_metadata = schedule(batch_size=num_reqs,
                                          cu_query_lens=query_start_loc,
310
311
312
313
314
                                          max_query_len=max_query_len,
                                          seqlens=suffix_kv_lens,
                                          max_seq_len=max_seq_len -
                                          common_prefix_len,
                                          causal=True)
315
316
317
318
        else:
            cu_prefix_query_lens = None
            prefix_kv_lens = None
            suffix_kv_lens = None
319
            prefix_scheduler_metadata = None
320
321
            scheduler_metadata = schedule(batch_size=num_reqs,
                                          cu_query_lens=query_start_loc,
322
323
324
                                          max_query_len=max_query_len,
                                          seqlens=seq_lens,
                                          max_seq_len=max_seq_len,
325
                                          causal=causal)
326
327
328
        # For FA3 + full cudagraph
        max_num_splits = 0
        if self.use_full_cuda_graph and scheduler_metadata is not None:
329
330
331
332
333
334
335
336
337
            n = scheduler_metadata.shape[0]
            self.scheduler_metadata[:n] = scheduler_metadata
            # NOTE(woosuk): We should zero out the rest of the scheduler
            # metadata to guarantee the correctness. Otherwise, some thread
            # blocks may use the invalid scheduler metadata and overwrite the
            # output buffer.
            self.scheduler_metadata[n:] = 0
            scheduler_metadata = self.scheduler_metadata[:n]

338
339
340
341
342
343
            if num_actual_tokens <= self.max_cudagraph_size:
                # NOTE(woosuk): Setting num_splits > 1 may increase the memory
                # usage, because the intermediate buffers of size [num_splits,
                # num_heads, num_tokens, head_size] are allocated. Therefore,
                # we only set num_splits when using cuda graphs.
                max_num_splits = self.max_num_splits
344

345
346
347
348
349
350
        attn_metadata = FlashAttentionMetadata(
            num_actual_tokens=num_actual_tokens,
            max_query_len=max_query_len,
            query_start_loc=query_start_loc,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
351
            block_table=block_table_tensor,
352
353
354
            slot_mapping=slot_mapping,
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
355
            scheduler_metadata=scheduler_metadata,
356
357
358
            cu_prefix_query_lens=cu_prefix_query_lens,
            prefix_kv_lens=prefix_kv_lens,
            suffix_kv_lens=suffix_kv_lens,
359
            prefix_scheduler_metadata=prefix_scheduler_metadata,
360
            max_num_splits=max_num_splits,
361
            causal=causal)
362
363
        return attn_metadata

364
365
366
    def use_cascade_attention(self, *args, **kwargs) -> bool:
        return use_cascade_attention(*args, **kwargs)

367

368
369
370
371
372
373
374
375
class FlashAttentionImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
376
        alibi_slopes: Optional[list[float]],
377
378
379
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float] = None,
380
        attn_type: AttentionType = AttentionType.DECODER,
381
        kv_sharing_target_layer_name: Optional[str] = None,
382
        sinks: Optional[torch.Tensor] = None,
383
384
385
386
387
388
389
390
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
391
392
        if sliding_window is None:
            self.sliding_window = (-1, -1)
393
394
        elif attn_type == AttentionType.ENCODER_ONLY:
            self.sliding_window = (sliding_window - 1, sliding_window - 1)
395
396
        else:
            self.sliding_window = (sliding_window - 1, 0)
397
398
399
400
401
        self.kv_cache_dtype = kv_cache_dtype
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap
402
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
403
404
405

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

406
        FlashAttentionBackend.validate_head_size(head_size)
407

408
        self.attn_type = attn_type
409
        self.vllm_flash_attn_version = get_flash_attn_version()
410
411
412
413
        if is_quantized_kv_cache(self.kv_cache_dtype) \
            and not flash_attn_supports_fp8():
            raise NotImplementedError(
                "FlashAttention does not support fp8 kv-cache on this device.")
414

415
416
417
418
419
420
421
422
        self.sinks = sinks
        if self.sinks is not None:
            assert self.vllm_flash_attn_version == 3, (
                "Sinks are only supported in FlashAttention 3")
            assert self.sinks.shape[0] == num_heads, (
                "Sinks must have the same number of heads as the number of "
                "heads in the layer")

423
424
    def forward(
        self,
425
        layer: torch.nn.Module,
426
427
428
429
430
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
431
        output: Optional[torch.Tensor] = None,
432
        output_scale: Optional[torch.Tensor] = None,
433
        output_block_scale: Optional[torch.Tensor] = None,
434
435
436
437
    ) -> torch.Tensor:
        """Forward pass with FlashAttention.

        Args:
438
439
440
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
441
442
            kv_cache: shape =
                [2, num_blocks, block_size, num_kv_heads, head_size]
443
444
445
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
446
447
448
        NOTE: FP8 quantization, flash-attn expect the size of
              {q,k,v}_descale to be (num_sequences, num_kv_heads).
              We use torch's .expand() to avoid duplicating values
449
        """
450
451
        assert output is not None, "Output tensor must be provided."

452
        if output_scale is not None or output_block_scale is not None:
453
454
455
456
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for FlashAttentionImpl")

457
458
459
460
        if attn_metadata is None:
            # Profiling run.
            return output

461
462
        attn_type = self.attn_type

463
464
465
466
467
468
469
470
        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.
471

472
        num_actual_tokens = attn_metadata.num_actual_tokens
473
474

        # Handle encoder attention differently - no KV cache needed
475
        if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
476
477
478
479
480
481
482
483
484
            # For encoder attention,
            # we use direct Q, K, V tensors without caching
            return self._forward_encoder_attention(query[:num_actual_tokens],
                                                   key[:num_actual_tokens],
                                                   value[:num_actual_tokens],
                                                   output[:num_actual_tokens],
                                                   attn_metadata, layer)

        # For decoder and cross-attention, use KV cache as before
485
        key_cache, value_cache = kv_cache.unbind(0)
486

487
488
489
490
491
        # key and value may be None in the case of cross attention. They are
        # calculated once based on the output from the encoder and then cached
        # in KV cache.
        if (self.kv_sharing_target_layer_name is None and key is not None
                and value is not None):
492
493
494
495
496
497
498
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
499
            reshape_and_cache_flash(
500
501
502
503
504
505
506
507
508
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )
509

510
        if self.kv_cache_dtype.startswith("fp8"):
511
512
513
514
            dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
                self.kv_cache_dtype)
            key_cache = key_cache.view(dtype)
            value_cache = value_cache.view(dtype)
515
516
517
518
519
520
            num_tokens, num_heads, head_size = query.shape
            query, _ = ops.scaled_fp8_quant(
                query.reshape(
                    (num_tokens, num_heads * head_size)).contiguous(),
                layer._q_scale)
            query = query.reshape((num_tokens, num_heads, head_size))
521

522
523
524
525
526
527
528
        if not attn_metadata.use_cascade:
            cu_seqlens_q = attn_metadata.query_start_loc
            seqused_k = attn_metadata.seq_lens
            max_seqlen_q = attn_metadata.max_query_len
            max_seqlen_k = attn_metadata.max_seq_len
            block_table = attn_metadata.block_table
            scheduler_metadata = attn_metadata.scheduler_metadata
529

530
            descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
531

532
533
534
535
536
            flash_attn_varlen_func(
                q=query[:num_actual_tokens],
                k=key_cache,
                v=value_cache,
                out=output[:num_actual_tokens],
537
538
539
540
                cu_seqlens_q=cu_seqlens_q,
                max_seqlen_q=max_seqlen_q,
                seqused_k=seqused_k,
                max_seqlen_k=max_seqlen_k,
541
                softmax_scale=self.scale,
542
                causal=attn_metadata.causal,
543
544
                alibi_slopes=self.alibi_slopes,
                window_size=self.sliding_window,
545
                block_table=block_table,
546
                softcap=self.logits_soft_cap,
547
                scheduler_metadata=scheduler_metadata,
548
                fa_version=self.vllm_flash_attn_version,
549
550
551
                q_descale=layer._q_scale.expand(descale_shape),
                k_descale=layer._k_scale.expand(descale_shape),
                v_descale=layer._v_scale.expand(descale_shape),
552
                num_splits=attn_metadata.max_num_splits,
553
                s_aux=self.sinks,
554
555
556
557
558
559
560
561
562
563
564
565
            )
            return output

        # Cascade attention (rare case).
        cascade_attention(
            output[:num_actual_tokens],
            query[:num_actual_tokens],
            key_cache,
            value_cache,
            cu_query_lens=attn_metadata.query_start_loc,
            max_query_len=attn_metadata.max_query_len,
            cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
566
567
            prefix_kv_lens=attn_metadata.prefix_kv_lens,
            suffix_kv_lens=attn_metadata.suffix_kv_lens,
568
            max_kv_len=attn_metadata.max_seq_len,
569
570
            softmax_scale=self.scale,
            alibi_slopes=self.alibi_slopes,
571
572
            sliding_window=self.sliding_window,
            logits_soft_cap=self.logits_soft_cap,
573
            block_table=attn_metadata.block_table,
574
            common_prefix_len=attn_metadata.common_prefix_len,
575
            fa_version=self.vllm_flash_attn_version,
576
577
            prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
            suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
578
579
580
            q_descale=layer._q_scale,
            k_descale=layer._k_scale,
            v_descale=layer._v_scale,
581
582
        )
        return output
583

584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
    def _forward_encoder_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
        layer: torch.nn.Module,
    ) -> torch.Tensor:
        """Forward pass for encoder attention without KV cache.

        Args:
            query: shape = [num_encoder_tokens, num_heads, head_size]
            key: shape = [num_encoder_tokens, num_kv_heads, head_size]
            value: shape = [num_encoder_tokens, num_kv_heads, head_size]
            output: shape = [num_encoder_tokens, num_heads, head_size]
            attn_metadata: Encoder attention metadata
            layer: The attention layer
        """
        # For encoder attention, process FP8 quantization if needed
        if self.kv_cache_dtype.startswith("fp8"):
            raise NotImplementedError(
                "quantization is not supported for encoder attention")

        # Use encoder-specific metadata for sequence information
        cu_seqlens_q = attn_metadata.query_start_loc
        cu_seqlens_k = attn_metadata.query_start_loc
        max_seqlen_q = attn_metadata.max_query_len
        max_seqlen_k = attn_metadata.max_query_len

        descale_shape = (
            cu_seqlens_q.shape[0] - 1,  # type: ignore[union-attr]
            self.num_kv_heads)

        # Call flash attention directly on Q, K, V tensors
        flash_attn_varlen_func(
            q=query,
            k=key,
            v=value,
            out=output,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_k=max_seqlen_k,
            softmax_scale=self.scale,
            causal=False,  # Encoder attention is bidirectional
            alibi_slopes=self.alibi_slopes,
            window_size=self.sliding_window,
            softcap=self.logits_soft_cap,
            fa_version=self.vllm_flash_attn_version,
            q_descale=layer._q_scale.expand(descale_shape),
            k_descale=layer._k_scale.expand(descale_shape),
            v_descale=layer._v_scale.expand(descale_shape),
        )

        return output

641
642
643
644
645
646
647
648

def use_cascade_attention(
    common_prefix_len: int,
    query_lens: np.ndarray,
    num_query_heads: int,
    num_kv_heads: int,
    use_alibi: bool,
    use_sliding_window: bool,
649
    use_local_attention: bool,
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
    num_sms: int,
) -> bool:
    """Decide whether to use cascade attention.

    This function 1) checks whether cascade attention is supported with the
    given configuration, and 2) heuristically decides whether using cascade
    attention can improve performance.
    """
    # Too short common prefix. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
    # NOTE(woosuk): This is the common case. We should return False as soon as
    # possible to avoid any unnecessary computation.
    if common_prefix_len < 256:
        return False
    # Cascade attention is currently not supported with these variants.
665
    if use_alibi or use_sliding_window or use_local_attention:
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
        return False
    # Too few queries. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
    num_reqs = len(query_lens)
    if num_reqs < 8:
        return False

    # Heuristics to decide whether using cascade attention is beneficial.
    # 1. When FlashDecoding is not used for normal attention, cascade attention
    #    is likely to be faster since it saves memory bandwidth.
    num_queries_per_kv = num_query_heads // num_kv_heads
    # The criteria for using FlashDecoding can be found in the following link:
    # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
    use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
                          and not use_alibi and np.all(query_lens == 1))
    if not use_flash_decoding:
        # Use cascade attention.
        return True

    # 2. When FlashDecoding is used for normal attention, it is not clear
    #    whether cascade attention is beneficial, because FlashDecoding can
    #    launch more CTAs than cascade attention.
    #    We use a simple performance model to compare the two methods.
    #    NOTE(woosuk): The performance model is very rough and may not be
    #    accurate.
    num_tokens = num_reqs
    # NOTE(woosuk): These are default tile sizes. flash-attn might use
    # different tile sizes (e.g., 64 or 256) depending on the configuration.
    q_tile_size = 128
    kv_tile_size = 128
    num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)

    cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
    cascade_waves = cdiv(cascade_ctas, num_sms)
    cascade_time = cascade_waves * num_prefix_tiles

    flash_decoding_ctas = (num_reqs * num_kv_heads *
                           cdiv(num_queries_per_kv, q_tile_size))
    flash_decoding_ctas *= num_prefix_tiles
    flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)

    # Use cascade attention if it is faster than FlashDecoding.
    return cascade_time < flash_decoding_time


def cascade_attention(
    output: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    cu_query_lens: torch.Tensor,
    max_query_len: int,
    cu_prefix_query_lens: torch.Tensor,
719
720
    prefix_kv_lens: torch.Tensor,
    suffix_kv_lens: torch.Tensor,
721
722
723
    max_kv_len: int,
    softmax_scale: float,
    alibi_slopes: Optional[torch.Tensor],
724
    sliding_window: tuple[int, int],
725
726
727
    logits_soft_cap: float,
    block_table: torch.Tensor,
    common_prefix_len: int,
728
    fa_version: int,
729
730
    prefix_scheduler_metadata: Optional[torch.Tensor] = None,
    suffix_scheduler_metadata: Optional[torch.Tensor] = None,
731
732
733
    q_descale: Optional[torch.Tensor] = None,
    k_descale: Optional[torch.Tensor] = None,
    v_descale: Optional[torch.Tensor] = None,
734
735
736
737
738
739
740
741
742
743
744
) -> torch.Tensor:
    assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
    # TODO: Support sliding window.
    assert sliding_window == (-1, -1), (
        "Cascade attention does not support sliding window.")

    num_tokens = query.shape[0]
    block_size = key_cache.shape[-3]
    assert common_prefix_len % block_size == 0
    num_common_kv_blocks = common_prefix_len // block_size
    assert num_common_kv_blocks > 0
745
    descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
746
747
748
749
750
751
752

    # Process shared prefix.
    prefix_output, prefix_lse = flash_attn_varlen_func(
        q=query,
        k=key_cache,
        v=value_cache,
        cu_seqlens_q=cu_prefix_query_lens,
753
        seqused_k=prefix_kv_lens,
754
755
756
757
758
759
760
761
        max_seqlen_q=num_tokens,
        max_seqlen_k=common_prefix_len,
        softmax_scale=softmax_scale,
        causal=False,
        window_size=sliding_window,
        block_table=block_table[:1],
        softcap=logits_soft_cap,
        return_softmax_lse=True,
762
        scheduler_metadata=prefix_scheduler_metadata,
763
        fa_version=fa_version,
764
765
766
767
768
769
        q_descale=q_descale.expand(descale_shape)
        if q_descale is not None else None,
        k_descale=k_descale.expand(descale_shape)
        if k_descale is not None else None,
        v_descale=v_descale.expand(descale_shape)
        if v_descale is not None else None,
770
771
    )

772
773
    descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])

774
775
776
777
778
779
    # Process suffix per query.
    suffix_output, suffix_lse = flash_attn_varlen_func(
        q=query,
        k=key_cache,
        v=value_cache,
        cu_seqlens_q=cu_query_lens,
780
        seqused_k=suffix_kv_lens,
781
782
783
784
785
786
787
788
        max_seqlen_q=max_query_len,
        max_seqlen_k=max_kv_len - common_prefix_len,
        softmax_scale=softmax_scale,
        causal=True,
        window_size=sliding_window,
        block_table=block_table[:, num_common_kv_blocks:],
        softcap=logits_soft_cap,
        return_softmax_lse=True,
789
        scheduler_metadata=suffix_scheduler_metadata,
790
        fa_version=fa_version,
791
792
793
794
795
796
        q_descale=q_descale.expand(descale_shape)
        if q_descale is not None else None,
        k_descale=k_descale.expand(descale_shape)
        if k_descale is not None else None,
        v_descale=v_descale.expand(descale_shape)
        if v_descale is not None else None,
797
798
799
800
    )

    # Merge prefix and suffix outputs, and store the result in output.
    merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
801
                      suffix_lse)