flash_attn.py 40.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer with FlashAttention."""
4

5
from dataclasses import dataclass
6
from typing import ClassVar
7

8
import numpy as np
9
10
import torch

11
12
13
14
from vllm.attention.backends.abstract import (
    AttentionBackend,
    AttentionImpl,
    AttentionType,
15
    MultipleOf,
16
17
    is_quantized_kv_cache,
)
18
from vllm.attention.layer import Attention
19
from vllm.attention.ops.common import cp_lse_ag_out_rs
20
from vllm.attention.ops.merge_attn_states import merge_attn_states
21
22
23
24
25
from vllm.attention.utils.fa_utils import (
    flash_attn_supports_fp8,
    get_flash_attn_version,
    is_flash_attn_varlen_func_available,
)
26
27

if is_flash_attn_varlen_func_available():
28
    from vllm.attention.utils.fa_utils import (
29
        flash_attn_supports_sinks,
30
31
32
33
        flash_attn_varlen_func,
        get_scheduler_metadata,
        reshape_and_cache_flash,
    )
34
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
35
from vllm.config.cache import CacheDType
36
from vllm.distributed.parallel_state import get_dcp_group
37
from vllm.logger import init_logger
38
from vllm.model_executor.layers.batch_invariant import (
39
    vllm_is_batch_invariant,
40
)
41
from vllm.platforms.interface import DeviceCapability
42
from vllm.utils.math_utils import cdiv
43
44
45
46
from vllm.v1.attention.backends.utils import (
    AttentionCGSupport,
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
47
    get_dcp_local_seq_lens,
48
49
    get_kv_cache_layout,
)
50
from vllm.v1.kv_cache_interface import AttentionSpec
51

52
53
logger = init_logger(__name__)

54
55

class FlashAttentionBackend(AttentionBackend):
56
    accept_output_buffer: bool = True
57
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        vllm_config = get_current_vllm_config()
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        if (
            model_config
            and model_config.is_hybrid
            and (
                cache_config.mamba_ssm_cache_dtype == "float32"
                or cache_config.mamba_cache_dtype == "float32"
            )
        ):
            # NOTE(tdoublep): while in principle, FA supports
            # MultipleOf(16), these are the block sizes that do not
            # suffer from the NaN propagation problem described here:
            # https://github.com/Dao-AILab/flash-attention/issues/1974
            return [16, 32, 64]
        return [MultipleOf(16)]
78

79
80
    @staticmethod
    def get_name() -> str:
81
        return "FLASH_ATTN"
82

83
84
85
86
87
88
89
90
91
92
    @classmethod
    def supports_attn_type(cls, attn_type: str) -> bool:
        """FlashAttention supports all attention types."""
        return attn_type in (
            AttentionType.DECODER,
            AttentionType.ENCODER,
            AttentionType.ENCODER_ONLY,
            AttentionType.ENCODER_DECODER,
        )

93
    @staticmethod
94
    def get_impl_cls() -> type["FlashAttentionImpl"]:
95
96
        return FlashAttentionImpl

97
    @staticmethod
98
    def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
99
100
        return FlashAttentionMetadataBuilder

101
102
103
104
105
106
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
107
        cache_dtype_str: str = "auto",
108
    ) -> tuple[int, ...]:
109
110
111
112
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")
        return (2, num_blocks, block_size, num_kv_heads, head_size)

113
    @staticmethod
114
115
116
    def get_kv_cache_stride_order(
        include_num_layers_dimension: bool = False,
    ) -> tuple[int, ...]:
117
        # `stride_order` indicates the permutation that gets
118
        # us from `get_kv_cache_shape` to the actual memory layout we want.
119
        cache_layout = get_kv_cache_layout()
120
121
122
123
        if cache_layout == "NHD" and include_num_layers_dimension:
            # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
            return (2, 0, 1, 3, 4, 5)
        elif cache_layout == "NHD":
124
            stride_order = (0, 1, 2, 3, 4)
125
126
127
        elif cache_layout == "HND" and include_num_layers_dimension:
            # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
            return (2, 4, 0, 1, 3, 5)
128
129
130
        elif cache_layout == "HND":
            stride_order = (0, 1, 3, 2, 4)
        else:
131
            raise ValueError(f"Unknown cache layout format {cache_layout}.")
132
133
        return stride_order

134
135
136
137
138
139
140
    @staticmethod
    def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

141
    @classmethod
142
143
    def supports_head_size(cls, head_size: int) -> bool:
        return head_size % 8 == 0 and head_size <= 256
144
145
146
147
148
149
150
151
152

    @classmethod
    def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
        if kv_cache_dtype is None:
            return True
        if kv_cache_dtype.startswith("fp8"):
            return flash_attn_supports_fp8()
        return kv_cache_dtype in ["auto"]

153
154
155
156
157
158
    @classmethod
    def supports_sink(cls) -> bool:
        if not is_flash_attn_varlen_func_available():
            return False
        return flash_attn_supports_sinks()

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    @classmethod
    def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
        return capability >= DeviceCapability(8, 0)

    @classmethod
    def supports_combination(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: CacheDType | None,
        block_size: int,
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
        device_capability: DeviceCapability,
    ) -> str | None:
        if has_sink and device_capability < DeviceCapability(9, 0):
            return "sink not supported on compute capability < 9.0"
        return None

179
180
181
182
183
184
185
186
187
188
189

@dataclass
class FlashAttentionMetadata:
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

190
    num_actual_tokens: int  # Number of tokens excluding padding.
191
192
193
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
194
    seq_lens: torch.Tensor
195
196
    block_table: torch.Tensor
    slot_mapping: torch.Tensor
197
198
199
200

    # For cascade attention.
    use_cascade: bool
    common_prefix_len: int
201
202
203
    cu_prefix_query_lens: torch.Tensor | None
    prefix_kv_lens: torch.Tensor | None
    suffix_kv_lens: torch.Tensor | None
204

205
206
207
208
    # For GQA DCP
    max_dcp_context_kv_len: int | None = None
    dcp_context_kv_lens: torch.Tensor | None = None

209
    # Optional aot scheduling
210
211
    scheduler_metadata: torch.Tensor | None = None
    prefix_scheduler_metadata: torch.Tensor | None = None
212
    max_num_splits: int = 0
213

214
215
    causal: bool = True

216

217
def _get_sliding_window_configs(
218
    vllm_config: VllmConfig,
219
) -> set[tuple[int, int] | None]:
220
    """Get the set of all sliding window configs used in the model."""
221
    sliding_window_configs: set[tuple[int, int] | None] = set()
222
223
224
225
226
227
228
    layers = get_layers_from_vllm_config(vllm_config, Attention)
    for layer in layers.values():
        assert isinstance(layer.impl, FlashAttentionImpl)
        sliding_window_configs.add(layer.impl.sliding_window)
    return sliding_window_configs


229
class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetadata]):
230
231
232
233
234
235
236
237
238
239
240
241
242
    # FA3:
    # Supports full cudagraphs for all cases.
    #
    # FA2:
    # For FA2, a graph is captured with max_query_len=1, (which is what we
    # capture by default for num_tokens <= max_num_seqs when there is no
    # spec-decode) then these graphs will not work for mixed prefill-decode
    # (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
    # in FA2.
    # In summary if we are running with spec decodes the graphs would
    # work for mixed prefill-decode and uniform-decode. But for non-spec decodes
    # the graphs would not work for mixed prefill-decode; sorta the inverse
    # of UNIFORM_SINGLE_TOKEN_DECODE.
co63oc's avatar
co63oc committed
243
    # There's probably a better way to describe this using `AttentionCGSupport`
244
245
246
247
    # but for now just set it to `UNIFORM_BATCH` to get use to drop down
    # to FULL_AND_PIECEWISE.
    # TODO(luka, lucas): audit FA2 as part of:
    #  https://github.com/vllm-project/vllm/issues/22945
248
    _cudagraph_support = (
249
250
251
252
        AttentionCGSupport.ALWAYS
        if get_flash_attn_version() == 3
        else AttentionCGSupport.UNIFORM_BATCH
    )
253

254
255
256
257
258
259
260
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
261
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)
262
263
264
265
        self.model_config = vllm_config.model_config
        self.parallel_config = vllm_config.parallel_config
        self.cache_config = vllm_config.cache_config
        self.compilation_config = vllm_config.compilation_config
266
        self.attention_config = vllm_config.attention_config
267
268

        self.num_heads_q = self.model_config.get_num_attention_heads(
269
270
271
            self.parallel_config
        )
        self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config)
272
        self.kv_cache_dtype = kv_cache_spec.dtype
273
        self.headdim = self.model_config.get_head_size()
274
        self.block_size = kv_cache_spec.block_size
275

276
        self.max_num_splits = 0  # No upper bound on the number of splits.
277
        self.aot_schedule = get_flash_attn_version() == 3
278

279
280
281
282
283
284
285
286
287
288
        try:
            from vllm.distributed.parallel_state import get_dcp_group

            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0

289
290
        self.cp_kv_cache_interleave_size = (
            self.parallel_config.cp_kv_cache_interleave_size
291
292
        )

293
        self.use_full_cuda_graph = (
294
            self.compilation_config.cudagraph_mode.has_full_cudagraphs()
295
        )
296
        self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
297
298

        if self.use_full_cuda_graph and self.aot_schedule:
299
            self.scheduler_metadata = torch.zeros(
300
                vllm_config.scheduler_config.max_num_seqs + 1,
301
                dtype=torch.int32,
302
                device=self.device,
303
304
305
306
            )
            # When using cuda graph, we need to set the upper bound of the
            # number of splits so that large enough intermediate buffers are
            # pre-allocated during capture.
307
308
309
            self.max_num_splits = (
                self.attention_config.flash_attn_max_num_splits_for_cuda_graph
            )
310

311
312
        # Sliding window size to be used with the AOT scheduler will be
        # populated on first build() call.
313
        self.aot_sliding_window: tuple[int, int] | None = None
314

315
316
317
318
319
320
    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> FlashAttentionMetadata:
321
        """
322
        fast_build disables AOT scheduling, used when there will be few
323
324
        iterations i.e. spec-decode
        """
325
326
327
        num_reqs = common_attn_metadata.num_reqs
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        max_query_len = common_attn_metadata.max_query_len
328
        max_seq_len = common_attn_metadata.max_seq_len
329
330
        query_start_loc = common_attn_metadata.query_start_loc
        seq_lens = common_attn_metadata.seq_lens
331
332
        block_table_tensor = common_attn_metadata.block_table_tensor
        slot_mapping = common_attn_metadata.slot_mapping
333
        causal = common_attn_metadata.causal
334

335
336
        # the overhead of the aot schedule is not worth it for spec-decode
        aot_schedule = self.aot_schedule and not fast_build
337

338
339
340
341
342
343
        if self.aot_sliding_window is None:
            self.aot_sliding_window = (-1, -1)
            # For the AOT scheduler we need the sliding window value to be
            # constant for all layers to. We have to populate this on the first
            # build() call so the layers are constructed (cannot populate)
            # in __init__.
344
            if aot_schedule:
345
                sliding_window_configs = _get_sliding_window_configs(self.vllm_config)
346
347
348
349
350
351
                if len(sliding_window_configs) == 1:
                    sliding_window_config = sliding_window_configs.pop()
                    if sliding_window_config is not None:
                        self.aot_sliding_window = sliding_window_config
                elif len(sliding_window_configs) > 1:
                    self.aot_schedule = False
352
                    aot_schedule = False
353

354
        max_num_splits = 0  # 0 means use FA3's heuristics, not CG compatible
355
        if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size:
356
357
358
359
360
361
            # NOTE(woosuk): Setting num_splits > 1 may increase the memory
            # usage, because the intermediate buffers of size [num_splits,
            # num_heads, num_tokens, head_size] are allocated. Therefore,
            # we only set num_splits when using cuda graphs.
            max_num_splits = self.max_num_splits

362
        if vllm_is_batch_invariant():
363
364
            max_num_splits = 1

365
366
367
        def schedule(
            batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
        ):
368
369
370
            cache_dtype = self.cache_config.cache_dtype
            if cache_dtype.startswith("fp8"):
                qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
371
372
                    cache_dtype
                )
373
374
            else:
                qkv_dtype = self.kv_cache_dtype
375
            if aot_schedule:
376
377
378
379
                return get_scheduler_metadata(
                    batch_size=batch_size,
                    max_seqlen_q=max_query_len,
                    max_seqlen_k=max_seq_len,
380
                    num_heads_q=self.num_heads_q * self.dcp_world_size,
381
382
                    num_heads_kv=self.num_heads_kv,
                    headdim=self.headdim,
383
384
                    cache_seqlens=seqlens,
                    qkv_dtype=qkv_dtype,
385
                    cu_seqlens_q=cu_query_lens,
386
                    page_size=self.block_size,
387
                    causal=causal,
388
                    window_size=self.aot_sliding_window,
389
                    num_splits=max_num_splits,
390
391
392
                )
            return None

393
        use_cascade = common_prefix_len > 0
394
395
396
397
398
399
400
401
402
        max_dcp_context_kv_len = 0
        dcp_context_kv_lens = None

        cu_prefix_query_lens = None
        prefix_kv_lens = None
        suffix_kv_lens = None
        prefix_scheduler_metadata = None

        if self.dcp_world_size > 1:
403
404
            query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
            dcp_context_kv_lens = seq_lens - query_kv_lens
405

406
407
            dcp_context_kv_lens = get_dcp_local_seq_lens(
                dcp_context_kv_lens,
408
409
                self.dcp_world_size,
                self.dcp_rank,
410
                self.cp_kv_cache_interleave_size,
411
            )
412
413
414
415
416
417
418
419
            # After DCP distribution, the maximum number of tokens for any rank is
            # ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
            # and I is cp_kv_cache_interleave_size.
            # This eliminates GPU->CPU sync while minimizing workspace over-allocation.
            num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size
            max_dcp_context_kv_len = (
                (max_seq_len + num_partitions - 1) // num_partitions
            ) * self.cp_kv_cache_interleave_size
420

421
422
423
424
425
426
427
428
429
            scheduler_metadata = schedule(
                batch_size=num_reqs,
                cu_query_lens=query_start_loc,
                max_query_len=max_query_len,
                seqlens=dcp_context_kv_lens,
                max_seq_len=max_dcp_context_kv_len,
                causal=False,
            )
        elif use_cascade:
430
431
432
433
434
435
            cu_prefix_query_lens = torch.tensor(
                [0, num_actual_tokens], dtype=torch.int32, device=self.device
            )
            prefix_kv_lens = torch.tensor(
                [common_prefix_len], dtype=torch.int32, device=self.device
            )
436
437
            # Use GPU tensor directly - no CPU sync needed
            suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
438
            prefix_scheduler_metadata = schedule(
439
                batch_size=1,
440
441
442
443
                cu_query_lens=cu_prefix_query_lens,
                max_query_len=num_actual_tokens,
                seqlens=prefix_kv_lens,
                max_seq_len=common_prefix_len,
444
445
446
447
448
449
450
451
452
453
                causal=False,
            )
            scheduler_metadata = schedule(
                batch_size=num_reqs,
                cu_query_lens=query_start_loc,
                max_query_len=max_query_len,
                seqlens=suffix_kv_lens,
                max_seq_len=max_seq_len - common_prefix_len,
                causal=True,
            )
454
        else:
455
456
457
458
459
460
461
462
            scheduler_metadata = schedule(
                batch_size=num_reqs,
                cu_query_lens=query_start_loc,
                max_query_len=max_query_len,
                seqlens=seq_lens,
                max_seq_len=max_seq_len,
                causal=causal,
            )
463
464
        # For FA3 + full cudagraph
        if self.use_full_cuda_graph and scheduler_metadata is not None:
465
466
467
468
469
470
471
472
473
            n = scheduler_metadata.shape[0]
            self.scheduler_metadata[:n] = scheduler_metadata
            # NOTE(woosuk): We should zero out the rest of the scheduler
            # metadata to guarantee the correctness. Otherwise, some thread
            # blocks may use the invalid scheduler metadata and overwrite the
            # output buffer.
            self.scheduler_metadata[n:] = 0
            scheduler_metadata = self.scheduler_metadata[:n]

474
475
476
477
478
479
        attn_metadata = FlashAttentionMetadata(
            num_actual_tokens=num_actual_tokens,
            max_query_len=max_query_len,
            query_start_loc=query_start_loc,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
480
            block_table=block_table_tensor,
481
            slot_mapping=slot_mapping,
482
483
            max_dcp_context_kv_len=max_dcp_context_kv_len,
            dcp_context_kv_lens=dcp_context_kv_lens,
484
485
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
486
            scheduler_metadata=scheduler_metadata,
487
488
489
            cu_prefix_query_lens=cu_prefix_query_lens,
            prefix_kv_lens=prefix_kv_lens,
            suffix_kv_lens=suffix_kv_lens,
490
            prefix_scheduler_metadata=prefix_scheduler_metadata,
491
            max_num_splits=max_num_splits,
492
493
            causal=causal,
        )
494
495
        return attn_metadata

496
497
498
    def use_cascade_attention(self, *args, **kwargs) -> bool:
        return use_cascade_attention(*args, **kwargs)

499

500
class FlashAttentionImpl(AttentionImpl):
501
502
    can_return_lse_for_decode: bool = True

503
504
505
506
507
508
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
509
510
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
511
        kv_cache_dtype: str,
512
        logits_soft_cap: float | None = None,
513
        attn_type: AttentionType = AttentionType.DECODER,
514
515
        kv_sharing_target_layer_name: str | None = None,
        sinks: torch.Tensor | None = None,
516
517
518
519
520
521
522
523
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
524
525
        if sliding_window is None:
            self.sliding_window = (-1, -1)
526
527
        elif attn_type == AttentionType.ENCODER_ONLY:
            self.sliding_window = (sliding_window - 1, sliding_window - 1)
528
529
        else:
            self.sliding_window = (sliding_window - 1, 0)
530
531
532
533
534
        self.kv_cache_dtype = kv_cache_dtype
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap
535
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
536
537
538

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

539
        self.attn_type = attn_type
540
        self.vllm_flash_attn_version = get_flash_attn_version()
541
        # Cache the batch invariant result for use in forward passes
542
        self.batch_invariant_enabled = vllm_is_batch_invariant()
543

544
        if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
545
            raise NotImplementedError(
546
547
                "FlashAttention does not support fp8 kv-cache on this device."
            )
548

549
550
        self.sinks = sinks
        if self.sinks is not None:
551
            assert flash_attn_supports_sinks(), (
552
553
                "Sinks are only supported in FlashAttention 3"
            )
554
555
            assert self.sinks.shape[0] == num_heads, (
                "Sinks must have the same number of heads as the number of "
556
557
                "heads in the layer"
            )
558

559
        self.supports_quant_query_input = True
560

561
562
    def forward(
        self,
563
        layer: torch.nn.Module,
564
565
566
567
568
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
569
570
571
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
572
573
574
575
    ) -> torch.Tensor:
        """Forward pass with FlashAttention.

        Args:
576
577
578
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
579
580
            kv_cache: shape =
                [2, num_blocks, block_size, num_kv_heads, head_size]
581
582
583
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
584
585
586
        NOTE: FP8 quantization, flash-attn expect the size of
              {q,k,v}_descale to be (num_sequences, num_kv_heads).
              We use torch's .expand() to avoid duplicating values
587
        """
588
589
        assert output is not None, "Output tensor must be provided."

590
        if output_scale is not None or output_block_scale is not None:
591
            raise NotImplementedError(
592
593
                "fused output quantization is not yet supported for FlashAttentionImpl"
            )
594

595
596
        if attn_metadata is None:
            # Profiling run.
597
            return output.fill_(0)
598

599
600
        attn_type = self.attn_type

601
602
603
604
605
606
607
608
        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.
609

610
        num_actual_tokens = attn_metadata.num_actual_tokens
611
612

        # Handle encoder attention differently - no KV cache needed
613
        if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
614
615
            # For encoder attention,
            # we use direct Q, K, V tensors without caching
616
617
618
619
620
621
622
623
            return self._forward_encoder_attention(
                query[:num_actual_tokens],
                key[:num_actual_tokens],
                value[:num_actual_tokens],
                output[:num_actual_tokens],
                attn_metadata,
                layer,
            )
624
625

        # For decoder and cross-attention, use KV cache as before
626
        key_cache, value_cache = kv_cache.unbind(0)
627

628
629
630
        # key and value may be None in the case of cross attention. They are
        # calculated once based on the output from the encoder and then cached
        # in KV cache.
631
632
633
634
635
        if (
            self.kv_sharing_target_layer_name is None
            and key is not None
            and value is not None
        ):
636
637
638
639
640
641
642
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
643
            reshape_and_cache_flash(
644
645
646
647
648
649
650
651
652
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )
653

654
        if self.kv_cache_dtype.startswith("fp8"):
655
            # queries are quantized in the attention layer
656
            dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
657
658
                self.kv_cache_dtype
            )
659
660
            key_cache = key_cache.view(dtype)
            value_cache = value_cache.view(dtype)
661

662
663
664
665
666
667
668
        if not attn_metadata.use_cascade:
            cu_seqlens_q = attn_metadata.query_start_loc
            seqused_k = attn_metadata.seq_lens
            max_seqlen_q = attn_metadata.max_query_len
            max_seqlen_k = attn_metadata.max_seq_len
            block_table = attn_metadata.block_table
            scheduler_metadata = attn_metadata.scheduler_metadata
669

670
            descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
671

672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
            if self.dcp_world_size > 1:
                self._forward_with_dcp(
                    query[:num_actual_tokens],
                    key[:num_actual_tokens],
                    value[:num_actual_tokens],
                    key_cache,
                    value_cache,
                    output[:num_actual_tokens],
                    attn_metadata,
                    q_descale=layer._q_scale.expand(descale_shape),
                    k_descale=layer._k_scale.expand(descale_shape),
                    v_descale=layer._v_scale.expand(descale_shape),
                )
                return output
            else:
                flash_attn_varlen_func(
                    q=query[:num_actual_tokens],
                    k=key_cache,
                    v=value_cache,
                    out=output[:num_actual_tokens],
                    cu_seqlens_q=cu_seqlens_q,
                    max_seqlen_q=max_seqlen_q,
                    seqused_k=seqused_k,
                    max_seqlen_k=max_seqlen_k,
                    softmax_scale=self.scale,
                    causal=attn_metadata.causal,
                    alibi_slopes=self.alibi_slopes,
                    window_size=self.sliding_window,
                    block_table=block_table,
                    softcap=self.logits_soft_cap,
                    scheduler_metadata=scheduler_metadata,
                    fa_version=self.vllm_flash_attn_version,
                    q_descale=layer._q_scale.expand(descale_shape),
                    k_descale=layer._k_scale.expand(descale_shape),
                    v_descale=layer._v_scale.expand(descale_shape),
                    num_splits=attn_metadata.max_num_splits,
                    s_aux=self.sinks,
                )
                return output
711
712
713
714
715
716
717
718
719
720

        # Cascade attention (rare case).
        cascade_attention(
            output[:num_actual_tokens],
            query[:num_actual_tokens],
            key_cache,
            value_cache,
            cu_query_lens=attn_metadata.query_start_loc,
            max_query_len=attn_metadata.max_query_len,
            cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
721
722
            prefix_kv_lens=attn_metadata.prefix_kv_lens,
            suffix_kv_lens=attn_metadata.suffix_kv_lens,
723
            max_kv_len=attn_metadata.max_seq_len,
724
725
            softmax_scale=self.scale,
            alibi_slopes=self.alibi_slopes,
726
727
            sliding_window=self.sliding_window,
            logits_soft_cap=self.logits_soft_cap,
728
            block_table=attn_metadata.block_table,
729
            common_prefix_len=attn_metadata.common_prefix_len,
730
            max_num_splits=attn_metadata.max_num_splits,
731
            fa_version=self.vllm_flash_attn_version,
732
733
            prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
            suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
734
735
736
            q_descale=layer._q_scale,
            k_descale=layer._k_scale,
            v_descale=layer._v_scale,
737
            s_aux=self.sinks,
738
739
        )
        return output
740

741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
    def _forward_with_dcp(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        output: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
        q_descale: torch.Tensor | None = None,
        k_descale: torch.Tensor | None = None,
        v_descale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        cu_seqlens_q = attn_metadata.query_start_loc
        max_seqlen_q = attn_metadata.max_query_len
        block_table = attn_metadata.block_table

        query = query.contiguous()
        query_across_dcp = get_dcp_group().all_gather(query, dim=1)
        context_attn_out, context_lse = flash_attn_varlen_func(
            q=query_across_dcp,
            k=key_cache,
            v=value_cache,
            out=None,
            cu_seqlens_q=cu_seqlens_q,
            max_seqlen_q=max_seqlen_q,
            seqused_k=attn_metadata.dcp_context_kv_lens,
            max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
            softmax_scale=self.scale,
            causal=False,
            alibi_slopes=self.alibi_slopes,
            window_size=self.sliding_window,
            block_table=block_table,
            softcap=self.logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=attn_metadata.scheduler_metadata,
            fa_version=self.vllm_flash_attn_version,
            q_descale=q_descale,
            k_descale=k_descale,
            v_descale=v_descale,
        )
        # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
        context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
            context_attn_out,
            context_lse.transpose(0, 1),
            get_dcp_group(),
            return_lse=True,
        )
        context_lse_cor = context_lse_cor.transpose(0, 1).contiguous()

        query_attn_out, query_lse = flash_attn_varlen_func(
            q=query,
            k=key,
            v=value,
            out=None,
            cu_seqlens_q=cu_seqlens_q,
            max_seqlen_q=max_seqlen_q,
            cu_seqlens_k=cu_seqlens_q,
            max_seqlen_k=max_seqlen_q,
            softmax_scale=self.scale,
            causal=attn_metadata.causal,
            alibi_slopes=self.alibi_slopes,
            window_size=self.sliding_window,
            softcap=self.logits_soft_cap,
            return_softmax_lse=True,
            fa_version=self.vllm_flash_attn_version,
            q_descale=q_descale,
            k_descale=k_descale,
            v_descale=v_descale,
        )
        assert context_attn_out_cor.shape == query_attn_out.shape
        assert context_lse_cor.shape == query_lse.shape
        merge_attn_states(
            output,
            context_attn_out_cor,
            context_lse_cor,
            query_attn_out,
            query_lse,
        )

821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
    def _forward_encoder_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
        layer: torch.nn.Module,
    ) -> torch.Tensor:
        """Forward pass for encoder attention without KV cache.

        Args:
            query: shape = [num_encoder_tokens, num_heads, head_size]
            key: shape = [num_encoder_tokens, num_kv_heads, head_size]
            value: shape = [num_encoder_tokens, num_kv_heads, head_size]
            output: shape = [num_encoder_tokens, num_heads, head_size]
            attn_metadata: Encoder attention metadata
            layer: The attention layer
        """
        # For encoder attention, process FP8 quantization if needed
        if self.kv_cache_dtype.startswith("fp8"):
            raise NotImplementedError(
843
844
                "quantization is not supported for encoder attention"
            )
845
846
847
848
849
850
851
852
853

        # Use encoder-specific metadata for sequence information
        cu_seqlens_q = attn_metadata.query_start_loc
        cu_seqlens_k = attn_metadata.query_start_loc
        max_seqlen_q = attn_metadata.max_query_len
        max_seqlen_k = attn_metadata.max_query_len

        descale_shape = (
            cu_seqlens_q.shape[0] - 1,  # type: ignore[union-attr]
854
855
            self.num_kv_heads,
        )
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875

        # Call flash attention directly on Q, K, V tensors
        flash_attn_varlen_func(
            q=query,
            k=key,
            v=value,
            out=output,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_k=max_seqlen_k,
            softmax_scale=self.scale,
            causal=False,  # Encoder attention is bidirectional
            alibi_slopes=self.alibi_slopes,
            window_size=self.sliding_window,
            softcap=self.logits_soft_cap,
            fa_version=self.vllm_flash_attn_version,
            q_descale=layer._q_scale.expand(descale_shape),
            k_descale=layer._k_scale.expand(descale_shape),
            v_descale=layer._v_scale.expand(descale_shape),
876
            num_splits=1 if self.batch_invariant_enabled else 0,
877
878
879
880
        )

        return output

881
882
883
884
885
886
887
888

def use_cascade_attention(
    common_prefix_len: int,
    query_lens: np.ndarray,
    num_query_heads: int,
    num_kv_heads: int,
    use_alibi: bool,
    use_sliding_window: bool,
889
    use_local_attention: bool,
890
    num_sms: int,
891
    dcp_world_size: int,
892
893
894
895
896
897
898
899
900
901
902
903
904
905
) -> bool:
    """Decide whether to use cascade attention.

    This function 1) checks whether cascade attention is supported with the
    given configuration, and 2) heuristically decides whether using cascade
    attention can improve performance.
    """
    # Too short common prefix. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
    # NOTE(woosuk): This is the common case. We should return False as soon as
    # possible to avoid any unnecessary computation.
    if common_prefix_len < 256:
        return False
    # Cascade attention is currently not supported with these variants.
906
    if use_alibi or use_sliding_window or use_local_attention:
907
908
909
910
911
912
        return False
    # Too few queries. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
    num_reqs = len(query_lens)
    if num_reqs < 8:
        return False
913
914
915
    # disable cascade attention for DCP
    if dcp_world_size > 1:
        return False
916
917
918
919
920
921
922

    # Heuristics to decide whether using cascade attention is beneficial.
    # 1. When FlashDecoding is not used for normal attention, cascade attention
    #    is likely to be faster since it saves memory bandwidth.
    num_queries_per_kv = num_query_heads // num_kv_heads
    # The criteria for using FlashDecoding can be found in the following link:
    # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
923
924
925
926
927
928
    use_flash_decoding = (
        num_queries_per_kv > 1
        and not use_sliding_window
        and not use_alibi
        and np.all(query_lens == 1)
    )
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
    if not use_flash_decoding:
        # Use cascade attention.
        return True

    # 2. When FlashDecoding is used for normal attention, it is not clear
    #    whether cascade attention is beneficial, because FlashDecoding can
    #    launch more CTAs than cascade attention.
    #    We use a simple performance model to compare the two methods.
    #    NOTE(woosuk): The performance model is very rough and may not be
    #    accurate.
    num_tokens = num_reqs
    # NOTE(woosuk): These are default tile sizes. flash-attn might use
    # different tile sizes (e.g., 64 or 256) depending on the configuration.
    q_tile_size = 128
    kv_tile_size = 128
    num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)

    cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
    cascade_waves = cdiv(cascade_ctas, num_sms)
    cascade_time = cascade_waves * num_prefix_tiles

950
951
952
    flash_decoding_ctas = (
        num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size)
    )
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
    flash_decoding_ctas *= num_prefix_tiles
    flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)

    # Use cascade attention if it is faster than FlashDecoding.
    return cascade_time < flash_decoding_time


def cascade_attention(
    output: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    cu_query_lens: torch.Tensor,
    max_query_len: int,
    cu_prefix_query_lens: torch.Tensor,
968
969
    prefix_kv_lens: torch.Tensor,
    suffix_kv_lens: torch.Tensor,
970
971
    max_kv_len: int,
    softmax_scale: float,
972
    alibi_slopes: torch.Tensor | None,
973
    sliding_window: tuple[int, int],
974
975
976
    logits_soft_cap: float,
    block_table: torch.Tensor,
    common_prefix_len: int,
977
    max_num_splits: int,
978
    fa_version: int,
979
980
981
982
983
984
    prefix_scheduler_metadata: torch.Tensor | None = None,
    suffix_scheduler_metadata: torch.Tensor | None = None,
    q_descale: torch.Tensor | None = None,
    k_descale: torch.Tensor | None = None,
    v_descale: torch.Tensor | None = None,
    s_aux: torch.Tensor | None = None,
985
) -> torch.Tensor:
986
    assert alibi_slopes is None, "Cascade attention does not support ALiBi."
987
988
    # TODO: Support sliding window.
    assert sliding_window == (-1, -1), (
989
990
        "Cascade attention does not support sliding window."
    )
991
992
993
994
995
996

    num_tokens = query.shape[0]
    block_size = key_cache.shape[-3]
    assert common_prefix_len % block_size == 0
    num_common_kv_blocks = common_prefix_len // block_size
    assert num_common_kv_blocks > 0
997
    descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
998
999
1000
1001
1002
1003
1004

    # Process shared prefix.
    prefix_output, prefix_lse = flash_attn_varlen_func(
        q=query,
        k=key_cache,
        v=value_cache,
        cu_seqlens_q=cu_prefix_query_lens,
1005
        seqused_k=prefix_kv_lens,
1006
1007
1008
1009
1010
1011
1012
1013
        max_seqlen_q=num_tokens,
        max_seqlen_k=common_prefix_len,
        softmax_scale=softmax_scale,
        causal=False,
        window_size=sliding_window,
        block_table=block_table[:1],
        softcap=logits_soft_cap,
        return_softmax_lse=True,
1014
        scheduler_metadata=prefix_scheduler_metadata,
1015
        fa_version=fa_version,
1016
1017
1018
        q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
        k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
        v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
1019
1020
1021
        # s_aux is incorporated into prefix_lse inside the GPU kernel,
        # enabling its effect during the final attention merge.
        s_aux=s_aux,
1022
        num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
1023
1024
    )

1025
1026
    descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])

1027
1028
1029
1030
1031
1032
    # Process suffix per query.
    suffix_output, suffix_lse = flash_attn_varlen_func(
        q=query,
        k=key_cache,
        v=value_cache,
        cu_seqlens_q=cu_query_lens,
1033
        seqused_k=suffix_kv_lens,
1034
1035
1036
1037
1038
1039
1040
1041
        max_seqlen_q=max_query_len,
        max_seqlen_k=max_kv_len - common_prefix_len,
        softmax_scale=softmax_scale,
        causal=True,
        window_size=sliding_window,
        block_table=block_table[:, num_common_kv_blocks:],
        softcap=logits_soft_cap,
        return_softmax_lse=True,
1042
        scheduler_metadata=suffix_scheduler_metadata,
1043
        fa_version=fa_version,
1044
1045
1046
        q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
        k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
        v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
1047
        num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
1048
1049
1050
    )

    # Merge prefix and suffix outputs, and store the result in output.
1051
    merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse)