rocm_flash_attn.py 38.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Attention layer ROCm GPUs."""
3
import itertools
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
6
7
8

import torch

9
import vllm.envs as envs
10
from vllm import _custom_ops as ops
11
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
12
                                              AttentionLayer,
13
                                              AttentionMetadata, AttentionType)
14
15
from vllm.attention.backends.utils import (CommonAttentionState,
                                           CommonMetadataBuilder)
16
17
18
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
from vllm.logger import init_logger
19
from vllm.platforms import current_platform
20
from vllm.platforms.rocm import use_rocm_custom_paged_attention
21

22
23
24
if TYPE_CHECKING:
    from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

25
logger = init_logger(__name__)
26
_PARTITION_SIZE_ROCM = 256
27

28
29
30

class ROCmFlashAttentionBackend(AttentionBackend):

31
32
    @staticmethod
    def get_name() -> str:
33
        return "ROCM_FLASH"
34

35
36
37
38
39
    @staticmethod
    def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
        return ROCmFlashAttentionImpl

    @staticmethod
40
41
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return ROCmFlashAttentionMetadata
42

43
44
45
46
    @staticmethod
    def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
        return ROCmFlashAttentionMetadataBuilder

47
    @staticmethod
48
49
50
51
    def get_state_cls() -> Type["CommonAttentionState"]:
        return CommonAttentionState

    @staticmethod
52
53
54
55
56
57
58
59
60
61
62
63
64
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
65
        src_to_dst: torch.Tensor,
66
67
68
69
70
71
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
72
        src_to_dists: torch.Tensor,
73
74
75
76
77
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
78
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
79
80
81
82
83
84
85
    """Metadata for FlashAttentionBackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
86
87
88
89
90
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]
91
92
93
94
95
96
97
98
99
100
101
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool
102

103
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
104
105
106
107
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
108
109
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
110

111
    # Maximum query length in the batch. None for decoding.
112
    max_query_len: Optional[int] = None
113
114
115
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
116
    query_start_loc: Optional[torch.Tensor] = None
117
118
119
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
120
    seq_start_loc: Optional[torch.Tensor] = None
121
122
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
123
    context_lens_tensor: Optional[torch.Tensor] = None
124

125
126
    # Max number of query tokens among request in the batch.
    max_decode_query_len: Optional[int] = None
127

128
129
130
    _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
    _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    # Begin encoder attn & enc/dec cross-attn fields...

    # Encoder sequence lengths representation
    encoder_seq_lens: Optional[List[int]] = None
    encoder_seq_lens_tensor: Optional[torch.Tensor] = None

    # Maximum sequence length among encoder sequences
    max_encoder_seq_len: Optional[int] = None

    # Number of tokens input to encoder
    num_encoder_tokens: Optional[int] = None

    # Cross-attention memory-mapping data structures: slot mapping
    # and block tables
    cross_slot_mapping: Optional[torch.Tensor] = None
    cross_block_tables: Optional[torch.Tensor] = None

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    @property
    def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
            return self._cached_prefill_metadata

        assert self.seq_lens is not None
        assert self.seq_lens_tensor is not None
        assert self.block_tables is not None

        self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
165
166
            multi_modal_placeholder_index_maps=self.
            multi_modal_placeholder_index_maps,
167
            enable_kv_scales_calculation=self.enable_kv_scales_calculation,
168
169
170
171
172
            seq_lens=self.seq_lens[:self.num_prefills],
            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
173
174
175
176
177
178
            query_start_loc=None if self.query_start_loc is None else
            self.query_start_loc[:self.num_prefills + 1],
            seq_start_loc=None if self.seq_start_loc is None else
            self.seq_start_loc[:self.num_prefills + 1],
            context_lens_tensor=None if self.context_lens_tensor is None else
            self.context_lens_tensor[:self.num_prefills],
179
180
            block_tables=self.block_tables[:self.num_prefills],
            use_cuda_graph=False,
181
182
183
184
185
186
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
            return self._cached_decode_metadata
        assert self.block_tables is not None
        assert self.seq_lens_tensor is not None

        self._cached_decode_metadata = ROCmFlashAttentionMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
204
            multi_modal_placeholder_index_maps=None,
205
            enable_kv_scales_calculation=True,
206
207
208
209
210
211
212
213
214
215
            seq_lens=None,
            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
            max_query_len=None,
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
            query_start_loc=None,
            seq_start_loc=None,
            context_lens_tensor=None,
            block_tables=self.block_tables[self.num_prefills:],
            use_cuda_graph=self.use_cuda_graph,
216
217
218
219
220
221
            # Begin encoder & cross attn fields below...
            encoder_seq_lens=self.encoder_seq_lens,
            encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
            max_encoder_seq_len=self.max_encoder_seq_len,
            cross_slot_mapping=self.cross_slot_mapping,
            cross_block_tables=self.cross_block_tables)
222
223
224
225
226
227
        # Batch may be composed of prefill|decodes, adjust query start indices
        # to refer to the start of decodes when the two are split apart.
        # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
        if self._cached_decode_metadata.query_start_loc is not None:
            qs = self._cached_decode_metadata.query_start_loc
            self._cached_decode_metadata.query_start_loc = qs - qs[0]
228
        return self._cached_decode_metadata
229

230
231
    def advance_step(self,
                     model_input: "ModelInputForGPUWithSamplingMetadata",
232
                     sampled_token_ids: Optional[torch.Tensor],
233
234
235
236
                     block_size: int,
                     num_seqs: int,
                     num_queries: int,
                     turn_prefills_into_decodes: bool = False):
237
238
239
        """
        Update metadata in-place to advance one decode step.
        """
240
241
242
243
244
245

        assert not turn_prefills_into_decodes, \
            ("Chunked prefill is not supported with rocm_flash_attn yet."
             "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
             "specific parameter.")

246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
        # When using cudagraph, the num_seqs is padded to the next captured
        # batch sized, but num_queries tracks the actual number of requests in
        # the batch. For --enforce-eager mode, num_seqs == num_queries
        if num_seqs != num_queries:
            assert num_seqs > num_queries
            assert self.use_cuda_graph

        assert self.num_prefills == 0
        assert self.num_prefill_tokens == 0
        assert self.num_decode_tokens == num_seqs
        assert self.slot_mapping.shape == (num_seqs, )

        assert self.seq_lens is not None
        assert len(self.seq_lens) == num_seqs
        assert self.seq_lens_tensor is not None
        assert self.seq_lens_tensor.shape == (num_seqs, )
        assert self.max_query_len == 1
        assert self.max_prefill_seq_len == 0
        assert self.max_decode_seq_len == max(self.seq_lens)

        assert self.query_start_loc is not None
        assert self.query_start_loc.shape == (num_queries + 1, )
        assert self.seq_start_loc is not None
        assert self.seq_start_loc.shape == (num_seqs + 1, )

        assert self.context_lens_tensor is not None
        assert self.context_lens_tensor.shape == (num_queries, )

        assert self.block_tables is not None
        assert self.block_tables.shape[0] == num_seqs

        # Update query lengths. Note that we update only queries and not seqs,
        # since tensors may be padded due to captured cuda graph batch size
        for i in range(num_queries):
            self.seq_lens[i] += 1
        self.max_decode_seq_len = max(self.seq_lens)

        ops.advance_step_flashattn(num_seqs=num_seqs,
                                   num_queries=num_queries,
                                   block_size=block_size,
                                   input_tokens=model_input.input_tokens,
                                   sampled_token_ids=sampled_token_ids,
                                   input_positions=model_input.input_positions,
                                   seq_lens=self.seq_lens_tensor,
                                   slot_mapping=self.slot_mapping,
                                   block_tables=self.block_tables)

293

294
295
296
297
298
299
class ROCmFlashAttentionMetadataBuilder(
        CommonMetadataBuilder[ROCmFlashAttentionMetadata]):

    _metadata_cls = ROCmFlashAttentionMetadata


300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def _make_alibi_bias(alibi_slopes: torch.Tensor,
                     dtype: torch.dtype,
                     seq_lens: Optional[List[int]],
                     make_attn_mask: bool = True) -> List[torch.Tensor]:
    attn_biases = []
    if seq_lens:
        for seq_len in seq_lens:
            bias = torch.arange(seq_len, dtype=dtype)
            # NOTE(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(seq_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
            bias = bias[None, :] - bias[:, None]

            num_heads = alibi_slopes.shape[0]
            bias = bias[None, :].repeat(
                (num_heads, 1, 1)).to(alibi_slopes.device)
            bias.mul_(alibi_slopes[:, None, None])
            if make_attn_mask:
                inf_mask = torch.empty(
                    (1, seq_len, seq_len),
                    dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
                        alibi_slopes.device)
                attn_biases.append((bias + inf_mask).to(dtype))
            else:
                attn_biases.append(bias.to(dtype))

    return attn_biases


331
332
333
334
335
336
337
338
339
340
341
342
def _get_seq_len_block_table_args(
    attn_metadata: ROCmFlashAttentionMetadata,
    attn_type: str,
) -> tuple:
    '''
    The particular choice of sequence-length
    attributes which should be extracted from attn_metadata is dependent
    on the type of attention operation.

    Decoder attn -> select entirely decoder self-attention-related fields
    Encoder/decoder cross-attn -> select encoder sequence lengths
    Encoder attn -> select encoder sequence lengths fields
343
344
    Encoder-only attn -> select prefill sequence lengths with 
        bidirectional attention
345
346
347
348
349
    
    Arguments:

    * attn_metadata: Attention metadata structure associated with attention op
    * attn_type: encoder attention, decoder self-attention,
350
                encoder/decoder cross-attention, encoder-only
351
352
353
354
355

    Returns:

    * Appropriate sequence-lengths tensors for query and key
    * Appropriate max sequence-length scalar
356
    * Causal masking flag
357
358
359
360
361
362
    '''

    if attn_type == AttentionType.ENCODER:
        assert attn_metadata.encoder_seq_lens is not None
        assert attn_metadata.encoder_seq_lens_tensor is not None
        query_seq_start_loc = torch.tensor(
363
            list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
364
365
366
367
368
369
370
371
            device=attn_metadata.encoder_seq_lens_tensor.device,
            dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
        causal_mask = False

        # No block tables associated with encoder attention
        return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
                query_seq_start_loc, attn_metadata.max_encoder_seq_len,
                attn_metadata.encoder_seq_lens, causal_mask)
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387

    elif attn_type == AttentionType.ENCODER_ONLY:
        # For encoder-only models, we use the prefill sequence lengths
        assert attn_metadata.seq_lens is not None
        assert attn_metadata.seq_lens_tensor is not None
        query_seq_start_loc = torch.tensor(
            list(itertools.accumulate([0] + attn_metadata.seq_lens)),
            device=attn_metadata.seq_lens_tensor.device,
            dtype=attn_metadata.seq_lens_tensor.dtype)
        max_seq_len = attn_metadata.max_prefill_seq_len
        # Encoder-only models typically use bidirectional attention
        causal_mask = False

        return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
                max_seq_len, attn_metadata.seq_lens, causal_mask)

388
389
390
391
392
393
    elif attn_type == AttentionType.DECODER:
        # Decoder self-attention
        # Choose max_seq_len based on whether we are in prompt_run
        assert attn_metadata.seq_lens is not None
        assert attn_metadata.seq_lens_tensor is not None
        query_seq_start_loc = torch.tensor(
394
            list(itertools.accumulate([0] + attn_metadata.seq_lens)),
395
396
397
398
399
400
401
402
403
404
405
            device=attn_metadata.seq_lens_tensor.device,
            dtype=attn_metadata.seq_lens_tensor.dtype)
        max_seq_len = attn_metadata.max_prefill_seq_len
        causal_mask = True

        return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
                max_seq_len, attn_metadata.seq_lens, causal_mask)
    elif attn_type == AttentionType.ENCODER_DECODER:
        assert attn_metadata.seq_lens is not None
        assert attn_metadata.encoder_seq_lens_tensor is not None
        query_start_loc = torch.tensor(
406
            list(itertools.accumulate([0] + attn_metadata.seq_lens)),
407
408
409
410
411
412
            device=attn_metadata.encoder_seq_lens_tensor.device,
            dtype=attn_metadata.encoder_seq_lens_tensor.dtype)

        assert attn_metadata.encoder_seq_lens is not None
        assert attn_metadata.seq_lens_tensor is not None
        key_seq_start_loc = torch.tensor(
413
            list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
414
415
416
417
418
419
420
421
422
423
424
425
426
            device=attn_metadata.seq_lens_tensor.device,
            dtype=attn_metadata.seq_lens_tensor.dtype)
        causal_mask = False

        # Enc/dec cross-attention KVs match encoder sequence length;
        # cross-attention utilizes special "cross" block tables
        return (query_start_loc, attn_metadata.max_prefill_seq_len,
                key_seq_start_loc, attn_metadata.max_encoder_seq_len,
                attn_metadata.seq_lens, causal_mask)
    else:
        raise AttributeError(f"Invalid attention type {str(attn_type)}")


427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
class ROCmFlashAttentionImpl(AttentionImpl):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|

    Otherwise, the layout is as follows:
    |<------------------ num_generation_tokens (M) ----------------->|
    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
442
443
444
445
446
447
448
449
450

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|	
    |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
451
452
453
454
455
456
457
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
458
459
460
461
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
462
        blocksparse_params: Optional[Dict[str, Any]] = None,
463
        logits_soft_cap: Optional[float] = None,
464
        attn_type: str = AttentionType.DECODER,
465
        use_irope: bool = False,
466
    ) -> None:
467
468
469
470
        if use_irope:
            logger.warning_once(
                "Using irope in ROCm Flash Attention is not supported yet, it "
                "will fail back to global attention for long context.")
471
472
473
        if blocksparse_params is not None:
            raise ValueError(
                "ROCmFlashAttention does not support blocksparse attention.")
474
475
476
477
        if use_irope:
            logger.warning(
                "Using irope in V0 is not supported yet, it will fall back "
                "to global attention for long context.")
478
479
480
481
482
483
        if logits_soft_cap is None:
            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
            self.logits_soft_cap = 0.0
        else:
            self.logits_soft_cap = logits_soft_cap
        self.attn_type = attn_type
484
485
486
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
487
        self.num_kv_heads = num_kv_heads
488
489
490
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
491
492
493
        self.sliding_window = ((sliding_window, sliding_window)
                               if sliding_window is not None else (-1, -1))
        self.kv_cache_dtype = kv_cache_dtype
494
495
496
497

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

498
499
        supported_head_sizes = PagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
500
501
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
502
                f"Supported head sizes are: {supported_head_sizes}.")
503

504
        self.use_naive_attn = False
505
        # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
506
        self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
507
        if self.use_triton_flash_attn:
508
509
510
            if logits_soft_cap is not None:
                raise ValueError(
                    "ROCm Triton FlashAttention does not support attention"
511
                    " logits soft capping."
512
513
514
515
                    " please try using the ROCm CK "
                    "FA backend instead by setting the env var "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")

516
517
518
519
            from vllm.attention.ops.triton_flash_attention import (  # noqa: F401
                triton_attention)
            self.attn_func = triton_attention
            logger.debug("Using Triton FA in ROCmBackend")
520
521
522
523
524
525
            if self.sliding_window != (-1, -1):
                logger.warning("ROCm Triton FA does not currently support "
                               "sliding window attention. If using half "
                               "precision, please try using the ROCm CK "
                               "FA backend instead by setting the env var "
                               "`VLLM_USE_TRITON_FLASH_ATTN=0`")
526
        else:
527
528
            # if not using triton, navi3x/navi21/navi10 do not use flash-attn
            # either
529
            if not current_platform.has_device_capability(90):
530
531
532
533
534
535
536
537
538
539
                self.use_naive_attn = True
            else:
                try:
                    from flash_attn import flash_attn_varlen_func  # noqa: F401
                    self.attn_func = flash_attn_varlen_func
                    logger.debug("Using CK FA in ROCmBackend")
                except ModuleNotFoundError:
                    self.use_naive_attn = True

            if self.use_naive_attn:
540
541
                if logits_soft_cap is not None:
                    raise ValueError(
542
                        "ROCm Naive FlashAttention does not support "
543
                        "attention logits soft capping.")
544

545
546
                self.attn_func = _sdpa_attention
                logger.debug("Using naive (SDPA) attention in ROCmBackend")
547

548
549
550
551
552
553
554
555
556
557
    def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
        """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
        tokens, n_kv_heads, head_dim = x.shape
        return (x[:, :,
                  None, :].expand(tokens, n_kv_heads, n_rep,
                                  head_dim).reshape(tokens, n_kv_heads * n_rep,
                                                    head_dim))

    def forward(
        self,
558
        layer: AttentionLayer,
559
560
561
562
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
563
        attn_metadata: ROCmFlashAttentionMetadata,
564
        output: Optional[torch.Tensor] = None,
565
566
567
    ) -> torch.Tensor:
        """Forward pass with FlashAttention and PagedAttention.

568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
        For decoder-only models: query, key and value must be non-None.

        For encoder/decoder models:
        * ROCmFlashAttentionImpl.forward() may be invoked for both self- and 
            cross-attention layers.
        * For self-attention: query, key and value must be non-None.
        * For cross-attention:
            * Query must be non-None
            * During prefill, key and value must be non-None; key and value
              get cached for use during decode.
            * During decode, key and value may be None, since:
              (1) key and value tensors were cached during prefill, and
              (2) cross-attention key and value tensors do not grow during
                  decode
        
        A note on how the attn_type (attention type enum) argument impacts
        attention forward() behavior:
    
            * DECODER: normal decoder-only behavior;
                use decoder self-attention block table
            * ENCODER: no KV caching; pass encoder sequence
                attributes (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len) to kernel, in lieu of decoder
                sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
            * ENCODER_DECODER: cross-attention behavior;
                use cross-attention block table for caching KVs derived
                from encoder hidden states; since KV sequence lengths
                will match encoder sequence lengths, pass encoder sequence
                attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
                max_encoder_seq_len)
598
599
            * ENCODER_ONLY: bidirectional attention with no KV caching;
                use prefill sequence attributes
600

601
602
603
604
605
        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
606
607
                NOTE: kv_cache will be an empty tensor with shape [0]
                for profiling run.
608
            attn_metadata: Metadata for attention.
609
610
611
612
            attn_type: Select attention type, between encoder attention,
                       decoder self-attention, or encoder/decoder cross-
                       attention. Defaults to decoder self-attention,
                       which is the vLLM default generally
613
614
615
616
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        query = query.view(-1, self.num_heads, self.head_size)
617
618
619
620
621
622
        if key is not None:
            assert value is not None
            key = key.view(-1, self.num_kv_heads, self.head_size)
            value = value.view(-1, self.num_kv_heads, self.head_size)
        else:
            assert value is None
623

624
625
626
627
628
        # Only update KV cache for decoder self-attention
        # and encoder-decoder cross-attention
        if self.attn_type not in [
                AttentionType.ENCODER, AttentionType.ENCODER_ONLY
        ] and kv_cache.numel() > 0:
629
630
631
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
            if key is not None and value is not None:
                # Reshape the input keys and values and store them in the
                # cache. If kv_cache is not provided, the new key and value
                # tensors are not cached. This happens during the initial
                # memory profiling run.
                PagedAttention.write_to_paged_cache(
                    key,
                    value,
                    key_cache,
                    value_cache,
                    attn_metadata.slot_mapping
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    attn_metadata.cross_slot_mapping,
                    self.kv_cache_dtype,
                    layer._k_scale,
                    layer._v_scale,
                )

        if self.attn_type != AttentionType.ENCODER:
            num_prefill_tokens = attn_metadata.num_prefill_tokens
652
653
654
        elif self.attn_type == AttentionType.ENCODER_ONLY:
            # For encoder-only models, all tokens are processed in one go
            num_prefill_tokens = query.shape[0]
655
656
657
        else:
            assert attn_metadata.num_encoder_tokens is not None
            num_prefill_tokens = attn_metadata.num_encoder_tokens
658
659
660
661
662
663
664

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]

665
666
667
668
        # For encoder-only and encoder models,
        # we process all tokens at once
        # For decoder and encoder-decoder,
        # we may need to limit key/value to prefill tokens
669
        if key is not None and value is not None \
670
671
            and self.attn_type not in [AttentionType.ENCODER_DECODER,
                                       AttentionType.ENCODER_ONLY]:
672
673
            key = key[:num_prefill_tokens]
            value = value[:num_prefill_tokens]
674
675

        if prefill_meta := attn_metadata.prefill_metadata:
676
            # Prompt run.
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
            # normal attention and DECODER
            if self.attn_type == AttentionType.DECODER and (
                    kv_cache.numel() == 0 or prefill_meta.block_tables is None
                    or prefill_meta.block_tables.numel() == 0):
                (query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
                 key_max_seq_len, seq_lens,
                 causal_mask) = (prefill_meta.seq_start_loc,
                                 prefill_meta.max_prefill_seq_len,
                                 prefill_meta.seq_start_loc,
                                 prefill_meta.max_prefill_seq_len,
                                 attn_metadata.seq_lens, True)
            # prefix-enabled attention and ENCODER/ENCODER_DECODER
            else:
                (query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
                 key_max_seq_len, seq_lens,
                 causal_mask) = _get_seq_len_block_table_args(
                     prefill_meta, self.attn_type)
            # Prompt run.
695
            if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
696
697
698
                # triton attention
                # When block_tables are not filled, it means q and k are the
                # prompt, and they have the same length.
699
                attn_masks = None
700
                if self.use_triton_flash_attn:
701
702
703
704
                    if self.alibi_slopes is not None:
                        attn_masks = _make_alibi_bias(
                            self.alibi_slopes,
                            query.dtype,
705
                            seq_lens,
706
                            make_attn_mask=causal_mask)  # type: ignore
707
708
709
710
711
                    out, _ = self.attn_func(
                        query,
                        key,
                        value,
                        None,
712
713
714
715
716
                        query_seq_start_loc,
                        key_seq_start_loc,
                        query_max_seq_len,
                        key_max_seq_len,
                        causal_mask,
717
                        self.scale,
718
719
                        attn_masks[0][None]
                        if attn_masks is not None else None,
720
721
                    )
                elif self.use_naive_attn:
722
723
724
725
                    if self.num_kv_heads != self.num_heads:
                        # Interleave for MQA workaround.
                        key = self.repeat_kv(key, self.num_queries_per_kv)
                        value = self.repeat_kv(value, self.num_queries_per_kv)
726
727
728
729
730
                    if self.alibi_slopes is not None:
                        attn_masks = _make_alibi_bias(
                            self.alibi_slopes,
                            query.dtype,
                            attn_metadata.seq_lens,
731
                            make_attn_mask=causal_mask)  # type: ignore
732
733
734
735
                    query = query.movedim(0, query.dim() - 2)
                    key = key.movedim(0, key.dim() - 2)
                    value = value.movedim(0, value.dim() - 2)
                    # sdpa math backend attention
736
737
738
739
                    out = self.attn_func(
                        query,
                        key,
                        value,
740
741
                        query_seq_start_loc,
                        num_prefill_tokens,
742
743
                        self.num_heads,
                        self.head_size,
744
                        self.scale,
745
                        attn_masks,
746
                    )
747
                else:
748
                    out = self.attn_func(
749
750
751
                        q=query,
                        k=key,
                        v=value,
752
753
                        cu_seqlens_q=query_seq_start_loc,
                        cu_seqlens_k=key_seq_start_loc,
754
                        max_seqlen_q=prefill_meta.max_prefill_seq_len,
755
                        max_seqlen_k=key_max_seq_len,
756
                        softmax_scale=self.scale,
757
                        causal=causal_mask,
758
759
                        window_size=self.sliding_window,
                        alibi_slopes=self.alibi_slopes,
760
                        softcap=self.logits_soft_cap,
761
                    )
762
763
764

                # common code for prefill
                assert output[:num_prefill_tokens].shape == out.shape
765
766
767
768
                if output.shape[0] > num_prefill_tokens:
                    output[:num_prefill_tokens] = out
                else:
                    output = out
769
            else:
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
                # prefix-enabled attention -
                # not applicable for encoder-only models
                if self.attn_type != AttentionType.ENCODER_ONLY:
                    output[:
                           num_prefill_tokens] = PagedAttention.forward_prefix(
                               query,
                               key,
                               value,
                               self.kv_cache_dtype,
                               key_cache,
                               value_cache,
                               prefill_meta.block_tables,
                               prefill_meta.query_start_loc,
                               prefill_meta.seq_lens_tensor,
                               prefill_meta.max_query_len,
                               self.alibi_slopes,
                               self.sliding_window[0],
                               layer._k_scale,
                               layer._v_scale,
                           )
        # Skip decode phase for encoder-only models
        if (decode_meta := attn_metadata.decode_metadata) and (
                self.attn_type != AttentionType.ENCODER_ONLY):
793
            # Decoding run.
794
795
796
797
            # Whether to use rocm custom paged attention or not
            num_seqs, num_heads, head_size = decode_query.shape
            block_size = value_cache.shape[3]
            gqa_ratio = num_heads // self.num_kv_heads
798
            use_custom = use_rocm_custom_paged_attention(
799
                decode_query.dtype, head_size, block_size, gqa_ratio,
800
                decode_meta.max_decode_seq_len, self.sliding_window)
801
            if use_custom:
802
803
804
805
                max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
                               != AttentionType.ENCODER_DECODER else
                               decode_meta.max_encoder_seq_len)
                assert max_seq_len is not None
806
807
808
809
                max_num_partitions = (
                    (max_seq_len + _PARTITION_SIZE_ROCM - 1) //
                    _PARTITION_SIZE_ROCM)
                assert _PARTITION_SIZE_ROCM % block_size == 0
810
811
812
813
814
815
816
817
818
819
820
                tmp_output = torch.empty(
                    size=(num_seqs, num_heads, max_num_partitions, head_size),
                    dtype=output.dtype,
                    device=output.device,
                )
                exp_sums = torch.empty(
                    size=(num_seqs, num_heads, max_num_partitions),
                    dtype=torch.float32,
                    device=output.device,
                )
                max_logits = torch.empty_like(exp_sums)
821
822
823
824
                if num_prefill_tokens > 0:
                    out = output[num_prefill_tokens:]
                else:
                    out = output
825
826

                query_start_loc = None
827
                ops.paged_attention_rocm(
828
                    out,
829
830
831
832
833
834
835
836
                    exp_sums,
                    max_logits,
                    tmp_output,
                    decode_query,
                    key_cache,
                    value_cache,
                    self.num_kv_heads,
                    self.scale,
837
838
839
840
841
842
                    decode_meta.block_tables
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.cross_block_tables,
                    decode_meta.seq_lens_tensor
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.encoder_seq_lens_tensor,
843
                    query_start_loc,
844
845
846
847
                    block_size,
                    max_seq_len,
                    self.alibi_slopes,
                    self.kv_cache_dtype,
848
849
                    layer._k_scale,
                    layer._v_scale,
850
851
852
853
854
855
                )
            else:
                output[num_prefill_tokens:] = PagedAttention.forward_decode(
                    decode_query,
                    key_cache,
                    value_cache,
856
857
858
859
860
861
862
863
864
                    decode_meta.block_tables
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.cross_block_tables,
                    decode_meta.seq_lens_tensor
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.encoder_seq_lens_tensor,
                    decode_meta.max_decode_seq_len
                    if self.attn_type != AttentionType.ENCODER_DECODER else
                    decode_meta.max_encoder_seq_len,
865
866
867
868
                    self.kv_cache_dtype,
                    self.num_kv_heads,
                    self.scale,
                    self.alibi_slopes,
869
870
                    layer._k_scale,
                    layer._v_scale,
871
                )
872
873

        # Reshape the output tensor.
874
        return output.view(-1, self.num_heads * self.head_size)
875
876


877
def _sdpa_attention(
878
879
880
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
881
    seq_lens: List[int],
882
883
884
    num_tokens: int,
    num_heads: int,
    head_size: int,
885
    scale: float,
886
    attn_masks: Optional[List[torch.Tensor]] = None,
887
888
) -> torch.Tensor:
    start = 0
889
890
891
892
    output = torch.empty((num_tokens, num_heads, head_size),
                         dtype=query.dtype,
                         device=query.device)

893
    for i, seq_len in enumerate(seq_lens):
894
        end = start + seq_len
cyyever's avatar
cyyever committed
895
896
        with torch.nn.attention.sdpa_kernel(
                torch.nn.attention.SDPBackend.MATH):
897
898
899
900
901
            sub_out = torch.nn.functional.scaled_dot_product_attention(
                query[:, start:end, :],
                key[:, start:end, :],
                value[:, start:end, :],
                dropout_p=0.0,
902
903
                is_causal=attn_masks is None,
                attn_mask=attn_masks[i] if attn_masks else None,
904
905
906
                scale=scale).movedim(query.dim() - 2, 0)
            output[start:end, :, :] = sub_out
            start = end
907

908
    return output