flash_attn.py 40.2 KB
Newer Older
1
"""Attention layer with FlashAttention."""
2
from collections import defaultdict
3
from dataclasses import dataclass
4
from itertools import accumulate
5
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
6
7
8

import torch

9
from vllm import _custom_ops as ops
10
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
11
12
13
                                              AttentionMetadata,
                                              AttentionMetadataBuilder,
                                              AttentionType)
14
15
16
17
18
from vllm.attention.backends.utils import (
    PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
    compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
    get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
    is_all_encoder_attn_metadata_set, is_block_tables_empty)
19
from vllm.forward_context import get_forward_context
20
from vllm.multimodal import MultiModalPlaceholderMap
21
22
from vllm.utils import (async_tensor_h2d, direct_register_custom_op,
                        make_tensor_with_pad)
23
24

if TYPE_CHECKING:
25
26
    from vllm.worker.model_runner import (ModelInputForGPUBuilder,
                                          ModelInputForGPUWithSamplingMetadata)
27

28
29
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
                                  flash_attn_with_kvcache)
30
31


32
33
class FlashAttentionBackend(AttentionBackend):

34
35
36
37
    @staticmethod
    def get_supported_head_sizes() -> List[int]:
        return [32, 64, 96, 128, 160, 192, 224, 256]

38
39
    @staticmethod
    def get_name() -> str:
40
        return "FLASH_ATTN"
41

42
43
44
45
46
    @staticmethod
    def get_impl_cls() -> Type["FlashAttentionImpl"]:
        return FlashAttentionImpl

    @staticmethod
47
48
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return FlashAttentionMetadata
49

50
51
52
53
    @staticmethod
    def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
        return FlashAttentionMetadataBuilder

54
55
56
57
    @staticmethod
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

58
59
60
61
62
63
64
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
65
66
67
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")
        return (2, num_blocks, block_size, num_kv_heads, head_size)
68
69
70
71
72

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
73
        src_to_dst: torch.Tensor,
74
    ) -> None:
75
76
        src_key_cache = src_kv_cache[0]
        dst_key_cache = dst_kv_cache[0]
77
        ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
78
79
        src_value_cache = src_kv_cache[1]
        dst_value_cache = dst_kv_cache[1]
80
        ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
81
82
83
84

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
85
        src_to_dists: torch.Tensor,
86
    ) -> None:
87
88
        key_caches = [kv_cache[0] for kv_cache in kv_caches]
        value_caches = [kv_cache[1] for kv_cache in kv_caches]
89

90
        ops.copy_blocks(key_caches, value_caches, src_to_dists)
91
92
93


@dataclass
94
class FlashAttentionMetadata(AttentionMetadata):
95
96
97
98
99
100
101
    """Metadata for FlashAttentionBackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
102
103
104
105
106
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]
107

108
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
109
110
111
112
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
113
    # |-------------------- seq_len ---------------------|
114
    #                                   |-- query_len ---|
115

116
117
118
119
120
121
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int
122
123
124
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
    context_lens_tensor: Optional[torch.Tensor]
125

126
127
128
129
130
131
132
133
    # (batch_size, max_blocks_per_seq).
    # Block addresses per sequence. (Seq id -> list of physical block)
    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
    # in the kv cache. Each block can contain up to block_size tokens.
    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
    # captured.
    block_tables: Optional[torch.Tensor]

134
135
136
    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
137

138
139
    use_cuda_graph: bool

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    # Maximum query length in the batch.
    max_query_len: Optional[int] = None

    # Max number of query tokens among request in the batch.
    max_decode_query_len: Optional[int] = None

    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    query_start_loc: Optional[torch.Tensor] = None
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    seq_start_loc: Optional[torch.Tensor] = None

155
156
157
    _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
    _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    # Begin encoder attn & enc/dec cross-attn fields...

    # Encoder sequence lengths representation
    encoder_seq_lens: Optional[List[int]] = None
    encoder_seq_lens_tensor: Optional[torch.Tensor] = None
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    encoder_seq_start_loc: Optional[torch.Tensor] = None
    # Maximum sequence length among encoder sequences
    max_encoder_seq_len: Optional[int] = None
    # Number of tokens input to encoder
    num_encoder_tokens: Optional[int] = None

    # Cross-attention memory-mapping data structures: slot mapping
    # and block tables
    cross_slot_mapping: Optional[torch.Tensor] = None
    cross_block_tables: Optional[torch.Tensor] = None

    @property
    def is_all_encoder_attn_metadata_set(self):
        '''
        All attention metadata required for encoder attention is set.
        '''
        return is_all_encoder_attn_metadata_set(self)

    @property
    def is_all_cross_attn_metadata_set(self):
        '''
        All attention metadata required for enc/dec cross-attention is set.

        Superset of encoder attention required metadata.
        '''
        return is_all_cross_attn_metadata_set(self)

193
194
195
196
197
198
199
200
    @property
    def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
            return self._cached_prefill_metadata

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        assert ((self.seq_lens is not None)
                or (self.encoder_seq_lens is not None))
        assert ((self.seq_lens_tensor is not None)
                or (self.encoder_seq_lens_tensor is not None))

        # Compute some attn_metadata fields which default to None
        query_start_loc = (None if self.query_start_loc is None else
                           self.query_start_loc[:self.num_prefills + 1])
        slot_mapping = (None if self.slot_mapping is None else
                        self.slot_mapping[:self.num_prefill_tokens])
        seq_lens = (None if self.seq_lens is None else
                    self.seq_lens[:self.num_prefills])
        seq_lens_tensor = (None if self.seq_lens_tensor is None else
                           self.seq_lens_tensor[:self.num_prefills])
        seq_start_loc = (None if self.seq_start_loc is None else
                         self.seq_start_loc[:self.num_prefills + 1])
        context_lens_tensor = (None if self.context_lens_tensor is None else
                               self.context_lens_tensor[:self.num_prefills])
        block_tables = (None if self.block_tables is None else
                        self.block_tables[:self.num_prefills])
221
222
223
224
225

        self._cached_prefill_metadata = FlashAttentionMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
226
            slot_mapping=slot_mapping,
227
228
            multi_modal_placeholder_index_maps=self.
            multi_modal_placeholder_index_maps,
229
230
            seq_lens=seq_lens,
            seq_lens_tensor=seq_lens_tensor,
231
232
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
233
            max_decode_query_len=0,
234
            max_decode_seq_len=0,
235
236
237
238
            query_start_loc=query_start_loc,
            seq_start_loc=seq_start_loc,
            context_lens_tensor=context_lens_tensor,
            block_tables=block_tables,
239
            use_cuda_graph=False,
240
241
242
243
244
245
246
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            encoder_seq_start_loc=self.encoder_seq_start_loc,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
247
248
249
250
251
252
253
254
255
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
            return self._cached_decode_metadata
256
257
258
259
260
261
262
263
264
265
        assert ((self.seq_lens_tensor is not None)
                or (self.encoder_seq_lens_tensor is not None))

        # Compute some attn_metadata fields which default to None
        slot_mapping = (None if self.slot_mapping is None else
                        self.slot_mapping[self.num_prefill_tokens:])
        seq_lens_tensor = (None if self.seq_lens_tensor is None else
                           self.seq_lens_tensor[self.num_prefills:])
        block_tables = (None if self.block_tables is None else
                        self.block_tables[self.num_prefills:])
266
267
268
269
270

        self._cached_decode_metadata = FlashAttentionMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
271
            slot_mapping=slot_mapping,
272
            multi_modal_placeholder_index_maps=None,
273
            seq_lens=None,
274
            seq_lens_tensor=seq_lens_tensor,
275
            max_decode_query_len=self.max_decode_query_len,
276
            max_query_len=self.max_query_len,
277
278
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
279
280
281
282
            query_start_loc=self.query_start_loc[self.num_prefills:]
            if self.query_start_loc is not None else None,
            seq_start_loc=self.seq_start_loc[self.num_prefills:]
            if self.seq_start_loc is not None else None,
283
            context_lens_tensor=None,
284
            block_tables=block_tables,
285
            use_cuda_graph=self.use_cuda_graph,
286
287
288
289
290
291
292
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            encoder_seq_start_loc=self.encoder_seq_start_loc,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
293
294
        return self._cached_decode_metadata

295
296
    def advance_step(self,
                     model_input: "ModelInputForGPUWithSamplingMetadata",
297
                     sampled_token_ids: Optional[torch.Tensor],
298
299
300
301
                     block_size: int,
                     num_seqs: int,
                     num_queries: int,
                     turn_prefills_into_decodes: bool = False):
302
303
304
305
306
307
308
309
310
311
        """
        Update metadata in-place to advance one decode step.
        """
        # When using cudagraph, the num_seqs is padded to the next captured
        # batch sized, but num_queries tracks the actual number of requests in
        # the batch. For --enforce-eager mode, num_seqs == num_queries
        if num_seqs != num_queries:
            assert num_seqs > num_queries
            assert self.use_cuda_graph

312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        if turn_prefills_into_decodes:
            # When Mutli-Step is enabled with Chunked-Prefill, prefills and
            # decodes are scheduled together. In the first step, all the
            # prefills turn into decodes. This update reflects that
            # conversion.
            assert self.num_decode_tokens + self.num_prefills == num_seqs
            self.num_decode_tokens += self.num_prefills
            self.num_prefills = 0
            self.num_prefill_tokens = 0
            self.max_prefill_seq_len = 0
            self.max_query_len = 1

            self.slot_mapping = self.slot_mapping[:num_seqs]
        else:
            assert self.seq_lens is not None
            assert self.max_decode_seq_len == max(self.seq_lens)

329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        assert self.num_prefills == 0
        assert self.num_prefill_tokens == 0
        assert self.num_decode_tokens == num_seqs
        assert self.slot_mapping.shape == (num_seqs, )

        assert self.seq_lens is not None
        assert len(self.seq_lens) == num_seqs
        assert self.seq_lens_tensor is not None
        assert self.seq_lens_tensor.shape == (num_seqs, )
        assert self.max_query_len == 1
        assert self.max_prefill_seq_len == 0

        assert self.query_start_loc is not None
        assert self.query_start_loc.shape == (num_queries + 1, )
        assert self.seq_start_loc is not None
        assert self.seq_start_loc.shape == (num_seqs + 1, )

        assert self.context_lens_tensor is not None
        assert self.context_lens_tensor.shape == (num_queries, )

        assert self.block_tables is not None
        assert self.block_tables.shape[0] == num_seqs

        # Update query lengths. Note that we update only queries and not seqs,
        # since tensors may be padded due to captured cuda graph batch size
        for i in range(num_queries):
            self.seq_lens[i] += 1
        self.max_decode_seq_len = max(self.seq_lens)

358
359
360
361
362
363
364
365
366
        ops.advance_step_flashattn(num_seqs=num_seqs,
                                   num_queries=num_queries,
                                   block_size=block_size,
                                   input_tokens=model_input.input_tokens,
                                   sampled_token_ids=sampled_token_ids,
                                   input_positions=model_input.input_positions,
                                   seq_lens=self.seq_lens_tensor,
                                   slot_mapping=self.slot_mapping,
                                   block_tables=self.block_tables)
367

368

369
370
371
372
373
374
375
376
377
class FlashAttentionMetadataBuilder(
        AttentionMetadataBuilder[FlashAttentionMetadata]):

    def __init__(self, input_builder: "ModelInputForGPUBuilder"):
        self.slot_mapping: List[int] = []
        self.prefill_seq_lens: List[int] = []
        self.context_lens: List[int] = []
        self.block_tables: List[List[int]] = []
        self.curr_seq_lens: List[int] = []
378
379
380
        self.multimodal_placeholder_maps: Dict[
            str,
            MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
381
382
383
        self.num_prefills = 0
        self.num_prefill_tokens = 0
        self.num_decode_tokens = 0
384
        self.has_prefix_cache_hit = False
385

386
387
        self.input_builder = input_builder
        self.runner = input_builder.runner
388
389
390
        self.sliding_window = input_builder.sliding_window
        self.block_size = input_builder.block_size

391
392
    def _add_seq_group(
            self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
393
            chunked_prefill_enabled: bool, prefix_cache_hit: bool):
394
395
396
397
398
        """Add a sequence group to the metadata. Specifically update/append
        1. context length.
        2. block table.
        3. slot mapping.
        """
399
400
        is_prompt = inter_data.is_prompt
        block_tables = inter_data.block_tables
401
402
403

        for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
             curr_sliding_window_block) in zip(
404
405
406
407
                 inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
                 inter_data.orig_seq_lens, inter_data.seq_lens,
                 inter_data.query_lens, inter_data.context_lens,
                 inter_data.curr_sliding_window_blocks):
408
409
410
            self.context_lens.append(context_len)

            if is_prompt:
411
412
413
414
415
416
                mm_maps = inter_data.multi_modal_placeholder_maps
                if mm_maps:
                    for modality, placeholders in mm_maps.items():
                        self.multimodal_placeholder_maps[modality].extend(
                            placeholders)

417
418
419
420
421
422
423
424
425
426
427
428
                self.num_prefills += 1
                self.num_prefill_tokens += token_len
                self.prefill_seq_lens.append(seq_len)
            else:
                self.num_decode_tokens += query_len
                self.curr_seq_lens.append(curr_seq_len)

            # Compute block table.
            # TODO(sang): Combine chunked prefill and prefix caching by
            # only allowing multiple of block_size chunk size.
            # NOTE: This only works for oooooooxxx style attention.
            block_table = []
429
            if prefix_cache_hit:
430
431
432
433
434
                # NOTE(woosuk): For flash-attn, the block table should
                # include the entries for the incoming prefill tokens.
                block_table = block_tables[seq_id]
            elif ((chunked_prefill_enabled or not is_prompt)
                  and block_tables is not None):
435
436
437
438
439
                if curr_sliding_window_block == 0:
                    block_table = block_tables[seq_id]
                else:
                    block_table = block_tables[seq_id][
                        -curr_sliding_window_block:]
440
441
442
443
            self.block_tables.append(block_table)

            # Compute slot mapping.
            is_profile_run = is_block_tables_empty(block_tables)
444
445
446
            start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
                                                       context_len,
                                                       self.sliding_window)
447
448
            compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
                                 seq_len, context_len, start_idx,
449
                                 self.block_size, inter_data.block_tables)
450

451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
    def _get_graph_runner_block_tables(
            self, num_seqs: int,
            block_tables: List[List[int]]) -> torch.Tensor:
        # The shape of graph_block_tables is
        # [max batch size, max context len // block size].
        max_batch_size, max_blocks = self.runner.graph_block_tables.shape
        assert max_batch_size >= num_seqs

        graph_block_tables = self.runner.graph_block_tables[:num_seqs]
        for i, block_table in enumerate(block_tables):
            if block_table:
                num_blocks = len(block_table)
                if num_blocks <= max_blocks:
                    graph_block_tables[i, :num_blocks] = block_table
                else:
                    # It may be possible to have more blocks allocated due
                    # to lookahead slots of multi-step, however, they are
                    # not used anyway, so can be safely ignored.
                    graph_block_tables[
                        i, :max_blocks] = block_table[:max_blocks]

        return torch.from_numpy(graph_block_tables).to(
            device=self.runner.device, non_blocking=True)

475
    def build(self, seq_lens: List[int], query_lens: List[int],
476
              cuda_graph_pad_size: int, batch_size: int):
477
478
479
480
481
482
483
484
485
        """Build attention metadata with on-device tensors.

        Args:
            seq_lens: The maybe padded sequence lengths of the input sequences.
            query_lens: The query lengths of the input sequences.
            cuda_graph_pad_size: The padding size for cuda graph.
                                 -1 if cuda graph is not used.
            batch_size: The maybe padded batch size.
        """
486
487
488
489
        prefix_cache_hit = any([
            inter_data.prefix_cache_hit
            for inter_data in self.input_builder.inter_data_list
        ])
490
491
        for inter_data in self.input_builder.inter_data_list:
            self._add_seq_group(inter_data,
492
493
                                self.input_builder.chunked_prefill_enabled,
                                prefix_cache_hit)
494
495

        device = self.runner.device
496
497
498
        use_captured_graph = cuda_graph_pad_size != -1

        max_query_len = max(query_lens)
499
500
        decode_query_lens = query_lens[self.num_prefills:]
        if len(decode_query_lens) > 0:
501
            max_decode_query_len = max(decode_query_lens)
502
        else:
503
            max_decode_query_len = 1
504
505
506
        max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
        max_decode_seq_len = max(self.curr_seq_lens, default=0)
        num_decode_tokens = self.num_decode_tokens
507
508
        query_start_loc = list(accumulate(query_lens, initial=0))
        seq_start_loc = list(accumulate(seq_lens, initial=0))
509

510
        num_seqs = len(seq_lens)
511
512
513
        if use_captured_graph:
            self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
            self.block_tables.extend([] * cuda_graph_pad_size)
514
515
516
            num_decode_tokens = batch_size - self.num_prefill_tokens
            block_tables = self._get_graph_runner_block_tables(
                num_seqs, self.block_tables)
517
518
519
520
521
522
523
524
525
        else:
            block_tables = make_tensor_with_pad(
                self.block_tables,
                pad=0,
                dtype=torch.int,
                device=device,
            )
        assert max_query_len > 0, ("query_lens: {}".format(query_lens))

526
527
528
529
530
531
532
        assert device is not None
        context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
                                               device, self.runner.pin_memory)
        seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
                                           self.runner.pin_memory)
        slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
                                               device, self.runner.pin_memory)
533
534
535
536
537
        query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
                                                  device,
                                                  self.runner.pin_memory)
        seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
                                                device, self.runner.pin_memory)
538
539
540
541
542
        placeholder_index_maps = {
            modality: placeholder_map.index_map()
            for modality, placeholder_map in
            self.multimodal_placeholder_maps.items()
        }
543
544
545
546
547
548
549

        return FlashAttentionMetadata(
            num_prefills=self.num_prefills,
            slot_mapping=slot_mapping_tensor,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            seq_lens=seq_lens,
550
            multi_modal_placeholder_index_maps=placeholder_index_maps,
551
552
            seq_lens_tensor=seq_lens_tensor,
            max_query_len=max_query_len,
553
            max_decode_query_len=max_decode_query_len,
554
555
            max_prefill_seq_len=max_prefill_seq_len,
            max_decode_seq_len=max_decode_seq_len,
556
557
            query_start_loc=query_start_loc_tensor,
            seq_start_loc=seq_start_loc_tensor,
558
559
560
561
562
563
            context_lens_tensor=context_lens_tensor,
            block_tables=block_tables,
            use_cuda_graph=use_captured_graph,
        )


564
565
566
class FlashAttentionImpl(AttentionImpl):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
567
568
    |<--------------- num_prefill_tokens ----------------->|	
    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
569
570

    Otherwise, the layout is as follows:	
571
572
    |<----------------- num_decode_tokens ------------------>|	
    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
573
574
575
576
577
578

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
579
580
581
582
583
584
585
586
587

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
    |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
588
589
590
591
592
593
594
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
595
596
597
598
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
599
        blocksparse_params: Optional[Dict[str, Any]] = None,
600
        logits_soft_cap: Optional[float] = None,
601
    ) -> None:
602
603
604
        if blocksparse_params is not None:
            raise ValueError(
                "FlashAttention does not support block-sparse attention.")
605
606
607
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
608
        self.num_kv_heads = num_kv_heads
609
610
611
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
612
613
        self.sliding_window = ((sliding_window - 1,
                                0) if sliding_window is not None else (-1, -1))
614
        self.kv_cache_dtype = kv_cache_dtype
615
616
617
618
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            logits_soft_cap = 0
        self.logits_soft_cap = logits_soft_cap
619
620
621
622

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

623
624
        support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
        if head_size not in support_head_sizes:
625
            raise ValueError(
626
                f"Head size {head_size} is not supported by FlashAttention. "
627
                f"Supported head sizes are: {support_head_sizes}.")
628
629
630
631
632
633
634

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
635
        attn_metadata: FlashAttentionMetadata,
636
637
        k_scale: float = 1.0,
        v_scale: float = 1.0,
638
        attn_type: AttentionType = AttentionType.DECODER,
639
    ) -> torch.Tensor:
640
        """Forward pass with FlashAttention.
641
642
643
644
645

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
646
            kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
647
648
                NOTE: kv_cache will be an empty tensor with shape [0]
                for profiling run.
649
650
651
652
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
653
        # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
654
655
        assert k_scale == 1.0 and v_scale == 1.0, (
            "key/v_scale is not supported in FlashAttention.")
656

657
658
659
660
661
662
663
664
665
666
        if (attn_type == AttentionType.ENCODER
                and (not attn_metadata.is_all_encoder_attn_metadata_set)):
            raise AttributeError("Encoder attention requires setting "
                                 "encoder metadata attributes.")
        elif (attn_type == AttentionType.ENCODER_DECODER
              and (not attn_metadata.is_all_cross_attn_metadata_set)):
            raise AttributeError("Encoder/decoder cross-attention "
                                 "requires setting cross-attention "
                                 "metadata attributes.")

667
668
669
670
671
672
673
674
675
676
677
678
        output = torch.ops.vllm.unified_flash_attention(
            query,
            key,
            value,
            self.num_heads,
            self.head_size,
            self.num_kv_heads,
            kv_cache,
            self.kv_cache_dtype,
            k_scale,
            v_scale,
            self.scale,
679
            attn_type.value,
680
681
682
683
            self.sliding_window,
            self.alibi_slopes,
            self.logits_soft_cap,
        )
684

685
686
687
        return output


688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
def _get_query_key_seq_metadata(
    attn_metadata,
    is_prompt: bool,
    attn_type: AttentionType,
) -> tuple:
    """
    Returns sequence metadata for key and query based on the specified 
    attention type and whether input is a prompt.

    This function computes the starting locations and maximum sequence lengths 
    for key and query sequences for different attention types.

    Args:
        attn_metadata: The attention metadata object
        is_prompt (bool): A flag indicating if the input is a prompt
        attn_type (AttentionType): The type of attention being used.

    Returns:
        tuple: A tuple containing four integers:
            - Starting location for the query sequence.
            - Maximum sequence length for the query sequence.
            - Starting location for the key sequence.
            - Maximum sequence length for the key sequence.

    Raises:
        AttributeError: If an invalid attention type is provided.
    """
    if attn_type == AttentionType.DECODER:
        # Decoder self-attention
        # Choose max_seq_len based on whether we are in prompt_run
        if is_prompt:
            max_seq_len = attn_metadata.max_prefill_seq_len
        else:
            max_seq_len = attn_metadata.max_decode_seq_len
        return (attn_metadata.seq_start_loc, max_seq_len,
                attn_metadata.seq_start_loc, max_seq_len)

    elif attn_type == AttentionType.ENCODER_DECODER:
        # This is cross attention between the where the key
        # is the precomputed encoder attention and query
        # is the input sequence.
        # Choose query max length based on whether it is prompt
        # or not.
        if is_prompt:
            max_seq_len = attn_metadata.max_prefill_seq_len
        else:
            max_seq_len = attn_metadata.max_decode_seq_len
        return (attn_metadata.seq_start_loc, max_seq_len,
                attn_metadata.encoder_seq_start_loc,
                attn_metadata.max_encoder_seq_len)
    elif attn_type == AttentionType.ENCODER:
        # For encoder attention both the query and the key are same i.e the
        # encoder sequence.
        return (attn_metadata.encoder_seq_start_loc,
                attn_metadata.max_encoder_seq_len,
                attn_metadata.encoder_seq_start_loc,
                attn_metadata.max_encoder_seq_len)
    elif attn_type == AttentionType.ENCODER_ONLY:
        assert is_prompt, "Should not have decode for encoder only model."
        return (attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len,
                attn_metadata.seq_start_loc, attn_metadata.max_prefill_seq_len)
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")


def _get_causal_option(attn_type: AttentionType) -> bool:
    """
    Determine whether the given attention type is suitable for causal 
    attention mechanisms.

    Args:
        attn_type (AttentionType): The type of attention being evaluated

    Returns:
        bool: Returns `True` if the attention type is suitable for causal 
        attention (i.e., not encoder, encoder-only, or encoder-decoder), 
        otherwise returns `False`.
    """
    return not (attn_type == AttentionType.ENCODER
                or attn_type == AttentionType.ENCODER_ONLY
                or attn_type == AttentionType.ENCODER_DECODER)


771
772
773
774
775
776
777
778
779
780
781
782
def unified_flash_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    num_heads: int,
    head_size: int,
    num_kv_heads: int,
    kv_cache: torch.Tensor,
    kv_cache_dtype: str,
    k_scale: float,
    v_scale: float,
    softmax_scale: float,
783
    attn_type_int_val: int,
784
785
786
787
788
    window_size: Optional[List[int]] = None,
    alibi_slopes: Optional[torch.Tensor] = None,
    logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:

789
790
791
792
793
794
795
    # Convert integer attn_type to enum
    try:
        attn_type = AttentionType(attn_type_int_val)
    except ValueError as err:
        raise AttributeError(
            f"Invalid attention type {str(attn_type_int_val)}") from err

796
797
798
799
800
801
    current_metadata = get_forward_context()
    assert current_metadata is not None
    assert isinstance(current_metadata, FlashAttentionMetadata)
    attn_metadata: FlashAttentionMetadata = current_metadata

    num_tokens, hidden_size = query.shape
802

803
804
    # Reshape the query, key, and value tensors.
    query = query.view(-1, num_heads, head_size)
805
806
807
    if (key is not None) and (value is not None):
        key = key.view(-1, num_kv_heads, head_size)
        value = value.view(-1, num_kv_heads, head_size)
808
809
810
811

    if kv_cache.numel() > 0:
        key_cache = kv_cache[0]
        value_cache = kv_cache[1]
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
        # We skip updating the KV cache under two conditions:
        #  a. When the Attention Type is ENCODER. In this phase, we compute
        #     only the encoder attention without updating the cache.
        #  b. When both Key and Value are None. This occurs during
        #     cross-attention computation in the decoding phase, where the KV
        #     cache is already populated with the cross-attention tensor.
        #     Thus, we skip cache updates during this time.
        if (attn_type != AttentionType.ENCODER) and (key is not None) and (
                value is not None):
            if attn_type == AttentionType.ENCODER_DECODER:
                # Update cross-attention KV cache (prefill-only)
                updated_slot_mapping = attn_metadata.cross_slot_mapping
            else:
                # Update self-attention KV cache (prefill/decode)
                updated_slot_mapping = attn_metadata.slot_mapping

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            torch.ops._C_cache_ops.reshape_and_cache_flash(
                key,
                value,
                kv_cache[0],
                kv_cache[1],
                updated_slot_mapping.flatten(),  # type: ignore[union-attr]
                kv_cache_dtype,
                k_scale,
                v_scale,
            )
841

842
843
844
845
    (num_prefill_query_tokens, num_prefill_kv_tokens,
    num_decode_query_tokens) = \
        get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
    decode_query = query[num_prefill_query_tokens:]
846
    # QKV for prefill.
847
848
849
    query = query[:num_prefill_query_tokens]
    assert query.shape[0] == num_prefill_query_tokens
    assert decode_query.shape[0] == num_decode_query_tokens
850
851
852
853
854
855
856
857
858
859

    prefill_output: Optional[torch.Tensor] = None
    decode_output: Optional[torch.Tensor] = None
    if prefill_meta := attn_metadata.prefill_metadata:
        # Prompt run.
        if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
                or prefill_meta.block_tables.numel() == 0):
            # normal attention
            # When block_tables are not filled, it means q and k are the
            # prompt, and they have the same length.
860
861
862
863
864
865
            q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \
                _get_query_key_seq_metadata(prefill_meta, True, attn_type)

            key = key[:num_prefill_kv_tokens]
            value = value[:num_prefill_kv_tokens]

866
867
868
869
            prefill_output = flash_attn_varlen_func(
                q=query,
                k=key,
                v=value,
870
871
872
873
                cu_seqlens_q=q_seq_start_loc,
                cu_seqlens_k=k_seq_start_loc,
                max_seqlen_q=q_seq_len,
                max_seqlen_k=k_seq_len,
874
                softmax_scale=softmax_scale,
875
                causal=_get_causal_option(attn_type),
876
877
878
879
880
881
                window_size=window_size,
                alibi_slopes=alibi_slopes,
                softcap=logits_soft_cap,
            )
        else:
            # prefix-enabled attention
882
883
            assert attn_type == AttentionType.DECODER, (
                "Only decoder-only models support prefix caching")
884
885
886
887
888
889
890
891
892
893
894
            assert prefill_meta.seq_lens is not None
            max_seq_len = max(prefill_meta.seq_lens)
            prefill_output = flash_attn_varlen_func(  # noqa
                q=query,
                k=key_cache,
                v=value_cache,
                cu_seqlens_q=prefill_meta.query_start_loc,
                max_seqlen_q=prefill_meta.max_query_len,
                cu_seqlens_k=prefill_meta.seq_start_loc,
                max_seqlen_k=max_seq_len,
                softmax_scale=softmax_scale,
895
                causal=True,
896
                window_size=window_size,
897
898
899
                alibi_slopes=alibi_slopes,
                block_table=prefill_meta.block_tables,
                softcap=logits_soft_cap,
900
            )
901

902
903
    if decode_meta := attn_metadata.decode_metadata:
        # Decoding run.
904
905
906
907
        # Use flash_attn_varlen_func kernel for speculative decoding
        # because different queries might have different lengths.
        assert decode_meta.max_decode_query_len is not None
        if decode_meta.max_decode_query_len > 1:
908
909
            assert attn_type == AttentionType.DECODER, (
                "Only decoder-only models support max_decode_query_len > 1")
910
911
912
913
914
915
916
917
918
919
            decode_output = flash_attn_varlen_func(
                q=decode_query,
                k=key_cache,
                v=value_cache,
                cu_seqlens_q=decode_meta.query_start_loc,
                max_seqlen_q=decode_meta.max_decode_query_len,
                cu_seqlens_k=decode_meta.seq_start_loc,
                max_seqlen_k=decode_meta.max_decode_seq_len,
                softmax_scale=softmax_scale,
                causal=True,
920
                window_size=window_size,
921
922
923
924
925
926
                alibi_slopes=alibi_slopes,
                softcap=logits_soft_cap,
                block_table=decode_meta.block_tables,
            )
        else:
            # Use flash_attn_with_kvcache for normal decoding.
927
928
929
930
931
            (
                seq_lens_arg,
                _,
                block_tables_arg,
            ) = get_seq_len_block_table_args(decode_meta, False, attn_type)
932
933
934
935
            decode_output = flash_attn_with_kvcache(
                q=decode_query.unsqueeze(1),
                k_cache=key_cache,
                v_cache=value_cache,
936
937
                block_table=block_tables_arg,
                cache_seqlens=seq_lens_arg,
938
939
                softmax_scale=softmax_scale,
                causal=True,
940
                window_size=window_size,
941
942
943
                alibi_slopes=alibi_slopes,
                softcap=logits_soft_cap,
            ).squeeze(1)
944
945
946

    if prefill_output is None:
        assert decode_output is not None
947
        return decode_output.view(num_decode_query_tokens, hidden_size)
948
949
    if decode_output is None:
        assert prefill_output is not None
950
        return prefill_output.view(num_prefill_query_tokens, hidden_size)
951
952
953
954
955
956
957
958
959

    # Chunked prefill does not work with speculative decoding.
    # Therefore, the query length for decode should be 1 in chunked prefill.
    assert decode_meta is not None
    decode_output = decode_output.squeeze(1)
    output = torch.cat([prefill_output, decode_output], dim=0)
    return output.view(num_tokens, hidden_size)


960
def unified_flash_attention_fake(
961
962
963
964
965
966
967
968
969
970
971
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    num_heads: int,
    head_size: int,
    num_kv_heads: int,
    kv_cache: torch.Tensor,
    kv_cache_dtype: str,
    k_scale: float,
    v_scale: float,
    softmax_scale: float,
972
    attn_type_int_val: int,
973
974
975
976
977
    window_size: Optional[List[int]] = None,
    alibi_slopes: Optional[torch.Tensor] = None,
    logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
    return torch.empty_like(query)
978
979
980
981
982
983
984
985


direct_register_custom_op(
    op_name="unified_flash_attention",
    op_func=unified_flash_attention,
    mutates_args=["kv_cache"],
    fake_impl=unified_flash_attention_fake,
)