flash_attn.py 42.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
5
from typing import Optional, Tuple
6

7
import numpy as np
8
9
import torch

10
import vllm.envs as envs
11
from vllm import _custom_ops as ops
12

13
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
14
15
                                              AttentionMetadata, AttentionType,
                                              is_quantized_kv_cache)
16
from vllm.attention.layer import Attention
17
from vllm.attention.ops.merge_attn_states import merge_attn_states
18
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
19
                                           get_flash_attn_version,
20
21
                                           is_flash_attn_varlen_func_available)

zhuwenwen's avatar
zhuwenwen committed
22
from vllm.platforms import current_platform
23
if is_flash_attn_varlen_func_available():
zhuwenwen's avatar
zhuwenwen committed
24
25
26
27
28
    if not current_platform.is_rocm():
        from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
                                                get_scheduler_metadata,
                                                reshape_and_cache_flash)
    else:
29
30
        from vllm.attention.utils.fa_utils import (vllm_flash_attn_varlen_func,
                                               reshape_and_cache_cuda)
31

32
from vllm.config import VllmConfig, get_layers_from_vllm_config
zhuwenwen's avatar
zhuwenwen committed
33

34
from vllm.logger import init_logger
35
from vllm.utils import cdiv
36
37
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
                                              AttentionMetadataBuilder,
38
39
                                              CommonAttentionMetadata,
                                              get_kv_cache_layout)
40
from vllm.v1.kv_cache_interface import AttentionSpec
41

42
43
logger = init_logger(__name__)

44
45
46

class FlashAttentionBackend(AttentionBackend):

47
    accept_output_buffer: bool = True
48
    supports_quant_query_input: bool = True
49

50
51
52
53
    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.float16, torch.bfloat16]

54
55
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
56
57
        return [32, 64, 96, 128, 160, 192, 224, 256]

58
59
60
61
62
63
64
65
66
67
68
    @classmethod
    def validate_head_size(cls, head_size: int) -> None:
        supported_head_sizes = cls.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            attn_type = cls.__name__.removesuffix("Backend")
            raise ValueError(
                f"Head size {head_size} is not supported by {attn_type}. "
                f"Supported head sizes are: {supported_head_sizes}. "
                "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
                "FlexAttention backend which supports all head sizes.")

69
70
    @staticmethod
    def get_name() -> str:
71
        return "FLASH_ATTN"
72
73

    @staticmethod
74
    def get_impl_cls() -> type["FlashAttentionImpl"]:
75
76
77
        return FlashAttentionImpl

    @staticmethod
78
    def get_metadata_cls() -> type["AttentionMetadata"]:
79
80
        return FlashAttentionMetadata

81
    @staticmethod
82
    def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
83
        return FlashAttentionMetadataBuilder
84
85
86
87
88
89
90
91
    
    if not current_platform.is_rocm():
        @staticmethod
        def get_kv_cache_shape(
            num_blocks: int,
            block_size: int,
            num_kv_heads: int,
            head_size: int,
92
            cache_dtype_str: str = "auto",
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        ) -> tuple[int, ...]:
            if block_size % 16 != 0:
                raise ValueError("Block size must be a multiple of 16.")
            return (2, num_blocks, block_size, num_kv_heads, head_size)

        @staticmethod
        def get_kv_cache_stride_order() -> tuple[int, ...]:
            # `stride_order` indicates the permutation that gets
            # us from `get_kv_cache_shape` to the actual memory layout we want.
            cache_layout = get_kv_cache_layout()
            if cache_layout == "NHD":
                stride_order = (0, 1, 2, 3, 4)
            elif cache_layout == "HND":
                stride_order = (0, 1, 3, 2, 4)
            else:
                raise ValueError(f"Unknown cache layout format {cache_layout}.")
            return stride_order
    else:
        @staticmethod
        def get_kv_cache_shape(
            num_blocks: int,
            block_size: int,
            num_kv_heads: int,
            head_size: int,
117
            cache_dtype_str: str = "auto",
118
119
120
121
122
123
124
        ) -> tuple[tuple[int, ...], tuple[int, ...]]:
            if block_size % 16 != 0:
                raise ValueError("Block size must be a multiple of 16.")
            return (
                (num_blocks, num_kv_heads, block_size, head_size),
                (num_blocks, num_kv_heads, head_size, block_size),
            )
125

126
127
128
129
130
131
132
133
134
135
136
137
138
139
        @staticmethod
        def get_kv_cache_stride_order() -> tuple[tuple[int, ...], tuple[int, ...]]:
            # `stride_order` indicates the permutation that gets
            # us from `get_kv_cache_shape` to the actual memory layout we want.
            cache_layout = get_kv_cache_layout()
            if cache_layout == "NHD":
                key_stride_order = (0, 1, 2, 3)
                value_stride_order = (0, 1, 2, 3)
            elif cache_layout == "HND":
                key_stride_order = (0, 2, 1, 3)
                value_stride_order = (0, 2, 1, 3)
            else:
                raise ValueError(f"Unknown cache layout format {cache_layout}.")
            return key_stride_order, value_stride_order
140
141


142
143
144
145
    @staticmethod
    def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
146
147
        elif kv_cache_dtype in ("fp8_e5m2"):
            return torch.float8_e5m2
148
149
150
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

151
152
153
154
155
156
157
158
159
160
161

@dataclass
class FlashAttentionMetadata:
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

162
    num_actual_tokens: int  # Number of tokens excluding padding.
163
164
165
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
166
    seq_lens: torch.Tensor
167
168
    block_table: torch.Tensor
    slot_mapping: torch.Tensor
169
170
171
172
173

    # For cascade attention.
    use_cascade: bool
    common_prefix_len: int
    cu_prefix_query_lens: Optional[torch.Tensor]
174
175
    prefix_kv_lens: Optional[torch.Tensor]
    suffix_kv_lens: Optional[torch.Tensor]
176

177
178
179
    # Optional aot scheduling
    scheduler_metadata: Optional[torch.Tensor] = None
    prefix_scheduler_metadata: Optional[torch.Tensor] = None
180
    max_num_splits: int = 0
181

182
183
    causal: bool = True

184

185
186
187
188
189
190
191
192
193
194
195
def _get_sliding_window_configs(
        vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
    """Get the set of all sliding window configs used in the model."""
    sliding_window_configs: set[Optional[tuple[int, int]]] = set()
    layers = get_layers_from_vllm_config(vllm_config, Attention)
    for layer in layers.values():
        assert isinstance(layer.impl, FlashAttentionImpl)
        sliding_window_configs.add(layer.impl.sliding_window)
    return sliding_window_configs


196
197
class FlashAttentionMetadataBuilder(
        AttentionMetadataBuilder[FlashAttentionMetadata]):
198
199
200
201
202
203
204
205
206
207
208
209
210
    # FA3:
    # Supports full cudagraphs for all cases.
    #
    # FA2:
    # For FA2, a graph is captured with max_query_len=1, (which is what we
    # capture by default for num_tokens <= max_num_seqs when there is no
    # spec-decode) then these graphs will not work for mixed prefill-decode
    # (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
    # in FA2.
    # In summary if we are running with spec decodes the graphs would
    # work for mixed prefill-decode and uniform-decode. But for non-spec decodes
    # the graphs would not work for mixed prefill-decode; sorta the inverse
    # of UNIFORM_SINGLE_TOKEN_DECODE.
co63oc's avatar
co63oc committed
211
    # There's probably a better way to describe this using `AttentionCGSupport`
212
213
214
215
216
    # but for now just set it to `UNIFORM_BATCH` to get use to drop down
    # to FULL_AND_PIECEWISE.
    # TODO(luka, lucas): audit FA2 as part of:
    #  https://github.com/vllm-project/vllm/issues/22945
    cudagraph_support = AttentionCGSupport.ALWAYS \
217
        if get_flash_attn_version() == 3 or current_platform.is_rocm()  else AttentionCGSupport.UNIFORM_BATCH
218

219
220
    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
221
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)
222
223
224
225
226
227
228
229
230
        self.model_config = vllm_config.model_config
        self.parallel_config = vllm_config.parallel_config
        self.cache_config = vllm_config.cache_config
        self.compilation_config = vllm_config.compilation_config

        self.num_heads_q = self.model_config.get_num_attention_heads(
            self.parallel_config)
        self.num_heads_kv = self.model_config.get_num_kv_heads(
            self.parallel_config)
231
        self.kv_cache_dtype = kv_cache_spec.dtype
232
        self.headdim = self.model_config.get_head_size()
233
        self.block_size = kv_cache_spec.block_size
234

235
        self.max_num_splits = 0  # No upper bound on the number of splits.
236
        self.aot_schedule = (get_flash_attn_version() == 3)
237
238
239

        self.use_full_cuda_graph = \
            self.compilation_config.cudagraph_mode.has_full_cudagraphs()
240
        self.max_cudagraph_size = self.compilation_config.max_capture_size
241
242

        if self.use_full_cuda_graph and self.aot_schedule:
243
244
245
246
247
248
249
250
            if self.max_cudagraph_size > 992:
                # This condition derives from FA3's internal heuristic.
                # TODO(woosuk): Support larger cudagraph sizes.
                raise ValueError(
                    "Capture size larger than 992 is not supported for "
                    "full cuda graph.")

            self.scheduler_metadata = torch.zeros(
251
                vllm_config.scheduler_config.max_num_seqs + 1,
252
                dtype=torch.int32,
253
                device=self.device,
254
255
256
257
            )
            # When using cuda graph, we need to set the upper bound of the
            # number of splits so that large enough intermediate buffers are
            # pre-allocated during capture.
258
259
            self.max_num_splits = (
                envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH)
260

261
262
263
264
        # Sliding window size to be used with the AOT scheduler will be
        # populated on first build() call.
        self.aot_sliding_window: Optional[tuple[int, int]] = None

265
266
267
268
269
270
271
272
    def build(self,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata,
              fast_build: bool = False) -> FlashAttentionMetadata:
        """
        fast_build disables AOT scheduling, used when there will be few 
        iterations i.e. spec-decode
        """
273
274
275
        num_reqs = common_attn_metadata.num_reqs
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        max_query_len = common_attn_metadata.max_query_len
276
        max_seq_len = common_attn_metadata.max_seq_len
277
278
        query_start_loc = common_attn_metadata.query_start_loc
        seq_lens = common_attn_metadata.seq_lens
279
280
281
        seq_lens_cpu = common_attn_metadata.seq_lens_cpu
        block_table_tensor = common_attn_metadata.block_table_tensor
        slot_mapping = common_attn_metadata.slot_mapping
282
        causal = common_attn_metadata.causal
283

284
285
        # the overhead of the aot schedule is not worth it for spec-decode
        aot_schedule = self.aot_schedule and not fast_build
286

zhuwenwen's avatar
zhuwenwen committed
287
288
289
290
291
292
        if self.aot_sliding_window is None:
            self.aot_sliding_window = (-1, -1)
            # For the AOT scheduler we need the sliding window value to be
            # constant for all layers to. We have to populate this on the first
            # build() call so the layers are constructed (cannot populate)
            # in __init__.
293
            if aot_schedule:
zhuwenwen's avatar
zhuwenwen committed
294
                sliding_window_configs = _get_sliding_window_configs(
295
                    self.vllm_config)
zhuwenwen's avatar
zhuwenwen committed
296
297
298
299
300
301
                if len(sliding_window_configs) == 1:
                    sliding_window_config = sliding_window_configs.pop()
                    if sliding_window_config is not None:
                        self.aot_sliding_window = sliding_window_config
                elif len(sliding_window_configs) > 1:
                    self.aot_schedule = False
302
                    aot_schedule = False
303

304
305
306
307
308
309
310
311
        max_num_splits = 0  # 0 means use FA3's heuristics, not CG compatible
        if self.use_full_cuda_graph and \
            num_actual_tokens <= self.max_cudagraph_size:
            # NOTE(woosuk): Setting num_splits > 1 may increase the memory
            # usage, because the intermediate buffers of size [num_splits,
            # num_heads, num_tokens, head_size] are allocated. Therefore,
            # we only set num_splits when using cuda graphs.
            max_num_splits = self.max_num_splits
312

313
314
        def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
                     max_seq_len, causal):
315
316
317
318
319
320
            cache_dtype = self.cache_config.cache_dtype
            if cache_dtype.startswith("fp8"):
                qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
                    cache_dtype)
            else:
                qkv_dtype = self.kv_cache_dtype
321
            if aot_schedule:
322
323
324
325
326
327
328
                return get_scheduler_metadata(
                    batch_size=batch_size,
                    max_seqlen_q=max_query_len,
                    max_seqlen_k=max_seq_len,
                    num_heads_q=self.num_heads_q,
                    num_heads_kv=self.num_heads_kv,
                    headdim=self.headdim,
329
330
                    cache_seqlens=seqlens,
                    qkv_dtype=qkv_dtype,
331
                    cu_seqlens_q=cu_query_lens,
332
                    page_size=self.block_size,
333
                    causal=causal,
334
                    window_size=self.aot_sliding_window,
335
                    num_splits=max_num_splits,
336
337
338
                )
            return None

339
        use_cascade = common_prefix_len > 0
340

341
342
343
        if use_cascade:
            cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
                                                dtype=torch.int32,
344
                                                device=self.device)
345
346
            prefix_kv_lens = torch.tensor([common_prefix_len],
                                          dtype=torch.int32,
347
348
349
                                          device=self.device)
            suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
                self.device, non_blocking=True)
350
            prefix_scheduler_metadata = schedule(
351
                batch_size=1,
352
353
354
355
356
                cu_query_lens=cu_prefix_query_lens,
                max_query_len=num_actual_tokens,
                seqlens=prefix_kv_lens,
                max_seq_len=common_prefix_len,
                causal=False)
357
358
            scheduler_metadata = schedule(batch_size=num_reqs,
                                          cu_query_lens=query_start_loc,
359
360
361
362
363
                                          max_query_len=max_query_len,
                                          seqlens=suffix_kv_lens,
                                          max_seq_len=max_seq_len -
                                          common_prefix_len,
                                          causal=True)
364
365
366
367
        else:
            cu_prefix_query_lens = None
            prefix_kv_lens = None
            suffix_kv_lens = None
368
            prefix_scheduler_metadata = None
369
370
            scheduler_metadata = schedule(batch_size=num_reqs,
                                          cu_query_lens=query_start_loc,
371
372
373
                                          max_query_len=max_query_len,
                                          seqlens=seq_lens,
                                          max_seq_len=max_seq_len,
374
                                          causal=causal)
375
376
        # For FA3 + full cudagraph
        if self.use_full_cuda_graph and scheduler_metadata is not None:
377
            n = scheduler_metadata.shape[0]
378
            self.scheduler_metadata[:n] = scheduler_metadata
379
380
381
382
383
384
385
            # NOTE(woosuk): We should zero out the rest of the scheduler
            # metadata to guarantee the correctness. Otherwise, some thread
            # blocks may use the invalid scheduler metadata and overwrite the
            # output buffer.
            self.scheduler_metadata[n:] = 0
            scheduler_metadata = self.scheduler_metadata[:n]

386
387
388
389
390
391
        attn_metadata = FlashAttentionMetadata(
            num_actual_tokens=num_actual_tokens,
            max_query_len=max_query_len,
            query_start_loc=query_start_loc,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
392
            block_table=block_table_tensor,
393
394
395
            slot_mapping=slot_mapping,
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
396
            scheduler_metadata=scheduler_metadata,
397
398
399
            cu_prefix_query_lens=cu_prefix_query_lens,
            prefix_kv_lens=prefix_kv_lens,
            suffix_kv_lens=suffix_kv_lens,
400
            prefix_scheduler_metadata=prefix_scheduler_metadata,
401
            max_num_splits=max_num_splits,
402
            causal=causal)
403
404
        return attn_metadata

405
406
407
    def use_cascade_attention(self, *args, **kwargs) -> bool:
        return use_cascade_attention(*args, **kwargs)

408

409
410
411
412
413
414
415
416
class FlashAttentionImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
417
        alibi_slopes: Optional[list[float]],
418
419
420
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float] = None,
421
        attn_type: AttentionType = AttentionType.DECODER,
422
        kv_sharing_target_layer_name: Optional[str] = None,
423
        sinks: Optional[torch.Tensor] = None,
424
425
426
427
428
429
430
431
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
432
433
        if sliding_window is None:
            self.sliding_window = (-1, -1)
434
435
        elif attn_type == AttentionType.ENCODER_ONLY:
            self.sliding_window = (sliding_window - 1, sliding_window - 1)
436
437
        else:
            self.sliding_window = (sliding_window - 1, 0)
438
439
440
441
442
        self.kv_cache_dtype = kv_cache_dtype
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap
443
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
444
445
446

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

447
        FlashAttentionBackend.validate_head_size(head_size)
448

449
        self.attn_type = attn_type
450
        self.vllm_flash_attn_version = get_flash_attn_version()
451
452
453
454
        if is_quantized_kv_cache(self.kv_cache_dtype) \
            and not flash_attn_supports_fp8():
            raise NotImplementedError(
                "FlashAttention does not support fp8 kv-cache on this device.")
455

456
457
        self.sinks = sinks
        if self.sinks is not None:
zhuwenwen's avatar
zhuwenwen committed
458
459
460
            if not current_platform.is_rocm():
                assert self.vllm_flash_attn_version == 3, (
                    "Sinks are only supported in FlashAttention 3")
461
462
463
464
            assert self.sinks.shape[0] == num_heads, (
                "Sinks must have the same number of heads as the number of "
                "heads in the layer")

465
466
    def forward(
        self,
467
        layer: torch.nn.Module,
468
469
470
471
472
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
473
        output: Optional[torch.Tensor] = None,
474
        output_scale: Optional[torch.Tensor] = None,
475
        output_block_scale: Optional[torch.Tensor] = None,
476
477
478
479
    ) -> torch.Tensor:
        """Forward pass with FlashAttention.

        Args:
480
481
482
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
483
484
            kv_cache: shape =
                [2, num_blocks, block_size, num_kv_heads, head_size]
485
486
487
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
488
489
490
        NOTE: FP8 quantization, flash-attn expect the size of
              {q,k,v}_descale to be (num_sequences, num_kv_heads).
              We use torch's .expand() to avoid duplicating values
491
        """
492
493
        assert output is not None, "Output tensor must be provided."

494
        if output_scale is not None or output_block_scale is not None:
495
496
497
498
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for FlashAttentionImpl")

499
500
501
502
        if attn_metadata is None:
            # Profiling run.
            return output

503
504
        attn_type = self.attn_type

505
506
507
508
509
510
511
512
        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.
513

514
        num_actual_tokens = attn_metadata.num_actual_tokens
515
516

        # Handle encoder attention differently - no KV cache needed
517
        if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
518
519
520
521
522
523
524
525
526
            # For encoder attention,
            # we use direct Q, K, V tensors without caching
            return self._forward_encoder_attention(query[:num_actual_tokens],
                                                   key[:num_actual_tokens],
                                                   value[:num_actual_tokens],
                                                   output[:num_actual_tokens],
                                                   attn_metadata, layer)

        # For decoder and cross-attention, use KV cache as before
527
528
529
530
        if not current_platform.is_rocm():
            key_cache, value_cache = kv_cache.unbind(0)
        else:
            key_cache, value_cache = kv_cache
531

532
533
534
535
536
        # key and value may be None in the case of cross attention. They are
        # calculated once based on the output from the encoder and then cached
        # in KV cache.
        if (self.kv_sharing_target_layer_name is None and key is not None
                and value is not None):
537
538
539
540
541
542
543
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
544
545
546
547
548
549
550
551
552
553
554
555
            if not current_platform.is_rocm():
                reshape_and_cache_flash(
                    key,
                    value,
                    key_cache,
                    value_cache,
                    attn_metadata.slot_mapping,
                    self.kv_cache_dtype,
                    layer._k_scale,
                    layer._v_scale,
                )
            else:
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
                if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == torch.float16:
                    from lightop import reshape_and_cache_cuda
                    reshape_and_cache_cuda(
                        key, 
                        value,
                        key_cache, 
                        value_cache,
                        attn_metadata.slot_mapping,
                        self.kv_cache_dtype,
                        layer._k_scale, 
                        layer._v_scale
                    )
                else:
                    from vllm.attention.utils.fa_utils import reshape_and_cache_cuda
                    reshape_and_cache_cuda(
                        key,
                        value,
                        key_cache,
                        value_cache,
                        attn_metadata.slot_mapping,
                        self.kv_cache_dtype,
                        layer._k_scale,
                        layer._v_scale,
                    )
580

581
        if self.kv_cache_dtype.startswith("fp8"):
582
            # queries are quantized in the attention layer
583
584
585
586
            dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
                self.kv_cache_dtype)
            key_cache = key_cache.view(dtype)
            value_cache = value_cache.view(dtype)
587

588
589
590
591
592
593
594
        if not attn_metadata.use_cascade:
            cu_seqlens_q = attn_metadata.query_start_loc
            seqused_k = attn_metadata.seq_lens
            max_seqlen_q = attn_metadata.max_query_len
            max_seqlen_k = attn_metadata.max_seq_len
            block_table = attn_metadata.block_table
            scheduler_metadata = attn_metadata.scheduler_metadata
595

596
            descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
597

zhuwenwen's avatar
zhuwenwen committed
598
            if not current_platform.is_rocm():
599
                 flash_attn_varlen_func(
zhuwenwen's avatar
zhuwenwen committed
600
601
602
603
604
605
606
607
608
                    q=query[:num_actual_tokens],
                    k=key_cache,
                    v=value_cache,
                    out=output[:num_actual_tokens],
                    cu_seqlens_q=cu_seqlens_q,
                    max_seqlen_q=max_seqlen_q,
                    seqused_k=seqused_k,
                    max_seqlen_k=max_seqlen_k,
                    softmax_scale=self.scale,
609
                    causal=attn_metadata.causal,
zhuwenwen's avatar
zhuwenwen committed
610
611
612
613
614
615
616
617
618
619
                    alibi_slopes=self.alibi_slopes,
                    window_size=self.sliding_window,
                    block_table=block_table,
                    softcap=self.logits_soft_cap,
                    scheduler_metadata=scheduler_metadata,
                    fa_version=self.vllm_flash_attn_version,
                    q_descale=layer._q_scale.expand(descale_shape),
                    k_descale=layer._k_scale.expand(descale_shape),
                    v_descale=layer._v_scale.expand(descale_shape),
                    num_splits=attn_metadata.max_num_splits,
620
                    s_aux=self.sinks,
zhuwenwen's avatar
zhuwenwen committed
621
622
                )
            else:
623
624
                if envs.VLLM_USE_PA_PRINT_PARAM:
                    print("PA SIZE:")
625
                    print(f"q.shape = {query[:num_actual_tokens].shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
626
627
                    print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
                    print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
zhuwenwen's avatar
zhuwenwen committed
628
629
630
631
632
633
634
635
636
637
                vllm_flash_attn_varlen_func(
                    q=query[:num_actual_tokens],
                    k=key_cache,
                    v=value_cache,
                    out=output[:num_actual_tokens],
                    cu_seqlens_q=cu_seqlens_q,
                    max_seqlen_q=max_seqlen_q,
                    seqused_k=seqused_k,
                    max_seqlen_k=max_seqlen_k,
                    softmax_scale=self.scale,
638
                    causal=attn_metadata.causal,
zhuwenwen's avatar
zhuwenwen committed
639
640
641
642
643
644
645
646
647
648
                    alibi_slopes=self.alibi_slopes,
                    window_size=self.sliding_window,
                    block_table=block_table,
                    softcap=self.logits_soft_cap,
                    scheduler_metadata=scheduler_metadata,
                    # fa_version=self.vllm_flash_attn_version,
                    # q_descale=layer._q_scale.expand(descale_shape),
                    # k_descale=layer._k_scale.expand(descale_shape),
                    # v_descale=layer._v_scale.expand(descale_shape),
                    # num_splits=attn_metadata.max_num_splits,
649
                    s_aux=self.sinks,
650
                    is_prefix_cache=True,
zhuwenwen's avatar
zhuwenwen committed
651
                )
652
            return output
653
654

        # Cascade attention (rare case).
zhuwenwen's avatar
zhuwenwen committed
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
        if not current_platform.is_rocm():
            cascade_attention(
                output[:num_actual_tokens],
                query[:num_actual_tokens],
                key_cache,
                value_cache,
                cu_query_lens=attn_metadata.query_start_loc,
                max_query_len=attn_metadata.max_query_len,
                cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
                prefix_kv_lens=attn_metadata.prefix_kv_lens,
                suffix_kv_lens=attn_metadata.suffix_kv_lens,
                max_kv_len=attn_metadata.max_seq_len,
                softmax_scale=self.scale,
                alibi_slopes=self.alibi_slopes,
                sliding_window=self.sliding_window,
                logits_soft_cap=self.logits_soft_cap,
                block_table=attn_metadata.block_table,
                common_prefix_len=attn_metadata.common_prefix_len,
                fa_version=self.vllm_flash_attn_version,
                prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
                suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
                q_descale=layer._q_scale,
                k_descale=layer._k_scale,
                v_descale=layer._v_scale,
            )
        else:
            cascade_attention(
                output[:num_actual_tokens],
                query[:num_actual_tokens],
                key_cache,
                value_cache,
                cu_query_lens=attn_metadata.query_start_loc,
                max_query_len=attn_metadata.max_query_len,
                cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
                prefix_kv_lens=attn_metadata.prefix_kv_lens,
                suffix_kv_lens=attn_metadata.suffix_kv_lens,
                max_kv_len=attn_metadata.max_seq_len,
                softmax_scale=self.scale,
                alibi_slopes=self.alibi_slopes,
                sliding_window=self.sliding_window,
                logits_soft_cap=self.logits_soft_cap,
                block_table=attn_metadata.block_table,
                common_prefix_len=attn_metadata.common_prefix_len,
                fa_version=2, #self.vllm_flash_attn_version,
                prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
                suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
                # q_descale=layer._q_scale,
                # k_descale=layer._k_scale,
                # v_descale=layer._v_scale,
            )
705
        return output
706

707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
    def _forward_encoder_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
        layer: torch.nn.Module,
    ) -> torch.Tensor:
        """Forward pass for encoder attention without KV cache.

        Args:
            query: shape = [num_encoder_tokens, num_heads, head_size]
            key: shape = [num_encoder_tokens, num_kv_heads, head_size]
            value: shape = [num_encoder_tokens, num_kv_heads, head_size]
            output: shape = [num_encoder_tokens, num_heads, head_size]
            attn_metadata: Encoder attention metadata
            layer: The attention layer
        """
        # For encoder attention, process FP8 quantization if needed
        if self.kv_cache_dtype.startswith("fp8"):
            raise NotImplementedError(
                "quantization is not supported for encoder attention")

        # Use encoder-specific metadata for sequence information
        cu_seqlens_q = attn_metadata.query_start_loc
        cu_seqlens_k = attn_metadata.query_start_loc
        max_seqlen_q = attn_metadata.max_query_len
        max_seqlen_k = attn_metadata.max_query_len

        descale_shape = (
            cu_seqlens_q.shape[0] - 1,  # type: ignore[union-attr]
            self.num_kv_heads)

        # Call flash attention directly on Q, K, V tensors
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
        if not current_platform.is_rocm():
            flash_attn_varlen_func(
                q=query,
                k=key,
                v=value,
                out=output,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_k=max_seqlen_k,
                softmax_scale=self.scale,
                causal=False,  # Encoder attention is bidirectional
                alibi_slopes=self.alibi_slopes,
                window_size=self.sliding_window,
                softcap=self.logits_soft_cap,
                fa_version=self.vllm_flash_attn_version,
                q_descale=layer._q_scale.expand(descale_shape),
                k_descale=layer._k_scale.expand(descale_shape),
                v_descale=layer._v_scale.expand(descale_shape),
            )
        else:
            vllm_flash_attn_varlen_func(
                q=query,
                k=key,
                v=value,
                out=output,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_k=max_seqlen_k,
                softmax_scale=self.scale,
                causal=False,  # Encoder attention is bidirectional
                alibi_slopes=self.alibi_slopes,
                window_size=self.sliding_window,
                softcap=self.logits_soft_cap,
                # fa_version=self.vllm_flash_attn_version,
                # q_descale=layer._q_scale.expand(descale_shape),
                # k_descale=layer._k_scale.expand(descale_shape),
                # v_descale=layer._v_scale.expand(descale_shape),
zhuwenwen's avatar
zhuwenwen committed
781
                is_prefix_cache=False,
782
            )
783
784
785

        return output

786
787
788
789
790
791
792
793

def use_cascade_attention(
    common_prefix_len: int,
    query_lens: np.ndarray,
    num_query_heads: int,
    num_kv_heads: int,
    use_alibi: bool,
    use_sliding_window: bool,
794
    use_local_attention: bool,
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
    num_sms: int,
) -> bool:
    """Decide whether to use cascade attention.

    This function 1) checks whether cascade attention is supported with the
    given configuration, and 2) heuristically decides whether using cascade
    attention can improve performance.
    """
    # Too short common prefix. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
    # NOTE(woosuk): This is the common case. We should return False as soon as
    # possible to avoid any unnecessary computation.
    if common_prefix_len < 256:
        return False
    # Cascade attention is currently not supported with these variants.
810
    if use_alibi or use_sliding_window or use_local_attention:
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
        return False
    # Too few queries. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
    num_reqs = len(query_lens)
    if num_reqs < 8:
        return False

    # Heuristics to decide whether using cascade attention is beneficial.
    # 1. When FlashDecoding is not used for normal attention, cascade attention
    #    is likely to be faster since it saves memory bandwidth.
    num_queries_per_kv = num_query_heads // num_kv_heads
    # The criteria for using FlashDecoding can be found in the following link:
    # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
    use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
                          and not use_alibi and np.all(query_lens == 1))
    if not use_flash_decoding:
        # Use cascade attention.
        return True

    # 2. When FlashDecoding is used for normal attention, it is not clear
    #    whether cascade attention is beneficial, because FlashDecoding can
    #    launch more CTAs than cascade attention.
    #    We use a simple performance model to compare the two methods.
    #    NOTE(woosuk): The performance model is very rough and may not be
    #    accurate.
    num_tokens = num_reqs
    # NOTE(woosuk): These are default tile sizes. flash-attn might use
    # different tile sizes (e.g., 64 or 256) depending on the configuration.
    q_tile_size = 128
    kv_tile_size = 128
    num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)

    cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
    cascade_waves = cdiv(cascade_ctas, num_sms)
    cascade_time = cascade_waves * num_prefix_tiles

    flash_decoding_ctas = (num_reqs * num_kv_heads *
                           cdiv(num_queries_per_kv, q_tile_size))
    flash_decoding_ctas *= num_prefix_tiles
    flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)

    # Use cascade attention if it is faster than FlashDecoding.
    return cascade_time < flash_decoding_time


def cascade_attention(
    output: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    cu_query_lens: torch.Tensor,
    max_query_len: int,
    cu_prefix_query_lens: torch.Tensor,
864
865
    prefix_kv_lens: torch.Tensor,
    suffix_kv_lens: torch.Tensor,
866
867
868
    max_kv_len: int,
    softmax_scale: float,
    alibi_slopes: Optional[torch.Tensor],
869
    sliding_window: tuple[int, int],
870
871
872
    logits_soft_cap: float,
    block_table: torch.Tensor,
    common_prefix_len: int,
873
    fa_version: int,
874
875
    prefix_scheduler_metadata: Optional[torch.Tensor] = None,
    suffix_scheduler_metadata: Optional[torch.Tensor] = None,
876
877
878
    q_descale: Optional[torch.Tensor] = None,
    k_descale: Optional[torch.Tensor] = None,
    v_descale: Optional[torch.Tensor] = None,
879
880
881
882
883
884
885
886
887
888
889
) -> torch.Tensor:
    assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
    # TODO: Support sliding window.
    assert sliding_window == (-1, -1), (
        "Cascade attention does not support sliding window.")

    num_tokens = query.shape[0]
    block_size = key_cache.shape[-3]
    assert common_prefix_len % block_size == 0
    num_common_kv_blocks = common_prefix_len // block_size
    assert num_common_kv_blocks > 0
890
    descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
891
892

    # Process shared prefix.
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
    if not current_platform.is_rocm():
        prefix_output, prefix_lse = flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_prefix_query_lens,
            seqused_k=prefix_kv_lens,
            max_seqlen_q=num_tokens,
            max_seqlen_k=common_prefix_len,
            softmax_scale=softmax_scale,
            causal=False,
            window_size=sliding_window,
            block_table=block_table[:1],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=prefix_scheduler_metadata,
            fa_version=fa_version,
            q_descale=q_descale.expand(descale_shape)
            if q_descale is not None else None,
            k_descale=k_descale.expand(descale_shape)
            if k_descale is not None else None,
            v_descale=v_descale.expand(descale_shape)
            if v_descale is not None else None,
        )
zhuwenwen's avatar
zhuwenwen committed
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
    else:
        prefix_output, prefix_lse, _ = vllm_flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_prefix_query_lens,
            seqused_k=prefix_kv_lens,
            max_seqlen_q=num_tokens,
            max_seqlen_k=common_prefix_len,
            softmax_scale=softmax_scale,
            causal=False,
            window_size=sliding_window,
            block_table=block_table[:1],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=prefix_scheduler_metadata,
            # fa_version=fa_version,
            # q_descale=q_descale.expand(descale_shape)
            # if q_descale is not None else None,
            # k_descale=k_descale.expand(descale_shape)
            # if k_descale is not None else None,
            # v_descale=v_descale.expand(descale_shape)
            # if v_descale is not None else None,
            is_prefix_cache=True,
        )
942

943
944
    descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])

945
    # Process suffix per query.
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
    if not current_platform.is_rocm():
        suffix_output, suffix_lse = flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_query_lens,
            seqused_k=suffix_kv_lens,
            max_seqlen_q=max_query_len,
            max_seqlen_k=max_kv_len - common_prefix_len,
            softmax_scale=softmax_scale,
            causal=True,
            window_size=sliding_window,
            block_table=block_table[:, num_common_kv_blocks:],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=suffix_scheduler_metadata,
            fa_version=fa_version,
            q_descale=q_descale.expand(descale_shape)
            if q_descale is not None else None,
            k_descale=k_descale.expand(descale_shape)
            if k_descale is not None else None,
            v_descale=v_descale.expand(descale_shape)
            if v_descale is not None else None,
        )
zhuwenwen's avatar
zhuwenwen committed
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
    else:
        suffix_output, suffix_lse, _ = vllm_flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_query_lens,
            seqused_k=suffix_kv_lens,
            max_seqlen_q=max_query_len,
            max_seqlen_k=max_kv_len - common_prefix_len,
            softmax_scale=softmax_scale,
            causal=True,
            window_size=sliding_window,
            block_table=block_table[:, num_common_kv_blocks:],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=suffix_scheduler_metadata,
            # fa_version=fa_version,
            # q_descale=q_descale.expand(descale_shape)
            # if q_descale is not None else None,
            # k_descale=k_descale.expand(descale_shape)
            # if k_descale is not None else None,
            # v_descale=v_descale.expand(descale_shape)
            # if v_descale is not None else None,
            is_prefix_cache=True,
        )
995
996
997

    # Merge prefix and suffix outputs, and store the result in output.
    merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
998
                      suffix_lse)