flash_attn.py 18.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Any, Optional
5

6
import numpy as np
7
8
import torch

9
from vllm import _custom_ops as ops
10
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
11
12
                                              AttentionMetadata, AttentionType,
                                              is_quantized_kv_cache)
13
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
14
from vllm.fa_utils import get_flash_attn_version
15
from vllm.logger import init_logger
16
from vllm.platforms import current_platform
17
from vllm.utils import cdiv
18

19
if TYPE_CHECKING:
20
    from vllm.v1.core.sched.output import SchedulerOutput
21
22
23
    from vllm.v1.worker.gpu_input_batch import InputBatch
    from vllm.v1.worker.gpu_model_runner import GPUModelRunner

24
25
if current_platform.is_cuda():
    from vllm.vllm_flash_attn import flash_attn_varlen_func
26

27
28
logger = init_logger(__name__)

29
30
31

class FlashAttentionBackend(AttentionBackend):

32
33
    accept_output_buffer: bool = True

34
    @staticmethod
35
    def get_supported_head_sizes() -> list[int]:
36
37
38
39
        return [32, 64, 96, 128, 160, 192, 224, 256]

    @staticmethod
    def get_name() -> str:
40
        return "FLASH_ATTN_VLLM_V1"
41
42

    @staticmethod
43
    def get_impl_cls() -> type["FlashAttentionImpl"]:
44
45
46
        return FlashAttentionImpl

    @staticmethod
47
    def get_metadata_cls() -> type["AttentionMetadata"]:
48
49
        return FlashAttentionMetadata

50
    @staticmethod
51
    def get_builder_cls() -> type["FlashAttentionMetadataBuilder"]:
52
53
        return FlashAttentionMetadataBuilder

54
55
56
57
58
59
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
60
    ) -> tuple[int, ...]:
61
62
63
64
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")
        return (2, num_blocks, block_size, num_kv_heads, head_size)

65
66
67
68
    @staticmethod
    def use_cascade_attention(*args, **kwargs) -> bool:
        return use_cascade_attention(*args, **kwargs)

69
70
71
72
73
74
75
76
77
78
79

@dataclass
class FlashAttentionMetadata:
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
    # |-------------------- seq_len ---------------------|
    #                                   |-- query_len ---|

80
    num_actual_tokens: int  # Number of tokens excluding padding.
81
82
83
    max_query_len: int
    query_start_loc: torch.Tensor
    max_seq_len: int
84
    seq_lens: torch.Tensor
85
86
    block_table: torch.Tensor
    slot_mapping: torch.Tensor
87
88
89
90
91

    # For cascade attention.
    use_cascade: bool
    common_prefix_len: int
    cu_prefix_query_lens: Optional[torch.Tensor]
92
93
    prefix_kv_lens: Optional[torch.Tensor]
    suffix_kv_lens: Optional[torch.Tensor]
94
95

    # For logging.
96
    num_input_tokens: int = 0  # Number of tokens including padding.
97
98


99
100
101
102
103
104
class FlashAttentionMetadataBuilder:

    def __init__(self, runner: "GPUModelRunner"):
        self.runner = runner

    def reorder_batch(self, input_batch: "InputBatch",
105
106
                      scheduler_output: "SchedulerOutput") -> bool:
        return False
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

    def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
              common_prefix_len: int):
        max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
        query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
            self.runner.device, non_blocking=True)
        seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
                                                          non_blocking=True)
        block_table = (
            self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
        slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
            self.runner.device, non_blocking=True).long()

        use_cascade = common_prefix_len > 0
        if use_cascade:
            # TODO: Optimize.
            cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
                                                dtype=torch.int32,
                                                device=self.runner.device)
            prefix_kv_lens = torch.tensor([common_prefix_len],
                                          dtype=torch.int32,
                                          device=self.runner.device)
            suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
                              common_prefix_len)
            suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
                self.runner.device)
        else:
            cu_prefix_query_lens = None
            prefix_kv_lens = None
            suffix_kv_lens = None

        attn_metadata = FlashAttentionMetadata(
            num_actual_tokens=num_actual_tokens,
            max_query_len=max_query_len,
            query_start_loc=query_start_loc,
            max_seq_len=max_seq_len,
            seq_lens=seq_lens,
            block_table=block_table,
            slot_mapping=slot_mapping,
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
            cu_prefix_query_lens=cu_prefix_query_lens,
            prefix_kv_lens=prefix_kv_lens,
            suffix_kv_lens=suffix_kv_lens,
        )
        return attn_metadata


155
156
157
158
159
160
161
162
class FlashAttentionImpl(AttentionImpl):

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
163
        alibi_slopes: Optional[list[float]],
164
165
        sliding_window: Optional[int],
        kv_cache_dtype: str,
166
        blocksparse_params: Optional[dict[str, Any]] = None,
167
        logits_soft_cap: Optional[float] = None,
168
        attn_type: AttentionType = AttentionType.DECODER,
169
170
171
172
173
174
175
176
177
178
179
    ) -> None:
        if blocksparse_params is not None:
            raise ValueError(
                "FlashAttention does not support block-sparse attention.")
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
180
181
182
183
        if sliding_window is None:
            self.sliding_window = (-1, -1)
        else:
            self.sliding_window = (sliding_window - 1, 0)
184
        self.kv_cache_dtype = kv_cache_dtype
185
186
187
        if is_quantized_kv_cache(self.kv_cache_dtype):
            raise NotImplementedError(
                "FlashAttention V1 with FP8 KV cache not yet supported")
188
189
190
191
192
193
194
195
196
197
198
199
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
        if head_size not in support_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by FlashAttention. "
200
201
                f"Supported head sizes are: {support_head_sizes}. "
                "Set VLLM_USE_V1=0 to use another attention backend.")
202

203
204
205
206
207
        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "FlashAttentionImpl")
208
        self.vllm_flash_attn_version = get_flash_attn_version()
209

210
211
    def forward(
        self,
212
        layer: torch.nn.Module,
213
214
215
216
217
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: FlashAttentionMetadata,
218
        output: Optional[torch.Tensor] = None,
219
220
221
222
    ) -> torch.Tensor:
        """Forward pass with FlashAttention.

        Args:
223
224
225
            query: shape = [num_tokens, num_heads, head_size]
            key: shape = [num_tokens, num_kv_heads, head_size]
            value: shape = [num_tokens, num_kv_heads, head_size]
226
227
228
229
            kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
230
231
232
        NOTE: FP8 quantization, flash-attn expect the size of
              {q,k,v}_descale to be (num_sequences, num_kv_heads).
              We use torch's .expand() to avoid duplicating values
233
        """
234
235
        assert output is not None, "Output tensor must be provided."

236
237
238
239
        if attn_metadata is None:
            # Profiling run.
            return output

240
241
242
243
244
245
246
247
        # IMPORTANT!
        # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
        # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
        # in this method. For example, `view` and `slice` (or `[:n]`) operations
        # are surprisingly slow even in the case they do not invoke any GPU ops.
        # Minimize the PyTorch ops in this method as much as possible.
        # Whenever making a change in this method, please benchmark the
        # performance to make sure it does not introduce any overhead.
248

249
        num_actual_tokens = attn_metadata.num_actual_tokens
250
        # Reshape the input keys and values and store them in the cache.
251
252
253
254
255
        # NOTE(woosuk): Here, key and value are padded while slot_mapping is
        # not padded. However, we don't need to do key[:num_actual_tokens] and
        # value[:num_actual_tokens] because the reshape_and_cache_flash op uses
        # the slot_mapping's shape to determine the number of actual tokens.
        key_cache, value_cache = kv_cache.unbind(0)
256
        torch.ops._C_cache_ops.reshape_and_cache_flash(
257
258
            key,
            value,
259
260
261
            key_cache,
            value_cache,
            attn_metadata.slot_mapping,
262
            self.kv_cache_dtype,
263
264
            layer._k_scale,
            layer._v_scale,
265
        )
266
267
268
269
270
271
272
273
274
275
276
        descale_shape = (attn_metadata.query_start_loc.shape[0] - 1,
                         key.shape[1])
        if self.kv_cache_dtype.startswith("fp8"):
            key_cache = key_cache.view(torch.float8_e4m3fn)
            value_cache = value_cache.view(torch.float8_e4m3fn)
            num_tokens, num_heads, head_size = query.shape
            query, _ = ops.scaled_fp8_quant(
                query.reshape(
                    (num_tokens, num_heads * head_size)).contiguous(),
                layer._q_scale)
            query = query.reshape((num_tokens, num_heads, head_size))
277
278

        # Compute attention and update output up to `num_actual_tokens`.
279
280
281
282
283
284
285
286
287
        if not attn_metadata.use_cascade:
            # Regular attention (common case).
            flash_attn_varlen_func(
                q=query[:num_actual_tokens],
                k=key_cache,
                v=value_cache,
                out=output[:num_actual_tokens],
                cu_seqlens_q=attn_metadata.query_start_loc,
                max_seqlen_q=attn_metadata.max_query_len,
288
                seqused_k=attn_metadata.seq_lens,
289
290
291
292
293
294
295
                max_seqlen_k=attn_metadata.max_seq_len,
                softmax_scale=self.scale,
                causal=True,
                alibi_slopes=self.alibi_slopes,
                window_size=self.sliding_window,
                block_table=attn_metadata.block_table,
                softcap=self.logits_soft_cap,
296
                fa_version=self.vllm_flash_attn_version,
297
298
299
                q_descale=layer._q_scale.expand(descale_shape),
                k_descale=layer._k_scale.expand(descale_shape),
                v_descale=layer._v_scale.expand(descale_shape),
300
301
302
303
304
305
306
307
308
309
310
311
            )
            return output

        # Cascade attention (rare case).
        cascade_attention(
            output[:num_actual_tokens],
            query[:num_actual_tokens],
            key_cache,
            value_cache,
            cu_query_lens=attn_metadata.query_start_loc,
            max_query_len=attn_metadata.max_query_len,
            cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
312
313
            prefix_kv_lens=attn_metadata.prefix_kv_lens,
            suffix_kv_lens=attn_metadata.suffix_kv_lens,
314
            max_kv_len=attn_metadata.max_seq_len,
315
316
            softmax_scale=self.scale,
            alibi_slopes=self.alibi_slopes,
317
318
            sliding_window=self.sliding_window,
            logits_soft_cap=self.logits_soft_cap,
319
            block_table=attn_metadata.block_table,
320
            common_prefix_len=attn_metadata.common_prefix_len,
321
            fa_version=self.vllm_flash_attn_version,
322
323
324
            q_descale=layer._q_scale,
            k_descale=layer._k_scale,
            v_descale=layer._v_scale,
325
326
        )
        return output
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404


def use_cascade_attention(
    common_prefix_len: int,
    query_lens: np.ndarray,
    num_query_heads: int,
    num_kv_heads: int,
    use_alibi: bool,
    use_sliding_window: bool,
    num_sms: int,
) -> bool:
    """Decide whether to use cascade attention.

    This function 1) checks whether cascade attention is supported with the
    given configuration, and 2) heuristically decides whether using cascade
    attention can improve performance.
    """
    # Too short common prefix. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 256 tokens. TODO: Tune this threshold.
    # NOTE(woosuk): This is the common case. We should return False as soon as
    # possible to avoid any unnecessary computation.
    if common_prefix_len < 256:
        return False
    # Cascade attention is currently not supported with these variants.
    if use_alibi or use_sliding_window:
        return False
    # Too few queries. Probably not worth using cascade attention.
    # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.
    num_reqs = len(query_lens)
    if num_reqs < 8:
        return False

    # Heuristics to decide whether using cascade attention is beneficial.
    # 1. When FlashDecoding is not used for normal attention, cascade attention
    #    is likely to be faster since it saves memory bandwidth.
    num_queries_per_kv = num_query_heads // num_kv_heads
    # The criteria for using FlashDecoding can be found in the following link:
    # https://github.com/vllm-project/flash-attention/blob/96266b1111111f3d11aabefaf3bacbab6a89d03c/csrc/flash_attn/flash_api.cpp#L535
    use_flash_decoding = (num_queries_per_kv > 1 and not use_sliding_window
                          and not use_alibi and np.all(query_lens == 1))
    if not use_flash_decoding:
        # Use cascade attention.
        return True

    # 2. When FlashDecoding is used for normal attention, it is not clear
    #    whether cascade attention is beneficial, because FlashDecoding can
    #    launch more CTAs than cascade attention.
    #    We use a simple performance model to compare the two methods.
    #    NOTE(woosuk): The performance model is very rough and may not be
    #    accurate.
    num_tokens = num_reqs
    # NOTE(woosuk): These are default tile sizes. flash-attn might use
    # different tile sizes (e.g., 64 or 256) depending on the configuration.
    q_tile_size = 128
    kv_tile_size = 128
    num_prefix_tiles = cdiv(common_prefix_len, kv_tile_size)

    cascade_ctas = num_query_heads * cdiv(num_tokens, q_tile_size)
    cascade_waves = cdiv(cascade_ctas, num_sms)
    cascade_time = cascade_waves * num_prefix_tiles

    flash_decoding_ctas = (num_reqs * num_kv_heads *
                           cdiv(num_queries_per_kv, q_tile_size))
    flash_decoding_ctas *= num_prefix_tiles
    flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)

    # Use cascade attention if it is faster than FlashDecoding.
    return cascade_time < flash_decoding_time


def cascade_attention(
    output: torch.Tensor,
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    cu_query_lens: torch.Tensor,
    max_query_len: int,
    cu_prefix_query_lens: torch.Tensor,
405
406
    prefix_kv_lens: torch.Tensor,
    suffix_kv_lens: torch.Tensor,
407
408
409
    max_kv_len: int,
    softmax_scale: float,
    alibi_slopes: Optional[torch.Tensor],
410
    sliding_window: tuple[int, int],
411
412
413
    logits_soft_cap: float,
    block_table: torch.Tensor,
    common_prefix_len: int,
414
    fa_version: int,
415
416
417
    q_descale: Optional[torch.Tensor] = None,
    k_descale: Optional[torch.Tensor] = None,
    v_descale: Optional[torch.Tensor] = None,
418
419
420
421
422
423
424
425
426
427
428
) -> torch.Tensor:
    assert alibi_slopes is None, ("Cascade attention does not support ALiBi.")
    # TODO: Support sliding window.
    assert sliding_window == (-1, -1), (
        "Cascade attention does not support sliding window.")

    num_tokens = query.shape[0]
    block_size = key_cache.shape[-3]
    assert common_prefix_len % block_size == 0
    num_common_kv_blocks = common_prefix_len // block_size
    assert num_common_kv_blocks > 0
429
    descale_shape = (cu_prefix_query_lens.shape[0] - 1, key_cache.shape[-2])
430
431
432
433
434
435
436

    # Process shared prefix.
    prefix_output, prefix_lse = flash_attn_varlen_func(
        q=query,
        k=key_cache,
        v=value_cache,
        cu_seqlens_q=cu_prefix_query_lens,
437
        seqused_k=prefix_kv_lens,
438
439
440
441
442
443
444
445
        max_seqlen_q=num_tokens,
        max_seqlen_k=common_prefix_len,
        softmax_scale=softmax_scale,
        causal=False,
        window_size=sliding_window,
        block_table=block_table[:1],
        softcap=logits_soft_cap,
        return_softmax_lse=True,
446
        fa_version=fa_version,
447
448
449
450
451
452
        q_descale=q_descale.expand(descale_shape)
        if q_descale is not None else None,
        k_descale=k_descale.expand(descale_shape)
        if k_descale is not None else None,
        v_descale=v_descale.expand(descale_shape)
        if v_descale is not None else None,
453
454
    )

455
456
    descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])

457
458
459
460
461
462
    # Process suffix per query.
    suffix_output, suffix_lse = flash_attn_varlen_func(
        q=query,
        k=key_cache,
        v=value_cache,
        cu_seqlens_q=cu_query_lens,
463
        seqused_k=suffix_kv_lens,
464
465
466
467
468
469
470
471
        max_seqlen_q=max_query_len,
        max_seqlen_k=max_kv_len - common_prefix_len,
        softmax_scale=softmax_scale,
        causal=True,
        window_size=sliding_window,
        block_table=block_table[:, num_common_kv_blocks:],
        softcap=logits_soft_cap,
        return_softmax_lse=True,
472
        fa_version=fa_version,
473
474
475
476
477
478
        q_descale=q_descale.expand(descale_shape)
        if q_descale is not None else None,
        k_descale=k_descale.expand(descale_shape)
        if k_descale is not None else None,
        v_descale=v_descale.expand(descale_shape)
        if v_descale is not None else None,
479
480
481
482
    )

    # Merge prefix and suffix outputs, and store the result in output.
    merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
483
                      suffix_lse)