rocm_flash_attn.py 28 KB
Newer Older
1
2
"""Attention layer ROCm GPUs."""
from dataclasses import dataclass
3
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
4
5
6

import torch

7
import vllm.envs as envs
8
from vllm import _custom_ops as ops
9
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10
                                              AttentionLayer,
11
                                              AttentionMetadata, AttentionType)
12
13
from vllm.attention.backends.utils import (CommonAttentionState,
                                           CommonMetadataBuilder)
14
15
16
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
from vllm.logger import init_logger
17
from vllm.platforms import current_platform
18

19
20
21
if TYPE_CHECKING:
    from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

22
23
logger = init_logger(__name__)

24
_PARTITION_SIZE_ROCM = 512
25
26
27
28
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
_ON_NAVI = "gfx1" in _GPU_ARCH
_ON_MI250_MI300 = any(arch in _GPU_ARCH
                      for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"])
29

30
31
32

class ROCmFlashAttentionBackend(AttentionBackend):

33
34
    @staticmethod
    def get_name() -> str:
35
        return "ROCM_FLASH"
36

37
38
39
40
41
    @staticmethod
    def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
        return ROCmFlashAttentionImpl

    @staticmethod
42
43
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return ROCmFlashAttentionMetadata
44

45
46
47
48
    @staticmethod
    def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
        return ROCmFlashAttentionMetadataBuilder

49
    @staticmethod
50
51
52
53
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

    @staticmethod
54
55
56
57
58
59
60
61
62
63
64
65
66
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
67
        src_to_dst: torch.Tensor,
68
69
70
71
72
73
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
74
        src_to_dists: torch.Tensor,
75
76
77
78
79
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
80
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
81
82
83
84
85
86
87
    """Metadata for FlashAttentionBackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
88
89
90
91
92
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]
93

94
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
95
96
97
98
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
99
100
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
101

102
    # Maximum query length in the batch. None for decoding.
103
    max_query_len: Optional[int]
104
105
106
107
108
109
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int
110
111
112
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
113
    query_start_loc: Optional[torch.Tensor]
114
115
116
117
118
119
120
121
122
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    seq_start_loc: Optional[torch.Tensor]

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool
123

124
125
126
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
    context_lens_tensor: Optional[torch.Tensor]
127

128
129
    # Max number of query tokens among request in the batch.
    max_decode_query_len: Optional[int] = None
130

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
    _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None

    @property
    def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
            return self._cached_prefill_metadata

        assert self.seq_lens is not None
        assert self.seq_lens_tensor is not None
        assert self.query_start_loc is not None
        assert self.context_lens_tensor is not None
        assert self.block_tables is not None
        assert self.seq_start_loc is not None

        self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
154
155
            multi_modal_placeholder_index_maps=self.
            multi_modal_placeholder_index_maps,
156
            enable_kv_scales_calculation=self.enable_kv_scales_calculation,
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
            seq_lens=self.seq_lens[:self.num_prefills],
            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
            query_start_loc=self.query_start_loc[:self.num_prefills + 1],
            seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
            context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
            block_tables=self.block_tables[:self.num_prefills],
            use_cuda_graph=False,
        )
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
            return self._cached_decode_metadata
        assert self.block_tables is not None
        assert self.seq_lens_tensor is not None

        self._cached_decode_metadata = ROCmFlashAttentionMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
185
            multi_modal_placeholder_index_maps=None,
186
            enable_kv_scales_calculation=True,
187
188
189
190
191
192
193
194
195
196
197
            seq_lens=None,
            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
            max_query_len=None,
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
            query_start_loc=None,
            seq_start_loc=None,
            context_lens_tensor=None,
            block_tables=self.block_tables[self.num_prefills:],
            use_cuda_graph=self.use_cuda_graph,
        )
198
199
200
201
202
203
        # Batch may be composed of prefill|decodes, adjust query start indices
        # to refer to the start of decodes when the two are split apart.
        # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
        if self._cached_decode_metadata.query_start_loc is not None:
            qs = self._cached_decode_metadata.query_start_loc
            self._cached_decode_metadata.query_start_loc = qs - qs[0]
204
        return self._cached_decode_metadata
205

206
207
    def advance_step(self,
                     model_input: "ModelInputForGPUWithSamplingMetadata",
208
                     sampled_token_ids: Optional[torch.Tensor],
209
210
211
212
                     block_size: int,
                     num_seqs: int,
                     num_queries: int,
                     turn_prefills_into_decodes: bool = False):
213
214
215
        """
        Update metadata in-place to advance one decode step.
        """
216
217
218
219
220
221

        assert not turn_prefills_into_decodes, \
            ("Chunked prefill is not supported with rocm_flash_attn yet."
             "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
             "specific parameter.")

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        # When using cudagraph, the num_seqs is padded to the next captured
        # batch sized, but num_queries tracks the actual number of requests in
        # the batch. For --enforce-eager mode, num_seqs == num_queries
        if num_seqs != num_queries:
            assert num_seqs > num_queries
            assert self.use_cuda_graph

        assert self.num_prefills == 0
        assert self.num_prefill_tokens == 0
        assert self.num_decode_tokens == num_seqs
        assert self.slot_mapping.shape == (num_seqs, )

        assert self.seq_lens is not None
        assert len(self.seq_lens) == num_seqs
        assert self.seq_lens_tensor is not None
        assert self.seq_lens_tensor.shape == (num_seqs, )
        assert self.max_query_len == 1
        assert self.max_prefill_seq_len == 0
        assert self.max_decode_seq_len == max(self.seq_lens)

        assert self.query_start_loc is not None
        assert self.query_start_loc.shape == (num_queries + 1, )
        assert self.seq_start_loc is not None
        assert self.seq_start_loc.shape == (num_seqs + 1, )

        assert self.context_lens_tensor is not None
        assert self.context_lens_tensor.shape == (num_queries, )

        assert self.block_tables is not None
        assert self.block_tables.shape[0] == num_seqs

        # Update query lengths. Note that we update only queries and not seqs,
        # since tensors may be padded due to captured cuda graph batch size
        for i in range(num_queries):
            self.seq_lens[i] += 1
        self.max_decode_seq_len = max(self.seq_lens)

        ops.advance_step_flashattn(num_seqs=num_seqs,
                                   num_queries=num_queries,
                                   block_size=block_size,
                                   input_tokens=model_input.input_tokens,
                                   sampled_token_ids=sampled_token_ids,
                                   input_positions=model_input.input_positions,
                                   seq_lens=self.seq_lens_tensor,
                                   slot_mapping=self.slot_mapping,
                                   block_tables=self.block_tables)

269

270
271
272
273
274
275
class ROCmFlashAttentionMetadataBuilder(
        CommonMetadataBuilder[ROCmFlashAttentionMetadata]):

    _metadata_cls = ROCmFlashAttentionMetadata


276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def _make_alibi_bias(alibi_slopes: torch.Tensor,
                     dtype: torch.dtype,
                     seq_lens: Optional[List[int]],
                     make_attn_mask: bool = True) -> List[torch.Tensor]:
    attn_biases = []
    if seq_lens:
        for seq_len in seq_lens:
            bias = torch.arange(seq_len, dtype=dtype)
            # NOTE(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(seq_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
            bias = bias[None, :] - bias[:, None]

            num_heads = alibi_slopes.shape[0]
            bias = bias[None, :].repeat(
                (num_heads, 1, 1)).to(alibi_slopes.device)
            bias.mul_(alibi_slopes[:, None, None])
            if make_attn_mask:
                inf_mask = torch.empty(
                    (1, seq_len, seq_len),
                    dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
                        alibi_slopes.device)
                attn_biases.append((bias + inf_mask).to(dtype))
            else:
                attn_biases.append(bias.to(dtype))

    return attn_biases


307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
class ROCmFlashAttentionImpl(AttentionImpl):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|

    Otherwise, the layout is as follows:
    |<------------------ num_generation_tokens (M) ----------------->|
    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
322
323
324
325
326
327
328
329
330

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|	
    |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
331
332
333
334
335
336
337
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
338
339
340
341
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
342
        blocksparse_params: Optional[Dict[str, Any]] = None,
343
        logits_soft_cap: Optional[float] = None,
344
        attn_type: str = AttentionType.DECODER,
345
    ) -> None:
346
347
348
349
350
351
352
        if blocksparse_params is not None:
            raise ValueError(
                "ROCmFlashAttention does not support blocksparse attention.")
        if logits_soft_cap is not None:
            raise ValueError(
                "ROCmFlashAttention does not support attention logits soft "
                "capping.")
353
354
355
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
356
        self.num_kv_heads = num_kv_heads
357
358
359
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
360
361
362
        self.sliding_window = ((sliding_window, sliding_window)
                               if sliding_window is not None else (-1, -1))
        self.kv_cache_dtype = kv_cache_dtype
363
364
365
366

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

367
368
        supported_head_sizes = PagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
369
370
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
371
                f"Supported head sizes are: {supported_head_sizes}.")
372

373
        self.use_naive_attn = False
374
        # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
375
        self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
376
        if self.use_triton_flash_attn:
377
378
379
380
            from vllm.attention.ops.triton_flash_attention import (  # noqa: F401
                triton_attention)
            self.attn_func = triton_attention
            logger.debug("Using Triton FA in ROCmBackend")
381
382
383
384
385
386
            if self.sliding_window != (-1, -1):
                logger.warning("ROCm Triton FA does not currently support "
                               "sliding window attention. If using half "
                               "precision, please try using the ROCm CK "
                               "FA backend instead by setting the env var "
                               "`VLLM_USE_TRITON_FLASH_ATTN=0`")
387
        else:
388
389
            # if not using triton, navi3x/navi21/navi10 do not use flash-attn
            # either
390
            if not current_platform.has_device_capability(90):
391
392
393
394
395
396
397
398
399
400
                self.use_naive_attn = True
            else:
                try:
                    from flash_attn import flash_attn_varlen_func  # noqa: F401
                    self.attn_func = flash_attn_varlen_func
                    logger.debug("Using CK FA in ROCmBackend")
                except ModuleNotFoundError:
                    self.use_naive_attn = True

            if self.use_naive_attn:
401
                self.attn_func = _sdpa_attention
402
                logger.debug("Using naive attention in ROCmBackend")
403

404
405
406
407
408
409
        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "ROCmFlashAttentionImpl")

410
411
412
413
414
415
416
417
418
419
    def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
        """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
        tokens, n_kv_heads, head_dim = x.shape
        return (x[:, :,
                  None, :].expand(tokens, n_kv_heads, n_rep,
                                  head_dim).reshape(tokens, n_kv_heads * n_rep,
                                                    head_dim))

    def forward(
        self,
420
        layer: AttentionLayer,
421
422
423
424
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
425
        attn_metadata: ROCmFlashAttentionMetadata,
426
        output: Optional[torch.Tensor] = None,
427
428
429
430
431
432
433
434
    ) -> torch.Tensor:
        """Forward pass with FlashAttention and PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
435
436
                NOTE: kv_cache will be an empty tensor with shape [0]
                for profiling run.
437
438
439
440
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
441
        # Reminder: Please update docs/source/features/compatibility_matrix.md
442
        # If the feature combo become valid
443
444
445
446
447
448
        num_tokens, hidden_size = query.shape
        # Reshape the query, key, and value tensors.
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

449
        if kv_cache.numel() > 0:
450
451
452
453
454
455
456
457
458
459
460
461
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            PagedAttention.write_to_paged_cache(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
462
                self.kv_cache_dtype,
463
464
                layer._k_scale,
                layer._v_scale,
465
466
            )

467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        num_prefill_tokens = attn_metadata.num_prefill_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
        assert value.shape[0] == num_prefill_tokens + num_decode_tokens

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]
        key = key[:num_prefill_tokens]
        value = value[:num_prefill_tokens]

        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens

        if prefill_meta := attn_metadata.prefill_metadata:
484
            # Prompt run.
485
            assert prefill_meta.seq_lens is not None
486
            if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
487
488
489
                # triton attention
                # When block_tables are not filled, it means q and k are the
                # prompt, and they have the same length.
490
                attn_masks = None
491
                if self.use_triton_flash_attn:
492
493
494
495
496
497
                    if self.alibi_slopes is not None:
                        attn_masks = _make_alibi_bias(
                            self.alibi_slopes,
                            query.dtype,
                            attn_metadata.seq_lens,
                            make_attn_mask=False)  # type: ignore
498
499
500
501
502
503
504
                    out, _ = self.attn_func(
                        query,
                        key,
                        value,
                        None,
                        prefill_meta.seq_start_loc,
                        prefill_meta.seq_start_loc,
505
506
                        prefill_meta.max_prefill_seq_len,
                        prefill_meta.max_prefill_seq_len,
507
508
                        True,
                        self.scale,
509
510
                        attn_masks[0][None]
                        if attn_masks is not None else None,
511
512
                    )
                elif self.use_naive_attn:
513
514
515
516
                    if self.num_kv_heads != self.num_heads:
                        # Interleave for MQA workaround.
                        key = self.repeat_kv(key, self.num_queries_per_kv)
                        value = self.repeat_kv(value, self.num_queries_per_kv)
517
518
519
520
521
522
                    if self.alibi_slopes is not None:
                        attn_masks = _make_alibi_bias(
                            self.alibi_slopes,
                            query.dtype,
                            attn_metadata.seq_lens,
                            make_attn_mask=True)  # type: ignore
523
524
525
526
                    query = query.movedim(0, query.dim() - 2)
                    key = key.movedim(0, key.dim() - 2)
                    value = value.movedim(0, value.dim() - 2)
                    # sdpa math backend attention
527
528
529
530
                    out = self.attn_func(
                        query,
                        key,
                        value,
531
                        prefill_meta.seq_lens,
532
533
534
                        num_tokens,
                        self.num_heads,
                        self.head_size,
535
                        self.scale,
536
                        attn_masks,
537
                    )
538
                else:
539
                    out = self.attn_func(
540
541
542
                        q=query,
                        k=key,
                        v=value,
543
544
                        cu_seqlens_q=prefill_meta.seq_start_loc,
                        cu_seqlens_k=prefill_meta.seq_start_loc,
545
546
                        max_seqlen_q=prefill_meta.max_prefill_seq_len,
                        max_seqlen_k=prefill_meta.max_prefill_seq_len,
547
548
                        softmax_scale=self.scale,
                        causal=True,
549
550
                        window_size=self.sliding_window,
                        alibi_slopes=self.alibi_slopes,
551
                    )
552
553
554
555

                # common code for prefill
                assert output[:num_prefill_tokens].shape == out.shape
                output[:num_prefill_tokens] = out
556
557
            else:
                # prefix-enabled attention
558
                output[:num_prefill_tokens] = PagedAttention.forward_prefix(
559
560
561
                    query,
                    key,
                    value,
562
                    self.kv_cache_dtype,
563
564
                    key_cache,
                    value_cache,
565
                    prefill_meta.block_tables,
566
                    prefill_meta.query_start_loc,
567
568
569
                    prefill_meta.seq_lens_tensor,
                    prefill_meta.context_lens_tensor,
                    prefill_meta.max_query_len,
570
                    self.alibi_slopes,
571
                    self.sliding_window[0],
572
573
                    layer._k_scale,
                    layer._v_scale,
574
                )
575
576

        if decode_meta := attn_metadata.decode_metadata:
577
            # Decoding run.
578
579
580
581
            # Whether to use rocm custom paged attention or not
            num_seqs, num_heads, head_size = decode_query.shape
            block_size = value_cache.shape[3]
            gqa_ratio = num_heads // self.num_kv_heads
582
583
584
            use_custom = _use_rocm_custom_paged_attention(
                decode_query.dtype, head_size, block_size, gqa_ratio,
                decode_meta.max_decode_seq_len)
585
586
            if use_custom:
                max_seq_len = decode_meta.max_decode_seq_len
587
588
589
590
                max_num_partitions = (
                    (max_seq_len + _PARTITION_SIZE_ROCM - 1) //
                    _PARTITION_SIZE_ROCM)
                assert _PARTITION_SIZE_ROCM % block_size == 0
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
                tmp_output = torch.empty(
                    size=(num_seqs, num_heads, max_num_partitions, head_size),
                    dtype=output.dtype,
                    device=output.device,
                )
                exp_sums = torch.empty(
                    size=(num_seqs, num_heads, max_num_partitions),
                    dtype=torch.float32,
                    device=output.device,
                )
                max_logits = torch.empty_like(exp_sums)
                ops.paged_attention_rocm(
                    output[num_prefill_tokens:],
                    exp_sums,
                    max_logits,
                    tmp_output,
                    decode_query,
                    key_cache,
                    value_cache,
                    self.num_kv_heads,
                    self.scale,
                    decode_meta.block_tables,
                    decode_meta.seq_lens_tensor,
                    block_size,
                    max_seq_len,
                    self.alibi_slopes,
                    self.kv_cache_dtype,
618
619
                    layer._k_scale,
                    layer._v_scale,
620
621
622
623
624
625
626
627
628
629
630
631
632
                )
            else:
                output[num_prefill_tokens:] = PagedAttention.forward_decode(
                    decode_query,
                    key_cache,
                    value_cache,
                    decode_meta.block_tables,
                    decode_meta.seq_lens_tensor,
                    decode_meta.max_decode_seq_len,
                    self.kv_cache_dtype,
                    self.num_kv_heads,
                    self.scale,
                    self.alibi_slopes,
633
634
                    layer._k_scale,
                    layer._v_scale,
635
                )
636
637
638
639
640

        # Reshape the output tensor.
        return output.view(num_tokens, hidden_size)


641
def _sdpa_attention(
642
643
644
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
645
    seq_lens: List[int],
646
647
648
    num_tokens: int,
    num_heads: int,
    head_size: int,
649
    scale: float,
650
    attn_masks: Optional[List[torch.Tensor]] = None,
651
652
) -> torch.Tensor:
    start = 0
653
654
655
656
    output = torch.empty((num_tokens, num_heads, head_size),
                         dtype=query.dtype,
                         device=query.device)

657
    for i, seq_len in enumerate(seq_lens):
658
        end = start + seq_len
659
660
661
662
663
664
665
666
        with torch.backends.cuda.sdp_kernel(enable_math=True,
                                            enable_flash=False,
                                            enable_mem_efficient=False):
            sub_out = torch.nn.functional.scaled_dot_product_attention(
                query[:, start:end, :],
                key[:, start:end, :],
                value[:, start:end, :],
                dropout_p=0.0,
667
668
                is_causal=attn_masks is None,
                attn_mask=attn_masks[i] if attn_masks else None,
669
670
671
                scale=scale).movedim(query.dim() - 2, 0)
            output[start:end, :, :] = sub_out
            start = end
672

673
    return output
674
675


676
677
678
def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
                                     block_size: int, gqa_ratio: int,
                                     max_seq_len: int) -> bool:
679
    # rocm custom page attention not support on navi (gfx1*)
680
681
    return (_ON_MI250_MI300 and not _ON_NAVI
            and (qtype == torch.half or qtype == torch.bfloat16)
682
683
684
            and (head_size == 64 or head_size == 128)
            and (block_size == 16 or block_size == 32)
            and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)