flash_attn.py 40.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
zhuwenwen's avatar
zhuwenwen committed
5
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple
6

7
import numpy as np
8
9
import torch

10
import vllm.envs as envs
11
from vllm import _custom_ops as ops
12
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
13
14
                                              AttentionMetadata, AttentionType,
                                              is_quantized_kv_cache)
15
from vllm.attention.layer import Attention
16
from vllm.attention.ops.merge_attn_states import merge_attn_states
17
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
18
                                           get_flash_attn_version,
19
20
                                           is_flash_attn_varlen_func_available)

zhuwenwen's avatar
zhuwenwen committed
21
from vllm.platforms import current_platform
22
if is_flash_attn_varlen_func_available():
zhuwenwen's avatar
zhuwenwen committed
23
24
25
26
27
    if not current_platform.is_rocm():
        from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
                                                get_scheduler_metadata,
                                                reshape_and_cache_flash)
    else:
zhuwenwen's avatar
zhuwenwen committed
28
29
        from vllm.attention.utils.fa_utils import (vllm_flash_attn_varlen_func,
                                               reshape_and_cache_cuda)
30

31
from vllm.config import VllmConfig, get_layers_from_vllm_config
zhuwenwen's avatar
zhuwenwen committed
32

33
from vllm.logger import init_logger
34
from vllm.utils import cdiv
35
36
37
from vllm.v1.attention.backends.utils import (
    AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
    make_local_attention_virtual_batches)
38
39
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
40

41
42
43
if TYPE_CHECKING:
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner

44
45
logger = init_logger(__name__)

46
47
48
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16

49
50
51

class FlashAttentionBackend(AttentionBackend):

52
53
    accept_output_buffer: bool = True

54
55
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
56
57
        return [32, 64, 96, 128, 160, 192, 224, 256]

58
59
60
61
62
63
64
65
66
67
68
    @classmethod
    def validate_head_size(cls, head_size: int) -> None:
        supported_head_sizes = cls.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            attn_type = cls.__name__.removesuffix("Backend")
            raise ValueError(
                f"Head size {head_size} is not supported by {attn_type}. "
                f"Supported head sizes are: {supported_head_sizes}. "
                "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
                "FlexAttention backend which supports all head sizes.")

69
70
    @staticmethod
    def get_name() -> str:
71
        return "FLASH_ATTN_VLLM_V1"
72
73

    @staticmethod
74
    def get_impl_cls() -> type["FlashAttentionImpl"]:
75
76
77
        return FlashAttentionImpl

    @staticmethod
78
    def get_metadata_cls() -> type["AttentionMetadata"]:
79
80
        return FlashAttentionMetadata

81
    @staticmethod
82
    def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
83
84
        return FlashAttentionMetadataBuilder

zhuwenwen's avatar
zhuwenwen committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    if not current_platform.is_rocm():
        @staticmethod
        def get_kv_cache_shape(
            num_blocks: int,
            block_size: int,
            num_kv_heads: int,
            head_size: int,
        ) -> tuple[int, ...]:
            if block_size % 16 != 0:
                raise ValueError("Block size must be a multiple of 16.")
            return (2, num_blocks, block_size, num_kv_heads, head_size)

        @staticmethod
        def get_kv_cache_stride_order() -> tuple[int, ...]:
            # `stride_order` indicates the permutation that gets
            # us from `get_kv_cache_shape` to the actual memory layout we want.
            cache_layout = get_kv_cache_layout()
            if cache_layout == "NHD":
                stride_order = (0, 1, 2, 3, 4)
            elif cache_layout == "HND":
                stride_order = (0, 1, 3, 2, 4)
            else:
                raise ValueError(f"Unknown cache layout format {cache_layout}.")
            return stride_order
    else:
        @staticmethod
        def get_kv_cache_shape(
            num_blocks: int,
            block_size: int,
            num_kv_heads: int,
            head_size: int,
        ) -> tuple[tuple[int, ...], tuple[int, ...]]:
            if block_size % 16 != 0:
                raise ValueError("Block size must be a multiple of 16.")
            return (
                (num_blocks, num_kv_heads, block_size, head_size),
                (num_blocks, num_kv_heads, head_size, block_size),
            )
123

zhuwenwen's avatar
zhuwenwen committed
124
125
126
127
128
129
130
        @staticmethod
        def get_kv_cache_stride_order() -> tuple[tuple[int, ...], tuple[int, ...]]:
            # `stride_order` indicates the permutation that gets
            # us from `get_kv_cache_shape` to the actual memory layout we want.
            cache_layout = get_kv_cache_layout()
            if cache_layout == "NHD":
                key_stride_order = (0, 1, 2, 3)
zhuwenwen's avatar
zhuwenwen committed
131
                value_stride_order = (0, 1, 2, 3)
zhuwenwen's avatar
zhuwenwen committed
132
133
            elif cache_layout == "HND":
                key_stride_order = (0, 2, 1, 3)
zhuwenwen's avatar
zhuwenwen committed
134
                value_stride_order = (0, 2, 1, 3)
zhuwenwen's avatar
zhuwenwen committed
135
136
137
138
            else:
                raise ValueError(f"Unknown cache layout format {cache_layout}.")
            return key_stride_order, value_stride_order  
        
139
140
141
142
143
144
145
146
147
148
149

@dataclass
class FlashAttentionMetadata:
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

150
    num_actual_tokens: int  # Number of tokens excluding padding.
151
152
153
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
154
    seq_lens: torch.Tensor
155
156
    block_table: torch.Tensor
    slot_mapping: torch.Tensor
157
158
159
160
161

    # For cascade attention.
    use_cascade: bool
    common_prefix_len: int
    cu_prefix_query_lens: Optional[torch.Tensor]
162
163
    prefix_kv_lens: Optional[torch.Tensor]
    suffix_kv_lens: Optional[torch.Tensor]
164

165
166
167
    # Optional aot scheduling
    scheduler_metadata: Optional[torch.Tensor] = None
    prefix_scheduler_metadata: Optional[torch.Tensor] = None
168
    max_num_splits: int = 0
169

170
171
172
173
174
175
176
177
    # for local attention
    @dataclass
    class LocalAttentionMetadata:
        local_query_start_loc: torch.Tensor
        local_seqused_k: torch.Tensor
        local_block_table: torch.Tensor
        local_max_query_len: int
        local_max_seq_len: int
178
        local_scheduler_metadata: Optional[torch.Tensor]
179
180
181
182

    local_attn_metadata: Optional[LocalAttentionMetadata] = None


183
184
185
186
187
188
189
190
191
192
193
def _get_sliding_window_configs(
        vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
    """Get the set of all sliding window configs used in the model."""
    sliding_window_configs: set[Optional[tuple[int, int]]] = set()
    layers = get_layers_from_vllm_config(vllm_config, Attention)
    for layer in layers.values():
        assert isinstance(layer.impl, FlashAttentionImpl)
        sliding_window_configs.add(layer.impl.sliding_window)
    return sliding_window_configs


194
195
class FlashAttentionMetadataBuilder(
        AttentionMetadataBuilder[FlashAttentionMetadata]):
zhuwenwen's avatar
zhuwenwen committed
196
    full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 or current_platform.is_rocm() 
197

198
199
    def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
                 block_table: BlockTable):
200
        model_config = runner.model_config
201
        compilation_config = runner.vllm_config.compilation_config
202

203
        self.runner = runner
204
205
206
        self.num_heads_q = model_config.get_num_attention_heads(
            runner.parallel_config)
        self.num_heads_kv = model_config.get_num_kv_heads(
207
208
            runner.parallel_config)
        self.headdim = model_config.get_head_size()
209
210
211
        self.block_size = kv_cache_spec.block_size
        self.kv_cache_spec = kv_cache_spec
        self.block_table = block_table
212

213
        self.max_num_splits = 0  # No upper bound on the number of splits.
214
215
        self.aot_schedule = (get_flash_attn_version() == 3)
        self.use_full_cuda_graph = compilation_config.full_cuda_graph
zhuwenwen's avatar
zhuwenwen committed
216
217
218
219
220
        if self.use_full_cuda_graph:
            if not current_platform.is_rocm():
                if not self.aot_schedule:
                    raise ValueError(
                        "AoT scheduling is required for full cuda graph.")
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
            capture_sizes = compilation_config.cudagraph_capture_sizes
            if not capture_sizes:
                raise ValueError(
                    "cudagraph_capture_sizes should not be None when "
                    "full_cuda_graph is True.")
            self.max_cudagraph_size = max(capture_sizes)
            if self.max_cudagraph_size > 992:
                # This condition derives from FA3's internal heuristic.
                # TODO(woosuk): Support larger cudagraph sizes.
                raise ValueError(
                    "Capture size larger than 992 is not supported for "
                    "full cuda graph.")

            self.scheduler_metadata = torch.zeros(
                self.runner.max_num_reqs + 1,
                dtype=torch.int32,
                device=self.runner.device,
            )
            # When using cuda graph, we need to set the upper bound of the
            # number of splits so that large enough intermediate buffers are
            # pre-allocated during capture.
            self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
243

244
245
246
247
        # Sliding window size to be used with the AOT scheduler will be
        # populated on first build() call.
        self.aot_sliding_window: Optional[tuple[int, int]] = None

248
249
250
251
252
253
254
    def build(
        self, common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata
    ) -> FlashAttentionMetadata:
        num_reqs = common_attn_metadata.num_reqs
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        max_query_len = common_attn_metadata.max_query_len
255

256
        max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
257
258
        query_start_loc = common_attn_metadata.query_start_loc
        seq_lens = common_attn_metadata.seq_lens
259
260
261
262
263
264
265
266
267
268
269
        block_table = self.block_table
        block_table_tensor = block_table.get_device_tensor()[:num_reqs]

        block_table.slot_mapping[:num_actual_tokens].copy_(
            block_table.slot_mapping_cpu[:num_actual_tokens],
            non_blocking=True)
        # Fill unused with -1. Needed for reshape_and_cache in full cuda graph
        # mode.
        block_table.slot_mapping[num_actual_tokens:].fill_(-1)

        slot_mapping = block_table.slot_mapping[:num_actual_tokens]
270

zhuwenwen's avatar
zhuwenwen committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        if self.aot_sliding_window is None:
            self.aot_sliding_window = (-1, -1)
            # For the AOT scheduler we need the sliding window value to be
            # constant for all layers to. We have to populate this on the first
            # build() call so the layers are constructed (cannot populate)
            # in __init__.
            if self.aot_schedule:
                sliding_window_configs = _get_sliding_window_configs(
                    self.runner.vllm_config)
                if len(sliding_window_configs) == 1:
                    sliding_window_config = sliding_window_configs.pop()
                    if sliding_window_config is not None:
                        self.aot_sliding_window = sliding_window_config
                elif len(sliding_window_configs) > 1:
                    self.aot_schedule = False
286

287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        if self.aot_sliding_window is None:
            self.aot_sliding_window = (-1, -1)
            # For the AOT scheduler we need the sliding window value to be
            # constant for all layers to. We have to populate this on the first
            # build() call so the layers are constructed (cannot populate)
            # in __init__.
            if self.aot_schedule:
                sliding_window_configs = _get_sliding_window_configs(
                    self.runner.vllm_config)
                if len(sliding_window_configs) == 1:
                    sliding_window_config = sliding_window_configs.pop()
                    if sliding_window_config is not None:
                        self.aot_sliding_window = sliding_window_config
                elif len(sliding_window_configs) > 1:
                    self.aot_schedule = False

303
304
305
306
307
308
309
310
311
312
313
        def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
                     max_seq_len, causal):
            if self.aot_schedule:
                return get_scheduler_metadata(
                    batch_size=batch_size,
                    max_seqlen_q=max_query_len,
                    max_seqlen_k=max_seq_len,
                    cache_seqlens=seqlens,
                    num_heads_q=self.num_heads_q,
                    num_heads_kv=self.num_heads_kv,
                    headdim=self.headdim,
314
                    page_size=self.block_size,
315
316
                    cu_seqlens_q=cu_query_lens,
                    causal=causal,
317
                    window_size=self.aot_sliding_window,
318
                    num_splits=self.max_num_splits,
319
320
321
                )
            return None

322
323
324
325
        # for local attention
        local_attn_metadata = None
        if self.runner.attention_chunk_size is not None:
            seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
326
                virt_block_table_tensor = make_local_attention_virtual_batches(
327
328
329
                    self.runner.attention_chunk_size,
                    self.runner.query_start_loc_np[:num_reqs + 1],
                    self.runner.seq_lens_np[:num_reqs],
330
331
                    block_table_tensor,
                    self.block_size,
332
                )
333
334
335
336
337
338
339
340
341
342
343
344
345
346
            local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
                self.runner.device, non_blocking=True)
            local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
                self.runner.device, non_blocking=True)
            local_max_query_len = seqlens_q_local_np.max()
            local_max_seq_len = virt_k_seqlens_np.max()
            local_scheduler_metadata = schedule(
                batch_size=local_query_start_loc.shape[0] - 1,
                cu_query_lens=local_query_start_loc,
                max_query_len=local_max_query_len,
                seqlens=local_seqused_k,
                max_seq_len=local_max_seq_len,
                causal=True)

347
            local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
348
349
                local_query_start_loc=local_query_start_loc,
                local_seqused_k=local_seqused_k,
350
                local_block_table=virt_block_table_tensor,
351
352
353
                local_max_query_len=local_max_query_len,
                local_max_seq_len=local_max_seq_len,
                local_scheduler_metadata=local_scheduler_metadata,
354
355
            )

356
        use_cascade = common_prefix_len > 0
357

358
359
360
361
362
363
364
365
366
367
368
        if use_cascade:
            cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
                                                dtype=torch.int32,
                                                device=self.runner.device)
            prefix_kv_lens = torch.tensor([common_prefix_len],
                                          dtype=torch.int32,
                                          device=self.runner.device)
            suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
                              common_prefix_len)
            suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
                self.runner.device)
369
            prefix_scheduler_metadata = schedule(
370
                batch_size=1,
371
372
373
374
375
                cu_query_lens=cu_prefix_query_lens,
                max_query_len=num_actual_tokens,
                seqlens=prefix_kv_lens,
                max_seq_len=common_prefix_len,
                causal=False)
376
377
            scheduler_metadata = schedule(batch_size=num_reqs,
                                          cu_query_lens=query_start_loc,
378
379
380
381
382
                                          max_query_len=max_query_len,
                                          seqlens=suffix_kv_lens,
                                          max_seq_len=max_seq_len -
                                          common_prefix_len,
                                          causal=True)
383
384
385
386
        else:
            cu_prefix_query_lens = None
            prefix_kv_lens = None
            suffix_kv_lens = None
387
            prefix_scheduler_metadata = None
388
389
            scheduler_metadata = schedule(batch_size=num_reqs,
                                          cu_query_lens=query_start_loc,
390
391
392
393
                                          max_query_len=max_query_len,
                                          seqlens=seq_lens,
                                          max_seq_len=max_seq_len,
                                          causal=True)
394

zhuwenwen's avatar
zhuwenwen committed
395
        if not current_platform.is_rocm() and self.use_full_cuda_graph:
396
397
            assert scheduler_metadata is not None
            n = scheduler_metadata.shape[0]
398
            self.scheduler_metadata[:n] = scheduler_metadata
399
400
401
402
403
404
405
            # NOTE(woosuk): We should zero out the rest of the scheduler
            # metadata to guarantee the correctness. Otherwise, some thread
            # blocks may use the invalid scheduler metadata and overwrite the
            # output buffer.
            self.scheduler_metadata[n:] = 0
            scheduler_metadata = self.scheduler_metadata[:n]

406
        max_num_splits = 0
zhuwenwen's avatar
zhuwenwen committed
407
        if (self.use_full_cuda_graph
408
409
410
411
412
413
414
                and num_actual_tokens <= self.max_cudagraph_size):
            # NOTE(woosuk): Setting num_splits > 1 may increase the memory
            # usage, because the intermediate buffers of size [num_splits,
            # num_heads, num_tokens, head_size] are allocated. Therefore,
            # we only set num_splits when using cuda graphs.
            max_num_splits = self.max_num_splits

415
416
417
418
419
420
        attn_metadata = FlashAttentionMetadata(
            num_actual_tokens=num_actual_tokens,
            max_query_len=max_query_len,
            query_start_loc=query_start_loc,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
421
            block_table=block_table_tensor,
422
423
424
            slot_mapping=slot_mapping,
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
425
            scheduler_metadata=scheduler_metadata,
426
427
428
            cu_prefix_query_lens=cu_prefix_query_lens,
            prefix_kv_lens=prefix_kv_lens,
            suffix_kv_lens=suffix_kv_lens,
429
            local_attn_metadata=local_attn_metadata,
430
            prefix_scheduler_metadata=prefix_scheduler_metadata,
431
            max_num_splits=max_num_splits,
432
433
434
        )
        return attn_metadata

435
436
437
438
439
    def can_run_in_cudagraph(
            self, common_attn_metadata: CommonAttentionMetadata) -> bool:
        # Full CUDA Graph always supported (FA2 support checked separately)
        return True

440
441
442
    def use_cascade_attention(self, *args, **kwargs) -> bool:
        return use_cascade_attention(*args, **kwargs)

443

444
445
446
447
448
449
450
451
class FlashAttentionImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
452
        alibi_slopes: Optional[list[float]],
453
454
        sliding_window: Optional[int],
        kv_cache_dtype: str,
455
        blocksparse_params: Optional[dict[str, Any]] = None,
456
        logits_soft_cap: Optional[float] = None,
457
        attn_type: AttentionType = AttentionType.DECODER,
458
        kv_sharing_target_layer_name: Optional[str] = None,
459
        use_irope: bool = False,
460
461
462
463
464
465
466
467
468
469
470
    ) -> None:
        if blocksparse_params is not None:
            raise ValueError(
                "FlashAttention does not support block-sparse attention.")
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
471
472
473
474
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)
475
476
477
478
479
        self.kv_cache_dtype = kv_cache_dtype
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap
480
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
481
482
483

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

484
        FlashAttentionBackend.validate_head_size(head_size)
485

486
487
488
489
490
        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashAttentionImpl")
491
        self.use_irope = use_irope
492
        self.vllm_flash_attn_version = get_flash_attn_version()
493
494
495
496
        if is_quantized_kv_cache(self.kv_cache_dtype) \
            and not flash_attn_supports_fp8():
            raise NotImplementedError(
                "FlashAttention does not support fp8 kv-cache on this device.")
497

498
499
    def forward(
        self,
500
        layer: torch.nn.Module,
501
502
503
504
505
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
506
        output: Optional[torch.Tensor] = None,
507
        output_scale: Optional[torch.Tensor] = None,
508
509
510
511
    ) -> torch.Tensor:
        """Forward pass with FlashAttention.

        Args:
512
513
514
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
515
516
517
518
            kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
519
520
521
        NOTE: FP8 quantization, flash-attn expect the size of
              {q,k,v}_descale to be (num_sequences, num_kv_heads).
              We use torch's .expand() to avoid duplicating values
522
        """
523
524
        assert output is not None, "Output tensor must be provided."

525
526
527
528
529
        if output_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for FlashAttentionImpl")

530
531
532
533
        if attn_metadata is None:
            # Profiling run.
            return output

534
535
536
537
538
539
540
541
        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.
542

543
        num_actual_tokens = attn_metadata.num_actual_tokens
zhuwenwen's avatar
zhuwenwen committed
544
545
546
547
        if not current_platform.is_rocm():
            key_cache, value_cache = kv_cache.unbind(0)
        else:
            key_cache, value_cache = kv_cache
548
549
550
551
552
553
554
555
556

        if self.kv_sharing_target_layer_name is None:
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
zhuwenwen's avatar
zhuwenwen committed
557
558
559
560
561
562
563
564
565
566
567
568
            if not current_platform.is_rocm():
                reshape_and_cache_flash(
                    key,
                    value,
                    key_cache,
                    value_cache,
                    attn_metadata.slot_mapping,
                    self.kv_cache_dtype,
                    layer._k_scale,
                    layer._v_scale,
                )
            else:
569
                if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == "fp16":
wujl5's avatar
wujl5 committed
570
                    from lightop import reshape_and_cache_cuda
571
572
573
574
575
576
577
578
                    reshape_and_cache_cuda(
                        key, value,
                        key_cache, value_cache,
                        attn_metadata.slot_mapping,
                        self.kv_cache_dtype,
                        layer._k_scale, layer._v_scale
                    )
                else:
wujl5's avatar
wujl5 committed
579
                    from vllm.attention.utils.fa_utils import reshape_and_cache_cuda
580
581
582
583
584
585
586
587
588
589
                    reshape_and_cache_cuda(
                        key,
                        value,
                        key_cache,
                        value_cache,
                        attn_metadata.slot_mapping,
                        self.kv_cache_dtype,
                        layer._k_scale,
                        layer._v_scale,
                    )
590

591
592
593
594
595
596
597
598
599
        if self.kv_cache_dtype.startswith("fp8"):
            key_cache = key_cache.view(torch.float8_e4m3fn)
            value_cache = value_cache.view(torch.float8_e4m3fn)
            num_tokens, num_heads, head_size = query.shape
            query, _ = ops.scaled_fp8_quant(
                query.reshape(
                    (num_tokens, num_heads * head_size)).contiguous(),
                layer._q_scale)
            query = query.reshape((num_tokens, num_heads, head_size))
600
601

        # Compute attention and update output up to `num_actual_tokens`.
602
603
604
605
606
607
608
609
610
611
612
613
        use_local_attn = \
            (self.use_irope and attn_metadata.local_attn_metadata is not None)

        if not attn_metadata.use_cascade or use_local_attn:
            if use_local_attn:
                assert attn_metadata.local_attn_metadata is not None
                local_metadata = attn_metadata.local_attn_metadata
                cu_seqlens_q = local_metadata.local_query_start_loc
                seqused_k = local_metadata.local_seqused_k
                max_seqlen_q = local_metadata.local_max_query_len
                max_seqlen_k = local_metadata.local_max_seq_len
                block_table = local_metadata.local_block_table
614
                scheduler_metadata = local_metadata.local_scheduler_metadata
615
616
617
618
619
620
            else:
                cu_seqlens_q = attn_metadata.query_start_loc
                seqused_k = attn_metadata.seq_lens
                max_seqlen_q = attn_metadata.max_query_len
                max_seqlen_k = attn_metadata.max_seq_len
                block_table = attn_metadata.block_table
621
                scheduler_metadata = attn_metadata.scheduler_metadata
622
623
624

            descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

zhuwenwen's avatar
zhuwenwen committed
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
            if not current_platform.is_rocm():
                flash_attn_varlen_func(
                    q=query[:num_actual_tokens],
                    k=key_cache,
                    v=value_cache,
                    out=output[:num_actual_tokens],
                    cu_seqlens_q=cu_seqlens_q,
                    max_seqlen_q=max_seqlen_q,
                    seqused_k=seqused_k,
                    max_seqlen_k=max_seqlen_k,
                    softmax_scale=self.scale,
                    causal=True,
                    alibi_slopes=self.alibi_slopes,
                    window_size=self.sliding_window,
                    block_table=block_table,
                    softcap=self.logits_soft_cap,
                    scheduler_metadata=scheduler_metadata,
                    fa_version=self.vllm_flash_attn_version,
                    q_descale=layer._q_scale.expand(descale_shape),
                    k_descale=layer._k_scale.expand(descale_shape),
                    v_descale=layer._v_scale.expand(descale_shape),
                    num_splits=attn_metadata.max_num_splits,
                )
            else:
649
650
                if envs.VLLM_USE_PA_PRINT_PARAM:
                    print("PA SIZE:")
zhuwenwen's avatar
zhuwenwen committed
651
                    print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
652
653
654
                    print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
                    print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
                       
zhuwenwen's avatar
zhuwenwen committed
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
                vllm_flash_attn_varlen_func(
                    q=query[:num_actual_tokens],
                    k=key_cache,
                    v=value_cache,
                    out=output[:num_actual_tokens],
                    cu_seqlens_q=cu_seqlens_q,
                    max_seqlen_q=max_seqlen_q,
                    seqused_k=seqused_k,
                    max_seqlen_k=max_seqlen_k,
                    softmax_scale=self.scale,
                    causal=True,
                    alibi_slopes=self.alibi_slopes,
                    window_size=self.sliding_window,
                    block_table=block_table,
                    softcap=self.logits_soft_cap,
                    scheduler_metadata=scheduler_metadata,
                    # fa_version=self.vllm_flash_attn_version,
                    # q_descale=layer._q_scale.expand(descale_shape),
                    # k_descale=layer._k_scale.expand(descale_shape),
                    # v_descale=layer._v_scale.expand(descale_shape),
                    # num_splits=attn_metadata.max_num_splits,
zhuwenwen's avatar
zhuwenwen committed
676
                    is_prefix_cache=True,
zhuwenwen's avatar
zhuwenwen committed
677
                )
678
            return output
679

680
681
        assert not use_local_attn, (
            "Cascade attention does not support local attention.")
682
        # Cascade attention (rare case).
zhuwenwen's avatar
zhuwenwen committed
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
        if not current_platform.is_rocm():
            cascade_attention(
                output[:num_actual_tokens],
                query[:num_actual_tokens],
                key_cache,
                value_cache,
                cu_query_lens=attn_metadata.query_start_loc,
                max_query_len=attn_metadata.max_query_len,
                cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
                prefix_kv_lens=attn_metadata.prefix_kv_lens,
                suffix_kv_lens=attn_metadata.suffix_kv_lens,
                max_kv_len=attn_metadata.max_seq_len,
                softmax_scale=self.scale,
                alibi_slopes=self.alibi_slopes,
                sliding_window=self.sliding_window,
                logits_soft_cap=self.logits_soft_cap,
                block_table=attn_metadata.block_table,
                common_prefix_len=attn_metadata.common_prefix_len,
                fa_version=self.vllm_flash_attn_version,
                prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
                suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
                q_descale=layer._q_scale,
                k_descale=layer._k_scale,
                v_descale=layer._v_scale,
            )
        else:
            cascade_attention(
                output[:num_actual_tokens],
                query[:num_actual_tokens],
                key_cache,
                value_cache,
                cu_query_lens=attn_metadata.query_start_loc,
                max_query_len=attn_metadata.max_query_len,
                cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
                prefix_kv_lens=attn_metadata.prefix_kv_lens,
                suffix_kv_lens=attn_metadata.suffix_kv_lens,
                max_kv_len=attn_metadata.max_seq_len,
                softmax_scale=self.scale,
                alibi_slopes=self.alibi_slopes,
                sliding_window=self.sliding_window,
                logits_soft_cap=self.logits_soft_cap,
                block_table=attn_metadata.block_table,
                common_prefix_len=attn_metadata.common_prefix_len,
                fa_version=2, #self.vllm_flash_attn_version,
                prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
                suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
                # q_descale=layer._q_scale,
                # k_descale=layer._k_scale,
                # v_descale=layer._v_scale,
            )
733
        return output
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811


def use_cascade_attention(
    common_prefix_len: int,
    query_lens: np.ndarray,
    num_query_heads: int,
    num_kv_heads: int,
    use_alibi: bool,
    use_sliding_window: bool,
    num_sms: int,
) -> bool:
    """Decide whether to use cascade attention.

    This function 1) checks whether cascade attention is supported with the
    given configuration, and 2) heuristically decides whether using cascade
    attention can improve performance.
    """
    # Too short common prefix. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
    # NOTE(woosuk): This is the common case. We should return False as soon as
    # possible to avoid any unnecessary computation.
    if common_prefix_len < 256:
        return False
    # Cascade attention is currently not supported with these variants.
    if use_alibi or use_sliding_window:
        return False
    # Too few queries. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
    num_reqs = len(query_lens)
    if num_reqs < 8:
        return False

    # Heuristics to decide whether using cascade attention is beneficial.
    # 1. When FlashDecoding is not used for normal attention, cascade attention
    #    is likely to be faster since it saves memory bandwidth.
    num_queries_per_kv = num_query_heads // num_kv_heads
    # The criteria for using FlashDecoding can be found in the following link:
    # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
    use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
                          and not use_alibi and np.all(query_lens == 1))
    if not use_flash_decoding:
        # Use cascade attention.
        return True

    # 2. When FlashDecoding is used for normal attention, it is not clear
    #    whether cascade attention is beneficial, because FlashDecoding can
    #    launch more CTAs than cascade attention.
    #    We use a simple performance model to compare the two methods.
    #    NOTE(woosuk): The performance model is very rough and may not be
    #    accurate.
    num_tokens = num_reqs
    # NOTE(woosuk): These are default tile sizes. flash-attn might use
    # different tile sizes (e.g., 64 or 256) depending on the configuration.
    q_tile_size = 128
    kv_tile_size = 128
    num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)

    cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
    cascade_waves = cdiv(cascade_ctas, num_sms)
    cascade_time = cascade_waves * num_prefix_tiles

    flash_decoding_ctas = (num_reqs * num_kv_heads *
                           cdiv(num_queries_per_kv, q_tile_size))
    flash_decoding_ctas *= num_prefix_tiles
    flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)

    # Use cascade attention if it is faster than FlashDecoding.
    return cascade_time < flash_decoding_time


def cascade_attention(
    output: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    cu_query_lens: torch.Tensor,
    max_query_len: int,
    cu_prefix_query_lens: torch.Tensor,
812
813
    prefix_kv_lens: torch.Tensor,
    suffix_kv_lens: torch.Tensor,
814
815
816
    max_kv_len: int,
    softmax_scale: float,
    alibi_slopes: Optional[torch.Tensor],
817
    sliding_window: tuple[int, int],
818
819
820
    logits_soft_cap: float,
    block_table: torch.Tensor,
    common_prefix_len: int,
821
    fa_version: int,
822
823
    prefix_scheduler_metadata: Optional[torch.Tensor] = None,
    suffix_scheduler_metadata: Optional[torch.Tensor] = None,
824
825
826
    q_descale: Optional[torch.Tensor] = None,
    k_descale: Optional[torch.Tensor] = None,
    v_descale: Optional[torch.Tensor] = None,
827
828
829
830
831
832
833
834
835
836
837
) -> torch.Tensor:
    assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
    # TODO: Support sliding window.
    assert sliding_window == (-1, -1), (
        "Cascade attention does not support sliding window.")

    num_tokens = query.shape[0]
    block_size = key_cache.shape[-3]
    assert common_prefix_len % block_size == 0
    num_common_kv_blocks = common_prefix_len // block_size
    assert num_common_kv_blocks > 0
838
    descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
839
840

    # Process shared prefix.
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
    if not current_platform.is_rocm():
        prefix_output, prefix_lse = flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_prefix_query_lens,
            seqused_k=prefix_kv_lens,
            max_seqlen_q=num_tokens,
            max_seqlen_k=common_prefix_len,
            softmax_scale=softmax_scale,
            causal=False,
            window_size=sliding_window,
            block_table=block_table[:1],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=prefix_scheduler_metadata,
            fa_version=fa_version,
            q_descale=q_descale.expand(descale_shape)
            if q_descale is not None else None,
            k_descale=k_descale.expand(descale_shape)
            if k_descale is not None else None,
            v_descale=v_descale.expand(descale_shape)
            if v_descale is not None else None,
        )
zhuwenwen's avatar
zhuwenwen committed
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
    else:
        prefix_output, prefix_lse, _ = vllm_flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_prefix_query_lens,
            seqused_k=prefix_kv_lens,
            max_seqlen_q=num_tokens,
            max_seqlen_k=common_prefix_len,
            softmax_scale=softmax_scale,
            causal=False,
            window_size=sliding_window,
            block_table=block_table[:1],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=prefix_scheduler_metadata,
            # fa_version=fa_version,
            # q_descale=q_descale.expand(descale_shape)
            # if q_descale is not None else None,
            # k_descale=k_descale.expand(descale_shape)
            # if k_descale is not None else None,
            # v_descale=v_descale.expand(descale_shape)
            # if v_descale is not None else None,
            is_prefix_cache=True,
        )
890

891
892
    descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])

893
    # Process suffix per query.
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
    if not current_platform.is_rocm():
        suffix_output, suffix_lse = flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_query_lens,
            seqused_k=suffix_kv_lens,
            max_seqlen_q=max_query_len,
            max_seqlen_k=max_kv_len - common_prefix_len,
            softmax_scale=softmax_scale,
            causal=True,
            window_size=sliding_window,
            block_table=block_table[:, num_common_kv_blocks:],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=suffix_scheduler_metadata,
            fa_version=fa_version,
            q_descale=q_descale.expand(descale_shape)
            if q_descale is not None else None,
            k_descale=k_descale.expand(descale_shape)
            if k_descale is not None else None,
            v_descale=v_descale.expand(descale_shape)
            if v_descale is not None else None,
        )
zhuwenwen's avatar
zhuwenwen committed
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
    else:
        suffix_output, suffix_lse, _ = vllm_flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_query_lens,
            seqused_k=suffix_kv_lens,
            max_seqlen_q=max_query_len,
            max_seqlen_k=max_kv_len - common_prefix_len,
            softmax_scale=softmax_scale,
            causal=True,
            window_size=sliding_window,
            block_table=block_table[:, num_common_kv_blocks:],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=suffix_scheduler_metadata,
            # fa_version=fa_version,
            # q_descale=q_descale.expand(descale_shape)
            # if q_descale is not None else None,
            # k_descale=k_descale.expand(descale_shape)
            # if k_descale is not None else None,
            # v_descale=v_descale.expand(descale_shape)
            # if v_descale is not None else None,
            is_prefix_cache=True,
        )
943
944
945

    # Merge prefix and suffix outputs, and store the result in output.
    merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
946
                      suffix_lse)