rocm_flash_attn.py 41.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Attention layer ROCm GPUs."""
4
import itertools
5
from dataclasses import dataclass
6
from functools import cache
7
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
8
9
10

import torch

11
import vllm.envs as envs
12
from vllm import _custom_ops as ops
13
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
14
                                              AttentionLayer,
15
                                              AttentionMetadata, AttentionType)
16
17
from vllm.attention.backends.utils import (CommonAttentionState,
                                           CommonMetadataBuilder)
18
19
20
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
from vllm.logger import init_logger
21
from vllm.platforms import current_platform
22
from vllm.platforms.rocm import use_rocm_custom_paged_attention
23

24
25
26
if TYPE_CHECKING:
    from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

27
logger = init_logger(__name__)
28
_PARTITION_SIZE_ROCM = 256
29

30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@cache
def is_rocm_aiter_paged_attn_enabled() -> bool:
    return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \
        and envs.VLLM_ROCM_USE_AITER \


@cache
def _get_paged_attn_module() -> PagedAttention:
    """
    Initializes the appropriate PagedAttention module from `attention/ops`, 
    which is used as helper function
    by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.

    The choice of attention module depends on whether 
    AITER paged attention is enabled:
    - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
    - Otherwise, it defaults to using the original `PagedAttention`.
    """
    if is_rocm_aiter_paged_attn_enabled():
        # Import AITERPagedAttention only when the flag is enabled
        from vllm.attention.ops.rocm_aiter_paged_attn import (
            AITERPagedAttention)
        return AITERPagedAttention()
    return PagedAttention()


57
class ROCmFlashAttentionBackend(AttentionBackend):
58
    accept_output_buffer: bool = True
59

60
61
    @staticmethod
    def get_name() -> str:
62
        return "ROCM_FLASH"
63

64
65
66
67
68
    @staticmethod
    def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
        return ROCmFlashAttentionImpl

    @staticmethod
69
70
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return ROCmFlashAttentionMetadata
71

72
73
74
75
    @staticmethod
    def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
        return ROCmFlashAttentionMetadataBuilder

76
    @staticmethod
77
78
79
80
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

    @staticmethod
81
82
83
84
85
86
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
87
88
89
        paged_attn = _get_paged_attn_module()
        return paged_attn.get_kv_cache_shape(num_blocks, block_size,
                                             num_kv_heads, head_size)
90
91
92
93
94

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
95
        src_to_dst: torch.Tensor,
96
    ) -> None:
97
98
        paged_attn = _get_paged_attn_module()
        paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
99
100
101
102

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
103
        src_to_dists: torch.Tensor,
104
    ) -> None:
105
106
        paged_attn = _get_paged_attn_module()
        paged_attn.copy_blocks(kv_caches, src_to_dists)
107
108
109


@dataclass
110
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
111
112
113
114
115
116
117
    """Metadata for FlashAttentionBackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
118
119
120
121
122
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]
123
124
125
126
127
128
129
130
131
132
133
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool
134

135
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
136
137
138
139
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
140
141
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
142

143
    # Maximum query length in the batch. None for decoding.
144
    max_query_len: Optional[int] = None
145
146
147
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
148
    query_start_loc: Optional[torch.Tensor] = None
149
150
151
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
152
    seq_start_loc: Optional[torch.Tensor] = None
153
154
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
155
    context_lens_tensor: Optional[torch.Tensor] = None
156

157
158
    # Max number of query tokens among request in the batch.
    max_decode_query_len: Optional[int] = None
159

160
161
162
    _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
    _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    # Begin encoder attn & enc/dec cross-attn fields...

    # Encoder sequence lengths representation
    encoder_seq_lens: Optional[List[int]] = None
    encoder_seq_lens_tensor: Optional[torch.Tensor] = None

    # Maximum sequence length among encoder sequences
    max_encoder_seq_len: Optional[int] = None

    # Number of tokens input to encoder
    num_encoder_tokens: Optional[int] = None

    # Cross-attention memory-mapping data structures: slot mapping
    # and block tables
    cross_slot_mapping: Optional[torch.Tensor] = None
    cross_block_tables: Optional[torch.Tensor] = None

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    @property
    def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
            return self._cached_prefill_metadata

        assert self.seq_lens is not None
        assert self.seq_lens_tensor is not None
        assert self.block_tables is not None

        self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
197
198
            multi_modal_placeholder_index_maps=self.
            multi_modal_placeholder_index_maps,
199
            enable_kv_scales_calculation=self.enable_kv_scales_calculation,
200
201
202
203
204
            seq_lens=self.seq_lens[:self.num_prefills],
            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
205
206
207
208
209
210
            query_start_loc=None if self.query_start_loc is None else
            self.query_start_loc[:self.num_prefills + 1],
            seq_start_loc=None if self.seq_start_loc is None else
            self.seq_start_loc[:self.num_prefills + 1],
            context_lens_tensor=None if self.context_lens_tensor is None else
            self.context_lens_tensor[:self.num_prefills],
211
212
            block_tables=self.block_tables[:self.num_prefills],
            use_cuda_graph=False,
213
214
215
216
217
218
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
            return self._cached_decode_metadata
        assert self.block_tables is not None
        assert self.seq_lens_tensor is not None

        self._cached_decode_metadata = ROCmFlashAttentionMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
236
            multi_modal_placeholder_index_maps=None,
237
            enable_kv_scales_calculation=True,
238
239
240
241
242
243
244
245
246
247
            seq_lens=None,
            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
            max_query_len=None,
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
            query_start_loc=None,
            seq_start_loc=None,
            context_lens_tensor=None,
            block_tables=self.block_tables[self.num_prefills:],
            use_cuda_graph=self.use_cuda_graph,
248
249
250
251
252
253
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
254
255
256
257
258
259
        # Batch may be composed of prefill|decodes, adjust query start indices
        # to refer to the start of decodes when the two are split apart.
        # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
        if self._cached_decode_metadata.query_start_loc is not None:
            qs = self._cached_decode_metadata.query_start_loc
            self._cached_decode_metadata.query_start_loc = qs - qs[0]
260
        return self._cached_decode_metadata
261

262
263
    def advance_step(self,
                     model_input: "ModelInputForGPUWithSamplingMetadata",
264
                     sampled_token_ids: Optional[torch.Tensor],
265
266
267
268
                     block_size: int,
                     num_seqs: int,
                     num_queries: int,
                     turn_prefills_into_decodes: bool = False):
269
270
271
        """
        Update metadata in-place to advance one decode step.
        """
272
273
274
275
276
277

        assert not turn_prefills_into_decodes, \
            ("Chunked prefill is not supported with rocm_flash_attn yet."
             "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
             "specific parameter.")

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
        # When using cudagraph, the num_seqs is padded to the next captured
        # batch sized, but num_queries tracks the actual number of requests in
        # the batch. For --enforce-eager mode, num_seqs == num_queries
        if num_seqs != num_queries:
            assert num_seqs > num_queries
            assert self.use_cuda_graph

        assert self.num_prefills == 0
        assert self.num_prefill_tokens == 0
        assert self.num_decode_tokens == num_seqs
        assert self.slot_mapping.shape == (num_seqs, )

        assert self.seq_lens is not None
        assert len(self.seq_lens) == num_seqs
        assert self.seq_lens_tensor is not None
        assert self.seq_lens_tensor.shape == (num_seqs, )
        assert self.max_query_len == 1
        assert self.max_prefill_seq_len == 0
        assert self.max_decode_seq_len == max(self.seq_lens)

        assert self.query_start_loc is not None
        assert self.query_start_loc.shape == (num_queries + 1, )
        assert self.seq_start_loc is not None
        assert self.seq_start_loc.shape == (num_seqs + 1, )

        assert self.context_lens_tensor is not None
        assert self.context_lens_tensor.shape == (num_queries, )

        assert self.block_tables is not None
        assert self.block_tables.shape[0] == num_seqs

        # Update query lengths. Note that we update only queries and not seqs,
        # since tensors may be padded due to captured cuda graph batch size
        for i in range(num_queries):
            self.seq_lens[i] += 1
        self.max_decode_seq_len = max(self.seq_lens)

        ops.advance_step_flashattn(num_seqs=num_seqs,
                                   num_queries=num_queries,
                                   block_size=block_size,
                                   input_tokens=model_input.input_tokens,
                                   sampled_token_ids=sampled_token_ids,
                                   input_positions=model_input.input_positions,
                                   seq_lens=self.seq_lens_tensor,
                                   slot_mapping=self.slot_mapping,
                                   block_tables=self.block_tables)

325

326
327
328
329
330
331
class ROCmFlashAttentionMetadataBuilder(
        CommonMetadataBuilder[ROCmFlashAttentionMetadata]):

    _metadata_cls = ROCmFlashAttentionMetadata


332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
def _make_alibi_bias(alibi_slopes: torch.Tensor,
                     dtype: torch.dtype,
                     seq_lens: Optional[List[int]],
                     make_attn_mask: bool = True) -> List[torch.Tensor]:
    attn_biases = []
    if seq_lens:
        for seq_len in seq_lens:
            bias = torch.arange(seq_len, dtype=dtype)
            # NOTE(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(seq_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
            bias = bias[None, :] - bias[:, None]

            num_heads = alibi_slopes.shape[0]
            bias = bias[None, :].repeat(
                (num_heads, 1, 1)).to(alibi_slopes.device)
            bias.mul_(alibi_slopes[:, None, None])
            if make_attn_mask:
                inf_mask = torch.empty(
                    (1, seq_len, seq_len),
                    dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
                        alibi_slopes.device)
                attn_biases.append((bias + inf_mask).to(dtype))
            else:
                attn_biases.append(bias.to(dtype))

    return attn_biases


363
364
365
366
367
368
369
370
371
372
373
374
def _get_seq_len_block_table_args(
    attn_metadata: ROCmFlashAttentionMetadata,
    attn_type: str,
) -> tuple:
    '''
    The particular choice of sequence-length
    attributes which should be extracted from attn_metadata is dependent
    on the type of attention operation.

    Decoder attn -> select entirely decoder self-attention-related fields
    Encoder/decoder cross-attn -> select encoder sequence lengths
    Encoder attn -> select encoder sequence lengths fields
375
376
    Encoder-only attn -> select prefill sequence lengths with 
        bidirectional attention
377
378
379
380
381
    
    Arguments:

    * attn_metadata: Attention metadata structure associated with attention op
    * attn_type: encoder attention, decoder self-attention,
382
                encoder/decoder cross-attention, encoder-only
383
384
385
386
387

    Returns:

    * Appropriate sequence-lengths tensors for query and key
    * Appropriate max sequence-length scalar
388
    * Causal masking flag
389
390
391
392
393
394
    '''

    if attn_type == AttentionType.ENCODER:
        assert attn_metadata.encoder_seq_lens is not None
        assert attn_metadata.encoder_seq_lens_tensor is not None
        query_seq_start_loc = torch.tensor(
395
            list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
396
397
398
399
400
401
402
403
            device=attn_metadata.encoder_seq_lens_tensor.device,
            dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
        causal_mask = False

        # No block tables associated with encoder attention
        return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
                query_seq_start_loc, attn_metadata.max_encoder_seq_len,
                attn_metadata.encoder_seq_lens, causal_mask)
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419

    elif attn_type == AttentionType.ENCODER_ONLY:
        # For encoder-only models, we use the prefill sequence lengths
        assert attn_metadata.seq_lens is not None
        assert attn_metadata.seq_lens_tensor is not None
        query_seq_start_loc = torch.tensor(
            list(itertools.accumulate([0] + attn_metadata.seq_lens)),
            device=attn_metadata.seq_lens_tensor.device,
            dtype=attn_metadata.seq_lens_tensor.dtype)
        max_seq_len = attn_metadata.max_prefill_seq_len
        # Encoder-only models typically use bidirectional attention
        causal_mask = False

        return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
                max_seq_len, attn_metadata.seq_lens, causal_mask)

420
421
422
423
424
425
    elif attn_type == AttentionType.DECODER:
        # Decoder self-attention
        # Choose max_seq_len based on whether we are in prompt_run
        assert attn_metadata.seq_lens is not None
        assert attn_metadata.seq_lens_tensor is not None
        query_seq_start_loc = torch.tensor(
426
            list(itertools.accumulate([0] + attn_metadata.seq_lens)),
427
428
429
430
431
432
433
434
435
436
437
            device=attn_metadata.seq_lens_tensor.device,
            dtype=attn_metadata.seq_lens_tensor.dtype)
        max_seq_len = attn_metadata.max_prefill_seq_len
        causal_mask = True

        return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
                max_seq_len, attn_metadata.seq_lens, causal_mask)
    elif attn_type == AttentionType.ENCODER_DECODER:
        assert attn_metadata.seq_lens is not None
        assert attn_metadata.encoder_seq_lens_tensor is not None
        query_start_loc = torch.tensor(
438
            list(itertools.accumulate([0] + attn_metadata.seq_lens)),
439
440
441
442
443
444
            device=attn_metadata.encoder_seq_lens_tensor.device,
            dtype=attn_metadata.encoder_seq_lens_tensor.dtype)

        assert attn_metadata.encoder_seq_lens is not None
        assert attn_metadata.seq_lens_tensor is not None
        key_seq_start_loc = torch.tensor(
445
            list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
446
447
448
449
450
451
452
453
454
455
456
457
458
            device=attn_metadata.seq_lens_tensor.device,
            dtype=attn_metadata.seq_lens_tensor.dtype)
        causal_mask = False

        # Enc/dec cross-attention KVs match encoder sequence length;
        # cross-attention utilizes special "cross" block tables
        return (query_start_loc, attn_metadata.max_prefill_seq_len,
                key_seq_start_loc, attn_metadata.max_encoder_seq_len,
                attn_metadata.seq_lens, causal_mask)
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")


459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
class ROCmFlashAttentionImpl(AttentionImpl):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|

    Otherwise, the layout is as follows:
    |<------------------ num_generation_tokens (M) ----------------->|
    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
474
475
476
477
478
479
480
481
482

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|	
    |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
483
484
485
486
487
488
489
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
490
491
492
493
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
494
        blocksparse_params: Optional[Dict[str, Any]] = None,
495
        logits_soft_cap: Optional[float] = None,
496
        attn_type: str = AttentionType.DECODER,
497
        use_irope: bool = False,
498
    ) -> None:
499
500
501
502
        if use_irope:
            logger.warning_once(
                "Using irope in ROCm Flash Attention is not supported yet, it "
                "will fail back to global attention for long context.")
503
504
505
        if blocksparse_params is not None:
            raise ValueError(
                "ROCmFlashAttention does not support blocksparse attention.")
506
507
508
509
        if use_irope:
            logger.warning(
                "Using irope in V0 is not supported yet, it will fall back "
                "to global attention for long context.")
510
511
512
513
514
515
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            self.logits_soft_cap = 0.0
        else:
            self.logits_soft_cap = logits_soft_cap
        self.attn_type = attn_type
516
517
518
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
519
        self.num_kv_heads = num_kv_heads
520
521
522
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
523
524
525
        self.sliding_window = ((sliding_window, sliding_window)
                               if sliding_window is not None else (-1, -1))
        self.kv_cache_dtype = kv_cache_dtype
526
527
528
529

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

530
531
532
533
        self.paged_attn_module = _get_paged_attn_module()
        supported_head_sizes = self.paged_attn_module.get_supported_head_sizes(
        )

534
        if head_size not in supported_head_sizes:
535
536
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
537
                f"Supported head sizes are: {supported_head_sizes}.")
538

539
        self.use_naive_attn = False
540
        # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
541
        self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
542
        if self.use_triton_flash_attn:
543
544
545
            if logits_soft_cap is not None:
                raise ValueError(
                    "ROCm Triton FlashAttention does not support attention"
546
                    " logits soft capping."
547
548
549
550
                    " please try using the ROCm CK "
                    "FA backend instead by setting the env var "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")

551
552
            from vllm.attention.ops.triton_flash_attention import (  # noqa: F401
                triton_attention)
553
            self.triton_attn_func = triton_attention
554
            logger.debug("Using Triton FA in ROCmBackend")
555
556
557
558
559
560
            if self.sliding_window != (-1, -1):
                logger.warning("ROCm Triton FA does not currently support "
                               "sliding window attention. If using half "
                               "precision, please try using the ROCm CK "
                               "FA backend instead by setting the env var "
                               "`VLLM_USE_TRITON_FLASH_ATTN=0`")
561
        else:
562
563
            # if not using triton, navi3x/navi21/navi10 do not use flash-attn
            # either
564
            if not current_platform.has_device_capability(90):
565
566
567
568
                self.use_naive_attn = True
            else:
                try:
                    from flash_attn import flash_attn_varlen_func  # noqa: F401
569
                    self.fa_attn_func = flash_attn_varlen_func
570
571
572
573
574
                    logger.debug("Using CK FA in ROCmBackend")
                except ModuleNotFoundError:
                    self.use_naive_attn = True

            if self.use_naive_attn:
575
576
                if logits_soft_cap is not None:
                    raise ValueError(
577
                        "ROCm Naive FlashAttention does not support "
578
                        "attention logits soft capping.")
579

580
                self.sdpa_attn_func = _sdpa_attention
581
                logger.debug("Using naive (SDPA) attention in ROCmBackend")
582

583
584
        self.aiter_kv_scales_initialized = False

585
586
587
588
589
590
591
592
593
594
    def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
        """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
        tokens, n_kv_heads, head_dim = x.shape
        return (x[:, :,
                  None, :].expand(tokens, n_kv_heads, n_rep,
                                  head_dim).reshape(tokens, n_kv_heads * n_rep,
                                                    head_dim))

    def forward(
        self,
595
        layer: AttentionLayer,
596
597
598
599
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
600
        attn_metadata: ROCmFlashAttentionMetadata,
601
        output: Optional[torch.Tensor] = None,
602
603
604
    ) -> torch.Tensor:
        """Forward pass with FlashAttention and PagedAttention.

605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
        For decoder-only models: query, key and value must be non-None.

        For encoder/decoder models:
        * ROCmFlashAttentionImpl.forward() may be invoked for both self- and 
            cross-attention layers.
        * For self-attention: query, key and value must be non-None.
        * For cross-attention:
            * Query must be non-None
            * During prefill, key and value must be non-None; key and value
              get cached for use during decode.
            * During decode, key and value may be None, since:
              (1) key and value tensors were cached during prefill, and
              (2) cross-attention key and value tensors do not grow during
                  decode
        
        A note on how the attn_type (attention type enum) argument impacts
        attention forward() behavior:
    
            * DECODER: normal decoder-only behavior;
                use decoder self-attention block table
            * ENCODER: no KV caching; pass encoder sequence
                attributes (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len) to kernel, in lieu of decoder
                sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
            * ENCODER_DECODER: cross-attention behavior;
                use cross-attention block table for caching KVs derived
                from encoder hidden states; since KV sequence lengths
                will match encoder sequence lengths, pass encoder sequence
                attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len)
635
636
            * ENCODER_ONLY: bidirectional attention with no KV caching;
                use prefill sequence attributes
637

638
639
640
641
642
        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
643
644
                NOTE: kv_cache will be an empty tensor with shape [0]
                for profiling run.
645
            attn_metadata: Metadata for attention.
646
647
648
649
            attn_type: Select attention type, between encoder attention,
                       decoder self-attention, or encoder/decoder cross-
                       attention. Defaults to decoder self-attention,
                       which is the vLLM default generally
650
651
652
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
653
654
        assert output is not None, "Output tensor must be provided."

655
        query = query.view(-1, self.num_heads, self.head_size)
656
657
658
659
660
661
        if key is not None:
            assert value is not None
            key = key.view(-1, self.num_kv_heads, self.head_size)
            value = value.view(-1, self.num_kv_heads, self.head_size)
        else:
            assert value is None
662

663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
        paged_attn = self.paged_attn_module

        # Reshaping kv tensors is required for AITER paged attention kernel
        # because it works on a different tensor shape,
        # when the size of one element is one byte (int8/fp8 dtypes).
        # This reshaping is only required on the first forward call
        # and the kv cache must not be empty.
        if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1
                and not self.aiter_kv_scales_initialized
                and kv_cache.shape != torch.Size([0])):
            num_blocks = kv_cache.shape[1]
            block_size = kv_cache.shape[2] // (self.num_kv_heads *
                                               self.head_size)
            k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
                                  dtype=torch.float32,
                                  device=kv_cache.device)
            v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
                                  dtype=torch.float32,
                                  device=kv_cache.device)
            self.aiter_kv_scales_initialized = True
            k_scale.fill_(layer._k_scale.item())
            v_scale.fill_(layer._v_scale.item())
            layer._k_scale = k_scale
            layer._v_scale = v_scale

688
689
690
691
692
        # Only update KV cache for decoder self-attention
        # and encoder-decoder cross-attention
        if self.attn_type not in [
                AttentionType.ENCODER, AttentionType.ENCODER_ONLY
        ] and kv_cache.numel() > 0:
693
            key_cache, value_cache = paged_attn.split_kv_cache(
694
695
                kv_cache, self.num_kv_heads, self.head_size)

696
697
698
699
700
            if key is not None and value is not None:
                # Reshape the input keys and values and store them in the
                # cache. If kv_cache is not provided, the new key and value
                # tensors are not cached. This happens during the initial
                # memory profiling run.
701
                paged_attn.write_to_paged_cache(
702
703
704
705
706
707
708
709
710
711
712
713
714
715
                    key,
                    value,
                    key_cache,
                    value_cache,
                    attn_metadata.slot_mapping
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    attn_metadata.cross_slot_mapping,
                    self.kv_cache_dtype,
                    layer._k_scale,
                    layer._v_scale,
                )

        if self.attn_type != AttentionType.ENCODER:
            num_prefill_tokens = attn_metadata.num_prefill_tokens
716
717
718
        elif self.attn_type == AttentionType.ENCODER_ONLY:
            # For encoder-only models, all tokens are processed in one go
            num_prefill_tokens = query.shape[0]
719
720
721
        else:
            assert attn_metadata.num_encoder_tokens is not None
            num_prefill_tokens = attn_metadata.num_encoder_tokens
722
723
724
725
726
727

        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]

728
729
730
731
        # For encoder-only and encoder models,
        # we process all tokens at once
        # For decoder and encoder-decoder,
        # we may need to limit key/value to prefill tokens
732
        if key is not None and value is not None \
733
734
            and self.attn_type not in [AttentionType.ENCODER_DECODER,
                                       AttentionType.ENCODER_ONLY]:
735
736
            key = key[:num_prefill_tokens]
            value = value[:num_prefill_tokens]
737
738

        if prefill_meta := attn_metadata.prefill_metadata:
739
            # Prompt run.
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
            # normal attention and DECODER
            if self.attn_type == AttentionType.DECODER and (
                    kv_cache.numel() == 0 or prefill_meta.block_tables is None
                    or prefill_meta.block_tables.numel() == 0):
                (query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
                 key_max_seq_len, seq_lens,
                 causal_mask) = (prefill_meta.seq_start_loc,
                                 prefill_meta.max_prefill_seq_len,
                                 prefill_meta.seq_start_loc,
                                 prefill_meta.max_prefill_seq_len,
                                 attn_metadata.seq_lens, True)
            # prefix-enabled attention and ENCODER/ENCODER_DECODER
            else:
                (query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
                 key_max_seq_len, seq_lens,
                 causal_mask) = _get_seq_len_block_table_args(
                     prefill_meta, self.attn_type)
            # Prompt run.
758
            if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
759
760
761
                # triton attention
                # When block_tables are not filled, it means q and k are the
                # prompt, and they have the same length.
762
                attn_masks = None
763
                if self.use_triton_flash_attn:
764
765
766
767
                    if self.alibi_slopes is not None:
                        attn_masks = _make_alibi_bias(
                            self.alibi_slopes,
                            query.dtype,
768
                            seq_lens,
769
                            make_attn_mask=causal_mask)  # type: ignore
770
771
772
773
                    use_fp8_scales = (layer._q_scale and layer._k_scale
                                      and layer._v_scale and layer._prob_scale
                                      and self.kv_cache_dtype == "fp8")
                    full_scales = (
774
775
776
                        layer._q_scale.item(), layer._k_scale.item(),
                        layer._v_scale.item(),
                        layer._prob_scale.item()) if use_fp8_scales else None
777
                    self.triton_attn_func(
778
779
780
                        query,
                        key,
                        value,
781
                        output[:num_prefill_tokens],
782
783
784
785
786
                        query_seq_start_loc,
                        key_seq_start_loc,
                        query_max_seq_len,
                        key_max_seq_len,
                        causal_mask,
787
                        self.scale,
788
789
                        attn_masks[0][None]
                        if attn_masks is not None else None,
790
                        full_scales,
791
792
                    )
                elif self.use_naive_attn:
793
794
795
796
                    if self.num_kv_heads != self.num_heads:
                        # Interleave for MQA workaround.
                        key = self.repeat_kv(key, self.num_queries_per_kv)
                        value = self.repeat_kv(value, self.num_queries_per_kv)
797
798
799
800
801
                    if self.alibi_slopes is not None:
                        attn_masks = _make_alibi_bias(
                            self.alibi_slopes,
                            query.dtype,
                            attn_metadata.seq_lens,
802
                            make_attn_mask=causal_mask)  # type: ignore
803
804
805
806
                    query = query.movedim(0, query.dim() - 2)
                    key = key.movedim(0, key.dim() - 2)
                    value = value.movedim(0, value.dim() - 2)
                    # sdpa math backend attention
807
                    self.sdpa_attn_func(
808
809
810
                        query,
                        key,
                        value,
811
                        output[:num_prefill_tokens],
812
813
                        query_seq_start_loc,
                        num_prefill_tokens,
814
815
                        self.num_heads,
                        self.head_size,
816
                        self.scale,
817
                        attn_masks,
818
                    )
819
                else:
820
821
                    # upstream FA does not support an output arg, copy
                    output[:num_prefill_tokens] = self.fa_attn_func(
822
823
824
                        q=query,
                        k=key,
                        v=value,
825
826
                        cu_seqlens_q=query_seq_start_loc,
                        cu_seqlens_k=key_seq_start_loc,
827
                        max_seqlen_q=prefill_meta.max_prefill_seq_len,
828
                        max_seqlen_k=key_max_seq_len,
829
                        softmax_scale=self.scale,
830
                        causal=causal_mask,
831
832
                        window_size=self.sliding_window,
                        alibi_slopes=self.alibi_slopes,
833
                        softcap=self.logits_soft_cap,
834
                    )
835

836
            else:
837
838
839
                # prefix-enabled attention -
                # not applicable for encoder-only models
                if self.attn_type != AttentionType.ENCODER_ONLY:
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
                    output[:num_prefill_tokens] = paged_attn.forward_prefix(
                        query,
                        key,
                        value,
                        self.kv_cache_dtype,
                        key_cache,
                        value_cache,
                        prefill_meta.block_tables,
                        prefill_meta.query_start_loc,
                        prefill_meta.seq_lens_tensor,
                        prefill_meta.max_query_len,
                        self.alibi_slopes,
                        self.sliding_window[0],
                        layer._k_scale,
                        layer._v_scale,
                    )
856
857
858
        # Skip decode phase for encoder-only models
        if (decode_meta := attn_metadata.decode_metadata) and (
                self.attn_type != AttentionType.ENCODER_ONLY):
859
            # Decoding run.
860
861
862
863
            # Whether to use rocm custom paged attention or not
            num_seqs, num_heads, head_size = decode_query.shape
            block_size = value_cache.shape[3]
            gqa_ratio = num_heads // self.num_kv_heads
864
            use_custom = use_rocm_custom_paged_attention(
865
                decode_query.dtype, head_size, block_size, gqa_ratio,
866
867
                decode_meta.max_decode_seq_len, self.sliding_window,
                self.kv_cache_dtype, self.alibi_slopes)
868
            if use_custom:
869
870
871
872
                max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
                               != AttentionType.ENCODER_DECODER else
                               decode_meta.max_encoder_seq_len)
                assert max_seq_len is not None
873
874
875
876
                max_num_partitions = (
                    (max_seq_len + _PARTITION_SIZE_ROCM - 1) //
                    _PARTITION_SIZE_ROCM)
                assert _PARTITION_SIZE_ROCM % block_size == 0
877
878
879
880
881
882
883
884
885
886
887
                tmp_output = torch.empty(
                    size=(num_seqs, num_heads, max_num_partitions, head_size),
                    dtype=output.dtype,
                    device=output.device,
                )
                exp_sums = torch.empty(
                    size=(num_seqs, num_heads, max_num_partitions),
                    dtype=torch.float32,
                    device=output.device,
                )
                max_logits = torch.empty_like(exp_sums)
888
889

                query_start_loc = None
890
                ops.paged_attention_rocm(
891
                    output[num_prefill_tokens:],
892
893
894
895
896
897
898
899
                    exp_sums,
                    max_logits,
                    tmp_output,
                    decode_query,
                    key_cache,
                    value_cache,
                    self.num_kv_heads,
                    self.scale,
900
901
902
903
904
905
                    decode_meta.block_tables
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.cross_block_tables,
                    decode_meta.seq_lens_tensor
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.encoder_seq_lens_tensor,
906
                    query_start_loc,
907
908
909
910
                    block_size,
                    max_seq_len,
                    self.alibi_slopes,
                    self.kv_cache_dtype,
911
912
                    layer._k_scale,
                    layer._v_scale,
913
914
                )
            else:
915
                output[num_prefill_tokens:] = paged_attn.forward_decode(
916
917
918
                    decode_query,
                    key_cache,
                    value_cache,
919
920
921
922
923
924
925
926
927
                    decode_meta.block_tables
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.cross_block_tables,
                    decode_meta.seq_lens_tensor
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.encoder_seq_lens_tensor,
                    decode_meta.max_decode_seq_len
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.max_encoder_seq_len,
928
929
930
931
                    self.kv_cache_dtype,
                    self.num_kv_heads,
                    self.scale,
                    self.alibi_slopes,
932
933
                    layer._k_scale,
                    layer._v_scale,
934
                )
935
936

        # Reshape the output tensor.
937
        return output.view(-1, self.num_heads * self.head_size)
938
939


940
def _sdpa_attention(
941
942
943
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
944
945
    output: torch.Tensor,
    seq_lens: torch.Tensor,
946
947
948
    num_tokens: int,
    num_heads: int,
    head_size: int,
949
    scale: float,
950
    attn_masks: Optional[List[torch.Tensor]] = None,
951
952
) -> torch.Tensor:
    start = 0
953
954
955
    assert output.shape == (num_tokens, num_heads, head_size)
    assert output.dtype == query.dtype
    assert output.device == query.device
956

957
    for i, seq_len in enumerate(seq_lens):
958
        end = start + seq_len
cyyever's avatar
cyyever committed
959
960
        with torch.nn.attention.sdpa_kernel(
                torch.nn.attention.SDPBackend.MATH):
961
962
963
964
965
            sub_out = torch.nn.functional.scaled_dot_product_attention(
                query[:, start:end, :],
                key[:, start:end, :],
                value[:, start:end, :],
                dropout_p=0.0,
966
967
                is_causal=attn_masks is None,
                attn_mask=attn_masks[i] if attn_masks else None,
968
969
970
                scale=scale).movedim(query.dim() - 2, 0)
            output[start:end, :, :] = sub_out
            start = end
971

972
    return output