flash_attn.py 38.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
5
from typing import Optional, Tuple
6

7
import numpy as np
8
9
import torch

10
import vllm.envs as envs
11
from vllm import _custom_ops as ops
12
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
13
14
                                              AttentionMetadata, AttentionType,
                                              is_quantized_kv_cache)
15
from vllm.attention.layer import Attention
16
from vllm.attention.ops.merge_attn_states import merge_attn_states
17
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
18
                                           get_flash_attn_version,
19
20
                                           is_flash_attn_varlen_func_available)

zhuwenwen's avatar
zhuwenwen committed
21
from vllm.platforms import current_platform
22
if is_flash_attn_varlen_func_available():
zhuwenwen's avatar
zhuwenwen committed
23
24
25
26
27
    if not current_platform.is_rocm():
        from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
                                                get_scheduler_metadata,
                                                reshape_and_cache_flash)
    else:
28
29
        from vllm.attention.utils.fa_utils import (vllm_flash_attn_varlen_func,
                                               reshape_and_cache_cuda)
30

31
from vllm.config import VllmConfig, get_layers_from_vllm_config
zhuwenwen's avatar
zhuwenwen committed
32

33
from vllm.logger import init_logger
34
from vllm.utils import cdiv
35
36
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
                                              AttentionMetadataBuilder,
37
38
                                              CommonAttentionMetadata,
                                              get_kv_cache_layout)
39
from vllm.v1.kv_cache_interface import AttentionSpec
40

41
42
logger = init_logger(__name__)

43
44
45
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16

46
47
48

class FlashAttentionBackend(AttentionBackend):

49
50
    accept_output_buffer: bool = True

51
52
53
54
    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.float16, torch.bfloat16]

55
56
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
57
58
        return [32, 64, 96, 128, 160, 192, 224, 256]

59
60
61
62
63
64
65
66
67
68
69
    @classmethod
    def validate_head_size(cls, head_size: int) -> None:
        supported_head_sizes = cls.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            attn_type = cls.__name__.removesuffix("Backend")
            raise ValueError(
                f"Head size {head_size} is not supported by {attn_type}. "
                f"Supported head sizes are: {supported_head_sizes}. "
                "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
                "FlexAttention backend which supports all head sizes.")

70
71
    @staticmethod
    def get_name() -> str:
72
        return "FLASH_ATTN_VLLM_V1"
73
74

    @staticmethod
75
    def get_impl_cls() -> type["FlashAttentionImpl"]:
76
77
78
        return FlashAttentionImpl

    @staticmethod
79
    def get_metadata_cls() -> type["AttentionMetadata"]:
80
81
        return FlashAttentionMetadata

82
    @staticmethod
83
    def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
84
        return FlashAttentionMetadataBuilder
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    
    if not current_platform.is_rocm():
        @staticmethod
        def get_kv_cache_shape(
            num_blocks: int,
            block_size: int,
            num_kv_heads: int,
            head_size: int,
        ) -> tuple[int, ...]:
            if block_size % 16 != 0:
                raise ValueError("Block size must be a multiple of 16.")
            return (2, num_blocks, block_size, num_kv_heads, head_size)

        @staticmethod
        def get_kv_cache_stride_order() -> tuple[int, ...]:
            # `stride_order` indicates the permutation that gets
            # us from `get_kv_cache_shape` to the actual memory layout we want.
            cache_layout = get_kv_cache_layout()
            if cache_layout == "NHD":
                stride_order = (0, 1, 2, 3, 4)
            elif cache_layout == "HND":
                stride_order = (0, 1, 3, 2, 4)
            else:
                raise ValueError(f"Unknown cache layout format {cache_layout}.")
            return stride_order
    else:
        @staticmethod
        def get_kv_cache_shape(
            num_blocks: int,
            block_size: int,
            num_kv_heads: int,
            head_size: int,
        ) -> tuple[tuple[int, ...], tuple[int, ...]]:
            if block_size % 16 != 0:
                raise ValueError("Block size must be a multiple of 16.")
            return (
                (num_blocks, num_kv_heads, block_size, head_size),
                (num_blocks, num_kv_heads, head_size, block_size),
            )
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
        @staticmethod
        def get_kv_cache_stride_order() -> tuple[tuple[int, ...], tuple[int, ...]]:
            # `stride_order` indicates the permutation that gets
            # us from `get_kv_cache_shape` to the actual memory layout we want.
            cache_layout = get_kv_cache_layout()
            if cache_layout == "NHD":
                key_stride_order = (0, 1, 2, 3)
                value_stride_order = (0, 1, 2, 3)
            elif cache_layout == "HND":
                key_stride_order = (0, 2, 1, 3)
                value_stride_order = (0, 2, 1, 3)
            else:
                raise ValueError(f"Unknown cache layout format {cache_layout}.")
            return key_stride_order, value_stride_order
139

140
141
142
143
144
145
146
    @staticmethod
    def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

147
148
149
150
151
152
153
154
155
156
157

@dataclass
class FlashAttentionMetadata:
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

158
    num_actual_tokens: int  # Number of tokens excluding padding.
159
160
161
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
162
    seq_lens: torch.Tensor
163
164
    block_table: torch.Tensor
    slot_mapping: torch.Tensor
165
166
167
168
169

    # For cascade attention.
    use_cascade: bool
    common_prefix_len: int
    cu_prefix_query_lens: Optional[torch.Tensor]
170
171
    prefix_kv_lens: Optional[torch.Tensor]
    suffix_kv_lens: Optional[torch.Tensor]
172

173
174
175
    # Optional aot scheduling
    scheduler_metadata: Optional[torch.Tensor] = None
    prefix_scheduler_metadata: Optional[torch.Tensor] = None
176
    max_num_splits: int = 0
177

178
179
    causal: bool = True

180

181
182
183
184
185
186
187
188
189
190
191
def _get_sliding_window_configs(
        vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
    """Get the set of all sliding window configs used in the model."""
    sliding_window_configs: set[Optional[tuple[int, int]]] = set()
    layers = get_layers_from_vllm_config(vllm_config, Attention)
    for layer in layers.values():
        assert isinstance(layer.impl, FlashAttentionImpl)
        sliding_window_configs.add(layer.impl.sliding_window)
    return sliding_window_configs


192
193
class FlashAttentionMetadataBuilder(
        AttentionMetadataBuilder[FlashAttentionMetadata]):
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    # FA3:
    # Supports full cudagraphs for all cases.
    #
    # FA2:
    # For FA2, a graph is captured with max_query_len=1, (which is what we
    # capture by default for num_tokens <= max_num_seqs when there is no
    # spec-decode) then these graphs will not work for mixed prefill-decode
    # (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
    # in FA2.
    # In summary if we are running with spec decodes the graphs would
    # work for mixed prefill-decode and uniform-decode. But for non-spec decodes
    # the graphs would not work for mixed prefill-decode; sorta the inverse
    # of UNIFORM_SINGLE_TOKEN_DECODE.
    # Theres probably a better way to describe this using `AttentionCGSupport`
    # but for now just set it to `UNIFORM_BATCH` to get use to drop down
    # to FULL_AND_PIECEWISE.
    # TODO(luka, lucas): audit FA2 as part of:
    #  https://github.com/vllm-project/vllm/issues/22945
    cudagraph_support = AttentionCGSupport.ALWAYS \
213
        if get_flash_attn_version() == 3 or current_platform.is_rocm()  else AttentionCGSupport.UNIFORM_BATCH
214

215
216
    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
217
218
219
220
221
222
223
224
225
226
227
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.parallel_config = vllm_config.parallel_config
        self.cache_config = vllm_config.cache_config
        self.compilation_config = vllm_config.compilation_config
        self.device = device

        self.num_heads_q = self.model_config.get_num_attention_heads(
            self.parallel_config)
        self.num_heads_kv = self.model_config.get_num_kv_heads(
            self.parallel_config)
228
        self.kv_cache_dtype = kv_cache_spec.dtype
229
        self.headdim = self.model_config.get_head_size()
230
        self.block_size = kv_cache_spec.block_size
231

232
        self.max_num_splits = 0  # No upper bound on the number of splits.
233
        self.aot_schedule = (get_flash_attn_version() == 3)
234
235
236
237
238
239
240

        self.use_full_cuda_graph = \
            self.compilation_config.cudagraph_mode.has_full_cudagraphs()

        if self.use_full_cuda_graph and self.aot_schedule:
            self.max_cudagraph_size = self.compilation_config.max_capture_size

241
242
243
244
245
246
247
248
            if self.max_cudagraph_size > 992:
                # This condition derives from FA3's internal heuristic.
                # TODO(woosuk): Support larger cudagraph sizes.
                raise ValueError(
                    "Capture size larger than 992 is not supported for "
                    "full cuda graph.")

            self.scheduler_metadata = torch.zeros(
249
                vllm_config.scheduler_config.max_num_seqs + 1,
250
                dtype=torch.int32,
251
                device=self.device,
252
253
254
255
256
            )
            # When using cuda graph, we need to set the upper bound of the
            # number of splits so that large enough intermediate buffers are
            # pre-allocated during capture.
            self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
257

258
259
260
261
        # Sliding window size to be used with the AOT scheduler will be
        # populated on first build() call.
        self.aot_sliding_window: Optional[tuple[int, int]] = None

262
263
264
265
266
267
268
269
    def build(self,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata,
              fast_build: bool = False) -> FlashAttentionMetadata:
        """
        fast_build disables AOT scheduling, used when there will be few 
        iterations i.e. spec-decode
        """
270
271
272
        num_reqs = common_attn_metadata.num_reqs
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        max_query_len = common_attn_metadata.max_query_len
273
        max_seq_len = common_attn_metadata.max_seq_len
274
275
        query_start_loc = common_attn_metadata.query_start_loc
        seq_lens = common_attn_metadata.seq_lens
276
277
278
        seq_lens_cpu = common_attn_metadata.seq_lens_cpu
        block_table_tensor = common_attn_metadata.block_table_tensor
        slot_mapping = common_attn_metadata.slot_mapping
279
        causal = common_attn_metadata.causal
280

281
282
        # the overhead of the aot schedule is not worth it for spec-decode
        aot_schedule = self.aot_schedule and not fast_build
283

zhuwenwen's avatar
zhuwenwen committed
284
285
286
287
288
289
        if self.aot_sliding_window is None:
            self.aot_sliding_window = (-1, -1)
            # For the AOT scheduler we need the sliding window value to be
            # constant for all layers to. We have to populate this on the first
            # build() call so the layers are constructed (cannot populate)
            # in __init__.
290
            if aot_schedule:
zhuwenwen's avatar
zhuwenwen committed
291
                sliding_window_configs = _get_sliding_window_configs(
292
                    self.vllm_config)
zhuwenwen's avatar
zhuwenwen committed
293
294
295
296
297
298
                if len(sliding_window_configs) == 1:
                    sliding_window_config = sliding_window_configs.pop()
                    if sliding_window_config is not None:
                        self.aot_sliding_window = sliding_window_config
                elif len(sliding_window_configs) > 1:
                    self.aot_schedule = False
299
                    aot_schedule = False
300

301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        if self.aot_sliding_window is None:
            self.aot_sliding_window = (-1, -1)
            # For the AOT scheduler we need the sliding window value to be
            # constant for all layers to. We have to populate this on the first
            # build() call so the layers are constructed (cannot populate)
            # in __init__.
            if self.aot_schedule:
                sliding_window_configs = _get_sliding_window_configs(
                    self.runner.vllm_config)
                if len(sliding_window_configs) == 1:
                    sliding_window_config = sliding_window_configs.pop()
                    if sliding_window_config is not None:
                        self.aot_sliding_window = sliding_window_config
                elif len(sliding_window_configs) > 1:
                    self.aot_schedule = False

317
318
        def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
                     max_seq_len, causal):
319
320
321
322
323
324
            cache_dtype = self.cache_config.cache_dtype
            if cache_dtype.startswith("fp8"):
                qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
                    cache_dtype)
            else:
                qkv_dtype = self.kv_cache_dtype
325
            if aot_schedule:
326
327
328
329
330
331
332
                return get_scheduler_metadata(
                    batch_size=batch_size,
                    max_seqlen_q=max_query_len,
                    max_seqlen_k=max_seq_len,
                    num_heads_q=self.num_heads_q,
                    num_heads_kv=self.num_heads_kv,
                    headdim=self.headdim,
333
334
                    cache_seqlens=seqlens,
                    qkv_dtype=qkv_dtype,
335
                    cu_seqlens_q=cu_query_lens,
336
                    page_size=self.block_size,
337
                    causal=causal,
338
                    window_size=self.aot_sliding_window,
339
                    num_splits=self.max_num_splits,
340
341
342
                )
            return None

343
        use_cascade = common_prefix_len > 0
344

345
346
347
        if use_cascade:
            cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
                                                dtype=torch.int32,
348
                                                device=self.device)
349
350
            prefix_kv_lens = torch.tensor([common_prefix_len],
                                          dtype=torch.int32,
351
352
353
                                          device=self.device)
            suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
                self.device, non_blocking=True)
354
            prefix_scheduler_metadata = schedule(
355
                batch_size=1,
356
357
358
359
360
                cu_query_lens=cu_prefix_query_lens,
                max_query_len=num_actual_tokens,
                seqlens=prefix_kv_lens,
                max_seq_len=common_prefix_len,
                causal=False)
361
362
            scheduler_metadata = schedule(batch_size=num_reqs,
                                          cu_query_lens=query_start_loc,
363
364
365
366
367
                                          max_query_len=max_query_len,
                                          seqlens=suffix_kv_lens,
                                          max_seq_len=max_seq_len -
                                          common_prefix_len,
                                          causal=True)
368
369
370
371
        else:
            cu_prefix_query_lens = None
            prefix_kv_lens = None
            suffix_kv_lens = None
372
            prefix_scheduler_metadata = None
373
374
            scheduler_metadata = schedule(batch_size=num_reqs,
                                          cu_query_lens=query_start_loc,
375
376
377
                                          max_query_len=max_query_len,
                                          seqlens=seq_lens,
                                          max_seq_len=max_seq_len,
378
                                          causal=causal)
379
                                          
380
381
382
        # For FA3 + full cudagraph
        max_num_splits = 0
        if self.use_full_cuda_graph and scheduler_metadata is not None:
383
            n = scheduler_metadata.shape[0]
384
            self.scheduler_metadata[:n] = scheduler_metadata
385
386
387
388
389
390
391
            # NOTE(woosuk): We should zero out the rest of the scheduler
            # metadata to guarantee the correctness. Otherwise, some thread
            # blocks may use the invalid scheduler metadata and overwrite the
            # output buffer.
            self.scheduler_metadata[n:] = 0
            scheduler_metadata = self.scheduler_metadata[:n]

392
393
394
395
396
397
            if num_actual_tokens <= self.max_cudagraph_size:
                # NOTE(woosuk): Setting num_splits > 1 may increase the memory
                # usage, because the intermediate buffers of size [num_splits,
                # num_heads, num_tokens, head_size] are allocated. Therefore,
                # we only set num_splits when using cuda graphs.
                max_num_splits = self.max_num_splits
398

399
400
401
402
403
404
        attn_metadata = FlashAttentionMetadata(
            num_actual_tokens=num_actual_tokens,
            max_query_len=max_query_len,
            query_start_loc=query_start_loc,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
405
            block_table=block_table_tensor,
406
407
408
            slot_mapping=slot_mapping,
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
409
            scheduler_metadata=scheduler_metadata,
410
411
412
            cu_prefix_query_lens=cu_prefix_query_lens,
            prefix_kv_lens=prefix_kv_lens,
            suffix_kv_lens=suffix_kv_lens,
413
            prefix_scheduler_metadata=prefix_scheduler_metadata,
414
            max_num_splits=max_num_splits,
415
            causal=causal)
416
417
        return attn_metadata

418
419
420
    def use_cascade_attention(self, *args, **kwargs) -> bool:
        return use_cascade_attention(*args, **kwargs)

421

422
423
424
425
426
427
428
429
class FlashAttentionImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
430
        alibi_slopes: Optional[list[float]],
431
432
433
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float] = None,
434
        attn_type: AttentionType = AttentionType.DECODER,
435
        kv_sharing_target_layer_name: Optional[str] = None,
436
        sinks: Optional[torch.Tensor] = None,
437
438
439
440
441
442
443
444
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
445
446
        if sliding_window is None:
            self.sliding_window = (-1, -1)
447
448
        elif attn_type == AttentionType.ENCODER_ONLY:
            self.sliding_window = (sliding_window - 1, sliding_window - 1)
449
450
        else:
            self.sliding_window = (sliding_window - 1, 0)
451
452
453
454
455
        self.kv_cache_dtype = kv_cache_dtype
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap
456
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
457
458
459

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

460
        FlashAttentionBackend.validate_head_size(head_size)
461

462
        self.attn_type = attn_type
463
        self.vllm_flash_attn_version = get_flash_attn_version()
464
465
466
467
        if is_quantized_kv_cache(self.kv_cache_dtype) \
            and not flash_attn_supports_fp8():
            raise NotImplementedError(
                "FlashAttention does not support fp8 kv-cache on this device.")
468

469
470
471
472
473
474
475
476
        self.sinks = sinks
        if self.sinks is not None:
            assert self.vllm_flash_attn_version == 3, (
                "Sinks are only supported in FlashAttention 3")
            assert self.sinks.shape[0] == num_heads, (
                "Sinks must have the same number of heads as the number of "
                "heads in the layer")

477
478
    def forward(
        self,
479
        layer: torch.nn.Module,
480
481
482
483
484
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
485
        output: Optional[torch.Tensor] = None,
486
        output_scale: Optional[torch.Tensor] = None,
487
        output_block_scale: Optional[torch.Tensor] = None,
488
489
490
491
    ) -> torch.Tensor:
        """Forward pass with FlashAttention.

        Args:
492
493
494
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
495
496
            kv_cache: shape =
                [2, num_blocks, block_size, num_kv_heads, head_size]
497
498
499
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
500
501
502
        NOTE: FP8 quantization, flash-attn expect the size of
              {q,k,v}_descale to be (num_sequences, num_kv_heads).
              We use torch's .expand() to avoid duplicating values
503
        """
504
505
        assert output is not None, "Output tensor must be provided."

506
        if output_scale is not None or output_block_scale is not None:
507
508
509
510
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for FlashAttentionImpl")

511
512
513
514
        if attn_metadata is None:
            # Profiling run.
            return output

515
516
        attn_type = self.attn_type

517
518
519
520
521
522
523
524
        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.
525

526
        num_actual_tokens = attn_metadata.num_actual_tokens
527
528

        # Handle encoder attention differently - no KV cache needed
529
        if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
530
531
532
533
534
535
536
537
538
            # For encoder attention,
            # we use direct Q, K, V tensors without caching
            return self._forward_encoder_attention(query[:num_actual_tokens],
                                                   key[:num_actual_tokens],
                                                   value[:num_actual_tokens],
                                                   output[:num_actual_tokens],
                                                   attn_metadata, layer)

        # For decoder and cross-attention, use KV cache as before
539
540
541
542
        if not current_platform.is_rocm():
            key_cache, value_cache = kv_cache.unbind(0)
        else:
            key_cache, value_cache = kv_cache
543

544
545
546
547
548
        # key and value may be None in the case of cross attention. They are
        # calculated once based on the output from the encoder and then cached
        # in KV cache.
        if (self.kv_sharing_target_layer_name is None and key is not None
                and value is not None):
549
550
551
552
553
554
555
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
            if not current_platform.is_rocm():
                reshape_and_cache_flash(
                    key,
                    value,
                    key_cache,
                    value_cache,
                    attn_metadata.slot_mapping,
                    self.kv_cache_dtype,
                    layer._k_scale,
                    layer._v_scale,
                )
            else:
                reshape_and_cache_cuda(
                    key,
                    value,
                    key_cache,
                    value_cache,
                    attn_metadata.slot_mapping,
                    self.kv_cache_dtype,
                    layer._k_scale,
                    layer._v_scale,
                )
578

579
        if self.kv_cache_dtype.startswith("fp8"):
580
581
582
583
            dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
                self.kv_cache_dtype)
            key_cache = key_cache.view(dtype)
            value_cache = value_cache.view(dtype)
584
585
586
587
588
589
            num_tokens, num_heads, head_size = query.shape
            query, _ = ops.scaled_fp8_quant(
                query.reshape(
                    (num_tokens, num_heads * head_size)).contiguous(),
                layer._q_scale)
            query = query.reshape((num_tokens, num_heads, head_size))
590

591
592
593
594
595
596
597
        if not attn_metadata.use_cascade:
            cu_seqlens_q = attn_metadata.query_start_loc
            seqused_k = attn_metadata.seq_lens
            max_seqlen_q = attn_metadata.max_query_len
            max_seqlen_k = attn_metadata.max_seq_len
            block_table = attn_metadata.block_table
            scheduler_metadata = attn_metadata.scheduler_metadata
598

599
            descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
600

zhuwenwen's avatar
zhuwenwen committed
601
            if not current_platform.is_rocm():
602
                 flash_attn_varlen_func(
zhuwenwen's avatar
zhuwenwen committed
603
604
605
606
607
608
609
610
611
                    q=query[:num_actual_tokens],
                    k=key_cache,
                    v=value_cache,
                    out=output[:num_actual_tokens],
                    cu_seqlens_q=cu_seqlens_q,
                    max_seqlen_q=max_seqlen_q,
                    seqused_k=seqused_k,
                    max_seqlen_k=max_seqlen_k,
                    softmax_scale=self.scale,
612
                    causal=attn_metadata.causal,
zhuwenwen's avatar
zhuwenwen committed
613
614
615
616
617
618
619
620
621
622
                    alibi_slopes=self.alibi_slopes,
                    window_size=self.sliding_window,
                    block_table=block_table,
                    softcap=self.logits_soft_cap,
                    scheduler_metadata=scheduler_metadata,
                    fa_version=self.vllm_flash_attn_version,
                    q_descale=layer._q_scale.expand(descale_shape),
                    k_descale=layer._k_scale.expand(descale_shape),
                    v_descale=layer._v_scale.expand(descale_shape),
                    num_splits=attn_metadata.max_num_splits,
623
                    s_aux=self.sinks,
zhuwenwen's avatar
zhuwenwen committed
624
625
                )
            else:
626
627
                if envs.VLLM_USE_PA_PRINT_PARAM:
                    print("PA SIZE:")
628
                    print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
629
630
                    print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
                    print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
zhuwenwen's avatar
zhuwenwen committed
631
632
633
634
635
636
637
638
639
640
                vllm_flash_attn_varlen_func(
                    q=query[:num_actual_tokens],
                    k=key_cache,
                    v=value_cache,
                    out=output[:num_actual_tokens],
                    cu_seqlens_q=cu_seqlens_q,
                    max_seqlen_q=max_seqlen_q,
                    seqused_k=seqused_k,
                    max_seqlen_k=max_seqlen_k,
                    softmax_scale=self.scale,
641
                    causal=attn_metadata.causal,
zhuwenwen's avatar
zhuwenwen committed
642
643
644
645
646
647
648
649
650
651
                    alibi_slopes=self.alibi_slopes,
                    window_size=self.sliding_window,
                    block_table=block_table,
                    softcap=self.logits_soft_cap,
                    scheduler_metadata=scheduler_metadata,
                    # fa_version=self.vllm_flash_attn_version,
                    # q_descale=layer._q_scale.expand(descale_shape),
                    # k_descale=layer._k_scale.expand(descale_shape),
                    # v_descale=layer._v_scale.expand(descale_shape),
                    # num_splits=attn_metadata.max_num_splits,
652
                    # s_aux=self.sinks,
653
                    is_prefix_cache=True,
zhuwenwen's avatar
zhuwenwen committed
654
                )
655
            return output
656
657
658
659
660
661
662
663
664
665

        # Cascade attention (rare case).
        cascade_attention(
            output[:num_actual_tokens],
            query[:num_actual_tokens],
            key_cache,
            value_cache,
            cu_query_lens=attn_metadata.query_start_loc,
            max_query_len=attn_metadata.max_query_len,
            cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
666
667
            prefix_kv_lens=attn_metadata.prefix_kv_lens,
            suffix_kv_lens=attn_metadata.suffix_kv_lens,
668
            max_kv_len=attn_metadata.max_seq_len,
669
670
            softmax_scale=self.scale,
            alibi_slopes=self.alibi_slopes,
671
672
            sliding_window=self.sliding_window,
            logits_soft_cap=self.logits_soft_cap,
673
            block_table=attn_metadata.block_table,
674
            common_prefix_len=attn_metadata.common_prefix_len,
675
            fa_version=self.vllm_flash_attn_version,
676
677
            prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
            suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
678
679
680
            q_descale=layer._q_scale,
            k_descale=layer._k_scale,
            v_descale=layer._v_scale,
681
682
        )
        return output
683

684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
    def _forward_encoder_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
        layer: torch.nn.Module,
    ) -> torch.Tensor:
        """Forward pass for encoder attention without KV cache.

        Args:
            query: shape = [num_encoder_tokens, num_heads, head_size]
            key: shape = [num_encoder_tokens, num_kv_heads, head_size]
            value: shape = [num_encoder_tokens, num_kv_heads, head_size]
            output: shape = [num_encoder_tokens, num_heads, head_size]
            attn_metadata: Encoder attention metadata
            layer: The attention layer
        """
        # For encoder attention, process FP8 quantization if needed
        if self.kv_cache_dtype.startswith("fp8"):
            raise NotImplementedError(
                "quantization is not supported for encoder attention")

        # Use encoder-specific metadata for sequence information
        cu_seqlens_q = attn_metadata.query_start_loc
        cu_seqlens_k = attn_metadata.query_start_loc
        max_seqlen_q = attn_metadata.max_query_len
        max_seqlen_k = attn_metadata.max_query_len

        descale_shape = (
            cu_seqlens_q.shape[0] - 1,  # type: ignore[union-attr]
            self.num_kv_heads)

        # Call flash attention directly on Q, K, V tensors
        flash_attn_varlen_func(
            q=query,
            k=key,
            v=value,
            out=output,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_k=max_seqlen_k,
            softmax_scale=self.scale,
            causal=False,  # Encoder attention is bidirectional
            alibi_slopes=self.alibi_slopes,
            window_size=self.sliding_window,
            softcap=self.logits_soft_cap,
            fa_version=self.vllm_flash_attn_version,
            q_descale=layer._q_scale.expand(descale_shape),
            k_descale=layer._k_scale.expand(descale_shape),
            v_descale=layer._v_scale.expand(descale_shape),
        )

        return output

741
742
743
744
745
746
747
748

def use_cascade_attention(
    common_prefix_len: int,
    query_lens: np.ndarray,
    num_query_heads: int,
    num_kv_heads: int,
    use_alibi: bool,
    use_sliding_window: bool,
749
    use_local_attention: bool,
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
    num_sms: int,
) -> bool:
    """Decide whether to use cascade attention.

    This function 1) checks whether cascade attention is supported with the
    given configuration, and 2) heuristically decides whether using cascade
    attention can improve performance.
    """
    # Too short common prefix. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
    # NOTE(woosuk): This is the common case. We should return False as soon as
    # possible to avoid any unnecessary computation.
    if common_prefix_len < 256:
        return False
    # Cascade attention is currently not supported with these variants.
765
    if use_alibi or use_sliding_window or use_local_attention:
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
        return False
    # Too few queries. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
    num_reqs = len(query_lens)
    if num_reqs < 8:
        return False

    # Heuristics to decide whether using cascade attention is beneficial.
    # 1. When FlashDecoding is not used for normal attention, cascade attention
    #    is likely to be faster since it saves memory bandwidth.
    num_queries_per_kv = num_query_heads // num_kv_heads
    # The criteria for using FlashDecoding can be found in the following link:
    # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
    use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
                          and not use_alibi and np.all(query_lens == 1))
    if not use_flash_decoding:
        # Use cascade attention.
        return True

    # 2. When FlashDecoding is used for normal attention, it is not clear
    #    whether cascade attention is beneficial, because FlashDecoding can
    #    launch more CTAs than cascade attention.
    #    We use a simple performance model to compare the two methods.
    #    NOTE(woosuk): The performance model is very rough and may not be
    #    accurate.
    num_tokens = num_reqs
    # NOTE(woosuk): These are default tile sizes. flash-attn might use
    # different tile sizes (e.g., 64 or 256) depending on the configuration.
    q_tile_size = 128
    kv_tile_size = 128
    num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)

    cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
    cascade_waves = cdiv(cascade_ctas, num_sms)
    cascade_time = cascade_waves * num_prefix_tiles

    flash_decoding_ctas = (num_reqs * num_kv_heads *
                           cdiv(num_queries_per_kv, q_tile_size))
    flash_decoding_ctas *= num_prefix_tiles
    flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)

    # Use cascade attention if it is faster than FlashDecoding.
    return cascade_time < flash_decoding_time


def cascade_attention(
    output: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    cu_query_lens: torch.Tensor,
    max_query_len: int,
    cu_prefix_query_lens: torch.Tensor,
819
820
    prefix_kv_lens: torch.Tensor,
    suffix_kv_lens: torch.Tensor,
821
822
823
    max_kv_len: int,
    softmax_scale: float,
    alibi_slopes: Optional[torch.Tensor],
824
    sliding_window: tuple[int, int],
825
826
827
    logits_soft_cap: float,
    block_table: torch.Tensor,
    common_prefix_len: int,
828
    fa_version: int,
829
830
    prefix_scheduler_metadata: Optional[torch.Tensor] = None,
    suffix_scheduler_metadata: Optional[torch.Tensor] = None,
831
832
833
    q_descale: Optional[torch.Tensor] = None,
    k_descale: Optional[torch.Tensor] = None,
    v_descale: Optional[torch.Tensor] = None,
834
835
836
837
838
839
840
841
842
843
844
) -> torch.Tensor:
    assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
    # TODO: Support sliding window.
    assert sliding_window == (-1, -1), (
        "Cascade attention does not support sliding window.")

    num_tokens = query.shape[0]
    block_size = key_cache.shape[-3]
    assert common_prefix_len % block_size == 0
    num_common_kv_blocks = common_prefix_len // block_size
    assert num_common_kv_blocks > 0
845
    descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
846
847

    # Process shared prefix.
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
    if not current_platform.is_rocm():
        prefix_output, prefix_lse = flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_prefix_query_lens,
            seqused_k=prefix_kv_lens,
            max_seqlen_q=num_tokens,
            max_seqlen_k=common_prefix_len,
            softmax_scale=softmax_scale,
            causal=False,
            window_size=sliding_window,
            block_table=block_table[:1],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=prefix_scheduler_metadata,
            fa_version=fa_version,
            q_descale=q_descale.expand(descale_shape)
            if q_descale is not None else None,
            k_descale=k_descale.expand(descale_shape)
            if k_descale is not None else None,
            v_descale=v_descale.expand(descale_shape)
            if v_descale is not None else None,
        )
872

873
874
    descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])

875
    # Process suffix per query.
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
    if not current_platform.is_rocm():
        suffix_output, suffix_lse = flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_query_lens,
            seqused_k=suffix_kv_lens,
            max_seqlen_q=max_query_len,
            max_seqlen_k=max_kv_len - common_prefix_len,
            softmax_scale=softmax_scale,
            causal=True,
            window_size=sliding_window,
            block_table=block_table[:, num_common_kv_blocks:],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=suffix_scheduler_metadata,
            fa_version=fa_version,
            q_descale=q_descale.expand(descale_shape)
            if q_descale is not None else None,
            k_descale=k_descale.expand(descale_shape)
            if k_descale is not None else None,
            v_descale=v_descale.expand(descale_shape)
            if v_descale is not None else None,
        )
900
901
902

    # Merge prefix and suffix outputs, and store the result in output.
    merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
903
                      suffix_lse)