flash_attn.py 59.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer with FlashAttention."""
4

5
import copy
6
from dataclasses import dataclass
7
from typing import ClassVar
8

9
import numpy as np
10
11
import torch

12
13
from vllm.attention.layer import Attention
from vllm.v1.attention.backend import (
14
15
16
    AttentionBackend,
    AttentionImpl,
    AttentionType,
17
    MultipleOf,
18
19
    is_quantized_kv_cache,
)
20
from vllm.v1.attention.backends.fa_utils import (
21
22
23
24
    flash_attn_supports_fp8,
    get_flash_attn_version,
    is_flash_attn_varlen_func_available,
)
25
26
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
27

zhuwenwen's avatar
zhuwenwen committed
28
from vllm.platforms import current_platform
29
if is_flash_attn_varlen_func_available():
zhuwenwen's avatar
zhuwenwen committed
30
    if current_platform.is_rocm():
31
        from vllm.v1.attention.backends.fa_utils import (
32
            flash_attn_supports_sinks,
zhuwenwen's avatar
zhuwenwen committed
33
34
            vllm_flash_attn_varlen_func,
            reshape_and_cache_cuda,
35
        )
36
37
38
39
40
41
42
        from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
            triton_reshape_and_cache_flash,
        )
        try:
            from flash_attn import varlen_fwd_unified
        except Exception:
            varlen_fwd_unified = None
zhuwenwen's avatar
zhuwenwen committed
43
    else:
44
        from vllm.v1.attention.backends.fa_utils import (
45
            flash_attn_supports_sinks,
zhuwenwen's avatar
zhuwenwen committed
46
47
48
            flash_attn_varlen_func,
            get_scheduler_metadata,
            reshape_and_cache_flash,
49
        )
50

51
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
52
from vllm.config.cache import CacheDType
53
from vllm.distributed.parallel_state import get_dcp_group
54
from vllm.logger import init_logger
55
from vllm.model_executor.layers.batch_invariant import (
56
    vllm_is_batch_invariant,
57
)
58
from vllm.platforms.interface import DeviceCapability
59
from vllm.utils.math_utils import cdiv
60
from vllm.v1.attention.backend import (
61
62
63
    AttentionCGSupport,
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
64
65
)
from vllm.v1.attention.backends.utils import (
66
    get_dcp_local_seq_lens,
67
68
    get_kv_cache_layout,
)
69
from vllm.v1.kv_cache_interface import AttentionSpec
zhuwenwen's avatar
zhuwenwen committed
70
import vllm.envs as envs
71

72
73
logger = init_logger(__name__)

74
75

class FlashAttentionBackend(AttentionBackend):
76
    accept_output_buffer: bool = True
77
    supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
78

79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        vllm_config = get_current_vllm_config()
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        if (
            model_config
            and model_config.is_hybrid
            and (
                cache_config.mamba_ssm_cache_dtype == "float32"
                or cache_config.mamba_cache_dtype == "float32"
            )
        ):
            # NOTE(tdoublep): while in principle, FA supports
            # MultipleOf(16), these are the block sizes that do not
            # suffer from the NaN propagation problem described here:
            # https://github.com/Dao-AILab/flash-attention/issues/1974
            return [16, 32, 64]
        return [MultipleOf(16)]
98

99
100
    forward_includes_kv_cache_update: bool = False

101
102
    @staticmethod
    def get_name() -> str:
103
        return "FLASH_ATTN"
104

105
106
107
108
109
110
111
112
113
114
    @classmethod
    def supports_attn_type(cls, attn_type: str) -> bool:
        """FlashAttention supports all attention types."""
        return attn_type in (
            AttentionType.DECODER,
            AttentionType.ENCODER,
            AttentionType.ENCODER_ONLY,
            AttentionType.ENCODER_DECODER,
        )

115
    @staticmethod
116
    def get_impl_cls() -> type["FlashAttentionImpl"]:
117
118
        return FlashAttentionImpl

119
    @staticmethod
120
    def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
121
        return FlashAttentionMetadataBuilder
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

    @classmethod
    def supports_alibi_sqrt(cls) -> bool:
        return True

    @classmethod
    def supports_mm_prefix(cls) -> bool:
        return True

    @staticmethod
    def _use_rocm_unified_kv_layout(
        block_size: int | None = None,
        key_cache: torch.Tensor | None = None,
        value_cache: torch.Tensor | None = None,
    ) -> bool:
        if not current_platform.is_rocm():
            return False

        if block_size is None:
            if key_cache is not None and value_cache is not None:
                if key_cache.ndim != 4 or value_cache.ndim != 4:
                    return False
                if key_cache.shape != value_cache.shape:
                    return False
                block_size = key_cache.shape[1]
            else:
                try:
                    block_size = get_current_vllm_config().cache_config.block_size
                except Exception:
                    return False

        return block_size is not None and block_size != 64 and block_size % 64 == 0
154
    
zhuwenwen's avatar
zhuwenwen committed
155
    if current_platform.is_rocm():
156
157
158
159
160
161
        @staticmethod
        def get_kv_cache_shape(
            num_blocks: int,
            block_size: int,
            num_kv_heads: int,
            head_size: int,
162
            cache_dtype_str: str = "auto",
zhuwenwen's avatar
zhuwenwen committed
163
        ) -> tuple[tuple[int, ...], tuple[int, ...]]:
164
165
            if block_size % 16 != 0:
                raise ValueError("Block size must be a multiple of 16.")
166
167
168
            if FlashAttentionBackend._use_rocm_unified_kv_layout(block_size):
                unified_shape = (num_blocks, block_size, num_kv_heads, head_size)
                return (unified_shape, unified_shape)
zhuwenwen's avatar
zhuwenwen committed
169
170
171
172
            return (
                (num_blocks, num_kv_heads, block_size, head_size),
                (num_blocks, num_kv_heads, head_size, block_size),
            )
173
174

        @staticmethod
175
176
        def get_kv_cache_stride_order(
            include_num_layers_dimension: bool = False,
zhuwenwen's avatar
zhuwenwen committed
177
        ) -> tuple[tuple[int, ...], tuple[int, ...]]:
178
179
180
            # `stride_order` indicates the permutation that gets
            # us from `get_kv_cache_shape` to the actual memory layout we want.
            cache_layout = get_kv_cache_layout()
181
182
183
184
185
186
187
188
            if FlashAttentionBackend._use_rocm_unified_kv_layout():
                if cache_layout != "NHD":
                    raise RuntimeError(
                        "ROCm unified KV layout currently supports NHD only."
                    )
                if include_num_layers_dimension:
                    # (num_blocks, num_layers, block_size, num_kv_heads, head_size)
                    return (1, 0, 2, 3, 4), (1, 0, 2, 3, 4)
zhuwenwen's avatar
zhuwenwen committed
189
190
                key_stride_order = (0, 1, 2, 3)
                value_stride_order = (0, 1, 2, 3)
191
            else:
192
193
194
195
196
197
198
199
200
201
202
203
204
205
                if cache_layout == "NHD" and include_num_layers_dimension:
                    # (num_blocks, num_layers, block_size, num_kv_heads, head_size)
                    return (1, 0, 3, 2, 5), (1, 0, 4, 2, 3)
                elif cache_layout == "NHD":
                    key_stride_order = (0, 1, 2, 3)
                    value_stride_order = (0, 1, 2, 3)
                elif cache_layout == "HND" and include_num_layers_dimension:
                    # (num_blocks, num_kv_heads, num_layers, block_size, head_size)
                    return (1, 2, 0, 3, 4), (1, 2, 0, 4, 3)
                elif cache_layout == "HND":
                    key_stride_order = (0, 1, 2, 3)
                    value_stride_order = (0, 1, 3, 2)
                else:
                    raise ValueError(f"Unknown cache layout format {cache_layout}.")
zhuwenwen's avatar
zhuwenwen committed
206
            return key_stride_order, value_stride_order
207
208
209
210
211
212
213
    else:
        @staticmethod
        def get_kv_cache_shape(
            num_blocks: int,
            block_size: int,
            num_kv_heads: int,
            head_size: int,
214
            cache_dtype_str: str = "auto",
zhuwenwen's avatar
zhuwenwen committed
215
        ) -> tuple[int, ...]:
216
217
            if block_size % 16 != 0:
                raise ValueError("Block size must be a multiple of 16.")
zhuwenwen's avatar
zhuwenwen committed
218
            return (2, num_blocks, block_size, num_kv_heads, head_size)
219

220
        @staticmethod
221
222
        def get_kv_cache_stride_order(
            include_num_layers_dimension: bool = False,
zhuwenwen's avatar
zhuwenwen committed
223
        ) -> tuple[int, ...]:
224
225
226
            # `stride_order` indicates the permutation that gets
            # us from `get_kv_cache_shape` to the actual memory layout we want.
            cache_layout = get_kv_cache_layout()
227
            if cache_layout == "NHD" and include_num_layers_dimension:
zhuwenwen's avatar
zhuwenwen committed
228
229
                # (num_blocks, num_layers, 2, block_size, num_kv_heads, head_size)
                return (2, 0, 1, 3, 4, 5)
230
            elif cache_layout == "NHD":
zhuwenwen's avatar
zhuwenwen committed
231
                stride_order = (0, 1, 2, 3, 4)
232
            elif cache_layout == "HND" and include_num_layers_dimension:
zhuwenwen's avatar
zhuwenwen committed
233
234
                # (num_blocks, num_kv_heads, num_layers, 2, block_size, head_size)
                return (2, 4, 0, 1, 3, 5)
235
            elif cache_layout == "HND":
zhuwenwen's avatar
zhuwenwen committed
236
                stride_order = (0, 1, 3, 2, 4)
237
238
            else:
                raise ValueError(f"Unknown cache layout format {cache_layout}.")
zhuwenwen's avatar
zhuwenwen committed
239
240
            return stride_order
        
241

242
243
244
    @staticmethod
    def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
245
246
247
248
            if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
                return torch.float8_e4m3fn
            else:
                raise ValueError(f"{kv_cache_dtype} only supported on nmz")
249
250
        elif kv_cache_dtype in ("fp8_e5m2"):
            return torch.float8_e5m2
251
252
253
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

254
    @classmethod
255
256
    def supports_head_size(cls, head_size: int) -> bool:
        return head_size % 8 == 0 and head_size <= 256
257
258
259
260
261
262
263

    @classmethod
    def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
        if kv_cache_dtype is None:
            return True
        if kv_cache_dtype.startswith("fp8"):
            return flash_attn_supports_fp8()
264
        return kv_cache_dtype in ["auto", "bfloat16"]
265

266
267
268
269
270
271
    @classmethod
    def supports_sink(cls) -> bool:
        if not is_flash_attn_varlen_func_available():
            return False
        return flash_attn_supports_sinks()

272
273
274
275
276
277
278
279
280
281
    @classmethod
    def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
        return capability >= DeviceCapability(8, 0)

    @classmethod
    def supports_combination(
        cls,
        head_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: CacheDType | None,
282
        block_size: int | None,
283
284
285
286
287
288
289
290
291
        use_mla: bool,
        has_sink: bool,
        use_sparse: bool,
        device_capability: DeviceCapability,
    ) -> str | None:
        if has_sink and device_capability < DeviceCapability(9, 0):
            return "sink not supported on compute capability < 9.0"
        return None

292
293
294
295
296
297
298
299
300
301
302

@dataclass
class FlashAttentionMetadata:
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

303
    num_actual_tokens: int  # Number of tokens excluding padding.
304
305
306
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
307
    seq_lens: torch.Tensor
308
309
    block_table: torch.Tensor
    slot_mapping: torch.Tensor
310
311
312
313

    # For cascade attention.
    use_cascade: bool
    common_prefix_len: int
314
315
316
    cu_prefix_query_lens: torch.Tensor | None
    prefix_kv_lens: torch.Tensor | None
    suffix_kv_lens: torch.Tensor | None
317

318
319
320
    # For GQA DCP
    max_dcp_context_kv_len: int | None = None
    dcp_context_kv_lens: torch.Tensor | None = None
321

322
    # Optional aot scheduling
323
324
    scheduler_metadata: torch.Tensor | None = None
    prefix_scheduler_metadata: torch.Tensor | None = None
325
    max_num_splits: int = 0
326

327
328
    mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
    qq_bias: torch.Tensor | None = None
329
330
    causal: bool = True

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    @property
    def mm_prefix_range_tensor(self) -> torch.Tensor | None:
        if self.mm_prefix_range is None:
            return None

        num_seqs = self.seq_lens.shape[0]
        device = self.seq_lens.device
        range_lists = [
            self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)]
            for i in range(num_seqs)
        ]

        if all(r == [(0, 0)] for r in range_lists):
            return None

        range_tensors = [
            torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
            for r in range_lists
        ]

        return torch.nested.nested_tensor(
            range_tensors, layout=torch.jagged
        ).to_padded_tensor(0)

355

356
def _get_sliding_window_configs(
357
    vllm_config: VllmConfig,
358
) -> set[tuple[int, int] | None]:
359
    """Get the set of all sliding window configs used in the model."""
360
    sliding_window_configs: set[tuple[int, int] | None] = set()
361
362
363
364
365
366
367
    layers = get_layers_from_vllm_config(vllm_config, Attention)
    for layer in layers.values():
        assert isinstance(layer.impl, FlashAttentionImpl)
        sliding_window_configs.add(layer.impl.sliding_window)
    return sliding_window_configs


368
class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetadata]):
369
370
371
372
373
374
375
376
377
378
379
380
381
    # FA3:
    # Supports full cudagraphs for all cases.
    #
    # FA2:
    # For FA2, a graph is captured with max_query_len=1, (which is what we
    # capture by default for num_tokens <= max_num_seqs when there is no
    # spec-decode) then these graphs will not work for mixed prefill-decode
    # (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
    # in FA2.
    # In summary if we are running with spec decodes the graphs would
    # work for mixed prefill-decode and uniform-decode. But for non-spec decodes
    # the graphs would not work for mixed prefill-decode; sorta the inverse
    # of UNIFORM_SINGLE_TOKEN_DECODE.
co63oc's avatar
co63oc committed
382
    # There's probably a better way to describe this using `AttentionCGSupport`
383
384
385
386
    # but for now just set it to `UNIFORM_BATCH` to get use to drop down
    # to FULL_AND_PIECEWISE.
    # TODO(luka, lucas): audit FA2 as part of:
    #  https://github.com/vllm-project/vllm/issues/22945
387
    _cudagraph_support = (
388
        AttentionCGSupport.ALWAYS
389
        if get_flash_attn_version() == 3 or current_platform.is_rocm()
390
391
        else AttentionCGSupport.UNIFORM_BATCH
    )
392
    supports_update_block_table: bool = True
393

394
395
396
397
398
399
400
401
    @classmethod
    def get_cudagraph_support(
        cls,
        vllm_config: "VllmConfig",
        kv_cache_spec: "AttentionSpec",
    ) -> AttentionCGSupport:
        return cls._cudagraph_support

402
403
404
405
406
407
408
    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
409
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)
410
411
412
413
        self.model_config = vllm_config.model_config
        self.parallel_config = vllm_config.parallel_config
        self.cache_config = vllm_config.cache_config
        self.compilation_config = vllm_config.compilation_config
414
        self.attention_config = vllm_config.attention_config
415
416

        self.num_heads_q = self.model_config.get_num_attention_heads(
417
418
419
            self.parallel_config
        )
        self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config)
420
        self.kv_cache_dtype = kv_cache_spec.dtype
421
        self.headdim = self.model_config.get_head_size()
422
        self.block_size = kv_cache_spec.block_size
423

424
        self.max_num_splits = 0  # No upper bound on the number of splits.
425
        self.aot_schedule = get_flash_attn_version() == 3
426

427
428
429
430
431
432
433
434
435
436
        try:
            from vllm.distributed.parallel_state import get_dcp_group

            self.dcp_world_size = get_dcp_group().world_size
            self.dcp_rank = get_dcp_group().rank_in_group
        except AssertionError:
            # DCP might not be initialized in testing
            self.dcp_world_size = 1
            self.dcp_rank = 0

437
438
        self.cp_kv_cache_interleave_size = (
            self.parallel_config.cp_kv_cache_interleave_size
439
440
        )

441
        self.use_full_cuda_graph = (
442
            self.compilation_config.cudagraph_mode.has_full_cudagraphs()
443
        )
444
        self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
445
446

        if self.use_full_cuda_graph and self.aot_schedule:
447
            self.scheduler_metadata = torch.zeros(
448
                vllm_config.scheduler_config.max_num_seqs + 1,
449
                dtype=torch.int32,
450
                device=self.device,
451
452
453
454
            )
            # When using cuda graph, we need to set the upper bound of the
            # number of splits so that large enough intermediate buffers are
            # pre-allocated during capture.
455
456
457
            self.max_num_splits = (
                self.attention_config.flash_attn_max_num_splits_for_cuda_graph
            )
458

459
460
        # Sliding window size to be used with the AOT scheduler will be
        # populated on first build() call.
461
        self.aot_sliding_window: tuple[int, int] | None = None
462

463
464
465
466
467
468
    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> FlashAttentionMetadata:
469
        """
470
        fast_build disables AOT scheduling, used when there will be few
471
472
        iterations i.e. spec-decode
        """
473
474
475
        num_reqs = common_attn_metadata.num_reqs
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        max_query_len = common_attn_metadata.max_query_len
476
        max_seq_len = common_attn_metadata.max_seq_len
477
478
        query_start_loc = common_attn_metadata.query_start_loc
        seq_lens = common_attn_metadata.seq_lens
479
480
        block_table_tensor = common_attn_metadata.block_table_tensor
        slot_mapping = common_attn_metadata.slot_mapping
481
        causal = common_attn_metadata.causal
482

483
484
        # the overhead of the aot schedule is not worth it for spec-decode
        aot_schedule = self.aot_schedule and not fast_build
485

zhuwenwen's avatar
zhuwenwen committed
486
487
488
489
490
491
        if self.aot_sliding_window is None:
            self.aot_sliding_window = (-1, -1)
            # For the AOT scheduler we need the sliding window value to be
            # constant for all layers to. We have to populate this on the first
            # build() call so the layers are constructed (cannot populate)
            # in __init__.
492
            if aot_schedule:
493
                sliding_window_configs = _get_sliding_window_configs(self.vllm_config)
zhuwenwen's avatar
zhuwenwen committed
494
495
496
497
498
499
                if len(sliding_window_configs) == 1:
                    sliding_window_config = sliding_window_configs.pop()
                    if sliding_window_config is not None:
                        self.aot_sliding_window = sliding_window_config
                elif len(sliding_window_configs) > 1:
                    self.aot_schedule = False
500
                    aot_schedule = False
501

502
        max_num_splits = 0  # 0 means use FA3's heuristics, not CG compatible
503
504
505
506
507
        if (
            self.use_full_cuda_graph
            and self.max_cudagraph_size is not None
            and num_actual_tokens <= self.max_cudagraph_size
        ):
508
509
510
511
512
            # NOTE(woosuk): Setting num_splits > 1 may increase the memory
            # usage, because the intermediate buffers of size [num_splits,
            # num_heads, num_tokens, head_size] are allocated. Therefore,
            # we only set num_splits when using cuda graphs.
            max_num_splits = self.max_num_splits
513

514
        if vllm_is_batch_invariant():
515
516
            max_num_splits = 1

517
518
519
        def schedule(
            batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
        ):
520
521
522
            cache_dtype = self.cache_config.cache_dtype
            if cache_dtype.startswith("fp8"):
                qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
523
524
                    cache_dtype
                )
525
526
            else:
                qkv_dtype = self.kv_cache_dtype
527
            if aot_schedule:
528
529
530
531
                return get_scheduler_metadata(
                    batch_size=batch_size,
                    max_seqlen_q=max_query_len,
                    max_seqlen_k=max_seq_len,
532
                    num_heads_q=self.num_heads_q * self.dcp_world_size,
533
534
                    num_heads_kv=self.num_heads_kv,
                    headdim=self.headdim,
535
536
                    cache_seqlens=seqlens,
                    qkv_dtype=qkv_dtype,
537
                    cu_seqlens_q=cu_query_lens,
538
                    page_size=self.block_size,
539
                    causal=causal,
540
                    window_size=self.aot_sliding_window,
541
                    num_splits=max_num_splits,
542
543
544
                )
            return None

545
        use_cascade = common_prefix_len > 0
546
547
548
549
550
551
552
553
554
        max_dcp_context_kv_len = 0
        dcp_context_kv_lens = None

        cu_prefix_query_lens = None
        prefix_kv_lens = None
        suffix_kv_lens = None
        prefix_scheduler_metadata = None

        if self.dcp_world_size > 1:
555
556
            query_kv_lens = query_start_loc[1:] - query_start_loc[:-1]
            dcp_context_kv_lens = seq_lens - query_kv_lens
557

558
559
            dcp_context_kv_lens = get_dcp_local_seq_lens(
                dcp_context_kv_lens,
560
561
                self.dcp_world_size,
                self.dcp_rank,
562
                self.cp_kv_cache_interleave_size,
563
            )
564
565
566
567
568
569
570
571
            # After DCP distribution, the maximum number of tokens for any rank is
            # ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
            # and I is cp_kv_cache_interleave_size.
            # This eliminates GPU->CPU sync while minimizing workspace over-allocation.
            num_partitions = self.dcp_world_size * self.cp_kv_cache_interleave_size
            max_dcp_context_kv_len = (
                (max_seq_len + num_partitions - 1) // num_partitions
            ) * self.cp_kv_cache_interleave_size
572

573
574
575
576
577
578
579
580
581
            scheduler_metadata = schedule(
                batch_size=num_reqs,
                cu_query_lens=query_start_loc,
                max_query_len=max_query_len,
                seqlens=dcp_context_kv_lens,
                max_seq_len=max_dcp_context_kv_len,
                causal=False,
            )
        elif use_cascade:
582
583
584
585
586
587
            cu_prefix_query_lens = torch.tensor(
                [0, num_actual_tokens], dtype=torch.int32, device=self.device
            )
            prefix_kv_lens = torch.tensor(
                [common_prefix_len], dtype=torch.int32, device=self.device
            )
588
589
            # Use GPU tensor directly - no CPU sync needed
            suffix_kv_lens = seq_lens[:num_reqs] - common_prefix_len
590
            prefix_scheduler_metadata = schedule(
591
                batch_size=1,
592
593
594
595
                cu_query_lens=cu_prefix_query_lens,
                max_query_len=num_actual_tokens,
                seqlens=prefix_kv_lens,
                max_seq_len=common_prefix_len,
596
597
598
599
600
601
602
603
604
605
                causal=False,
            )
            scheduler_metadata = schedule(
                batch_size=num_reqs,
                cu_query_lens=query_start_loc,
                max_query_len=max_query_len,
                seqlens=suffix_kv_lens,
                max_seq_len=max_seq_len - common_prefix_len,
                causal=True,
            )
606
        else:
607
608
609
610
611
612
613
614
            scheduler_metadata = schedule(
                batch_size=num_reqs,
                cu_query_lens=query_start_loc,
                max_query_len=max_query_len,
                seqlens=seq_lens,
                max_seq_len=max_seq_len,
                causal=causal,
            )
615
616
        # For FA3 + full cudagraph
        if self.use_full_cuda_graph and scheduler_metadata is not None:
617
            n = scheduler_metadata.shape[0]
618
            self.scheduler_metadata[:n] = scheduler_metadata
619
620
621
622
623
624
625
            # NOTE(woosuk): We should zero out the rest of the scheduler
            # metadata to guarantee the correctness. Otherwise, some thread
            # blocks may use the invalid scheduler metadata and overwrite the
            # output buffer.
            self.scheduler_metadata[n:] = 0
            scheduler_metadata = self.scheduler_metadata[:n]

626
627
628
629
630
631
        attn_metadata = FlashAttentionMetadata(
            num_actual_tokens=num_actual_tokens,
            max_query_len=max_query_len,
            query_start_loc=query_start_loc,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
632
            block_table=block_table_tensor,
633
            slot_mapping=slot_mapping,
634
635
            max_dcp_context_kv_len=max_dcp_context_kv_len,
            dcp_context_kv_lens=dcp_context_kv_lens,
636
637
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
638
            scheduler_metadata=scheduler_metadata,
639
640
641
            cu_prefix_query_lens=cu_prefix_query_lens,
            prefix_kv_lens=prefix_kv_lens,
            suffix_kv_lens=suffix_kv_lens,
642
            prefix_scheduler_metadata=prefix_scheduler_metadata,
643
            max_num_splits=max_num_splits,
644
645
            causal=causal,
        )
646
647
        return attn_metadata

648
649
650
651
652
653
654
655
656
657
658
    def update_block_table(
        self,
        metadata: FlashAttentionMetadata,
        blk_table: torch.Tensor,
        slot_mapping: torch.Tensor,
    ) -> FlashAttentionMetadata:
        new_metadata = copy.copy(metadata)
        new_metadata.block_table = blk_table
        new_metadata.slot_mapping = slot_mapping
        return new_metadata

659
660
661
    def use_cascade_attention(self, *args, **kwargs) -> bool:
        return use_cascade_attention(*args, **kwargs)

662

663
class FlashAttentionImpl(AttentionImpl):
664
    can_return_lse_for_decode: bool = True
665
666
667
668
669
670
671

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
672
673
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
674
        kv_cache_dtype: str,
675
        logits_soft_cap: float | None = None,
676
        attn_type: AttentionType = AttentionType.DECODER,
677
678
        kv_sharing_target_layer_name: str | None = None,
        sinks: torch.Tensor | None = None,
679
        use_alibi_sqrt: bool = False,
680
681
682
683
684
685
686
687
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
688
689
        if sliding_window is None:
            self.sliding_window = (-1, -1)
690
691
        elif attn_type == AttentionType.ENCODER_ONLY:
            self.sliding_window = (sliding_window - 1, sliding_window - 1)
692
693
        else:
            self.sliding_window = (sliding_window - 1, 0)
694
695
696
697
698
        self.kv_cache_dtype = kv_cache_dtype
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap
699
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
700
701
702

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

703
        self.attn_type = attn_type
704
        self.vllm_flash_attn_version = get_flash_attn_version()
705
        self.use_alibi_sqrt = use_alibi_sqrt
706
        # Cache the batch invariant result for use in forward passes
707
        self.batch_invariant_enabled = vllm_is_batch_invariant()
708

709
        if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
710
            raise NotImplementedError(
711
712
                "FlashAttention does not support fp8 kv-cache on this device."
            )
713

714
715
        self.sinks = sinks
        if self.sinks is not None:
716
717
718
719
            if not current_platform.is_rocm():
                assert flash_attn_supports_sinks(), (
                    "Sinks are only supported in FlashAttention 3"
                )
720
721
            assert self.sinks.shape[0] == num_heads, (
                "Sinks must have the same number of heads as the number of "
722
723
                "heads in the layer"
            )
724

725
        self.supports_quant_query_input = True
726
727
728
729
730
        self.supports_per_head_quant_scales = (
            self.vllm_flash_attn_version >= 3
            if self.vllm_flash_attn_version is not None
            else False
        )
731

732
733
734
735
736
737
738
739
    def _get_unified_extras(
        self,
        attn_metadata: FlashAttentionMetadata,
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
        mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
        qq_bias = attn_metadata.qq_bias
        return mm_prefix_range_tensor, qq_bias

740
741
    def forward(
        self,
742
        layer: torch.nn.Module,
743
744
745
746
747
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
748
749
750
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
751
752
753
754
    ) -> torch.Tensor:
        """Forward pass with FlashAttention.

        Args:
755
756
757
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
758
759
            kv_cache: shape =
                [2, num_blocks, block_size, num_kv_heads, head_size]
760
761
762
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
763
764
765
        NOTE: FP8 quantization, flash-attn expect the size of
              {q,k,v}_descale to be (num_sequences, num_kv_heads).
              We use torch's .expand() to avoid duplicating values
766
        """
767
        assert output is not None, "Output tensor must be provided."
768
769
770
        assert self.vllm_flash_attn_version is not None, (
            "FlashAttention version not detected."
        )
771

772
        if output_scale is not None or output_block_scale is not None:
773
            raise NotImplementedError(
774
775
                "fused output quantization is not yet supported for FlashAttentionImpl"
            )
776

777
778
        if attn_metadata is None:
            # Profiling run.
779
            return output.fill_(0)
780

781
782
        attn_type = self.attn_type

783
784
785
786
787
788
789
790
        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.
791

792
        num_actual_tokens = attn_metadata.num_actual_tokens
793
794

        # Handle encoder attention differently - no KV cache needed
795
        if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
796
797
            # For encoder attention,
            # we use direct Q, K, V tensors without caching
798
799
800
801
802
803
804
805
            return self._forward_encoder_attention(
                query[:num_actual_tokens],
                key[:num_actual_tokens],
                value[:num_actual_tokens],
                output[:num_actual_tokens],
                attn_metadata,
                layer,
            )
806
807

        # For decoder and cross-attention, use KV cache as before
zhuwenwen's avatar
zhuwenwen committed
808
        if current_platform.is_rocm():
809
            key_cache, value_cache = kv_cache
zhuwenwen's avatar
zhuwenwen committed
810
811
        else:
            key_cache, value_cache = kv_cache.unbind(0)
812

813
        if self.kv_cache_dtype.startswith("fp8"):
814
            # queries are quantized in the attention layer
815
            dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
816
817
                self.kv_cache_dtype
            )
818
819
            key_cache = key_cache.view(dtype)
            value_cache = value_cache.view(dtype)
820

821
822
823
824
825
826
827
        if not attn_metadata.use_cascade:
            cu_seqlens_q = attn_metadata.query_start_loc
            seqused_k = attn_metadata.seq_lens
            max_seqlen_q = attn_metadata.max_query_len
            max_seqlen_k = attn_metadata.max_seq_len
            block_table = attn_metadata.block_table
            scheduler_metadata = attn_metadata.scheduler_metadata
828

zhuwenwen's avatar
zhuwenwen committed
829
830
831
832
833
            if current_platform.is_rocm():
                q_descale = None
                k_descale = layer._k_scale
                v_descale = layer._v_scale
            else:
834
                descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
835

zhuwenwen's avatar
zhuwenwen committed
836
837
838
                q_descale = layer._q_scale.expand(descale_shape)
                k_descale = layer._k_scale.expand(descale_shape)
                v_descale = layer._v_scale.expand(descale_shape)
839

840
841
842
843
844
845
846
847
848
            if self.dcp_world_size > 1:
                self._forward_with_dcp(
                    query[:num_actual_tokens],
                    key[:num_actual_tokens],
                    value[:num_actual_tokens],
                    key_cache,
                    value_cache,
                    output[:num_actual_tokens],
                    attn_metadata,
849
850
851
                    q_descale=q_descale,
                    k_descale=k_descale,
                    v_descale=v_descale,
zhuwenwen's avatar
zhuwenwen committed
852
                )
853
                return output
zhuwenwen's avatar
zhuwenwen committed
854
            else:
855
856
857
858
859
                sliding_window_size = (
                    list(self.sliding_window)
                    if self.sliding_window is not None
                    else None
                )
zhuwenwen's avatar
zhuwenwen committed
860
861
862
863
864
865
                if current_platform.is_rocm():
                    if envs.VLLM_USE_PA_PRINT_PARAM:
                        print("PA SIZE:")
                        print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
                        print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
                        print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
866
867
868
                    use_unified_kv_layout = (
                        FlashAttentionBackend._use_rocm_unified_kv_layout(
                            key_cache=key_cache, value_cache=value_cache)
869
                    )
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
                    if use_unified_kv_layout:
                        mm_prefix_range_tensor, qq_bias = self._get_unified_extras(
                            attn_metadata
                        )
                        varlen_fwd_unified(
                            q=query[:num_actual_tokens],
                            k=key_cache,
                            v=value_cache,
                            cu_seqlens_q=cu_seqlens_q,
                            seqused_k=seqused_k,
                            block_table=block_table,
                            max_seqlen_q=max_seqlen_q,
                            max_seqlen_k=max_seqlen_k,
                            softmax_scale=self.scale,
                            causal=attn_metadata.causal,
                            softcap=self.logits_soft_cap,
                            window_size=tuple(self.sliding_window),
                            alibi_slopes=self.alibi_slopes,
                            use_alibi_sqrt=self.use_alibi_sqrt,
                            qq_bias=qq_bias,
                            s_aux=self.sinks,
                            mm_prefix_range=mm_prefix_range_tensor,
                            return_softmax_lse=False,
                            out=output[:num_actual_tokens],
                        )
                    else:
                        vllm_flash_attn_varlen_func(
                            q=query[:num_actual_tokens],
                            k=key_cache,
                            v=value_cache,
                            out=output[:num_actual_tokens],
                            cu_seqlens_q=cu_seqlens_q,
                            max_seqlen_q=max_seqlen_q,
                            seqused_k=seqused_k,
                            max_seqlen_k=max_seqlen_k,
                            softmax_scale=self.scale,
                            causal=attn_metadata.causal,
                            alibi_slopes=self.alibi_slopes,
                            window_size=sliding_window_size,
                            block_table=block_table,
                            softcap=self.logits_soft_cap,
                            scheduler_metadata=scheduler_metadata,
                            fa_version=self.vllm_flash_attn_version,
                            q_descale=q_descale,
                            k_descale=k_descale,
                            v_descale=v_descale,
                            # num_splits=attn_metadata.max_num_splits,
                            s_aux=self.sinks,
                            is_prefix_cache=True,
                        )
920
                else:
zhuwenwen's avatar
zhuwenwen committed
921
                    flash_attn_varlen_func(
922
923
924
925
926
927
928
929
930
931
932
                        q=query[:num_actual_tokens],
                        k=key_cache,
                        v=value_cache,
                        out=output[:num_actual_tokens],
                        cu_seqlens_q=cu_seqlens_q,
                        max_seqlen_q=max_seqlen_q,
                        seqused_k=seqused_k,
                        max_seqlen_k=max_seqlen_k,
                        softmax_scale=self.scale,
                        causal=attn_metadata.causal,
                        alibi_slopes=self.alibi_slopes,
933
                        window_size=sliding_window_size,
934
935
936
                        block_table=block_table,
                        softcap=self.logits_soft_cap,
                        scheduler_metadata=scheduler_metadata,
zhuwenwen's avatar
zhuwenwen committed
937
938
939
940
941
                        fa_version=self.vllm_flash_attn_version,
                        q_descale=q_descale,
                        k_descale=k_descale,
                        v_descale=v_descale,
                        num_splits=attn_metadata.max_num_splits,
942
                        s_aux=self.sinks,
943
                    )
944
                return output
945
946

        # Cascade attention (rare case).
zhuwenwen's avatar
zhuwenwen committed
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
        cascade_attention(
            output[:num_actual_tokens],
            query[:num_actual_tokens],
            key_cache,
            value_cache,
            cu_query_lens=attn_metadata.query_start_loc,
            max_query_len=attn_metadata.max_query_len,
            cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
            prefix_kv_lens=attn_metadata.prefix_kv_lens,
            suffix_kv_lens=attn_metadata.suffix_kv_lens,
            max_kv_len=attn_metadata.max_seq_len,
            softmax_scale=self.scale,
            alibi_slopes=self.alibi_slopes,
            sliding_window=self.sliding_window,
            logits_soft_cap=self.logits_soft_cap,
            block_table=attn_metadata.block_table,
            common_prefix_len=attn_metadata.common_prefix_len,
            max_num_splits=attn_metadata.max_num_splits,
            fa_version=self.vllm_flash_attn_version,
            prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
            suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
            q_descale=None if current_platform.is_rocm() else layer._q_scale,
            k_descale=layer._k_scale,
            v_descale=layer._v_scale,
            s_aux=self.sinks,
        )
973
        return output
974

975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
    def do_kv_cache_update(
        self,
        layer: torch.nn.Module,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
    ) -> None:
        if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
            # For encoder attention,
            # we use direct Q, K, V tensors without caching
            return

        # key and value may be None in the case of cross attention. They are
        # calculated once based on the output from the encoder and then cached
        # in KV cache.
        if (
            self.kv_sharing_target_layer_name is not None
            or key is None
            or value is None
        ):
            return

zhuwenwen's avatar
zhuwenwen committed
998
        if current_platform.is_rocm():
zhuwenwen's avatar
zhuwenwen committed
999
            key_cache, value_cache = kv_cache
zhuwenwen's avatar
zhuwenwen committed
1000
1001
        else:
            key_cache, value_cache = kv_cache.unbind(0)
1002
1003
1004
1005
1006
1007
1008
1009

        # Reshape the input keys and values and store them in the cache.
        # Skip this if sharing KV cache with an earlier attention layer.
        # NOTE(woosuk): Here, key and value are padded while slot_mapping is
        # not padded. However, we don't need to do key[:num_actual_tokens]
        # and value[:num_actual_tokens] because the reshape_and_cache_flash
        # op uses the slot_mapping's shape to determine the number of
        # actual tokens.
zhuwenwen's avatar
zhuwenwen committed
1010
        if current_platform.is_rocm():
1011
1012
1013
1014
1015
            if FlashAttentionBackend._use_rocm_unified_kv_layout(
                key_cache=key_cache,
                value_cache=value_cache,
            ):
                triton_reshape_and_cache_flash(
zhuwenwen's avatar
zhuwenwen committed
1016
1017
1018
1019
1020
1021
1022
1023
1024
                    key,
                    value,
                    key_cache,
                    value_cache,
                    slot_mapping,
                    self.kv_cache_dtype,
                    layer._k_scale,
                    layer._v_scale,
                )
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
            else:
                if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE:
                    from lightop import reshape_and_cache_cuda

                    reshape_and_cache_cuda(
                        key,
                        value,
                        key_cache,
                        value_cache,
                        slot_mapping,
                        self.kv_cache_dtype,
                        layer._k_scale,
                        layer._v_scale
                    )
                else:
                    from vllm.v1.attention.backends.fa_utils import reshape_and_cache_cuda
                    reshape_and_cache_cuda(
                        key,
                        value,
                        key_cache,
                        value_cache,
                        slot_mapping,
                        self.kv_cache_dtype,
                        layer._k_scale,
                        layer._v_scale,
                    )
zhuwenwen's avatar
zhuwenwen committed
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
        else:
            reshape_and_cache_flash(
                key,
                value,
                key_cache,
                value_cache,
                slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )
zhuwenwen's avatar
zhuwenwen committed
1062

1063

1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
    def _forward_with_dcp(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        output: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
        q_descale: torch.Tensor | None = None,
        k_descale: torch.Tensor | None = None,
        v_descale: torch.Tensor | None = None,
    ) -> torch.Tensor:
1077
1078
1079
1080
        assert self.vllm_flash_attn_version is not None, (
            "FlashAttention version not detected."
        )

1081
1082
1083
1084
1085
1086
        cu_seqlens_q = attn_metadata.query_start_loc
        max_seqlen_q = attn_metadata.max_query_len
        block_table = attn_metadata.block_table

        query = query.contiguous()
        query_across_dcp = get_dcp_group().all_gather(query, dim=1)
1087
1088
1089
        sliding_window_size = (
            list(self.sliding_window) if self.sliding_window is not None else None
        )
zhuwenwen's avatar
zhuwenwen committed
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
        if current_platform.is_rocm():
            context_attn_out, context_lse = vllm_flash_attn_varlen_func(
                q=query_across_dcp,
                k=key_cache,
                v=value_cache,
                out=None,
                cu_seqlens_q=cu_seqlens_q,
                max_seqlen_q=max_seqlen_q,
                seqused_k=attn_metadata.dcp_context_kv_lens,
                max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
                softmax_scale=self.scale,
                causal=False,
                alibi_slopes=self.alibi_slopes,
                window_size=sliding_window_size,
                block_table=block_table,
                softcap=self.logits_soft_cap,
                return_softmax_lse=True,
                scheduler_metadata=attn_metadata.scheduler_metadata,
                fa_version=self.vllm_flash_attn_version,
                q_descale=q_descale,
                k_descale=k_descale,
                v_descale=v_descale,
                is_prefix_cache=True,
            )
        else:
            context_attn_out, context_lse = flash_attn_varlen_func(
                q=query_across_dcp,
                k=key_cache,
                v=value_cache,
                out=None,
                cu_seqlens_q=cu_seqlens_q,
                max_seqlen_q=max_seqlen_q,
                seqused_k=attn_metadata.dcp_context_kv_lens,
                max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
                softmax_scale=self.scale,
                causal=False,
                alibi_slopes=self.alibi_slopes,
                window_size=sliding_window_size,
                block_table=block_table,
                softcap=self.logits_soft_cap,
                return_softmax_lse=True,
                scheduler_metadata=attn_metadata.scheduler_metadata,
                fa_version=self.vllm_flash_attn_version,
                q_descale=q_descale,
                k_descale=k_descale,
                v_descale=v_descale,
            )
1137
1138
1139
1140
1141
1142
1143
1144
1145
        # FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
        context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
            context_attn_out,
            context_lse.transpose(0, 1),
            get_dcp_group(),
            return_lse=True,
        )
        context_lse_cor = context_lse_cor.transpose(0, 1).contiguous()

zhuwenwen's avatar
zhuwenwen committed
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
        if current_platform.is_rocm():
             query_attn_out, query_lse = vllm_flash_attn_varlen_func(
                q=query,
                k=key,
                v=value,
                out=None,
                cu_seqlens_q=cu_seqlens_q,
                max_seqlen_q=max_seqlen_q,
                cu_seqlens_k=cu_seqlens_q,
                max_seqlen_k=max_seqlen_q,
                softmax_scale=self.scale,
                causal=attn_metadata.causal,
                alibi_slopes=self.alibi_slopes,
                window_size=sliding_window_size,
                softcap=self.logits_soft_cap,
                return_softmax_lse=True,
                fa_version=self.vllm_flash_attn_version,
                q_descale=q_descale,
                k_descale=k_descale,
                v_descale=v_descale,
            )
        else:
            query_attn_out, query_lse = flash_attn_varlen_func(
                q=query,
                k=key,
                v=value,
                out=None,
                cu_seqlens_q=cu_seqlens_q,
                max_seqlen_q=max_seqlen_q,
                cu_seqlens_k=cu_seqlens_q,
                max_seqlen_k=max_seqlen_q,
                softmax_scale=self.scale,
                causal=attn_metadata.causal,
                alibi_slopes=self.alibi_slopes,
                window_size=sliding_window_size,
                softcap=self.logits_soft_cap,
                return_softmax_lse=True,
                fa_version=self.vllm_flash_attn_version,
                q_descale=q_descale,
                k_descale=k_descale,
                v_descale=v_descale,
            )
           
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
        assert context_attn_out_cor.shape == query_attn_out.shape
        assert context_lse_cor.shape == query_lse.shape
        merge_attn_states(
            output,
            context_attn_out_cor,
            context_lse_cor,
            query_attn_out,
            query_lse,
        )

1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
    def _forward_encoder_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
        layer: torch.nn.Module,
    ) -> torch.Tensor:
        """Forward pass for encoder attention without KV cache.

        Args:
            query: shape = [num_encoder_tokens, num_heads, head_size]
            key: shape = [num_encoder_tokens, num_kv_heads, head_size]
            value: shape = [num_encoder_tokens, num_kv_heads, head_size]
            output: shape = [num_encoder_tokens, num_heads, head_size]
            attn_metadata: Encoder attention metadata
            layer: The attention layer
        """
1218
1219
1220
1221
        assert self.vllm_flash_attn_version is not None, (
            "FlashAttention version not detected."
        )

1222
1223
1224
        # For encoder attention, process FP8 quantization if needed
        if self.kv_cache_dtype.startswith("fp8"):
            raise NotImplementedError(
1225
1226
                "quantization is not supported for encoder attention"
            )
1227
1228
1229
1230
1231
1232
1233
1234
1235

        # Use encoder-specific metadata for sequence information
        cu_seqlens_q = attn_metadata.query_start_loc
        cu_seqlens_k = attn_metadata.query_start_loc
        max_seqlen_q = attn_metadata.max_query_len
        max_seqlen_k = attn_metadata.max_query_len

        descale_shape = (
            cu_seqlens_q.shape[0] - 1,  # type: ignore[union-attr]
1236
1237
            self.num_kv_heads,
        )
1238
1239

        # Call flash attention directly on Q, K, V tensors
1240
1241
1242
        sliding_window_size = (
            list(self.sliding_window) if self.sliding_window is not None else None
        )
zhuwenwen's avatar
zhuwenwen committed
1243
1244
        if current_platform.is_rocm():
            vllm_flash_attn_varlen_func(
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
                q=query,
                k=key,
                v=value,
                out=output,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_k=max_seqlen_k,
                softmax_scale=self.scale,
                causal=False,  # Encoder attention is bidirectional
                alibi_slopes=self.alibi_slopes,
1256
                window_size=sliding_window_size,
1257
                softcap=self.logits_soft_cap,
zhuwenwen's avatar
zhuwenwen committed
1258
1259
1260
1261
1262
1263
1264
1265
1266
                # fa_version=self.vllm_flash_attn_version,
                # q_descale=layer._q_scale.expand(descale_shape),
                # k_descale=layer._k_scale.expand(descale_shape),
                # v_descale=layer._v_scale.expand(descale_shape),
                q_descale=None,
                k_descale=layer._k_scale,
                v_descale=layer._v_scale,
                # num_splits=1 if self.batch_invariant_enabled else 0,
                is_prefix_cache=False,
1267
1268
            )
        else:
zhuwenwen's avatar
zhuwenwen committed
1269
            flash_attn_varlen_func(
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
                q=query,
                k=key,
                v=value,
                out=output,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_k=max_seqlen_k,
                softmax_scale=self.scale,
                causal=False,  # Encoder attention is bidirectional
                alibi_slopes=self.alibi_slopes,
1281
                window_size=sliding_window_size,
1282
                softcap=self.logits_soft_cap,
zhuwenwen's avatar
zhuwenwen committed
1283
1284
1285
1286
1287
                fa_version=self.vllm_flash_attn_version,
                q_descale=layer._q_scale.expand(descale_shape),
                k_descale=layer._k_scale.expand(descale_shape),
                v_descale=layer._v_scale.expand(descale_shape),
                num_splits=1 if self.batch_invariant_enabled else 0,
1288
            )
1289
1290
1291

        return output

1292
1293
1294
1295
1296
1297
1298
1299

def use_cascade_attention(
    common_prefix_len: int,
    query_lens: np.ndarray,
    num_query_heads: int,
    num_kv_heads: int,
    use_alibi: bool,
    use_sliding_window: bool,
1300
    use_local_attention: bool,
1301
    num_sms: int,
1302
    dcp_world_size: int,
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
) -> bool:
    """Decide whether to use cascade attention.

    This function 1) checks whether cascade attention is supported with the
    given configuration, and 2) heuristically decides whether using cascade
    attention can improve performance.
    """
    # Too short common prefix. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
    # NOTE(woosuk): This is the common case. We should return False as soon as
    # possible to avoid any unnecessary computation.
    if common_prefix_len < 256:
        return False
    # Cascade attention is currently not supported with these variants.
1317
    if use_alibi or use_sliding_window or use_local_attention:
1318
1319
1320
1321
1322
1323
        return False
    # Too few queries. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
    num_reqs = len(query_lens)
    if num_reqs < 8:
        return False
1324
1325
1326
    # disable cascade attention for DCP
    if dcp_world_size > 1:
        return False
1327
1328
1329
1330
1331
1332
1333

    # Heuristics to decide whether using cascade attention is beneficial.
    # 1. When FlashDecoding is not used for normal attention, cascade attention
    #    is likely to be faster since it saves memory bandwidth.
    num_queries_per_kv = num_query_heads // num_kv_heads
    # The criteria for using FlashDecoding can be found in the following link:
    # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
1334
1335
1336
1337
1338
1339
    use_flash_decoding = (
        num_queries_per_kv > 1
        and not use_sliding_window
        and not use_alibi
        and np.all(query_lens == 1)
    )
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
    if not use_flash_decoding:
        # Use cascade attention.
        return True

    # 2. When FlashDecoding is used for normal attention, it is not clear
    #    whether cascade attention is beneficial, because FlashDecoding can
    #    launch more CTAs than cascade attention.
    #    We use a simple performance model to compare the two methods.
    #    NOTE(woosuk): The performance model is very rough and may not be
    #    accurate.
    num_tokens = num_reqs
    # NOTE(woosuk): These are default tile sizes. flash-attn might use
    # different tile sizes (e.g., 64 or 256) depending on the configuration.
    q_tile_size = 128
    kv_tile_size = 128
    num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)

    cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
    cascade_waves = cdiv(cascade_ctas, num_sms)
    cascade_time = cascade_waves * num_prefix_tiles

1361
1362
1363
    flash_decoding_ctas = (
        num_reqs * num_kv_heads * cdiv(num_queries_per_kv, q_tile_size)
    )
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
    flash_decoding_ctas *= num_prefix_tiles
    flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)

    # Use cascade attention if it is faster than FlashDecoding.
    return cascade_time < flash_decoding_time


def cascade_attention(
    output: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    cu_query_lens: torch.Tensor,
    max_query_len: int,
    cu_prefix_query_lens: torch.Tensor,
1379
1380
    prefix_kv_lens: torch.Tensor,
    suffix_kv_lens: torch.Tensor,
1381
1382
    max_kv_len: int,
    softmax_scale: float,
1383
    alibi_slopes: torch.Tensor | None,
1384
    sliding_window: tuple[int, int],
1385
1386
1387
    logits_soft_cap: float,
    block_table: torch.Tensor,
    common_prefix_len: int,
1388
    max_num_splits: int,
1389
    fa_version: int,
1390
1391
1392
1393
1394
1395
    prefix_scheduler_metadata: torch.Tensor | None = None,
    suffix_scheduler_metadata: torch.Tensor | None = None,
    q_descale: torch.Tensor | None = None,
    k_descale: torch.Tensor | None = None,
    v_descale: torch.Tensor | None = None,
    s_aux: torch.Tensor | None = None,
1396
) -> torch.Tensor:
1397
    assert alibi_slopes is None, "Cascade attention does not support ALiBi."
1398
1399
    # TODO: Support sliding window.
    assert sliding_window == (-1, -1), (
1400
1401
        "Cascade attention does not support sliding window."
    )
1402
1403
1404
1405
1406
1407

    num_tokens = query.shape[0]
    block_size = key_cache.shape[-3]
    assert common_prefix_len % block_size == 0
    num_common_kv_blocks = common_prefix_len // block_size
    assert num_common_kv_blocks > 0
zhuwenwen's avatar
zhuwenwen committed
1408
1409
    if not current_platform.is_rocm():
        descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
1410
1411

    # Process shared prefix.
zhuwenwen's avatar
zhuwenwen committed
1412
1413
    if current_platform.is_rocm():
        prefix_output, prefix_lse, _ = vllm_flash_attn_varlen_func(
1414
1415
1416
1417
1418
1419
1420
1421
1422
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_prefix_query_lens,
            seqused_k=prefix_kv_lens,
            max_seqlen_q=num_tokens,
            max_seqlen_k=common_prefix_len,
            softmax_scale=softmax_scale,
            causal=False,
1423
            window_size=list(sliding_window),
1424
1425
1426
1427
1428
            block_table=block_table[:1],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=prefix_scheduler_metadata,
            fa_version=fa_version,
zhuwenwen's avatar
zhuwenwen committed
1429
1430
1431
            q_descale=q_descale if q_descale is not None else None,
            k_descale=k_descale if k_descale is not None else None,
            v_descale=v_descale if v_descale is not None else None,
1432
1433
1434
            # s_aux is incorporated into prefix_lse inside the GPU kernel,
            # enabling its effect during the final attention merge.
            s_aux=s_aux,
zhuwenwen's avatar
zhuwenwen committed
1435
1436
            # num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
            is_prefix_cache=True,
1437
        )
zhuwenwen's avatar
zhuwenwen committed
1438
    else:
zhuwenwen's avatar
zhuwenwen committed
1439
        prefix_output, prefix_lse = flash_attn_varlen_func(
zhuwenwen's avatar
zhuwenwen committed
1440
1441
1442
1443
1444
1445
1446
1447
1448
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_prefix_query_lens,
            seqused_k=prefix_kv_lens,
            max_seqlen_q=num_tokens,
            max_seqlen_k=common_prefix_len,
            softmax_scale=softmax_scale,
            causal=False,
1449
            window_size=list(sliding_window),
zhuwenwen's avatar
zhuwenwen committed
1450
1451
1452
1453
            block_table=block_table[:1],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=prefix_scheduler_metadata,
zhuwenwen's avatar
zhuwenwen committed
1454
            fa_version=fa_version,
1455
1456
1457
1458
1459
1460
            q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
            k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
            v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
            # s_aux is incorporated into prefix_lse inside the GPU kernel,
            # enabling its effect during the final attention merge.
            s_aux=s_aux,
zhuwenwen's avatar
zhuwenwen committed
1461
            num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
zhuwenwen's avatar
zhuwenwen committed
1462
        )
zhuwenwen's avatar
zhuwenwen committed
1463
        
1464
1465
    descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])

1466
    # Process suffix per query.
zhuwenwen's avatar
zhuwenwen committed
1467
1468
    if current_platform.is_rocm():
        suffix_output, suffix_lse, _ = vllm_flash_attn_varlen_func(
1469
1470
1471
1472
1473
1474
1475
1476
1477
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_query_lens,
            seqused_k=suffix_kv_lens,
            max_seqlen_q=max_query_len,
            max_seqlen_k=max_kv_len - common_prefix_len,
            softmax_scale=softmax_scale,
            causal=True,
1478
            window_size=list(sliding_window),
1479
1480
1481
1482
1483
            block_table=block_table[:, num_common_kv_blocks:],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=suffix_scheduler_metadata,
            fa_version=fa_version,
zhuwenwen's avatar
zhuwenwen committed
1484
1485
1486
1487
1488
            q_descale=q_descale if q_descale is not None else None,
            k_descale=k_descale if k_descale is not None else None,
            v_descale=v_descale if v_descale is not None else None,
            # num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
            is_prefix_cache=True,
1489
        )
zhuwenwen's avatar
zhuwenwen committed
1490
    else:
zhuwenwen's avatar
zhuwenwen committed
1491
        suffix_output, suffix_lse = flash_attn_varlen_func(
zhuwenwen's avatar
zhuwenwen committed
1492
1493
1494
1495
1496
1497
1498
1499
1500
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_query_lens,
            seqused_k=suffix_kv_lens,
            max_seqlen_q=max_query_len,
            max_seqlen_k=max_kv_len - common_prefix_len,
            softmax_scale=softmax_scale,
            causal=True,
1501
            window_size=list(sliding_window),
zhuwenwen's avatar
zhuwenwen committed
1502
1503
1504
1505
            block_table=block_table[:, num_common_kv_blocks:],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=suffix_scheduler_metadata,
zhuwenwen's avatar
zhuwenwen committed
1506
            fa_version=fa_version,
1507
1508
1509
            q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
            k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
            v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
zhuwenwen's avatar
zhuwenwen committed
1510
            num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
zhuwenwen's avatar
zhuwenwen committed
1511
        )
1512
1513

    # Merge prefix and suffix outputs, and store the result in output.
1514
    merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse)