flash_attn.py 42.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
5
from typing import Optional, Tuple
6

7
import numpy as np
8
9
import torch

10
import vllm.envs as envs
11
from vllm import _custom_ops as ops
12

13
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
14
15
                                              AttentionMetadata, AttentionType,
                                              is_quantized_kv_cache)
16
from vllm.attention.layer import Attention
17
from vllm.attention.ops.merge_attn_states import merge_attn_states
18
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
19
                                           get_flash_attn_version,
20
21
                                           is_flash_attn_varlen_func_available)

zhuwenwen's avatar
zhuwenwen committed
22
from vllm.platforms import current_platform
23
if is_flash_attn_varlen_func_available():
zhuwenwen's avatar
zhuwenwen committed
24
25
26
27
28
    if not current_platform.is_rocm():
        from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
                                                get_scheduler_metadata,
                                                reshape_and_cache_flash)
    else:
29
30
        from vllm.attention.utils.fa_utils import (vllm_flash_attn_varlen_func,
                                               reshape_and_cache_cuda)
31

32
from vllm.config import VllmConfig, get_layers_from_vllm_config
zhuwenwen's avatar
zhuwenwen committed
33

34
from vllm.logger import init_logger
35
from vllm.utils import cdiv
36
37
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
                                              AttentionMetadataBuilder,
38
39
                                              CommonAttentionMetadata,
                                              get_kv_cache_layout)
40
from vllm.v1.kv_cache_interface import AttentionSpec
41

42
43
logger = init_logger(__name__)

44
45
46

class FlashAttentionBackend(AttentionBackend):

47
    accept_output_buffer: bool = True
48
    supports_quant_query_input: bool = True
49

50
51
52
53
    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.float16, torch.bfloat16]

54
55
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
56
57
        return [32, 64, 96, 128, 160, 192, 224, 256]

58
59
60
61
62
63
64
65
66
67
68
    @classmethod
    def validate_head_size(cls, head_size: int) -> None:
        supported_head_sizes = cls.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            attn_type = cls.__name__.removesuffix("Backend")
            raise ValueError(
                f"Head size {head_size} is not supported by {attn_type}. "
                f"Supported head sizes are: {supported_head_sizes}. "
                "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
                "FlexAttention backend which supports all head sizes.")

69
70
    @staticmethod
    def get_name() -> str:
71
        return "FLASH_ATTN"
72
73

    @staticmethod
74
    def get_impl_cls() -> type["FlashAttentionImpl"]:
75
76
77
        return FlashAttentionImpl

    @staticmethod
78
    def get_metadata_cls() -> type["AttentionMetadata"]:
79
80
        return FlashAttentionMetadata

81
    @staticmethod
82
    def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
83
        return FlashAttentionMetadataBuilder
84
85
86
87
88
89
90
91
    
    if not current_platform.is_rocm():
        @staticmethod
        def get_kv_cache_shape(
            num_blocks: int,
            block_size: int,
            num_kv_heads: int,
            head_size: int,
92
            cache_dtype_str: str = "auto",
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        ) -> tuple[int, ...]:
            if block_size % 16 != 0:
                raise ValueError("Block size must be a multiple of 16.")
            return (2, num_blocks, block_size, num_kv_heads, head_size)

        @staticmethod
        def get_kv_cache_stride_order() -> tuple[int, ...]:
            # `stride_order` indicates the permutation that gets
            # us from `get_kv_cache_shape` to the actual memory layout we want.
            cache_layout = get_kv_cache_layout()
            if cache_layout == "NHD":
                stride_order = (0, 1, 2, 3, 4)
            elif cache_layout == "HND":
                stride_order = (0, 1, 3, 2, 4)
            else:
                raise ValueError(f"Unknown cache layout format {cache_layout}.")
            return stride_order
    else:
        @staticmethod
        def get_kv_cache_shape(
            num_blocks: int,
            block_size: int,
            num_kv_heads: int,
            head_size: int,
117
            cache_dtype_str: str = "auto",
118
119
120
121
122
123
124
        ) -> tuple[tuple[int, ...], tuple[int, ...]]:
            if block_size % 16 != 0:
                raise ValueError("Block size must be a multiple of 16.")
            return (
                (num_blocks, num_kv_heads, block_size, head_size),
                (num_blocks, num_kv_heads, head_size, block_size),
            )
125

126
127
128
129
130
131
132
133
134
135
136
137
138
139
        @staticmethod
        def get_kv_cache_stride_order() -> tuple[tuple[int, ...], tuple[int, ...]]:
            # `stride_order` indicates the permutation that gets
            # us from `get_kv_cache_shape` to the actual memory layout we want.
            cache_layout = get_kv_cache_layout()
            if cache_layout == "NHD":
                key_stride_order = (0, 1, 2, 3)
                value_stride_order = (0, 1, 2, 3)
            elif cache_layout == "HND":
                key_stride_order = (0, 2, 1, 3)
                value_stride_order = (0, 2, 1, 3)
            else:
                raise ValueError(f"Unknown cache layout format {cache_layout}.")
            return key_stride_order, value_stride_order
140
141


142
143
144
145
146
147
148
    @staticmethod
    def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

149
150
151
152
153
154
155
156
157
158
159

@dataclass
class FlashAttentionMetadata:
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

160
    num_actual_tokens: int  # Number of tokens excluding padding.
161
162
163
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
164
    seq_lens: torch.Tensor
165
166
    block_table: torch.Tensor
    slot_mapping: torch.Tensor
167
168
169
170
171

    # For cascade attention.
    use_cascade: bool
    common_prefix_len: int
    cu_prefix_query_lens: Optional[torch.Tensor]
172
173
    prefix_kv_lens: Optional[torch.Tensor]
    suffix_kv_lens: Optional[torch.Tensor]
174

175
176
177
    # Optional aot scheduling
    scheduler_metadata: Optional[torch.Tensor] = None
    prefix_scheduler_metadata: Optional[torch.Tensor] = None
178
    max_num_splits: int = 0
179

180
181
    causal: bool = True

182

183
184
185
186
187
188
189
190
191
192
193
def _get_sliding_window_configs(
        vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
    """Get the set of all sliding window configs used in the model."""
    sliding_window_configs: set[Optional[tuple[int, int]]] = set()
    layers = get_layers_from_vllm_config(vllm_config, Attention)
    for layer in layers.values():
        assert isinstance(layer.impl, FlashAttentionImpl)
        sliding_window_configs.add(layer.impl.sliding_window)
    return sliding_window_configs


194
195
class FlashAttentionMetadataBuilder(
        AttentionMetadataBuilder[FlashAttentionMetadata]):
196
197
198
199
200
201
202
203
204
205
206
207
208
    # FA3:
    # Supports full cudagraphs for all cases.
    #
    # FA2:
    # For FA2, a graph is captured with max_query_len=1, (which is what we
    # capture by default for num_tokens <= max_num_seqs when there is no
    # spec-decode) then these graphs will not work for mixed prefill-decode
    # (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
    # in FA2.
    # In summary if we are running with spec decodes the graphs would
    # work for mixed prefill-decode and uniform-decode. But for non-spec decodes
    # the graphs would not work for mixed prefill-decode; sorta the inverse
    # of UNIFORM_SINGLE_TOKEN_DECODE.
co63oc's avatar
co63oc committed
209
    # There's probably a better way to describe this using `AttentionCGSupport`
210
211
212
213
214
    # but for now just set it to `UNIFORM_BATCH` to get use to drop down
    # to FULL_AND_PIECEWISE.
    # TODO(luka, lucas): audit FA2 as part of:
    #  https://github.com/vllm-project/vllm/issues/22945
    cudagraph_support = AttentionCGSupport.ALWAYS \
215
        if get_flash_attn_version() == 3 or current_platform.is_rocm()  else AttentionCGSupport.UNIFORM_BATCH
216

217
218
    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
219
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)
220
221
222
223
224
225
226
227
228
        self.model_config = vllm_config.model_config
        self.parallel_config = vllm_config.parallel_config
        self.cache_config = vllm_config.cache_config
        self.compilation_config = vllm_config.compilation_config

        self.num_heads_q = self.model_config.get_num_attention_heads(
            self.parallel_config)
        self.num_heads_kv = self.model_config.get_num_kv_heads(
            self.parallel_config)
229
        self.kv_cache_dtype = kv_cache_spec.dtype
230
        self.headdim = self.model_config.get_head_size()
231
        self.block_size = kv_cache_spec.block_size
232

233
        self.max_num_splits = 0  # No upper bound on the number of splits.
234
        self.aot_schedule = (get_flash_attn_version() == 3)
235
236
237

        self.use_full_cuda_graph = \
            self.compilation_config.cudagraph_mode.has_full_cudagraphs()
238
        self.max_cudagraph_size = self.compilation_config.max_capture_size
239
240

        if self.use_full_cuda_graph and self.aot_schedule:
241
242
243
244
245
246
247
248
            if self.max_cudagraph_size > 992:
                # This condition derives from FA3's internal heuristic.
                # TODO(woosuk): Support larger cudagraph sizes.
                raise ValueError(
                    "Capture size larger than 992 is not supported for "
                    "full cuda graph.")

            self.scheduler_metadata = torch.zeros(
249
                vllm_config.scheduler_config.max_num_seqs + 1,
250
                dtype=torch.int32,
251
                device=self.device,
252
253
254
255
            )
            # When using cuda graph, we need to set the upper bound of the
            # number of splits so that large enough intermediate buffers are
            # pre-allocated during capture.
256
257
            self.max_num_splits = (
                envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
258

259
260
261
262
        # Sliding window size to be used with the AOT scheduler will be
        # populated on first build() call.
        self.aot_sliding_window: Optional[tuple[int, int]] = None

263
264
265
266
267
268
269
270
    def build(self,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata,
              fast_build: bool = False) -> FlashAttentionMetadata:
        """
        fast_build disables AOT scheduling, used when there will be few 
        iterations i.e. spec-decode
        """
271
272
273
        num_reqs = common_attn_metadata.num_reqs
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        max_query_len = common_attn_metadata.max_query_len
274
        max_seq_len = common_attn_metadata.max_seq_len
275
276
        query_start_loc = common_attn_metadata.query_start_loc
        seq_lens = common_attn_metadata.seq_lens
277
278
279
        seq_lens_cpu = common_attn_metadata.seq_lens_cpu
        block_table_tensor = common_attn_metadata.block_table_tensor
        slot_mapping = common_attn_metadata.slot_mapping
280
        causal = common_attn_metadata.causal
281

282
283
        # the overhead of the aot schedule is not worth it for spec-decode
        aot_schedule = self.aot_schedule and not fast_build
284

zhuwenwen's avatar
zhuwenwen committed
285
286
287
288
289
290
        if self.aot_sliding_window is None:
            self.aot_sliding_window = (-1, -1)
            # For the AOT scheduler we need the sliding window value to be
            # constant for all layers to. We have to populate this on the first
            # build() call so the layers are constructed (cannot populate)
            # in __init__.
291
            if aot_schedule:
zhuwenwen's avatar
zhuwenwen committed
292
                sliding_window_configs = _get_sliding_window_configs(
293
                    self.vllm_config)
zhuwenwen's avatar
zhuwenwen committed
294
295
296
297
298
299
                if len(sliding_window_configs) == 1:
                    sliding_window_config = sliding_window_configs.pop()
                    if sliding_window_config is not None:
                        self.aot_sliding_window = sliding_window_config
                elif len(sliding_window_configs) > 1:
                    self.aot_schedule = False
300
                    aot_schedule = False
301

302
303
304
305
306
307
308
309
        max_num_splits = 0  # 0 means use FA3's heuristics, not CG compatible
        if self.use_full_cuda_graph and \
            num_actual_tokens <= self.max_cudagraph_size:
            # NOTE(woosuk): Setting num_splits > 1 may increase the memory
            # usage, because the intermediate buffers of size [num_splits,
            # num_heads, num_tokens, head_size] are allocated. Therefore,
            # we only set num_splits when using cuda graphs.
            max_num_splits = self.max_num_splits
310

311
312
        def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
                     max_seq_len, causal):
313
314
315
316
317
318
            cache_dtype = self.cache_config.cache_dtype
            if cache_dtype.startswith("fp8"):
                qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
                    cache_dtype)
            else:
                qkv_dtype = self.kv_cache_dtype
319
            if aot_schedule:
320
321
322
323
324
325
326
                return get_scheduler_metadata(
                    batch_size=batch_size,
                    max_seqlen_q=max_query_len,
                    max_seqlen_k=max_seq_len,
                    num_heads_q=self.num_heads_q,
                    num_heads_kv=self.num_heads_kv,
                    headdim=self.headdim,
327
328
                    cache_seqlens=seqlens,
                    qkv_dtype=qkv_dtype,
329
                    cu_seqlens_q=cu_query_lens,
330
                    page_size=self.block_size,
331
                    causal=causal,
332
                    window_size=self.aot_sliding_window,
333
                    num_splits=max_num_splits,
334
335
336
                )
            return None

337
        use_cascade = common_prefix_len > 0
338

339
340
341
        if use_cascade:
            cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
                                                dtype=torch.int32,
342
                                                device=self.device)
343
344
            prefix_kv_lens = torch.tensor([common_prefix_len],
                                          dtype=torch.int32,
345
346
347
                                          device=self.device)
            suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
                self.device, non_blocking=True)
348
            prefix_scheduler_metadata = schedule(
349
                batch_size=1,
350
351
352
353
354
                cu_query_lens=cu_prefix_query_lens,
                max_query_len=num_actual_tokens,
                seqlens=prefix_kv_lens,
                max_seq_len=common_prefix_len,
                causal=False)
355
356
            scheduler_metadata = schedule(batch_size=num_reqs,
                                          cu_query_lens=query_start_loc,
357
358
359
360
361
                                          max_query_len=max_query_len,
                                          seqlens=suffix_kv_lens,
                                          max_seq_len=max_seq_len -
                                          common_prefix_len,
                                          causal=True)
362
363
364
365
        else:
            cu_prefix_query_lens = None
            prefix_kv_lens = None
            suffix_kv_lens = None
366
            prefix_scheduler_metadata = None
367
368
            scheduler_metadata = schedule(batch_size=num_reqs,
                                          cu_query_lens=query_start_loc,
369
370
371
                                          max_query_len=max_query_len,
                                          seqlens=seq_lens,
                                          max_seq_len=max_seq_len,
372
                                          causal=causal)
373
374
        # For FA3 + full cudagraph
        if self.use_full_cuda_graph and scheduler_metadata is not None:
375
            n = scheduler_metadata.shape[0]
376
            self.scheduler_metadata[:n] = scheduler_metadata
377
378
379
380
381
382
383
            # NOTE(woosuk): We should zero out the rest of the scheduler
            # metadata to guarantee the correctness. Otherwise, some thread
            # blocks may use the invalid scheduler metadata and overwrite the
            # output buffer.
            self.scheduler_metadata[n:] = 0
            scheduler_metadata = self.scheduler_metadata[:n]

384
385
386
387
388
389
        attn_metadata = FlashAttentionMetadata(
            num_actual_tokens=num_actual_tokens,
            max_query_len=max_query_len,
            query_start_loc=query_start_loc,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
390
            block_table=block_table_tensor,
391
392
393
            slot_mapping=slot_mapping,
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
394
            scheduler_metadata=scheduler_metadata,
395
396
397
            cu_prefix_query_lens=cu_prefix_query_lens,
            prefix_kv_lens=prefix_kv_lens,
            suffix_kv_lens=suffix_kv_lens,
398
            prefix_scheduler_metadata=prefix_scheduler_metadata,
399
            max_num_splits=max_num_splits,
400
            causal=causal)
401
402
        return attn_metadata

403
404
405
    def use_cascade_attention(self, *args, **kwargs) -> bool:
        return use_cascade_attention(*args, **kwargs)

406

407
408
409
410
411
412
413
414
class FlashAttentionImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
415
        alibi_slopes: Optional[list[float]],
416
417
418
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float] = None,
419
        attn_type: AttentionType = AttentionType.DECODER,
420
        kv_sharing_target_layer_name: Optional[str] = None,
421
        sinks: Optional[torch.Tensor] = None,
422
423
424
425
426
427
428
429
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
430
431
        if sliding_window is None:
            self.sliding_window = (-1, -1)
432
433
        elif attn_type == AttentionType.ENCODER_ONLY:
            self.sliding_window = (sliding_window - 1, sliding_window - 1)
434
435
        else:
            self.sliding_window = (sliding_window - 1, 0)
436
437
438
439
440
        self.kv_cache_dtype = kv_cache_dtype
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap
441
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
442
443
444

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

445
        FlashAttentionBackend.validate_head_size(head_size)
446

447
        self.attn_type = attn_type
448
        self.vllm_flash_attn_version = get_flash_attn_version()
449
450
451
452
        if is_quantized_kv_cache(self.kv_cache_dtype) \
            and not flash_attn_supports_fp8():
            raise NotImplementedError(
                "FlashAttention does not support fp8 kv-cache on this device.")
453

454
455
456
457
458
459
460
461
        self.sinks = sinks
        if self.sinks is not None:
            assert self.vllm_flash_attn_version == 3, (
                "Sinks are only supported in FlashAttention 3")
            assert self.sinks.shape[0] == num_heads, (
                "Sinks must have the same number of heads as the number of "
                "heads in the layer")

462
463
    def forward(
        self,
464
        layer: torch.nn.Module,
465
466
467
468
469
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
470
        output: Optional[torch.Tensor] = None,
471
        output_scale: Optional[torch.Tensor] = None,
472
        output_block_scale: Optional[torch.Tensor] = None,
473
474
475
476
    ) -> torch.Tensor:
        """Forward pass with FlashAttention.

        Args:
477
478
479
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
480
481
            kv_cache: shape =
                [2, num_blocks, block_size, num_kv_heads, head_size]
482
483
484
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
485
486
487
        NOTE: FP8 quantization, flash-attn expect the size of
              {q,k,v}_descale to be (num_sequences, num_kv_heads).
              We use torch's .expand() to avoid duplicating values
488
        """
489
490
        assert output is not None, "Output tensor must be provided."

491
        if output_scale is not None or output_block_scale is not None:
492
493
494
495
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for FlashAttentionImpl")

496
497
498
499
        if attn_metadata is None:
            # Profiling run.
            return output

500
501
        attn_type = self.attn_type

502
503
504
505
506
507
508
509
        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.
510

511
        num_actual_tokens = attn_metadata.num_actual_tokens
512
513

        # Handle encoder attention differently - no KV cache needed
514
        if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
515
516
517
518
519
520
521
522
523
            # For encoder attention,
            # we use direct Q, K, V tensors without caching
            return self._forward_encoder_attention(query[:num_actual_tokens],
                                                   key[:num_actual_tokens],
                                                   value[:num_actual_tokens],
                                                   output[:num_actual_tokens],
                                                   attn_metadata, layer)

        # For decoder and cross-attention, use KV cache as before
524
525
526
527
        if not current_platform.is_rocm():
            key_cache, value_cache = kv_cache.unbind(0)
        else:
            key_cache, value_cache = kv_cache
528

529
530
531
532
533
        # key and value may be None in the case of cross attention. They are
        # calculated once based on the output from the encoder and then cached
        # in KV cache.
        if (self.kv_sharing_target_layer_name is None and key is not None
                and value is not None):
534
535
536
537
538
539
540
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
541
542
543
544
545
546
547
548
549
550
551
552
            if not current_platform.is_rocm():
                reshape_and_cache_flash(
                    key,
                    value,
                    key_cache,
                    value_cache,
                    attn_metadata.slot_mapping,
                    self.kv_cache_dtype,
                    layer._k_scale,
                    layer._v_scale,
                )
            else:
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
                if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == torch.float16:
                    from lightop import reshape_and_cache_cuda
                    reshape_and_cache_cuda(
                        key, 
                        value,
                        key_cache, 
                        value_cache,
                        attn_metadata.slot_mapping,
                        self.kv_cache_dtype,
                        layer._k_scale, 
                        layer._v_scale
                    )
                else:
                    from vllm.attention.utils.fa_utils import reshape_and_cache_cuda
                    reshape_and_cache_cuda(
                        key,
                        value,
                        key_cache,
                        value_cache,
                        attn_metadata.slot_mapping,
                        self.kv_cache_dtype,
                        layer._k_scale,
                        layer._v_scale,
                    )
577

578
        if self.kv_cache_dtype.startswith("fp8"):
579
            # queries are quantized in the attention layer
580
581
582
583
            dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
                self.kv_cache_dtype)
            key_cache = key_cache.view(dtype)
            value_cache = value_cache.view(dtype)
584

585
586
587
588
589
590
591
        if not attn_metadata.use_cascade:
            cu_seqlens_q = attn_metadata.query_start_loc
            seqused_k = attn_metadata.seq_lens
            max_seqlen_q = attn_metadata.max_query_len
            max_seqlen_k = attn_metadata.max_seq_len
            block_table = attn_metadata.block_table
            scheduler_metadata = attn_metadata.scheduler_metadata
592

593
            descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
594

zhuwenwen's avatar
zhuwenwen committed
595
            if not current_platform.is_rocm():
596
                 flash_attn_varlen_func(
zhuwenwen's avatar
zhuwenwen committed
597
598
599
600
601
602
603
604
605
                    q=query[:num_actual_tokens],
                    k=key_cache,
                    v=value_cache,
                    out=output[:num_actual_tokens],
                    cu_seqlens_q=cu_seqlens_q,
                    max_seqlen_q=max_seqlen_q,
                    seqused_k=seqused_k,
                    max_seqlen_k=max_seqlen_k,
                    softmax_scale=self.scale,
606
                    causal=attn_metadata.causal,
zhuwenwen's avatar
zhuwenwen committed
607
608
609
610
611
612
613
614
615
616
                    alibi_slopes=self.alibi_slopes,
                    window_size=self.sliding_window,
                    block_table=block_table,
                    softcap=self.logits_soft_cap,
                    scheduler_metadata=scheduler_metadata,
                    fa_version=self.vllm_flash_attn_version,
                    q_descale=layer._q_scale.expand(descale_shape),
                    k_descale=layer._k_scale.expand(descale_shape),
                    v_descale=layer._v_scale.expand(descale_shape),
                    num_splits=attn_metadata.max_num_splits,
617
                    s_aux=self.sinks,
zhuwenwen's avatar
zhuwenwen committed
618
619
                )
            else:
620
621
                if envs.VLLM_USE_PA_PRINT_PARAM:
                    print("PA SIZE:")
622
                    print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
623
624
                    print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
                    print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
zhuwenwen's avatar
zhuwenwen committed
625
626
627
628
629
630
631
632
633
634
                vllm_flash_attn_varlen_func(
                    q=query[:num_actual_tokens],
                    k=key_cache,
                    v=value_cache,
                    out=output[:num_actual_tokens],
                    cu_seqlens_q=cu_seqlens_q,
                    max_seqlen_q=max_seqlen_q,
                    seqused_k=seqused_k,
                    max_seqlen_k=max_seqlen_k,
                    softmax_scale=self.scale,
635
                    causal=attn_metadata.causal,
zhuwenwen's avatar
zhuwenwen committed
636
637
638
639
640
641
642
643
644
645
                    alibi_slopes=self.alibi_slopes,
                    window_size=self.sliding_window,
                    block_table=block_table,
                    softcap=self.logits_soft_cap,
                    scheduler_metadata=scheduler_metadata,
                    # fa_version=self.vllm_flash_attn_version,
                    # q_descale=layer._q_scale.expand(descale_shape),
                    # k_descale=layer._k_scale.expand(descale_shape),
                    # v_descale=layer._v_scale.expand(descale_shape),
                    # num_splits=attn_metadata.max_num_splits,
646
                    # s_aux=self.sinks,
647
                    is_prefix_cache=True,
zhuwenwen's avatar
zhuwenwen committed
648
                )
649
            return output
650
651

        # Cascade attention (rare case).
zhuwenwen's avatar
zhuwenwen committed
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
        if not current_platform.is_rocm():
            cascade_attention(
                output[:num_actual_tokens],
                query[:num_actual_tokens],
                key_cache,
                value_cache,
                cu_query_lens=attn_metadata.query_start_loc,
                max_query_len=attn_metadata.max_query_len,
                cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
                prefix_kv_lens=attn_metadata.prefix_kv_lens,
                suffix_kv_lens=attn_metadata.suffix_kv_lens,
                max_kv_len=attn_metadata.max_seq_len,
                softmax_scale=self.scale,
                alibi_slopes=self.alibi_slopes,
                sliding_window=self.sliding_window,
                logits_soft_cap=self.logits_soft_cap,
                block_table=attn_metadata.block_table,
                common_prefix_len=attn_metadata.common_prefix_len,
                fa_version=self.vllm_flash_attn_version,
                prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
                suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
                q_descale=layer._q_scale,
                k_descale=layer._k_scale,
                v_descale=layer._v_scale,
            )
        else:
            cascade_attention(
                output[:num_actual_tokens],
                query[:num_actual_tokens],
                key_cache,
                value_cache,
                cu_query_lens=attn_metadata.query_start_loc,
                max_query_len=attn_metadata.max_query_len,
                cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
                prefix_kv_lens=attn_metadata.prefix_kv_lens,
                suffix_kv_lens=attn_metadata.suffix_kv_lens,
                max_kv_len=attn_metadata.max_seq_len,
                softmax_scale=self.scale,
                alibi_slopes=self.alibi_slopes,
                sliding_window=self.sliding_window,
                logits_soft_cap=self.logits_soft_cap,
                block_table=attn_metadata.block_table,
                common_prefix_len=attn_metadata.common_prefix_len,
                fa_version=2, #self.vllm_flash_attn_version,
                prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
                suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
                # q_descale=layer._q_scale,
                # k_descale=layer._k_scale,
                # v_descale=layer._v_scale,
            )
702
        return output
703

704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
    def _forward_encoder_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
        layer: torch.nn.Module,
    ) -> torch.Tensor:
        """Forward pass for encoder attention without KV cache.

        Args:
            query: shape = [num_encoder_tokens, num_heads, head_size]
            key: shape = [num_encoder_tokens, num_kv_heads, head_size]
            value: shape = [num_encoder_tokens, num_kv_heads, head_size]
            output: shape = [num_encoder_tokens, num_heads, head_size]
            attn_metadata: Encoder attention metadata
            layer: The attention layer
        """
        # For encoder attention, process FP8 quantization if needed
        if self.kv_cache_dtype.startswith("fp8"):
            raise NotImplementedError(
                "quantization is not supported for encoder attention")

        # Use encoder-specific metadata for sequence information
        cu_seqlens_q = attn_metadata.query_start_loc
        cu_seqlens_k = attn_metadata.query_start_loc
        max_seqlen_q = attn_metadata.max_query_len
        max_seqlen_k = attn_metadata.max_query_len

        descale_shape = (
            cu_seqlens_q.shape[0] - 1,  # type: ignore[union-attr]
            self.num_kv_heads)

        # Call flash attention directly on Q, K, V tensors
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
        if not current_platform.is_rocm():
            flash_attn_varlen_func(
                q=query,
                k=key,
                v=value,
                out=output,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_k=max_seqlen_k,
                softmax_scale=self.scale,
                causal=False,  # Encoder attention is bidirectional
                alibi_slopes=self.alibi_slopes,
                window_size=self.sliding_window,
                softcap=self.logits_soft_cap,
                fa_version=self.vllm_flash_attn_version,
                q_descale=layer._q_scale.expand(descale_shape),
                k_descale=layer._k_scale.expand(descale_shape),
                v_descale=layer._v_scale.expand(descale_shape),
            )
        else:
            vllm_flash_attn_varlen_func(
                q=query,
                k=key,
                v=value,
                out=output,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_k=max_seqlen_k,
                softmax_scale=self.scale,
                causal=False,  # Encoder attention is bidirectional
                alibi_slopes=self.alibi_slopes,
                window_size=self.sliding_window,
                softcap=self.logits_soft_cap,
                # fa_version=self.vllm_flash_attn_version,
                # q_descale=layer._q_scale.expand(descale_shape),
                # k_descale=layer._k_scale.expand(descale_shape),
                # v_descale=layer._v_scale.expand(descale_shape),
                is_prefix_cache=True,
            )
780
781
782

        return output

783
784
785
786
787
788
789
790

def use_cascade_attention(
    common_prefix_len: int,
    query_lens: np.ndarray,
    num_query_heads: int,
    num_kv_heads: int,
    use_alibi: bool,
    use_sliding_window: bool,
791
    use_local_attention: bool,
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
    num_sms: int,
) -> bool:
    """Decide whether to use cascade attention.

    This function 1) checks whether cascade attention is supported with the
    given configuration, and 2) heuristically decides whether using cascade
    attention can improve performance.
    """
    # Too short common prefix. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
    # NOTE(woosuk): This is the common case. We should return False as soon as
    # possible to avoid any unnecessary computation.
    if common_prefix_len < 256:
        return False
    # Cascade attention is currently not supported with these variants.
807
    if use_alibi or use_sliding_window or use_local_attention:
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
        return False
    # Too few queries. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
    num_reqs = len(query_lens)
    if num_reqs < 8:
        return False

    # Heuristics to decide whether using cascade attention is beneficial.
    # 1. When FlashDecoding is not used for normal attention, cascade attention
    #    is likely to be faster since it saves memory bandwidth.
    num_queries_per_kv = num_query_heads // num_kv_heads
    # The criteria for using FlashDecoding can be found in the following link:
    # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
    use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
                          and not use_alibi and np.all(query_lens == 1))
    if not use_flash_decoding:
        # Use cascade attention.
        return True

    # 2. When FlashDecoding is used for normal attention, it is not clear
    #    whether cascade attention is beneficial, because FlashDecoding can
    #    launch more CTAs than cascade attention.
    #    We use a simple performance model to compare the two methods.
    #    NOTE(woosuk): The performance model is very rough and may not be
    #    accurate.
    num_tokens = num_reqs
    # NOTE(woosuk): These are default tile sizes. flash-attn might use
    # different tile sizes (e.g., 64 or 256) depending on the configuration.
    q_tile_size = 128
    kv_tile_size = 128
    num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)

    cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
    cascade_waves = cdiv(cascade_ctas, num_sms)
    cascade_time = cascade_waves * num_prefix_tiles

    flash_decoding_ctas = (num_reqs * num_kv_heads *
                           cdiv(num_queries_per_kv, q_tile_size))
    flash_decoding_ctas *= num_prefix_tiles
    flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)

    # Use cascade attention if it is faster than FlashDecoding.
    return cascade_time < flash_decoding_time


def cascade_attention(
    output: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    cu_query_lens: torch.Tensor,
    max_query_len: int,
    cu_prefix_query_lens: torch.Tensor,
861
862
    prefix_kv_lens: torch.Tensor,
    suffix_kv_lens: torch.Tensor,
863
864
865
    max_kv_len: int,
    softmax_scale: float,
    alibi_slopes: Optional[torch.Tensor],
866
    sliding_window: tuple[int, int],
867
868
869
    logits_soft_cap: float,
    block_table: torch.Tensor,
    common_prefix_len: int,
870
    fa_version: int,
871
872
    prefix_scheduler_metadata: Optional[torch.Tensor] = None,
    suffix_scheduler_metadata: Optional[torch.Tensor] = None,
873
874
875
    q_descale: Optional[torch.Tensor] = None,
    k_descale: Optional[torch.Tensor] = None,
    v_descale: Optional[torch.Tensor] = None,
876
877
878
879
880
881
882
883
884
885
886
) -> torch.Tensor:
    assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
    # TODO: Support sliding window.
    assert sliding_window == (-1, -1), (
        "Cascade attention does not support sliding window.")

    num_tokens = query.shape[0]
    block_size = key_cache.shape[-3]
    assert common_prefix_len % block_size == 0
    num_common_kv_blocks = common_prefix_len // block_size
    assert num_common_kv_blocks > 0
887
    descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
888
889

    # Process shared prefix.
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
    if not current_platform.is_rocm():
        prefix_output, prefix_lse = flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_prefix_query_lens,
            seqused_k=prefix_kv_lens,
            max_seqlen_q=num_tokens,
            max_seqlen_k=common_prefix_len,
            softmax_scale=softmax_scale,
            causal=False,
            window_size=sliding_window,
            block_table=block_table[:1],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=prefix_scheduler_metadata,
            fa_version=fa_version,
            q_descale=q_descale.expand(descale_shape)
            if q_descale is not None else None,
            k_descale=k_descale.expand(descale_shape)
            if k_descale is not None else None,
            v_descale=v_descale.expand(descale_shape)
            if v_descale is not None else None,
        )
zhuwenwen's avatar
zhuwenwen committed
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
    else:
        prefix_output, prefix_lse, _ = vllm_flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_prefix_query_lens,
            seqused_k=prefix_kv_lens,
            max_seqlen_q=num_tokens,
            max_seqlen_k=common_prefix_len,
            softmax_scale=softmax_scale,
            causal=False,
            window_size=sliding_window,
            block_table=block_table[:1],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=prefix_scheduler_metadata,
            # fa_version=fa_version,
            # q_descale=q_descale.expand(descale_shape)
            # if q_descale is not None else None,
            # k_descale=k_descale.expand(descale_shape)
            # if k_descale is not None else None,
            # v_descale=v_descale.expand(descale_shape)
            # if v_descale is not None else None,
            is_prefix_cache=True,
        )
939

940
941
    descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])

942
    # Process suffix per query.
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
    if not current_platform.is_rocm():
        suffix_output, suffix_lse = flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_query_lens,
            seqused_k=suffix_kv_lens,
            max_seqlen_q=max_query_len,
            max_seqlen_k=max_kv_len - common_prefix_len,
            softmax_scale=softmax_scale,
            causal=True,
            window_size=sliding_window,
            block_table=block_table[:, num_common_kv_blocks:],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=suffix_scheduler_metadata,
            fa_version=fa_version,
            q_descale=q_descale.expand(descale_shape)
            if q_descale is not None else None,
            k_descale=k_descale.expand(descale_shape)
            if k_descale is not None else None,
            v_descale=v_descale.expand(descale_shape)
            if v_descale is not None else None,
        )
zhuwenwen's avatar
zhuwenwen committed
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
    else:
        suffix_output, suffix_lse, _ = vllm_flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_query_lens,
            seqused_k=suffix_kv_lens,
            max_seqlen_q=max_query_len,
            max_seqlen_k=max_kv_len - common_prefix_len,
            softmax_scale=softmax_scale,
            causal=True,
            window_size=sliding_window,
            block_table=block_table[:, num_common_kv_blocks:],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=suffix_scheduler_metadata,
            # fa_version=fa_version,
            # q_descale=q_descale.expand(descale_shape)
            # if q_descale is not None else None,
            # k_descale=k_descale.expand(descale_shape)
            # if k_descale is not None else None,
            # v_descale=v_descale.expand(descale_shape)
            # if v_descale is not None else None,
            is_prefix_cache=True,
        )
992
993
994

    # Merge prefix and suffix outputs, and store the result in output.
    merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
995
                      suffix_lse)