flash_attn.py 36.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
5
from typing import Optional
6

7
import numpy as np
8
9
import torch

10
import vllm.envs as envs
11
from vllm import _custom_ops as ops
12
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
13
14
                                              AttentionMetadata, AttentionType,
                                              is_quantized_kv_cache)
15
from vllm.attention.layer import Attention
16
from vllm.attention.ops.merge_attn_states import merge_attn_states
17
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
18
                                           get_flash_attn_version,
19
20
                                           is_flash_attn_varlen_func_available)

zhuwenwen's avatar
zhuwenwen committed
21
from vllm.platforms import current_platform
22
if is_flash_attn_varlen_func_available():
zhuwenwen's avatar
zhuwenwen committed
23
24
25
26
27
28
29
    if not current_platform.is_rocm():
        from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
                                                get_scheduler_metadata,
                                                reshape_and_cache_flash)
    else:
        from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
                                               vllm_flash_attn_varlen_func,
30
31
                                               reshape_and_cache_flash)

32
from vllm.config import VllmConfig, get_layers_from_vllm_config
zhuwenwen's avatar
zhuwenwen committed
33

34
from vllm.logger import init_logger
35
from vllm.utils import cdiv
36
37
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
                                              AttentionMetadataBuilder,
38
39
                                              CommonAttentionMetadata,
                                              get_kv_cache_layout)
40
from vllm.v1.kv_cache_interface import AttentionSpec
41

42
43
logger = init_logger(__name__)

44
45
46
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16

47
48
49

class FlashAttentionBackend(AttentionBackend):

50
51
    accept_output_buffer: bool = True

52
53
54
55
    @classmethod
    def get_supported_dtypes(cls) -> list[torch.dtype]:
        return [torch.float16, torch.bfloat16]

56
57
    @classmethod
    def get_supported_head_sizes(cls) -> list[int]:
58
59
        return [32, 64, 96, 128, 160, 192, 224, 256]

60
61
62
63
64
65
66
67
68
69
70
    @classmethod
    def validate_head_size(cls, head_size: int) -> None:
        supported_head_sizes = cls.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
            attn_type = cls.__name__.removesuffix("Backend")
            raise ValueError(
                f"Head size {head_size} is not supported by {attn_type}. "
                f"Supported head sizes are: {supported_head_sizes}. "
                "Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
                "FlexAttention backend which supports all head sizes.")

71
72
    @staticmethod
    def get_name() -> str:
73
        return "FLASH_ATTN_VLLM_V1"
74
75

    @staticmethod
76
    def get_impl_cls() -> type["FlashAttentionImpl"]:
77
78
79
        return FlashAttentionImpl

    @staticmethod
80
    def get_metadata_cls() -> type["AttentionMetadata"]:
81
82
        return FlashAttentionMetadata

83
    @staticmethod
84
    def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
85
86
        return FlashAttentionMetadataBuilder

87
88
89
90
91
92
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
93
    ) -> tuple[int, ...]:
94
95
96
97
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")
        return (2, num_blocks, block_size, num_kv_heads, head_size)

98
99
    @staticmethod
    def get_kv_cache_stride_order() -> tuple[int, ...]:
100
        # `stride_order` indicates the permutation that gets
101
        # us from `get_kv_cache_shape` to the actual memory layout we want.
102
        cache_layout = get_kv_cache_layout()
103
104
105
106
107
        if cache_layout == "NHD":
            stride_order = (0, 1, 2, 3, 4)
        elif cache_layout == "HND":
            stride_order = (0, 1, 3, 2, 4)
        else:
108
            raise ValueError(f"Unknown cache layout format {cache_layout}.")
109
110
        return stride_order

111
112
113
114
115
116
117
    @staticmethod
    def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
        if kv_cache_dtype in ("fp8", "fp8_e4m3"):
            return torch.float8_e4m3fn
        else:
            raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")

118
119
120
121
122
123
124
125
126
127
128

@dataclass
class FlashAttentionMetadata:
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

129
    num_actual_tokens: int  # Number of tokens excluding padding.
130
131
132
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
133
    seq_lens: torch.Tensor
134
135
    block_table: torch.Tensor
    slot_mapping: torch.Tensor
136
137
138
139
140

    # For cascade attention.
    use_cascade: bool
    common_prefix_len: int
    cu_prefix_query_lens: Optional[torch.Tensor]
141
142
    prefix_kv_lens: Optional[torch.Tensor]
    suffix_kv_lens: Optional[torch.Tensor]
143

144
145
146
    # Optional aot scheduling
    scheduler_metadata: Optional[torch.Tensor] = None
    prefix_scheduler_metadata: Optional[torch.Tensor] = None
147
    max_num_splits: int = 0
148

149
150
    causal: bool = True

151

152
153
154
155
156
157
158
159
160
161
162
def _get_sliding_window_configs(
        vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
    """Get the set of all sliding window configs used in the model."""
    sliding_window_configs: set[Optional[tuple[int, int]]] = set()
    layers = get_layers_from_vllm_config(vllm_config, Attention)
    for layer in layers.values():
        assert isinstance(layer.impl, FlashAttentionImpl)
        sliding_window_configs.add(layer.impl.sliding_window)
    return sliding_window_configs


163
164
class FlashAttentionMetadataBuilder(
        AttentionMetadataBuilder[FlashAttentionMetadata]):
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    # FA3:
    # Supports full cudagraphs for all cases.
    #
    # FA2:
    # For FA2, a graph is captured with max_query_len=1, (which is what we
    # capture by default for num_tokens <= max_num_seqs when there is no
    # spec-decode) then these graphs will not work for mixed prefill-decode
    # (unlike FA3). This is due to special max_query_len=1 packed-GQA handling
    # in FA2.
    # In summary if we are running with spec decodes the graphs would
    # work for mixed prefill-decode and uniform-decode. But for non-spec decodes
    # the graphs would not work for mixed prefill-decode; sorta the inverse
    # of UNIFORM_SINGLE_TOKEN_DECODE.
    # Theres probably a better way to describe this using `AttentionCGSupport`
    # but for now just set it to `UNIFORM_BATCH` to get use to drop down
    # to FULL_AND_PIECEWISE.
    # TODO(luka, lucas): audit FA2 as part of:
    #  https://github.com/vllm-project/vllm/issues/22945
    cudagraph_support = AttentionCGSupport.ALWAYS \
184
        if get_flash_attn_version() == 3 or current_platform.is_rocm()  else AttentionCGSupport.UNIFORM_BATCH
185

186
187
    def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
                 vllm_config: VllmConfig, device: torch.device):
188
189
190
191
192
193
194
195
196
197
198
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.parallel_config = vllm_config.parallel_config
        self.cache_config = vllm_config.cache_config
        self.compilation_config = vllm_config.compilation_config
        self.device = device

        self.num_heads_q = self.model_config.get_num_attention_heads(
            self.parallel_config)
        self.num_heads_kv = self.model_config.get_num_kv_heads(
            self.parallel_config)
199
        self.kv_cache_dtype = kv_cache_spec.dtype
200
        self.headdim = self.model_config.get_head_size()
201
        self.block_size = kv_cache_spec.block_size
202

203
        self.max_num_splits = 0  # No upper bound on the number of splits.
204
        self.aot_schedule = (get_flash_attn_version() == 3)
205
206
207
208
209
210
211

        self.use_full_cuda_graph = \
            self.compilation_config.cudagraph_mode.has_full_cudagraphs()

        if self.use_full_cuda_graph and self.aot_schedule:
            self.max_cudagraph_size = self.compilation_config.max_capture_size

212
213
214
215
216
217
218
219
            if self.max_cudagraph_size > 992:
                # This condition derives from FA3's internal heuristic.
                # TODO(woosuk): Support larger cudagraph sizes.
                raise ValueError(
                    "Capture size larger than 992 is not supported for "
                    "full cuda graph.")

            self.scheduler_metadata = torch.zeros(
220
                vllm_config.scheduler_config.max_num_seqs + 1,
221
                dtype=torch.int32,
222
                device=self.device,
223
224
225
226
227
            )
            # When using cuda graph, we need to set the upper bound of the
            # number of splits so that large enough intermediate buffers are
            # pre-allocated during capture.
            self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
228

229
230
231
232
        # Sliding window size to be used with the AOT scheduler will be
        # populated on first build() call.
        self.aot_sliding_window: Optional[tuple[int, int]] = None

233
234
235
236
237
238
239
240
    def build(self,
              common_prefix_len: int,
              common_attn_metadata: CommonAttentionMetadata,
              fast_build: bool = False) -> FlashAttentionMetadata:
        """
        fast_build disables AOT scheduling, used when there will be few 
        iterations i.e. spec-decode
        """
241
242
243
        num_reqs = common_attn_metadata.num_reqs
        num_actual_tokens = common_attn_metadata.num_actual_tokens
        max_query_len = common_attn_metadata.max_query_len
244
        max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
245
246
        query_start_loc = common_attn_metadata.query_start_loc
        seq_lens = common_attn_metadata.seq_lens
247
248
249
        seq_lens_cpu = common_attn_metadata.seq_lens_cpu
        block_table_tensor = common_attn_metadata.block_table_tensor
        slot_mapping = common_attn_metadata.slot_mapping
250
        causal = common_attn_metadata.causal
251

252
253
        # the overhead of the aot schedule is not worth it for spec-decode
        aot_schedule = self.aot_schedule and not fast_build
254

zhuwenwen's avatar
zhuwenwen committed
255
256
257
258
259
260
        if self.aot_sliding_window is None:
            self.aot_sliding_window = (-1, -1)
            # For the AOT scheduler we need the sliding window value to be
            # constant for all layers to. We have to populate this on the first
            # build() call so the layers are constructed (cannot populate)
            # in __init__.
261
            if aot_schedule:
zhuwenwen's avatar
zhuwenwen committed
262
                sliding_window_configs = _get_sliding_window_configs(
263
                    self.vllm_config)
zhuwenwen's avatar
zhuwenwen committed
264
265
266
267
268
269
                if len(sliding_window_configs) == 1:
                    sliding_window_config = sliding_window_configs.pop()
                    if sliding_window_config is not None:
                        self.aot_sliding_window = sliding_window_config
                elif len(sliding_window_configs) > 1:
                    self.aot_schedule = False
270
                    aot_schedule = False
271

272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        if self.aot_sliding_window is None:
            self.aot_sliding_window = (-1, -1)
            # For the AOT scheduler we need the sliding window value to be
            # constant for all layers to. We have to populate this on the first
            # build() call so the layers are constructed (cannot populate)
            # in __init__.
            if self.aot_schedule:
                sliding_window_configs = _get_sliding_window_configs(
                    self.runner.vllm_config)
                if len(sliding_window_configs) == 1:
                    sliding_window_config = sliding_window_configs.pop()
                    if sliding_window_config is not None:
                        self.aot_sliding_window = sliding_window_config
                elif len(sliding_window_configs) > 1:
                    self.aot_schedule = False

288
289
        def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
                     max_seq_len, causal):
290
291
292
293
294
295
            cache_dtype = self.cache_config.cache_dtype
            if cache_dtype.startswith("fp8"):
                qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
                    cache_dtype)
            else:
                qkv_dtype = self.kv_cache_dtype
296
            if aot_schedule:
297
298
299
300
301
302
303
                return get_scheduler_metadata(
                    batch_size=batch_size,
                    max_seqlen_q=max_query_len,
                    max_seqlen_k=max_seq_len,
                    num_heads_q=self.num_heads_q,
                    num_heads_kv=self.num_heads_kv,
                    headdim=self.headdim,
304
305
                    cache_seqlens=seqlens,
                    qkv_dtype=qkv_dtype,
306
                    cu_seqlens_q=cu_query_lens,
307
                    page_size=self.block_size,
308
                    causal=causal,
309
                    window_size=self.aot_sliding_window,
310
                    num_splits=self.max_num_splits,
311
312
313
                )
            return None

314
        use_cascade = common_prefix_len > 0
315

316
317
318
        if use_cascade:
            cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
                                                dtype=torch.int32,
319
                                                device=self.device)
320
321
            prefix_kv_lens = torch.tensor([common_prefix_len],
                                          dtype=torch.int32,
322
323
324
                                          device=self.device)
            suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to(
                self.device, non_blocking=True)
325
            prefix_scheduler_metadata = schedule(
326
                batch_size=1,
327
328
329
330
331
                cu_query_lens=cu_prefix_query_lens,
                max_query_len=num_actual_tokens,
                seqlens=prefix_kv_lens,
                max_seq_len=common_prefix_len,
                causal=False)
332
333
            scheduler_metadata = schedule(batch_size=num_reqs,
                                          cu_query_lens=query_start_loc,
334
335
336
337
338
                                          max_query_len=max_query_len,
                                          seqlens=suffix_kv_lens,
                                          max_seq_len=max_seq_len -
                                          common_prefix_len,
                                          causal=True)
339
340
341
342
        else:
            cu_prefix_query_lens = None
            prefix_kv_lens = None
            suffix_kv_lens = None
343
            prefix_scheduler_metadata = None
344
345
            scheduler_metadata = schedule(batch_size=num_reqs,
                                          cu_query_lens=query_start_loc,
346
347
348
                                          max_query_len=max_query_len,
                                          seqlens=seq_lens,
                                          max_seq_len=max_seq_len,
349
                                          causal=causal)
350
                                          
351
352
353
        # For FA3 + full cudagraph
        max_num_splits = 0
        if self.use_full_cuda_graph and scheduler_metadata is not None:
354
            n = scheduler_metadata.shape[0]
355
            self.scheduler_metadata[:n] = scheduler_metadata
356
357
358
359
360
361
362
            # NOTE(woosuk): We should zero out the rest of the scheduler
            # metadata to guarantee the correctness. Otherwise, some thread
            # blocks may use the invalid scheduler metadata and overwrite the
            # output buffer.
            self.scheduler_metadata[n:] = 0
            scheduler_metadata = self.scheduler_metadata[:n]

363
364
365
366
367
368
            if num_actual_tokens <= self.max_cudagraph_size:
                # NOTE(woosuk): Setting num_splits > 1 may increase the memory
                # usage, because the intermediate buffers of size [num_splits,
                # num_heads, num_tokens, head_size] are allocated. Therefore,
                # we only set num_splits when using cuda graphs.
                max_num_splits = self.max_num_splits
369

370
371
372
373
374
375
        attn_metadata = FlashAttentionMetadata(
            num_actual_tokens=num_actual_tokens,
            max_query_len=max_query_len,
            query_start_loc=query_start_loc,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
376
            block_table=block_table_tensor,
377
378
379
            slot_mapping=slot_mapping,
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
380
            scheduler_metadata=scheduler_metadata,
381
382
383
            cu_prefix_query_lens=cu_prefix_query_lens,
            prefix_kv_lens=prefix_kv_lens,
            suffix_kv_lens=suffix_kv_lens,
384
            prefix_scheduler_metadata=prefix_scheduler_metadata,
385
            max_num_splits=max_num_splits,
386
            causal=causal)
387
388
        return attn_metadata

389
390
391
    def use_cascade_attention(self, *args, **kwargs) -> bool:
        return use_cascade_attention(*args, **kwargs)

392

393
394
395
396
397
398
399
400
class FlashAttentionImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
401
        alibi_slopes: Optional[list[float]],
402
403
404
        sliding_window: Optional[int],
        kv_cache_dtype: str,
        logits_soft_cap: Optional[float] = None,
405
        attn_type: AttentionType = AttentionType.DECODER,
406
        kv_sharing_target_layer_name: Optional[str] = None,
407
        sinks: Optional[torch.Tensor] = None,
408
409
410
411
412
413
414
415
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
416
417
        if sliding_window is None:
            self.sliding_window = (-1, -1)
418
419
        elif attn_type == AttentionType.ENCODER_ONLY:
            self.sliding_window = (sliding_window - 1, sliding_window - 1)
420
421
        else:
            self.sliding_window = (sliding_window - 1, 0)
422
423
424
425
426
        self.kv_cache_dtype = kv_cache_dtype
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap
427
        self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
428
429
430

        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

431
        FlashAttentionBackend.validate_head_size(head_size)
432

433
434
435
436
437
        if attn_type not in [
                AttentionType.DECODER, AttentionType.ENCODER_ONLY
        ]:
            raise NotImplementedError("Encoder/decoder cross-attention "
                                      "is not implemented for "
438
                                      "FlashAttentionImpl")
439
440

        self.attn_type = attn_type
441
        self.vllm_flash_attn_version = get_flash_attn_version()
442
443
444
445
        if is_quantized_kv_cache(self.kv_cache_dtype) \
            and not flash_attn_supports_fp8():
            raise NotImplementedError(
                "FlashAttention does not support fp8 kv-cache on this device.")
446

447
448
449
450
451
452
453
454
        self.sinks = sinks
        if self.sinks is not None:
            assert self.vllm_flash_attn_version == 3, (
                "Sinks are only supported in FlashAttention 3")
            assert self.sinks.shape[0] == num_heads, (
                "Sinks must have the same number of heads as the number of "
                "heads in the layer")

455
456
    def forward(
        self,
457
        layer: torch.nn.Module,
458
459
460
461
462
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
463
        output: Optional[torch.Tensor] = None,
464
        output_scale: Optional[torch.Tensor] = None,
465
466
467
468
    ) -> torch.Tensor:
        """Forward pass with FlashAttention.

        Args:
469
470
471
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
472
473
474
475
            kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
476
477
478
        NOTE: FP8 quantization, flash-attn expect the size of
              {q,k,v}_descale to be (num_sequences, num_kv_heads).
              We use torch's .expand() to avoid duplicating values
479
        """
480
481
        assert output is not None, "Output tensor must be provided."

482
483
484
485
486
        if output_scale is not None:
            raise NotImplementedError(
                "fused output quantization is not yet supported"
                " for FlashAttentionImpl")

487
488
489
490
        if attn_metadata is None:
            # Profiling run.
            return output

491
492
        attn_type = self.attn_type

493
494
495
496
497
498
499
500
        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.
501

502
        num_actual_tokens = attn_metadata.num_actual_tokens
503
504
505
506
507
508
509
510
511
512
513
514

        # Handle encoder attention differently - no KV cache needed
        if attn_type in (AttentionType.ENCODER_ONLY, ):
            # For encoder attention,
            # we use direct Q, K, V tensors without caching
            return self._forward_encoder_attention(query[:num_actual_tokens],
                                                   key[:num_actual_tokens],
                                                   value[:num_actual_tokens],
                                                   output[:num_actual_tokens],
                                                   attn_metadata, layer)

        # For decoder and cross-attention, use KV cache as before
515
        key_cache, value_cache = kv_cache.unbind(0)
516
517
518
519
520
521
522
523
524

        if self.kv_sharing_target_layer_name is None:
            # Reshape the input keys and values and store them in the cache.
            # Skip this if sharing KV cache with an earlier attention layer.
            # NOTE(woosuk): Here, key and value are padded while slot_mapping is
            # not padded. However, we don't need to do key[:num_actual_tokens]
            # and value[:num_actual_tokens] because the reshape_and_cache_flash
            # op uses the slot_mapping's shape to determine the number of
            # actual tokens.
525
            reshape_and_cache_flash(
526
527
528
529
530
531
532
533
534
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
                self.kv_cache_dtype,
                layer._k_scale,
                layer._v_scale,
            )
535

536
        if self.kv_cache_dtype.startswith("fp8"):
537
538
539
540
            dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
                self.kv_cache_dtype)
            key_cache = key_cache.view(dtype)
            value_cache = value_cache.view(dtype)
541
542
543
544
545
546
            num_tokens, num_heads, head_size = query.shape
            query, _ = ops.scaled_fp8_quant(
                query.reshape(
                    (num_tokens, num_heads * head_size)).contiguous(),
                layer._q_scale)
            query = query.reshape((num_tokens, num_heads, head_size))
547

548
549
550
551
552
553
554
        if not attn_metadata.use_cascade:
            cu_seqlens_q = attn_metadata.query_start_loc
            seqused_k = attn_metadata.seq_lens
            max_seqlen_q = attn_metadata.max_query_len
            max_seqlen_k = attn_metadata.max_seq_len
            block_table = attn_metadata.block_table
            scheduler_metadata = attn_metadata.scheduler_metadata
555
556
557

            descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

zhuwenwen's avatar
zhuwenwen committed
558
            if not current_platform.is_rocm():
559
                 flash_attn_varlen_func(
zhuwenwen's avatar
zhuwenwen committed
560
561
562
563
564
565
566
567
568
                    q=query[:num_actual_tokens],
                    k=key_cache,
                    v=value_cache,
                    out=output[:num_actual_tokens],
                    cu_seqlens_q=cu_seqlens_q,
                    max_seqlen_q=max_seqlen_q,
                    seqused_k=seqused_k,
                    max_seqlen_k=max_seqlen_k,
                    softmax_scale=self.scale,
569
                    causal=attn_metadata.causal,
zhuwenwen's avatar
zhuwenwen committed
570
571
572
573
574
575
576
577
578
579
                    alibi_slopes=self.alibi_slopes,
                    window_size=self.sliding_window,
                    block_table=block_table,
                    softcap=self.logits_soft_cap,
                    scheduler_metadata=scheduler_metadata,
                    fa_version=self.vllm_flash_attn_version,
                    q_descale=layer._q_scale.expand(descale_shape),
                    k_descale=layer._k_scale.expand(descale_shape),
                    v_descale=layer._v_scale.expand(descale_shape),
                    num_splits=attn_metadata.max_num_splits,
580
                    s_aux=self.sinks,
zhuwenwen's avatar
zhuwenwen committed
581
582
                )
            else:
583
584
585
586
587
                if envs.VLLM_USE_PA_PRINT_PARAM:
                    print("PA SIZE:")
                    print(f"q.shape = {query[:num_actual_tokens].unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
                    print(f"cu_seqlens_q.shape = {cu_seqlens_q.shape}, max_seqlen_q = {max_seqlen_q}, seqused_k.shape = {seqused_k.shape}, max_seqlen_k = {max_seqlen_k}")
                    print(f"softmax_scale = {self.scale:.3f}, alibi_slopes = {self.alibi_slopes}, window_size = {self.sliding_window}, block_tables.shape = {block_table.shape}, softcap = {self.logits_soft_cap}, scheduler_metadata = {scheduler_metadata}")
zhuwenwen's avatar
zhuwenwen committed
588
589
590
591
592
593
594
595
596
597
                vllm_flash_attn_varlen_func(
                    q=query[:num_actual_tokens],
                    k=key_cache,
                    v=value_cache,
                    out=output[:num_actual_tokens],
                    cu_seqlens_q=cu_seqlens_q,
                    max_seqlen_q=max_seqlen_q,
                    seqused_k=seqused_k,
                    max_seqlen_k=max_seqlen_k,
                    softmax_scale=self.scale,
598
                    causal=attn_metadata.causal,
zhuwenwen's avatar
zhuwenwen committed
599
600
601
602
603
604
605
606
607
608
                    alibi_slopes=self.alibi_slopes,
                    window_size=self.sliding_window,
                    block_table=block_table,
                    softcap=self.logits_soft_cap,
                    scheduler_metadata=scheduler_metadata,
                    # fa_version=self.vllm_flash_attn_version,
                    # q_descale=layer._q_scale.expand(descale_shape),
                    # k_descale=layer._k_scale.expand(descale_shape),
                    # v_descale=layer._v_scale.expand(descale_shape),
                    # num_splits=attn_metadata.max_num_splits,
609
                    # s_aux=self.sinks,
zhuwenwen's avatar
zhuwenwen committed
610
611
                    is_prefix_cache=False,
                )
612
            return output
613
614
615
616
617
618
619
620
621
622

        # Cascade attention (rare case).
        cascade_attention(
            output[:num_actual_tokens],
            query[:num_actual_tokens],
            key_cache,
            value_cache,
            cu_query_lens=attn_metadata.query_start_loc,
            max_query_len=attn_metadata.max_query_len,
            cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
623
624
            prefix_kv_lens=attn_metadata.prefix_kv_lens,
            suffix_kv_lens=attn_metadata.suffix_kv_lens,
625
            max_kv_len=attn_metadata.max_seq_len,
626
627
            softmax_scale=self.scale,
            alibi_slopes=self.alibi_slopes,
628
629
            sliding_window=self.sliding_window,
            logits_soft_cap=self.logits_soft_cap,
630
            block_table=attn_metadata.block_table,
631
            common_prefix_len=attn_metadata.common_prefix_len,
632
            fa_version=self.vllm_flash_attn_version,
633
634
            prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
            suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
635
636
637
            q_descale=layer._q_scale,
            k_descale=layer._k_scale,
            v_descale=layer._v_scale,
638
639
        )
        return output
640

641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
    def _forward_encoder_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        output: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
        layer: torch.nn.Module,
    ) -> torch.Tensor:
        """Forward pass for encoder attention without KV cache.

        Args:
            query: shape = [num_encoder_tokens, num_heads, head_size]
            key: shape = [num_encoder_tokens, num_kv_heads, head_size]
            value: shape = [num_encoder_tokens, num_kv_heads, head_size]
            output: shape = [num_encoder_tokens, num_heads, head_size]
            attn_metadata: Encoder attention metadata
            layer: The attention layer
        """
        # For encoder attention, process FP8 quantization if needed
        if self.kv_cache_dtype.startswith("fp8"):
            raise NotImplementedError(
                "quantization is not supported for encoder attention")

        # Use encoder-specific metadata for sequence information
        cu_seqlens_q = attn_metadata.query_start_loc
        cu_seqlens_k = attn_metadata.query_start_loc
        max_seqlen_q = attn_metadata.max_query_len
        max_seqlen_k = attn_metadata.max_query_len

        descale_shape = (
            cu_seqlens_q.shape[0] - 1,  # type: ignore[union-attr]
            self.num_kv_heads)

        # Call flash attention directly on Q, K, V tensors
        flash_attn_varlen_func(
            q=query,
            k=key,
            v=value,
            out=output,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_k=max_seqlen_k,
            softmax_scale=self.scale,
            causal=False,  # Encoder attention is bidirectional
            alibi_slopes=self.alibi_slopes,
            window_size=self.sliding_window,
            softcap=self.logits_soft_cap,
            fa_version=self.vllm_flash_attn_version,
            q_descale=layer._q_scale.expand(descale_shape),
            k_descale=layer._k_scale.expand(descale_shape),
            v_descale=layer._v_scale.expand(descale_shape),
        )

        return output

698
699
700
701
702
703
704
705

def use_cascade_attention(
    common_prefix_len: int,
    query_lens: np.ndarray,
    num_query_heads: int,
    num_kv_heads: int,
    use_alibi: bool,
    use_sliding_window: bool,
706
    use_local_attention: bool,
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
    num_sms: int,
) -> bool:
    """Decide whether to use cascade attention.

    This function 1) checks whether cascade attention is supported with the
    given configuration, and 2) heuristically decides whether using cascade
    attention can improve performance.
    """
    # Too short common prefix. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
    # NOTE(woosuk): This is the common case. We should return False as soon as
    # possible to avoid any unnecessary computation.
    if common_prefix_len < 256:
        return False
    # Cascade attention is currently not supported with these variants.
722
    if use_alibi or use_sliding_window or use_local_attention:
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
        return False
    # Too few queries. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
    num_reqs = len(query_lens)
    if num_reqs < 8:
        return False

    # Heuristics to decide whether using cascade attention is beneficial.
    # 1. When FlashDecoding is not used for normal attention, cascade attention
    #    is likely to be faster since it saves memory bandwidth.
    num_queries_per_kv = num_query_heads // num_kv_heads
    # The criteria for using FlashDecoding can be found in the following link:
    # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
    use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
                          and not use_alibi and np.all(query_lens == 1))
    if not use_flash_decoding:
        # Use cascade attention.
        return True

    # 2. When FlashDecoding is used for normal attention, it is not clear
    #    whether cascade attention is beneficial, because FlashDecoding can
    #    launch more CTAs than cascade attention.
    #    We use a simple performance model to compare the two methods.
    #    NOTE(woosuk): The performance model is very rough and may not be
    #    accurate.
    num_tokens = num_reqs
    # NOTE(woosuk): These are default tile sizes. flash-attn might use
    # different tile sizes (e.g., 64 or 256) depending on the configuration.
    q_tile_size = 128
    kv_tile_size = 128
    num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)

    cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
    cascade_waves = cdiv(cascade_ctas, num_sms)
    cascade_time = cascade_waves * num_prefix_tiles

    flash_decoding_ctas = (num_reqs * num_kv_heads *
                           cdiv(num_queries_per_kv, q_tile_size))
    flash_decoding_ctas *= num_prefix_tiles
    flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)

    # Use cascade attention if it is faster than FlashDecoding.
    return cascade_time < flash_decoding_time


def cascade_attention(
    output: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    cu_query_lens: torch.Tensor,
    max_query_len: int,
    cu_prefix_query_lens: torch.Tensor,
776
777
    prefix_kv_lens: torch.Tensor,
    suffix_kv_lens: torch.Tensor,
778
779
780
    max_kv_len: int,
    softmax_scale: float,
    alibi_slopes: Optional[torch.Tensor],
781
    sliding_window: tuple[int, int],
782
783
784
    logits_soft_cap: float,
    block_table: torch.Tensor,
    common_prefix_len: int,
785
    fa_version: int,
786
787
    prefix_scheduler_metadata: Optional[torch.Tensor] = None,
    suffix_scheduler_metadata: Optional[torch.Tensor] = None,
788
789
790
    q_descale: Optional[torch.Tensor] = None,
    k_descale: Optional[torch.Tensor] = None,
    v_descale: Optional[torch.Tensor] = None,
791
792
793
794
795
796
797
798
799
800
801
) -> torch.Tensor:
    assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
    # TODO: Support sliding window.
    assert sliding_window == (-1, -1), (
        "Cascade attention does not support sliding window.")

    num_tokens = query.shape[0]
    block_size = key_cache.shape[-3]
    assert common_prefix_len % block_size == 0
    num_common_kv_blocks = common_prefix_len // block_size
    assert num_common_kv_blocks > 0
802
    descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
803
804

    # Process shared prefix.
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
    if not current_platform.is_rocm():
        prefix_output, prefix_lse = flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_prefix_query_lens,
            seqused_k=prefix_kv_lens,
            max_seqlen_q=num_tokens,
            max_seqlen_k=common_prefix_len,
            softmax_scale=softmax_scale,
            causal=False,
            window_size=sliding_window,
            block_table=block_table[:1],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=prefix_scheduler_metadata,
            fa_version=fa_version,
            q_descale=q_descale.expand(descale_shape)
            if q_descale is not None else None,
            k_descale=k_descale.expand(descale_shape)
            if k_descale is not None else None,
            v_descale=v_descale.expand(descale_shape)
            if v_descale is not None else None,
        )
829

830
831
    descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])

832
    # Process suffix per query.
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
    if not current_platform.is_rocm():
        suffix_output, suffix_lse = flash_attn_varlen_func(
            q=query,
            k=key_cache,
            v=value_cache,
            cu_seqlens_q=cu_query_lens,
            seqused_k=suffix_kv_lens,
            max_seqlen_q=max_query_len,
            max_seqlen_k=max_kv_len - common_prefix_len,
            softmax_scale=softmax_scale,
            causal=True,
            window_size=sliding_window,
            block_table=block_table[:, num_common_kv_blocks:],
            softcap=logits_soft_cap,
            return_softmax_lse=True,
            scheduler_metadata=suffix_scheduler_metadata,
            fa_version=fa_version,
            q_descale=q_descale.expand(descale_shape)
            if q_descale is not None else None,
            k_descale=k_descale.expand(descale_shape)
            if k_descale is not None else None,
            v_descale=v_descale.expand(descale_shape)
            if v_descale is not None else None,
        )
857
858
859

    # Merge prefix and suffix outputs, and store the result in output.
    merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
860
                      suffix_lse)