rocm_flash_attn.py 43.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer ROCm GPUs."""
4
import itertools
5
from dataclasses import dataclass
6
from functools import cache
7
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
8
9
10

import torch

11
import vllm.envs as envs
12
from vllm import _custom_ops as ops
13
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
14
                                              AttentionLayer,
15
                                              AttentionMetadata, AttentionType)
16
17
from vllm.attention.backends.utils import (CommonAttentionState,
                                           CommonMetadataBuilder)
18
19
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
20
from vllm.config import get_current_vllm_config
21
from vllm.logger import init_logger
22
from vllm.platforms import current_platform
23
from vllm.platforms.rocm import use_rocm_custom_paged_attention
24

25
26
27
if TYPE_CHECKING:
    from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

28
logger = init_logger(__name__)
29
_PARTITION_SIZE_ROCM = 256
30

31

32
33
34
35
36
37
38
39
40
@cache
def is_rocm_aiter_paged_attn_enabled() -> bool:
    return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \
        and envs.VLLM_ROCM_USE_AITER \


@cache
def _get_paged_attn_module() -> PagedAttention:
    """
41
    Initializes the appropriate PagedAttention module from `attention/ops`,
42
43
44
    which is used as helper function
    by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.

45
    The choice of attention module depends on whether
46
47
48
49
50
51
52
53
54
55
56
57
    AITER paged attention is enabled:
    - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
    - Otherwise, it defaults to using the original `PagedAttention`.
    """
    if is_rocm_aiter_paged_attn_enabled():
        # Import AITERPagedAttention only when the flag is enabled
        from vllm.attention.ops.rocm_aiter_paged_attn import (
            AITERPagedAttention)
        return AITERPagedAttention()
    return PagedAttention()


58
class ROCmFlashAttentionBackend(AttentionBackend):
59
    accept_output_buffer: bool = True
60

61
62
    @staticmethod
    def get_name() -> str:
63
        return "ROCM_FLASH"
64

65
66
67
68
69
    @staticmethod
    def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
        return ROCmFlashAttentionImpl

    @staticmethod
70
71
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return ROCmFlashAttentionMetadata
72

73
74
75
76
    @staticmethod
    def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
        return ROCmFlashAttentionMetadataBuilder

77
    @staticmethod
78
79
80
81
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

    @staticmethod
82
83
84
85
86
87
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
88
89
90
        paged_attn = _get_paged_attn_module()
        return paged_attn.get_kv_cache_shape(num_blocks, block_size,
                                             num_kv_heads, head_size)
91
92
93
94
95

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
96
        src_to_dst: torch.Tensor,
97
    ) -> None:
98
99
        paged_attn = _get_paged_attn_module()
        paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
100
101
102
103

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
104
        src_to_dists: torch.Tensor,
105
    ) -> None:
106
107
        paged_attn = _get_paged_attn_module()
        paged_attn.copy_blocks(kv_caches, src_to_dists)
108
109
110


@dataclass
111
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
112
113
114
115
116
117
118
    """Metadata for FlashAttentionBackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
119
120
121
122
123
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]
124
125
126
127
128
129
130
131
132
133
134
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool
135

136
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
137
138
139
140
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
141
142
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
143

144
    # Maximum query length in the batch. None for decoding.
145
    max_query_len: Optional[int] = None
146
147
148
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
149
    query_start_loc: Optional[torch.Tensor] = None
150
151
152
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
153
    seq_start_loc: Optional[torch.Tensor] = None
154
155
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
156
    context_lens_tensor: Optional[torch.Tensor] = None
157

158
159
    # Max number of query tokens among request in the batch.
    max_decode_query_len: Optional[int] = None
160

161
162
163
    _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
    _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    # Begin encoder attn & enc/dec cross-attn fields...

    # Encoder sequence lengths representation
    encoder_seq_lens: Optional[List[int]] = None
    encoder_seq_lens_tensor: Optional[torch.Tensor] = None

    # Maximum sequence length among encoder sequences
    max_encoder_seq_len: Optional[int] = None

    # Number of tokens input to encoder
    num_encoder_tokens: Optional[int] = None

    # Cross-attention memory-mapping data structures: slot mapping
    # and block tables
    cross_slot_mapping: Optional[torch.Tensor] = None
    cross_block_tables: Optional[torch.Tensor] = None

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    @property
    def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
            return self._cached_prefill_metadata

        assert self.seq_lens is not None
        assert self.seq_lens_tensor is not None
        assert self.block_tables is not None

        self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
198
199
            multi_modal_placeholder_index_maps=self.
            multi_modal_placeholder_index_maps,
200
            enable_kv_scales_calculation=self.enable_kv_scales_calculation,
201
202
203
204
205
            seq_lens=self.seq_lens[:self.num_prefills],
            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
206
207
208
209
210
211
            query_start_loc=None if self.query_start_loc is None else
            self.query_start_loc[:self.num_prefills + 1],
            seq_start_loc=None if self.seq_start_loc is None else
            self.seq_start_loc[:self.num_prefills + 1],
            context_lens_tensor=None if self.context_lens_tensor is None else
            self.context_lens_tensor[:self.num_prefills],
212
213
            block_tables=self.block_tables[:self.num_prefills],
            use_cuda_graph=False,
214
215
216
217
218
219
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
            return self._cached_decode_metadata
        assert self.block_tables is not None
        assert self.seq_lens_tensor is not None

        self._cached_decode_metadata = ROCmFlashAttentionMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
237
            multi_modal_placeholder_index_maps=None,
238
            enable_kv_scales_calculation=True,
239
240
241
242
243
244
245
246
247
248
            seq_lens=None,
            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
            max_query_len=None,
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
            query_start_loc=None,
            seq_start_loc=None,
            context_lens_tensor=None,
            block_tables=self.block_tables[self.num_prefills:],
            use_cuda_graph=self.use_cuda_graph,
249
250
251
252
253
254
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
255
256
257
258
259
260
        # Batch may be composed of prefill|decodes, adjust query start indices
        # to refer to the start of decodes when the two are split apart.
        # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
        if self._cached_decode_metadata.query_start_loc is not None:
            qs = self._cached_decode_metadata.query_start_loc
            self._cached_decode_metadata.query_start_loc = qs - qs[0]
261
        return self._cached_decode_metadata
262

263
264
    def advance_step(self,
                     model_input: "ModelInputForGPUWithSamplingMetadata",
265
                     sampled_token_ids: Optional[torch.Tensor],
266
267
268
269
                     block_size: int,
                     num_seqs: int,
                     num_queries: int,
                     turn_prefills_into_decodes: bool = False):
270
271
272
        """
        Update metadata in-place to advance one decode step.
        """
273
274
275
276
277
278

        assert not turn_prefills_into_decodes, \
            ("Chunked prefill is not supported with rocm_flash_attn yet."
             "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
             "specific parameter.")

279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        # When using cudagraph, the num_seqs is padded to the next captured
        # batch sized, but num_queries tracks the actual number of requests in
        # the batch. For --enforce-eager mode, num_seqs == num_queries
        if num_seqs != num_queries:
            assert num_seqs > num_queries
            assert self.use_cuda_graph

        assert self.num_prefills == 0
        assert self.num_prefill_tokens == 0
        assert self.num_decode_tokens == num_seqs
        assert self.slot_mapping.shape == (num_seqs, )

        assert self.seq_lens is not None
        assert len(self.seq_lens) == num_seqs
        assert self.seq_lens_tensor is not None
        assert self.seq_lens_tensor.shape == (num_seqs, )
        assert self.max_query_len == 1
        assert self.max_prefill_seq_len == 0
        assert self.max_decode_seq_len == max(self.seq_lens)

        assert self.query_start_loc is not None
        assert self.query_start_loc.shape == (num_queries + 1, )
        assert self.seq_start_loc is not None
        assert self.seq_start_loc.shape == (num_seqs + 1, )

        assert self.context_lens_tensor is not None
        assert self.context_lens_tensor.shape == (num_queries, )

        assert self.block_tables is not None
        assert self.block_tables.shape[0] == num_seqs

        # Update query lengths. Note that we update only queries and not seqs,
        # since tensors may be padded due to captured cuda graph batch size
        for i in range(num_queries):
            self.seq_lens[i] += 1
        self.max_decode_seq_len = max(self.seq_lens)

        ops.advance_step_flashattn(num_seqs=num_seqs,
                                   num_queries=num_queries,
                                   block_size=block_size,
                                   input_tokens=model_input.input_tokens,
                                   sampled_token_ids=sampled_token_ids,
                                   input_positions=model_input.input_positions,
                                   seq_lens=self.seq_lens_tensor,
                                   slot_mapping=self.slot_mapping,
                                   block_tables=self.block_tables)

326

327
328
329
330
331
332
class ROCmFlashAttentionMetadataBuilder(
        CommonMetadataBuilder[ROCmFlashAttentionMetadata]):

    _metadata_cls = ROCmFlashAttentionMetadata


333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
def _make_alibi_bias(alibi_slopes: torch.Tensor,
                     dtype: torch.dtype,
                     seq_lens: Optional[List[int]],
                     make_attn_mask: bool = True) -> List[torch.Tensor]:
    attn_biases = []
    if seq_lens:
        for seq_len in seq_lens:
            bias = torch.arange(seq_len, dtype=dtype)
            # NOTE(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(seq_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
            bias = bias[None, :] - bias[:, None]

            num_heads = alibi_slopes.shape[0]
            bias = bias[None, :].repeat(
                (num_heads, 1, 1)).to(alibi_slopes.device)
            bias.mul_(alibi_slopes[:, None, None])
            if make_attn_mask:
                inf_mask = torch.empty(
                    (1, seq_len, seq_len),
                    dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
                        alibi_slopes.device)
                attn_biases.append((bias + inf_mask).to(dtype))
            else:
                attn_biases.append(bias.to(dtype))

    return attn_biases


364
365
366
367
368
369
370
371
372
373
374
375
def _get_seq_len_block_table_args(
    attn_metadata: ROCmFlashAttentionMetadata,
    attn_type: str,
) -> tuple:
    '''
    The particular choice of sequence-length
    attributes which should be extracted from attn_metadata is dependent
    on the type of attention operation.

    Decoder attn -> select entirely decoder self-attention-related fields
    Encoder/decoder cross-attn -> select encoder sequence lengths
    Encoder attn -> select encoder sequence lengths fields
376
377
    Encoder-only attn -> select prefill sequence lengths with 
        bidirectional attention
378
379
380
381
382
    
    Arguments:

    * attn_metadata: Attention metadata structure associated with attention op
    * attn_type: encoder attention, decoder self-attention,
383
                encoder/decoder cross-attention, encoder-only
384
385
386
387
388

    Returns:

    * Appropriate sequence-lengths tensors for query and key
    * Appropriate max sequence-length scalar
389
    * Causal masking flag
390
391
392
393
394
395
    '''

    if attn_type == AttentionType.ENCODER:
        assert attn_metadata.encoder_seq_lens is not None
        assert attn_metadata.encoder_seq_lens_tensor is not None
        query_seq_start_loc = torch.tensor(
396
            list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
397
398
399
400
401
402
403
404
            device=attn_metadata.encoder_seq_lens_tensor.device,
            dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
        causal_mask = False

        # No block tables associated with encoder attention
        return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
                query_seq_start_loc, attn_metadata.max_encoder_seq_len,
                attn_metadata.encoder_seq_lens, causal_mask)
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420

    elif attn_type == AttentionType.ENCODER_ONLY:
        # For encoder-only models, we use the prefill sequence lengths
        assert attn_metadata.seq_lens is not None
        assert attn_metadata.seq_lens_tensor is not None
        query_seq_start_loc = torch.tensor(
            list(itertools.accumulate([0] + attn_metadata.seq_lens)),
            device=attn_metadata.seq_lens_tensor.device,
            dtype=attn_metadata.seq_lens_tensor.dtype)
        max_seq_len = attn_metadata.max_prefill_seq_len
        # Encoder-only models typically use bidirectional attention
        causal_mask = False

        return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
                max_seq_len, attn_metadata.seq_lens, causal_mask)

421
422
423
424
425
426
    elif attn_type == AttentionType.DECODER:
        # Decoder self-attention
        # Choose max_seq_len based on whether we are in prompt_run
        assert attn_metadata.seq_lens is not None
        assert attn_metadata.seq_lens_tensor is not None
        query_seq_start_loc = torch.tensor(
427
            list(itertools.accumulate([0] + attn_metadata.seq_lens)),
428
429
430
431
432
433
434
435
436
437
438
            device=attn_metadata.seq_lens_tensor.device,
            dtype=attn_metadata.seq_lens_tensor.dtype)
        max_seq_len = attn_metadata.max_prefill_seq_len
        causal_mask = True

        return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
                max_seq_len, attn_metadata.seq_lens, causal_mask)
    elif attn_type == AttentionType.ENCODER_DECODER:
        assert attn_metadata.seq_lens is not None
        assert attn_metadata.encoder_seq_lens_tensor is not None
        query_start_loc = torch.tensor(
439
            list(itertools.accumulate([0] + attn_metadata.seq_lens)),
440
441
442
443
444
445
            device=attn_metadata.encoder_seq_lens_tensor.device,
            dtype=attn_metadata.encoder_seq_lens_tensor.dtype)

        assert attn_metadata.encoder_seq_lens is not None
        assert attn_metadata.seq_lens_tensor is not None
        key_seq_start_loc = torch.tensor(
446
            list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
447
448
449
450
451
452
453
454
455
456
457
458
459
            device=attn_metadata.seq_lens_tensor.device,
            dtype=attn_metadata.seq_lens_tensor.dtype)
        causal_mask = False

        # Enc/dec cross-attention KVs match encoder sequence length;
        # cross-attention utilizes special "cross" block tables
        return (query_start_loc, attn_metadata.max_prefill_seq_len,
                key_seq_start_loc, attn_metadata.max_encoder_seq_len,
                attn_metadata.seq_lens, causal_mask)
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")


460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
class ROCmFlashAttentionImpl(AttentionImpl):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|

    Otherwise, the layout is as follows:
    |<------------------ num_generation_tokens (M) ----------------->|
    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
475
476
477
478
479
480
481
482
483

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|	
    |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
484
485
486
487
488
489
490
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
491
492
493
494
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
495
        blocksparse_params: Optional[Dict[str, Any]] = None,
496
        logits_soft_cap: Optional[float] = None,
497
        attn_type: str = AttentionType.DECODER,
498
        kv_sharing_target_layer_name: Optional[str] = None,
499
        use_irope: bool = False,
500
    ) -> None:
501
502
        if kv_sharing_target_layer_name is not None:
            raise NotImplementedError("KV sharing is not supported in V0.")
503
504
505
506
        if use_irope:
            logger.warning_once(
                "Using irope in ROCm Flash Attention is not supported yet, it "
                "will fail back to global attention for long context.")
507
508
509
        if blocksparse_params is not None:
            raise ValueError(
                "ROCmFlashAttention does not support blocksparse attention.")
510
511
512
513
        if use_irope:
            logger.warning(
                "Using irope in V0 is not supported yet, it will fall back "
                "to global attention for long context.")
514
515
516
517
518
519
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            self.logits_soft_cap = 0.0
        else:
            self.logits_soft_cap = logits_soft_cap
        self.attn_type = attn_type
520
521
522
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
523
        self.num_kv_heads = num_kv_heads
524
525
526
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
527
528
529
        self.sliding_window = ((sliding_window, sliding_window)
                               if sliding_window is not None else (-1, -1))
        self.kv_cache_dtype = kv_cache_dtype
530
531
532
533

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

534
535
536
537
        self.paged_attn_module = _get_paged_attn_module()
        supported_head_sizes = self.paged_attn_module.get_supported_head_sizes(
        )

538
        if head_size not in supported_head_sizes:
539
540
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
541
                f"Supported head sizes are: {supported_head_sizes}.")
542

543
        self.use_naive_attn = False
544
        # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
545
        self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
546
        if self.use_triton_flash_attn:
547
548
549
            if logits_soft_cap is not None:
                raise ValueError(
                    "ROCm Triton FlashAttention does not support attention"
550
                    " logits soft capping."
551
552
553
554
                    " please try using the ROCm CK "
                    "FA backend instead by setting the env var "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")

555
556
            from vllm.attention.ops.triton_flash_attention import (  # noqa: F401
                triton_attention)
557
            self.triton_attn_func = triton_attention
558
            logger.debug("Using Triton FA in ROCmBackend")
559
560
561
562
563
564
            if self.sliding_window != (-1, -1):
                logger.warning("ROCm Triton FA does not currently support "
                               "sliding window attention. If using half "
                               "precision, please try using the ROCm CK "
                               "FA backend instead by setting the env var "
                               "`VLLM_USE_TRITON_FLASH_ATTN=0`")
565
        else:
566
567
            # if not using triton, navi3x/navi21/navi10 do not use flash-attn
            # either
568
            if not current_platform.has_device_capability(90):
569
570
571
572
                self.use_naive_attn = True
            else:
                try:
                    from flash_attn import flash_attn_varlen_func  # noqa: F401
573
                    self.fa_attn_func = flash_attn_varlen_func
574
575
576
577
578
                    logger.debug("Using CK FA in ROCmBackend")
                except ModuleNotFoundError:
                    self.use_naive_attn = True

            if self.use_naive_attn:
579
580
                if logits_soft_cap is not None:
                    raise ValueError(
581
                        "ROCm Naive FlashAttention does not support "
582
                        "attention logits soft capping.")
583

584
                self.sdpa_attn_func = _sdpa_attention
585
                logger.debug("Using naive (SDPA) attention in ROCmBackend")
586

587
        self.aiter_kv_scales_initialized = False
588
589
590
591
        self.force_fp8_attention = (
            get_current_vllm_config() is not None
            and get_current_vllm_config().model_config.override_attention_dtype
            == "fp8")
592

593
594
595
596
597
598
599
600
    def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
        """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
        tokens, n_kv_heads, head_dim = x.shape
        return (x[:, :,
                  None, :].expand(tokens, n_kv_heads, n_rep,
                                  head_dim).reshape(tokens, n_kv_heads * n_rep,
                                                    head_dim))

601
602
603
604
605
606
607
608
609
    def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
                                     group_shape: tuple[int, int]):
        if self.use_triton_flash_attn:
            return dtype == current_platform.fp8_dtype(
            ) and static and group_shape == (-1, -1)  # per-tensor

        # Only supported in the Triton backend
        return False

610
611
    def forward(
        self,
612
        layer: AttentionLayer,
613
614
615
616
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
617
        attn_metadata: ROCmFlashAttentionMetadata,
618
        output: Optional[torch.Tensor] = None,
619
        output_scale: Optional[torch.Tensor] = None,
620
621
622
    ) -> torch.Tensor:
        """Forward pass with FlashAttention and PagedAttention.

623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
        For decoder-only models: query, key and value must be non-None.

        For encoder/decoder models:
        * ROCmFlashAttentionImpl.forward() may be invoked for both self- and 
            cross-attention layers.
        * For self-attention: query, key and value must be non-None.
        * For cross-attention:
            * Query must be non-None
            * During prefill, key and value must be non-None; key and value
              get cached for use during decode.
            * During decode, key and value may be None, since:
              (1) key and value tensors were cached during prefill, and
              (2) cross-attention key and value tensors do not grow during
                  decode
        
        A note on how the attn_type (attention type enum) argument impacts
        attention forward() behavior:
    
            * DECODER: normal decoder-only behavior;
                use decoder self-attention block table
            * ENCODER: no KV caching; pass encoder sequence
                attributes (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len) to kernel, in lieu of decoder
                sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
            * ENCODER_DECODER: cross-attention behavior;
                use cross-attention block table for caching KVs derived
                from encoder hidden states; since KV sequence lengths
                will match encoder sequence lengths, pass encoder sequence
                attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len)
653
654
            * ENCODER_ONLY: bidirectional attention with no KV caching;
                use prefill sequence attributes
655

656
657
658
659
660
        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
661
662
                NOTE: kv_cache will be an empty tensor with shape [0]
                for profiling run.
663
            attn_metadata: Metadata for attention.
664
665
666
667
            attn_type: Select attention type, between encoder attention,
                       decoder self-attention, or encoder/decoder cross-
                       attention. Defaults to decoder self-attention,
                       which is the vLLM default generally
668
669
670
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
671
672
        assert output is not None, "Output tensor must be provided."

673
674
675
676
677
        if output_scale is not None and not self.use_triton_flash_attn:
            raise NotImplementedError(
                "fused output quantization only supported for Triton"
                " implementation in ROCMFlashAttentionImpl for now")

678
        query = query.view(-1, self.num_heads, self.head_size)
679
680
681
682
683
684
        if key is not None:
            assert value is not None
            key = key.view(-1, self.num_kv_heads, self.head_size)
            value = value.view(-1, self.num_kv_heads, self.head_size)
        else:
            assert value is None
685

686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
        paged_attn = self.paged_attn_module

        # Reshaping kv tensors is required for AITER paged attention kernel
        # because it works on a different tensor shape,
        # when the size of one element is one byte (int8/fp8 dtypes).
        # This reshaping is only required on the first forward call
        # and the kv cache must not be empty.
        if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1
                and not self.aiter_kv_scales_initialized
                and kv_cache.shape != torch.Size([0])):
            num_blocks = kv_cache.shape[1]
            block_size = kv_cache.shape[2] // (self.num_kv_heads *
                                               self.head_size)
            k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
                                  dtype=torch.float32,
                                  device=kv_cache.device)
            v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
                                  dtype=torch.float32,
                                  device=kv_cache.device)
            self.aiter_kv_scales_initialized = True
            k_scale.fill_(layer._k_scale.item())
            v_scale.fill_(layer._v_scale.item())
            layer._k_scale = k_scale
            layer._v_scale = v_scale

711
712
713
714
715
        # Only update KV cache for decoder self-attention
        # and encoder-decoder cross-attention
        if self.attn_type not in [
                AttentionType.ENCODER, AttentionType.ENCODER_ONLY
        ] and kv_cache.numel() > 0:
716
            key_cache, value_cache = paged_attn.split_kv_cache(
717
718
                kv_cache, self.num_kv_heads, self.head_size)

719
720
721
722
723
            if key is not None and value is not None:
                # Reshape the input keys and values and store them in the
                # cache. If kv_cache is not provided, the new key and value
                # tensors are not cached. This happens during the initial
                # memory profiling run.
724
                paged_attn.write_to_paged_cache(
725
726
727
728
729
730
731
732
733
734
735
736
737
738
                    key,
                    value,
                    key_cache,
                    value_cache,
                    attn_metadata.slot_mapping
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    attn_metadata.cross_slot_mapping,
                    self.kv_cache_dtype,
                    layer._k_scale,
                    layer._v_scale,
                )

        if self.attn_type != AttentionType.ENCODER:
            num_prefill_tokens = attn_metadata.num_prefill_tokens
739
740
741
        elif self.attn_type == AttentionType.ENCODER_ONLY:
            # For encoder-only models, all tokens are processed in one go
            num_prefill_tokens = query.shape[0]
742
743
744
        else:
            assert attn_metadata.num_encoder_tokens is not None
            num_prefill_tokens = attn_metadata.num_encoder_tokens
745
746
747
748
749
750

        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]

751
752
753
754
        # For encoder-only and encoder models,
        # we process all tokens at once
        # For decoder and encoder-decoder,
        # we may need to limit key/value to prefill tokens
755
        if key is not None and value is not None \
756
757
            and self.attn_type not in [AttentionType.ENCODER_DECODER,
                                       AttentionType.ENCODER_ONLY]:
758
759
            key = key[:num_prefill_tokens]
            value = value[:num_prefill_tokens]
760
761

        if prefill_meta := attn_metadata.prefill_metadata:
762
            # Prompt run.
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
            # normal attention and DECODER
            if self.attn_type == AttentionType.DECODER and (
                    kv_cache.numel() == 0 or prefill_meta.block_tables is None
                    or prefill_meta.block_tables.numel() == 0):
                (query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
                 key_max_seq_len, seq_lens,
                 causal_mask) = (prefill_meta.seq_start_loc,
                                 prefill_meta.max_prefill_seq_len,
                                 prefill_meta.seq_start_loc,
                                 prefill_meta.max_prefill_seq_len,
                                 attn_metadata.seq_lens, True)
            # prefix-enabled attention and ENCODER/ENCODER_DECODER
            else:
                (query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
                 key_max_seq_len, seq_lens,
                 causal_mask) = _get_seq_len_block_table_args(
                     prefill_meta, self.attn_type)
            # Prompt run.
781
            if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
782
783
784
                # triton attention
                # When block_tables are not filled, it means q and k are the
                # prompt, and they have the same length.
785
                attn_masks = None
786
                if self.use_triton_flash_attn:
787
788
789
790
                    if self.alibi_slopes is not None:
                        attn_masks = _make_alibi_bias(
                            self.alibi_slopes,
                            query.dtype,
791
                            seq_lens,
792
                            make_attn_mask=causal_mask)  # type: ignore
793

794
795
                    use_fp8_scales = (layer._q_scale and layer._k_scale
                                      and layer._v_scale and layer._prob_scale
796
797
798
                                      and (self.kv_cache_dtype == "fp8"
                                           or self.force_fp8_attention))

799
                    full_scales = (
800
801
802
                        layer._q_scale.item(), layer._k_scale.item(),
                        layer._v_scale.item(),
                        layer._prob_scale.item()) if use_fp8_scales else None
803
                    self.triton_attn_func(
804
805
806
                        query,
                        key,
                        value,
807
                        output[:num_prefill_tokens],
808
809
810
811
812
                        query_seq_start_loc,
                        key_seq_start_loc,
                        query_max_seq_len,
                        key_max_seq_len,
                        causal_mask,
813
                        self.scale,
814
815
                        attn_masks[0][None]
                        if attn_masks is not None else None,
816
                        full_scales,
817
                        output_scale,
818
819
                    )
                elif self.use_naive_attn:
820
821
822
823
                    if self.num_kv_heads != self.num_heads:
                        # Interleave for MQA workaround.
                        key = self.repeat_kv(key, self.num_queries_per_kv)
                        value = self.repeat_kv(value, self.num_queries_per_kv)
824
825
826
827
828
                    if self.alibi_slopes is not None:
                        attn_masks = _make_alibi_bias(
                            self.alibi_slopes,
                            query.dtype,
                            attn_metadata.seq_lens,
829
                            make_attn_mask=causal_mask)  # type: ignore
830
831
832
833
                    query = query.movedim(0, query.dim() - 2)
                    key = key.movedim(0, key.dim() - 2)
                    value = value.movedim(0, value.dim() - 2)
                    # sdpa math backend attention
834
                    self.sdpa_attn_func(
835
836
837
                        query,
                        key,
                        value,
838
                        output[:num_prefill_tokens],
839
840
                        query_seq_start_loc,
                        num_prefill_tokens,
841
842
                        self.num_heads,
                        self.head_size,
843
                        self.scale,
844
                        attn_masks,
845
                    )
846
                else:
847
848
                    # upstream FA does not support an output arg, copy
                    output[:num_prefill_tokens] = self.fa_attn_func(
849
850
851
                        q=query,
                        k=key,
                        v=value,
852
853
                        cu_seqlens_q=query_seq_start_loc,
                        cu_seqlens_k=key_seq_start_loc,
854
                        max_seqlen_q=prefill_meta.max_prefill_seq_len,
855
                        max_seqlen_k=key_max_seq_len,
856
                        softmax_scale=self.scale,
857
                        causal=causal_mask,
858
859
                        window_size=self.sliding_window,
                        alibi_slopes=self.alibi_slopes,
860
                        softcap=self.logits_soft_cap,
861
                    )
862

863
            else:
864
865
866
                # prefix-enabled attention -
                # not applicable for encoder-only models
                if self.attn_type != AttentionType.ENCODER_ONLY:
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
                    output[:num_prefill_tokens] = paged_attn.forward_prefix(
                        query,
                        key,
                        value,
                        self.kv_cache_dtype,
                        key_cache,
                        value_cache,
                        prefill_meta.block_tables,
                        prefill_meta.query_start_loc,
                        prefill_meta.seq_lens_tensor,
                        prefill_meta.max_query_len,
                        self.alibi_slopes,
                        self.sliding_window[0],
                        layer._k_scale,
                        layer._v_scale,
                    )
883
884
885
        # Skip decode phase for encoder-only models
        if (decode_meta := attn_metadata.decode_metadata) and (
                self.attn_type != AttentionType.ENCODER_ONLY):
886
            # Decoding run.
887
888
889
890
            # Whether to use rocm custom paged attention or not
            num_seqs, num_heads, head_size = decode_query.shape
            block_size = value_cache.shape[3]
            gqa_ratio = num_heads // self.num_kv_heads
891
            use_custom = use_rocm_custom_paged_attention(
892
                decode_query.dtype, head_size, block_size, gqa_ratio,
893
894
                decode_meta.max_decode_seq_len, self.sliding_window,
                self.kv_cache_dtype, self.alibi_slopes)
895

896
            if use_custom:
897
898
899
900
                max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
                               != AttentionType.ENCODER_DECODER else
                               decode_meta.max_encoder_seq_len)
                assert max_seq_len is not None
901
902
903
904
                max_num_partitions = (
                    (max_seq_len + _PARTITION_SIZE_ROCM - 1) //
                    _PARTITION_SIZE_ROCM)
                assert _PARTITION_SIZE_ROCM % block_size == 0
905
906
                tmp_output = torch.empty(
                    size=(num_seqs, num_heads, max_num_partitions, head_size),
907
                    dtype=query.dtype,
908
909
910
911
912
913
914
915
                    device=output.device,
                )
                exp_sums = torch.empty(
                    size=(num_seqs, num_heads, max_num_partitions),
                    dtype=torch.float32,
                    device=output.device,
                )
                max_logits = torch.empty_like(exp_sums)
916
917

                query_start_loc = None
918
                ops.paged_attention_rocm(
919
                    output[num_prefill_tokens:],
920
921
922
923
924
925
926
927
                    exp_sums,
                    max_logits,
                    tmp_output,
                    decode_query,
                    key_cache,
                    value_cache,
                    self.num_kv_heads,
                    self.scale,
928
929
930
931
932
933
                    decode_meta.block_tables
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.cross_block_tables,
                    decode_meta.seq_lens_tensor
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.encoder_seq_lens_tensor,
934
                    query_start_loc,
935
936
937
938
                    block_size,
                    max_seq_len,
                    self.alibi_slopes,
                    self.kv_cache_dtype,
939
940
                    layer._k_scale,
                    layer._v_scale,
941
                    output_scale,
942
943
                )
            else:
944
945
946
947
948
949
950
951
                # PagedAttention does not support fused quant, manually quantize
                if output_scale is None:
                    out_pa = output[num_prefill_tokens:]
                else:
                    out_pa = torch.empty_like(output[num_prefill_tokens:],
                                              dtype=query.dtype)

                out_pa[:] = paged_attn.forward_decode(
952
953
954
                    decode_query,
                    key_cache,
                    value_cache,
955
956
957
958
959
960
961
962
963
                    decode_meta.block_tables
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.cross_block_tables,
                    decode_meta.seq_lens_tensor
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.encoder_seq_lens_tensor,
                    decode_meta.max_decode_seq_len
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.max_encoder_seq_len,
964
965
966
967
                    self.kv_cache_dtype,
                    self.num_kv_heads,
                    self.scale,
                    self.alibi_slopes,
968
969
                    layer._k_scale,
                    layer._v_scale,
970
                )
971

972
973
974
975
976
977
978
979
                # Manually perform quantization
                if output_scale is not None:
                    out_uq = out_pa.view(-1, self.num_heads * self.head_size)
                    out_q = output.view(-1, self.num_heads * self.head_size)
                    ops.scaled_fp8_quant(out_uq,
                                         output_scale,
                                         output=out_q[num_prefill_tokens:])

980
        # Reshape the output tensor.
981
        return output.view(-1, self.num_heads * self.head_size)
982
983


984
def _sdpa_attention(
985
986
987
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
988
989
    output: torch.Tensor,
    seq_lens: torch.Tensor,
990
991
992
    num_tokens: int,
    num_heads: int,
    head_size: int,
993
    scale: float,
994
    attn_masks: Optional[List[torch.Tensor]] = None,
995
996
) -> torch.Tensor:
    start = 0
997
998
999
    assert output.shape == (num_tokens, num_heads, head_size)
    assert output.dtype == query.dtype
    assert output.device == query.device
1000

1001
    for i, seq_len in enumerate(seq_lens):
1002
        end = start + seq_len
cyyever's avatar
cyyever committed
1003
1004
        with torch.nn.attention.sdpa_kernel(
                torch.nn.attention.SDPBackend.MATH):
1005
1006
1007
1008
1009
            sub_out = torch.nn.functional.scaled_dot_product_attention(
                query[:, start:end, :],
                key[:, start:end, :],
                value[:, start:end, :],
                dropout_p=0.0,
1010
1011
                is_causal=attn_masks is None,
                attn_mask=attn_masks[i] if attn_masks else None,
1012
1013
1014
                scale=scale).movedim(query.dim() - 2, 0)
            output[start:end, :, :] = sub_out
            start = end
1015

1016
    return output