rocm_flash_attn.py 37.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
"""Attention layer ROCm GPUs."""
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
5
6
7

import torch

8
import vllm.envs as envs
9
from vllm import _custom_ops as ops
10
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
11
                                              AttentionLayer,
12
                                              AttentionMetadata, AttentionType)
13
14
from vllm.attention.backends.utils import (CommonAttentionState,
                                           CommonMetadataBuilder)
15
16
17
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
from vllm.logger import init_logger
18
from vllm.platforms import current_platform
19

20
21
22
if TYPE_CHECKING:
    from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

23
24
logger = init_logger(__name__)

25
_PARTITION_SIZE_ROCM = 512
26
27
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
_ON_NAVI = "gfx1" in _GPU_ARCH
28
_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"])
29

30
31
32

class ROCmFlashAttentionBackend(AttentionBackend):

33
34
    @staticmethod
    def get_name() -> str:
35
        return "ROCM_FLASH"
36

37
38
39
40
41
    @staticmethod
    def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
        return ROCmFlashAttentionImpl

    @staticmethod
42
43
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return ROCmFlashAttentionMetadata
44

45
46
47
48
    @staticmethod
    def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
        return ROCmFlashAttentionMetadataBuilder

49
    @staticmethod
50
51
52
53
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

    @staticmethod
54
55
56
57
58
59
60
61
62
63
64
65
66
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
67
        src_to_dst: torch.Tensor,
68
69
70
71
72
73
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
74
        src_to_dists: torch.Tensor,
75
76
77
78
79
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
80
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
81
82
83
84
85
86
87
    """Metadata for FlashAttentionBackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
88
89
90
91
92
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]
93
94
95
96
97
98
99
100
101
102
103
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool
104

105
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
106
107
108
109
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
110
111
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
112

113
    # Maximum query length in the batch. None for decoding.
114
    max_query_len: Optional[int] = None
115
116
117
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
118
    query_start_loc: Optional[torch.Tensor] = None
119
120
121
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
122
    seq_start_loc: Optional[torch.Tensor] = None
123
124
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
125
    context_lens_tensor: Optional[torch.Tensor] = None
126

127
128
    # Max number of query tokens among request in the batch.
    max_decode_query_len: Optional[int] = None
129

130
131
132
    _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
    _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    # Begin encoder attn & enc/dec cross-attn fields...

    # Encoder sequence lengths representation
    encoder_seq_lens: Optional[List[int]] = None
    encoder_seq_lens_tensor: Optional[torch.Tensor] = None

    # Maximum sequence length among encoder sequences
    max_encoder_seq_len: Optional[int] = None

    # Number of tokens input to encoder
    num_encoder_tokens: Optional[int] = None

    # Cross-attention memory-mapping data structures: slot mapping
    # and block tables
    cross_slot_mapping: Optional[torch.Tensor] = None
    cross_block_tables: Optional[torch.Tensor] = None

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    @property
    def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
            return self._cached_prefill_metadata

        assert self.seq_lens is not None
        assert self.seq_lens_tensor is not None
        assert self.block_tables is not None

        self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
167
168
            multi_modal_placeholder_index_maps=self.
            multi_modal_placeholder_index_maps,
169
            enable_kv_scales_calculation=self.enable_kv_scales_calculation,
170
171
172
173
174
            seq_lens=self.seq_lens[:self.num_prefills],
            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
175
176
177
178
179
180
            query_start_loc=None if self.query_start_loc is None else
            self.query_start_loc[:self.num_prefills + 1],
            seq_start_loc=None if self.seq_start_loc is None else
            self.seq_start_loc[:self.num_prefills + 1],
            context_lens_tensor=None if self.context_lens_tensor is None else
            self.context_lens_tensor[:self.num_prefills],
181
182
            block_tables=self.block_tables[:self.num_prefills],
            use_cuda_graph=False,
183
184
185
186
187
188
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
            return self._cached_decode_metadata
        assert self.block_tables is not None
        assert self.seq_lens_tensor is not None

        self._cached_decode_metadata = ROCmFlashAttentionMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
206
            multi_modal_placeholder_index_maps=None,
207
            enable_kv_scales_calculation=True,
208
209
210
211
212
213
214
215
216
217
            seq_lens=None,
            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
            max_query_len=None,
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
            query_start_loc=None,
            seq_start_loc=None,
            context_lens_tensor=None,
            block_tables=self.block_tables[self.num_prefills:],
            use_cuda_graph=self.use_cuda_graph,
218
219
220
221
222
223
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
224
225
226
227
228
229
        # Batch may be composed of prefill|decodes, adjust query start indices
        # to refer to the start of decodes when the two are split apart.
        # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
        if self._cached_decode_metadata.query_start_loc is not None:
            qs = self._cached_decode_metadata.query_start_loc
            self._cached_decode_metadata.query_start_loc = qs - qs[0]
230
        return self._cached_decode_metadata
231

232
233
    def advance_step(self,
                     model_input: "ModelInputForGPUWithSamplingMetadata",
234
                     sampled_token_ids: Optional[torch.Tensor],
235
236
237
238
                     block_size: int,
                     num_seqs: int,
                     num_queries: int,
                     turn_prefills_into_decodes: bool = False):
239
240
241
        """
        Update metadata in-place to advance one decode step.
        """
242
243
244
245
246
247

        assert not turn_prefills_into_decodes, \
            ("Chunked prefill is not supported with rocm_flash_attn yet."
             "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
             "specific parameter.")

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        # When using cudagraph, the num_seqs is padded to the next captured
        # batch sized, but num_queries tracks the actual number of requests in
        # the batch. For --enforce-eager mode, num_seqs == num_queries
        if num_seqs != num_queries:
            assert num_seqs > num_queries
            assert self.use_cuda_graph

        assert self.num_prefills == 0
        assert self.num_prefill_tokens == 0
        assert self.num_decode_tokens == num_seqs
        assert self.slot_mapping.shape == (num_seqs, )

        assert self.seq_lens is not None
        assert len(self.seq_lens) == num_seqs
        assert self.seq_lens_tensor is not None
        assert self.seq_lens_tensor.shape == (num_seqs, )
        assert self.max_query_len == 1
        assert self.max_prefill_seq_len == 0
        assert self.max_decode_seq_len == max(self.seq_lens)

        assert self.query_start_loc is not None
        assert self.query_start_loc.shape == (num_queries + 1, )
        assert self.seq_start_loc is not None
        assert self.seq_start_loc.shape == (num_seqs + 1, )

        assert self.context_lens_tensor is not None
        assert self.context_lens_tensor.shape == (num_queries, )

        assert self.block_tables is not None
        assert self.block_tables.shape[0] == num_seqs

        # Update query lengths. Note that we update only queries and not seqs,
        # since tensors may be padded due to captured cuda graph batch size
        for i in range(num_queries):
            self.seq_lens[i] += 1
        self.max_decode_seq_len = max(self.seq_lens)

        ops.advance_step_flashattn(num_seqs=num_seqs,
                                   num_queries=num_queries,
                                   block_size=block_size,
                                   input_tokens=model_input.input_tokens,
                                   sampled_token_ids=sampled_token_ids,
                                   input_positions=model_input.input_positions,
                                   seq_lens=self.seq_lens_tensor,
                                   slot_mapping=self.slot_mapping,
                                   block_tables=self.block_tables)

295

296
297
298
299
300
301
class ROCmFlashAttentionMetadataBuilder(
        CommonMetadataBuilder[ROCmFlashAttentionMetadata]):

    _metadata_cls = ROCmFlashAttentionMetadata


302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def _make_alibi_bias(alibi_slopes: torch.Tensor,
                     dtype: torch.dtype,
                     seq_lens: Optional[List[int]],
                     make_attn_mask: bool = True) -> List[torch.Tensor]:
    attn_biases = []
    if seq_lens:
        for seq_len in seq_lens:
            bias = torch.arange(seq_len, dtype=dtype)
            # NOTE(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(seq_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
            bias = bias[None, :] - bias[:, None]

            num_heads = alibi_slopes.shape[0]
            bias = bias[None, :].repeat(
                (num_heads, 1, 1)).to(alibi_slopes.device)
            bias.mul_(alibi_slopes[:, None, None])
            if make_attn_mask:
                inf_mask = torch.empty(
                    (1, seq_len, seq_len),
                    dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
                        alibi_slopes.device)
                attn_biases.append((bias + inf_mask).to(dtype))
            else:
                attn_biases.append(bias.to(dtype))

    return attn_biases


333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
def _get_seq_len_block_table_args(
    attn_metadata: ROCmFlashAttentionMetadata,
    attn_type: str,
) -> tuple:
    '''
    The particular choice of sequence-length
    attributes which should be extracted from attn_metadata is dependent
    on the type of attention operation.

    Decoder attn -> select entirely decoder self-attention-related fields
    Encoder/decoder cross-attn -> select encoder sequence lengths
    Encoder attn -> select encoder sequence lengths fields
    
    Arguments:

    * attn_metadata: Attention metadata structure associated with attention op
    * attn_type: encoder attention, decoder self-attention,
                encoder/decoder cross-attention

    Returns:

    * Appropriate sequence-lengths tensors for query and key
    * Appropriate max sequence-length scalar
    '''

    partial_prefix_sum = 0
    if attn_type == AttentionType.ENCODER:
        assert attn_metadata.encoder_seq_lens is not None
        assert attn_metadata.encoder_seq_lens_tensor is not None
        query_seq_start_loc = torch.tensor(
            [0] + [
                partial_prefix_sum := partial_prefix_sum + i
                for i in attn_metadata.encoder_seq_lens
            ],
            device=attn_metadata.encoder_seq_lens_tensor.device,
            dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
        causal_mask = False

        # No block tables associated with encoder attention
        return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
                query_seq_start_loc, attn_metadata.max_encoder_seq_len,
                attn_metadata.encoder_seq_lens, causal_mask)
    elif attn_type == AttentionType.DECODER:
        # Decoder self-attention
        # Choose max_seq_len based on whether we are in prompt_run
        assert attn_metadata.seq_lens is not None
        assert attn_metadata.seq_lens_tensor is not None
        query_seq_start_loc = torch.tensor(
            [0] + [
                partial_prefix_sum := partial_prefix_sum + i
                for i in attn_metadata.seq_lens
            ],
            device=attn_metadata.seq_lens_tensor.device,
            dtype=attn_metadata.seq_lens_tensor.dtype)
        max_seq_len = attn_metadata.max_prefill_seq_len
        causal_mask = True

        return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
                max_seq_len, attn_metadata.seq_lens, causal_mask)
    elif attn_type == AttentionType.ENCODER_DECODER:
        assert attn_metadata.seq_lens is not None
        assert attn_metadata.encoder_seq_lens_tensor is not None
        query_start_loc = torch.tensor(
            [0] + [
                partial_prefix_sum := partial_prefix_sum + i
                for i in attn_metadata.seq_lens
            ],
            device=attn_metadata.encoder_seq_lens_tensor.device,
            dtype=attn_metadata.encoder_seq_lens_tensor.dtype)

        partial_prefix_sum = 0
        assert attn_metadata.encoder_seq_lens is not None
        assert attn_metadata.seq_lens_tensor is not None
        key_seq_start_loc = torch.tensor(
            [0] + [
                partial_prefix_sum := partial_prefix_sum + i
                for i in attn_metadata.encoder_seq_lens
            ],
            device=attn_metadata.seq_lens_tensor.device,
            dtype=attn_metadata.seq_lens_tensor.dtype)
        causal_mask = False

        # Enc/dec cross-attention KVs match encoder sequence length;
        # cross-attention utilizes special "cross" block tables
        return (query_start_loc, attn_metadata.max_prefill_seq_len,
                key_seq_start_loc, attn_metadata.max_encoder_seq_len,
                attn_metadata.seq_lens, causal_mask)
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")


424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
class ROCmFlashAttentionImpl(AttentionImpl):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|

    Otherwise, the layout is as follows:
    |<------------------ num_generation_tokens (M) ----------------->|
    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
439
440
441
442
443
444
445
446
447

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|	
    |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
448
449
450
451
452
453
454
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
455
456
457
458
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
459
        blocksparse_params: Optional[Dict[str, Any]] = None,
460
        logits_soft_cap: Optional[float] = None,
461
        attn_type: str = AttentionType.DECODER,
462
    ) -> None:
463
464
465
        if blocksparse_params is not None:
            raise ValueError(
                "ROCmFlashAttention does not support blocksparse attention.")
466
467
468
469
470
471
472

        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            self.logits_soft_cap = 0.0
        else:
            self.logits_soft_cap = logits_soft_cap
        self.attn_type = attn_type
473
474
475
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
476
        self.num_kv_heads = num_kv_heads
477
478
479
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
480
481
482
        self.sliding_window = ((sliding_window, sliding_window)
                               if sliding_window is not None else (-1, -1))
        self.kv_cache_dtype = kv_cache_dtype
483
484
485
486

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

487
488
        supported_head_sizes = PagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
489
490
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
491
                f"Supported head sizes are: {supported_head_sizes}.")
492

493
        self.use_naive_attn = False
494
        # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
495
        self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
496
        if self.use_triton_flash_attn:
497
498
499
500
501
502
503
504
            if logits_soft_cap is not None:
                raise ValueError(
                    "ROCm Triton FlashAttention does not support attention"
                    "logits soft capping."
                    " please try using the ROCm CK "
                    "FA backend instead by setting the env var "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")

505
506
507
508
            from vllm.attention.ops.triton_flash_attention import (  # noqa: F401
                triton_attention)
            self.attn_func = triton_attention
            logger.debug("Using Triton FA in ROCmBackend")
509
510
511
512
513
514
            if self.sliding_window != (-1, -1):
                logger.warning("ROCm Triton FA does not currently support "
                               "sliding window attention. If using half "
                               "precision, please try using the ROCm CK "
                               "FA backend instead by setting the env var "
                               "`VLLM_USE_TRITON_FLASH_ATTN=0`")
515
        else:
516
517
            # if not using triton, navi3x/navi21/navi10 do not use flash-attn
            # either
518
            if not current_platform.has_device_capability(90):
519
520
521
522
523
524
525
526
527
528
                self.use_naive_attn = True
            else:
                try:
                    from flash_attn import flash_attn_varlen_func  # noqa: F401
                    self.attn_func = flash_attn_varlen_func
                    logger.debug("Using CK FA in ROCmBackend")
                except ModuleNotFoundError:
                    self.use_naive_attn = True

            if self.use_naive_attn:
529
530
531
532
                if logits_soft_cap is not None:
                    raise ValueError(
                        "ROCm Naive FlashAttention does not support"
                        "attention logits soft capping.")
533

534
535
                self.attn_func = _sdpa_attention
                logger.debug("Using naive (SDPA) attention in ROCmBackend")
536

537
538
539
540
541
542
543
544
545
546
    def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
        """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
        tokens, n_kv_heads, head_dim = x.shape
        return (x[:, :,
                  None, :].expand(tokens, n_kv_heads, n_rep,
                                  head_dim).reshape(tokens, n_kv_heads * n_rep,
                                                    head_dim))

    def forward(
        self,
547
        layer: AttentionLayer,
548
549
550
551
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
552
        attn_metadata: ROCmFlashAttentionMetadata,
553
        output: Optional[torch.Tensor] = None,
554
555
556
    ) -> torch.Tensor:
        """Forward pass with FlashAttention and PagedAttention.

557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
        For decoder-only models: query, key and value must be non-None.

        For encoder/decoder models:
        * ROCmFlashAttentionImpl.forward() may be invoked for both self- and 
            cross-attention layers.
        * For self-attention: query, key and value must be non-None.
        * For cross-attention:
            * Query must be non-None
            * During prefill, key and value must be non-None; key and value
              get cached for use during decode.
            * During decode, key and value may be None, since:
              (1) key and value tensors were cached during prefill, and
              (2) cross-attention key and value tensors do not grow during
                  decode
        
        A note on how the attn_type (attention type enum) argument impacts
        attention forward() behavior:
    
            * DECODER: normal decoder-only behavior;
                use decoder self-attention block table
            * ENCODER: no KV caching; pass encoder sequence
                attributes (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len) to kernel, in lieu of decoder
                sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
            * ENCODER_DECODER: cross-attention behavior;
                use cross-attention block table for caching KVs derived
                from encoder hidden states; since KV sequence lengths
                will match encoder sequence lengths, pass encoder sequence
                attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len)

588
589
590
591
592
        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
593
594
                NOTE: kv_cache will be an empty tensor with shape [0]
                for profiling run.
595
            attn_metadata: Metadata for attention.
596
597
598
599
            attn_type: Select attention type, between encoder attention,
                       decoder self-attention, or encoder/decoder cross-
                       attention. Defaults to decoder self-attention,
                       which is the vLLM default generally
600
601
602
603
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        query = query.view(-1, self.num_heads, self.head_size)
604
605
606
607
608
609
        if key is not None:
            assert value is not None
            key = key.view(-1, self.num_kv_heads, self.head_size)
            value = value.view(-1, self.num_kv_heads, self.head_size)
        else:
            assert value is None
610

611
        if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
612
613
614
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
            if key is not None and value is not None:
                # Reshape the input keys and values and store them in the
                # cache. If kv_cache is not provided, the new key and value
                # tensors are not cached. This happens during the initial
                # memory profiling run.
                PagedAttention.write_to_paged_cache(
                    key,
                    value,
                    key_cache,
                    value_cache,
                    attn_metadata.slot_mapping
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    attn_metadata.cross_slot_mapping,
                    self.kv_cache_dtype,
                    layer._k_scale,
                    layer._v_scale,
                )

        if self.attn_type != AttentionType.ENCODER:
            num_prefill_tokens = attn_metadata.num_prefill_tokens
        else:
            assert attn_metadata.num_encoder_tokens is not None
            num_prefill_tokens = attn_metadata.num_encoder_tokens
638
639
640
641
642
643
644

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]

645
646
647
648
        if key is not None and value is not None \
            and self.attn_type != AttentionType.ENCODER_DECODER:
            key = key[:num_prefill_tokens]
            value = value[:num_prefill_tokens]
649
650

        if prefill_meta := attn_metadata.prefill_metadata:
651
            # Prompt run.
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
            # normal attention and DECODER
            if self.attn_type == AttentionType.DECODER and (
                    kv_cache.numel() == 0 or prefill_meta.block_tables is None
                    or prefill_meta.block_tables.numel() == 0):
                (query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
                 key_max_seq_len, seq_lens,
                 causal_mask) = (prefill_meta.seq_start_loc,
                                 prefill_meta.max_prefill_seq_len,
                                 prefill_meta.seq_start_loc,
                                 prefill_meta.max_prefill_seq_len,
                                 attn_metadata.seq_lens, True)
            # prefix-enabled attention and ENCODER/ENCODER_DECODER
            else:
                (query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
                 key_max_seq_len, seq_lens,
                 causal_mask) = _get_seq_len_block_table_args(
                     prefill_meta, self.attn_type)
            # Prompt run.
670
            if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
671
672
673
                # triton attention
                # When block_tables are not filled, it means q and k are the
                # prompt, and they have the same length.
674
                attn_masks = None
675
                if self.use_triton_flash_attn:
676
677
678
679
                    if self.alibi_slopes is not None:
                        attn_masks = _make_alibi_bias(
                            self.alibi_slopes,
                            query.dtype,
680
                            seq_lens,
681
                            make_attn_mask=False)  # type: ignore
682
683
684
685
686
                    out, _ = self.attn_func(
                        query,
                        key,
                        value,
                        None,
687
688
689
690
691
                        query_seq_start_loc,
                        key_seq_start_loc,
                        query_max_seq_len,
                        key_max_seq_len,
                        causal_mask,
692
                        self.scale,
693
694
                        attn_masks[0][None]
                        if attn_masks is not None else None,
695
696
                    )
                elif self.use_naive_attn:
697
698
699
700
                    if self.num_kv_heads != self.num_heads:
                        # Interleave for MQA workaround.
                        key = self.repeat_kv(key, self.num_queries_per_kv)
                        value = self.repeat_kv(value, self.num_queries_per_kv)
701
702
703
704
705
706
                    if self.alibi_slopes is not None:
                        attn_masks = _make_alibi_bias(
                            self.alibi_slopes,
                            query.dtype,
                            attn_metadata.seq_lens,
                            make_attn_mask=True)  # type: ignore
707
708
709
710
                    query = query.movedim(0, query.dim() - 2)
                    key = key.movedim(0, key.dim() - 2)
                    value = value.movedim(0, value.dim() - 2)
                    # sdpa math backend attention
711
712
713
714
                    out = self.attn_func(
                        query,
                        key,
                        value,
715
716
                        query_seq_start_loc,
                        num_prefill_tokens,
717
718
                        self.num_heads,
                        self.head_size,
719
                        self.scale,
720
                        attn_masks,
721
                    )
722
                else:
723
                    out = self.attn_func(
724
725
726
                        q=query,
                        k=key,
                        v=value,
727
728
                        cu_seqlens_q=query_seq_start_loc,
                        cu_seqlens_k=key_seq_start_loc,
729
                        max_seqlen_q=prefill_meta.max_prefill_seq_len,
730
                        max_seqlen_k=key_max_seq_len,
731
732
                        softmax_scale=self.scale,
                        causal=True,
733
734
                        window_size=self.sliding_window,
                        alibi_slopes=self.alibi_slopes,
735
                        softcap=self.logits_soft_cap,
736
                    )
737
738
739

                # common code for prefill
                assert output[:num_prefill_tokens].shape == out.shape
740
741
742
743
                if output.shape[0] > num_prefill_tokens:
                    output[:num_prefill_tokens] = out
                else:
                    output = out
744
745
            else:
                # prefix-enabled attention
746
                output[:num_prefill_tokens] = PagedAttention.forward_prefix(
747
748
749
                    query,
                    key,
                    value,
750
                    self.kv_cache_dtype,
751
752
                    key_cache,
                    value_cache,
753
                    prefill_meta.block_tables,
754
                    prefill_meta.query_start_loc,
755
756
                    prefill_meta.seq_lens_tensor,
                    prefill_meta.max_query_len,
757
                    self.alibi_slopes,
758
                    self.sliding_window[0],
759
760
                    layer._k_scale,
                    layer._v_scale,
761
                )
762
763

        if decode_meta := attn_metadata.decode_metadata:
764
            # Decoding run.
765
766
767
768
            # Whether to use rocm custom paged attention or not
            num_seqs, num_heads, head_size = decode_query.shape
            block_size = value_cache.shape[3]
            gqa_ratio = num_heads // self.num_kv_heads
769
770
771
            use_custom = _use_rocm_custom_paged_attention(
                decode_query.dtype, head_size, block_size, gqa_ratio,
                decode_meta.max_decode_seq_len)
772
            if use_custom:
773
774
775
776
                max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
                               != AttentionType.ENCODER_DECODER else
                               decode_meta.max_encoder_seq_len)
                assert max_seq_len is not None
777
778
779
780
                max_num_partitions = (
                    (max_seq_len + _PARTITION_SIZE_ROCM - 1) //
                    _PARTITION_SIZE_ROCM)
                assert _PARTITION_SIZE_ROCM % block_size == 0
781
782
783
784
785
786
787
788
789
790
791
                tmp_output = torch.empty(
                    size=(num_seqs, num_heads, max_num_partitions, head_size),
                    dtype=output.dtype,
                    device=output.device,
                )
                exp_sums = torch.empty(
                    size=(num_seqs, num_heads, max_num_partitions),
                    dtype=torch.float32,
                    device=output.device,
                )
                max_logits = torch.empty_like(exp_sums)
792
793
794
795
                if num_prefill_tokens > 0:
                    out = output[num_prefill_tokens:]
                else:
                    out = output
796
                ops.paged_attention_rocm(
797
                    out,
798
799
800
801
802
803
804
805
                    exp_sums,
                    max_logits,
                    tmp_output,
                    decode_query,
                    key_cache,
                    value_cache,
                    self.num_kv_heads,
                    self.scale,
806
807
808
809
810
811
                    decode_meta.block_tables
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.cross_block_tables,
                    decode_meta.seq_lens_tensor
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.encoder_seq_lens_tensor,
812
813
814
815
                    block_size,
                    max_seq_len,
                    self.alibi_slopes,
                    self.kv_cache_dtype,
816
817
                    layer._k_scale,
                    layer._v_scale,
818
819
820
821
822
823
                )
            else:
                output[num_prefill_tokens:] = PagedAttention.forward_decode(
                    decode_query,
                    key_cache,
                    value_cache,
824
825
826
827
828
829
830
831
832
                    decode_meta.block_tables
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.cross_block_tables,
                    decode_meta.seq_lens_tensor
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.encoder_seq_lens_tensor,
                    decode_meta.max_decode_seq_len
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.max_encoder_seq_len,
833
834
835
836
                    self.kv_cache_dtype,
                    self.num_kv_heads,
                    self.scale,
                    self.alibi_slopes,
837
838
                    layer._k_scale,
                    layer._v_scale,
839
                )
840
841

        # Reshape the output tensor.
842
        return output.view(-1, self.num_heads * self.head_size)
843
844


845
def _sdpa_attention(
846
847
848
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
849
    seq_lens: List[int],
850
851
852
    num_tokens: int,
    num_heads: int,
    head_size: int,
853
    scale: float,
854
    attn_masks: Optional[List[torch.Tensor]] = None,
855
856
) -> torch.Tensor:
    start = 0
857
858
859
860
    output = torch.empty((num_tokens, num_heads, head_size),
                         dtype=query.dtype,
                         device=query.device)

861
    for i, seq_len in enumerate(seq_lens):
862
        end = start + seq_len
863
864
865
866
867
868
869
870
        with torch.backends.cuda.sdp_kernel(enable_math=True,
                                            enable_flash=False,
                                            enable_mem_efficient=False):
            sub_out = torch.nn.functional.scaled_dot_product_attention(
                query[:, start:end, :],
                key[:, start:end, :],
                value[:, start:end, :],
                dropout_p=0.0,
871
872
                is_causal=attn_masks is None,
                attn_mask=attn_masks[i] if attn_masks else None,
873
874
875
                scale=scale).movedim(query.dim() - 2, 0)
            output[start:end, :, :] = sub_out
            start = end
876

877
    return output
878
879


880
881
882
def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
                                     block_size: int, gqa_ratio: int,
                                     max_seq_len: int) -> bool:
883
    # rocm custom page attention not support on navi (gfx1*)
884
885
    return (_ON_MI250_MI300 and not _ON_NAVI
            and (qtype == torch.half or qtype == torch.bfloat16)
886
887
888
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)