rocm_flash_attn.py 23.6 KB
Newer Older
1
2
"""Attention layer ROCm GPUs."""
from dataclasses import dataclass
3
from typing import Any, Dict, List, Optional, Tuple, Type
4
5
6

import torch

7
import vllm.envs as envs
8
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
9
                                              AttentionMetadata, AttentionType)
10
from vllm.attention.backends.utils import CommonMetadataBuilder
11
12
13
14
15
16
17
18
19
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
from vllm.logger import init_logger

logger = init_logger(__name__)


class ROCmFlashAttentionBackend(AttentionBackend):

20
21
22
23
    @staticmethod
    def get_name() -> str:
        return "rocm-flash-attn"

24
25
26
27
28
    @staticmethod
    def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
        return ROCmFlashAttentionImpl

    @staticmethod
29
30
    def get_metadata_cls() -> Type["AttentionMetadata"]:
        return ROCmFlashAttentionMetadata
31

32
33
34
35
    @staticmethod
    def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
        return ROCmFlashAttentionMetadataBuilder

36
37
38
39
40
41
42
43
44
45
46
47
48
49
    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
50
        src_to_dst: torch.Tensor,
51
52
53
54
55
56
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
57
        src_to_dists: torch.Tensor,
58
59
60
61
62
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
63
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
64
65
66
67
68
69
70
    """Metadata for FlashAttentionBackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
71
72
73
74
75
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]
76

77
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
78
79
80
81
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
82
83
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
84

85
    # Maximum query length in the batch. None for decoding.
86
    max_query_len: Optional[int]
87
88
89
90
91
92
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int
93
94
95
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
96
    query_start_loc: Optional[torch.Tensor]
97
98
99
100
101
102
103
104
105
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    seq_start_loc: Optional[torch.Tensor]

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool
106
107
108
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
    context_lens_tensor: Optional[torch.Tensor]
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
    _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None

    @property
    def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
            return self._cached_prefill_metadata

        assert self.seq_lens is not None
        assert self.seq_lens_tensor is not None
        assert self.query_start_loc is not None
        assert self.context_lens_tensor is not None
        assert self.block_tables is not None
        assert self.seq_start_loc is not None

        self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
            seq_lens=self.seq_lens[:self.num_prefills],
            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
            query_start_loc=self.query_start_loc[:self.num_prefills + 1],
            seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
            context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
            block_tables=self.block_tables[:self.num_prefills],
            use_cuda_graph=False,
        )
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
            return self._cached_decode_metadata
        assert self.block_tables is not None
        assert self.seq_lens_tensor is not None

        self._cached_decode_metadata = ROCmFlashAttentionMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
            seq_lens=None,
            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
            max_query_len=None,
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
            query_start_loc=None,
            seq_start_loc=None,
            context_lens_tensor=None,
            block_tables=self.block_tables[self.num_prefills:],
            use_cuda_graph=self.use_cuda_graph,
        )
        return self._cached_decode_metadata
172
173


174
175
176
177
178
179
class ROCmFlashAttentionMetadataBuilder(
        CommonMetadataBuilder[ROCmFlashAttentionMetadata]):

    _metadata_cls = ROCmFlashAttentionMetadata


180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def _make_alibi_bias(alibi_slopes: torch.Tensor,
                     dtype: torch.dtype,
                     seq_lens: Optional[List[int]],
                     make_attn_mask: bool = True) -> List[torch.Tensor]:
    attn_biases = []
    if seq_lens:
        for seq_len in seq_lens:
            bias = torch.arange(seq_len, dtype=dtype)
            # NOTE(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(seq_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
            bias = bias[None, :] - bias[:, None]

            num_heads = alibi_slopes.shape[0]
            bias = bias[None, :].repeat(
                (num_heads, 1, 1)).to(alibi_slopes.device)
            bias.mul_(alibi_slopes[:, None, None])
            if make_attn_mask:
                inf_mask = torch.empty(
                    (1, seq_len, seq_len),
                    dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to(
                        alibi_slopes.device)
                attn_biases.append((bias + inf_mask).to(dtype))
            else:
                attn_biases.append(bias.to(dtype))

    return attn_biases


211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
class ROCmFlashAttentionImpl(AttentionImpl):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|

    Otherwise, the layout is as follows:
    |<------------------ num_generation_tokens (M) ----------------->|
    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
226
227
228
229
230
231
232
233
234

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|	
    |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
235
236
237
238
239
240
241
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
242
243
244
245
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
246
        blocksparse_params: Optional[Dict[str, Any]] = None,
247
        logits_soft_cap: Optional[float] = None,
248
    ) -> None:
249
250
251
252
253
254
255
        if blocksparse_params is not None:
            raise ValueError(
                "ROCmFlashAttention does not support blocksparse attention.")
        if logits_soft_cap is not None:
            raise ValueError(
                "ROCmFlashAttention does not support attention logits soft "
                "capping.")
256
257
258
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
259
        self.num_kv_heads = num_kv_heads
260
261
262
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
263
264
265
        self.sliding_window = ((sliding_window, sliding_window)
                               if sliding_window is not None else (-1, -1))
        self.kv_cache_dtype = kv_cache_dtype
266
267
268
269

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

270
271
        supported_head_sizes = PagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
272
273
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
274
                f"Supported head sizes are: {supported_head_sizes}.")
275

276
        self.use_naive_attn = False
277
        # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
278
        self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
279
        # NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen > 8000
280
        self.use_flash_attn_auto = envs.VLLM_USE_FLASH_ATTN_AUTO
281
        if self.use_triton_flash_attn:
282
283
            if self.use_flash_attn_auto:
                from vllm.attention.ops.flash_attn_triton_mqa_gqa import ( 
284
                flash_attn_varlen_func)
285
286
287
288
                self.attn_func_triton = flash_attn_varlen_func
                
                from flash_attn import flash_attn_varlen_func  # noqa: F401
                self.attn_func_ck = flash_attn_varlen_func
289
                logger.debug("When SEQ_LEN > 8000, Use Triton FA in ROCmBackend, otherwise Use CK FA")
290
291
292
293
294
295
296
            else:
                # from vllm.attention.ops.triton_flash_attention import (  # noqa: F401
                #     triton_attention)
                from vllm.attention.ops.flash_attn_triton_mqa_gqa import ( 
                    flash_attn_varlen_func)
                self.attn_func = flash_attn_varlen_func # triton_attention
                logger.debug("Using Triton FA in ROCmBackend")
297
298
299
300
301
302
                if self.sliding_window != (-1, -1):
                    logger.warning("ROCm Triton FA does not currently support "
                                "sliding window attention. If using half "
                                "precision, please try using the ROCm CK "
                                "FA backend instead by setting the env var "
                                "`VLLM_USE_TRITON_FLASH_ATTN=0`")
303
        
304
        else:
305
306
307
            # if not using triton, navi3x/navi21/navi10 do not use flash-attn
            # either
            if torch.cuda.get_device_capability()[0] != 9:
308
309
310
311
312
313
314
315
316
317
                self.use_naive_attn = True
            else:
                try:
                    from flash_attn import flash_attn_varlen_func  # noqa: F401
                    self.attn_func = flash_attn_varlen_func
                    logger.debug("Using CK FA in ROCmBackend")
                except ModuleNotFoundError:
                    self.use_naive_attn = True

            if self.use_naive_attn:
318
                self.attn_func = _sdpa_attention
319
                logger.debug("Using naive attention in ROCmBackend")
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334

    def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
        """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
        tokens, n_kv_heads, head_dim = x.shape
        return (x[:, :,
                  None, :].expand(tokens, n_kv_heads, n_rep,
                                  head_dim).reshape(tokens, n_kv_heads * n_rep,
                                                    head_dim))

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
335
        attn_metadata: ROCmFlashAttentionMetadata,
336
337
        k_scale: float = 1.0,
        v_scale: float = 1.0,
338
        attn_type: AttentionType = AttentionType.DECODER,
339
340
341
342
343
344
345
346
347
348
349
350
    ) -> torch.Tensor:
        """Forward pass with FlashAttention and PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
351
352
353
354
355
356
        if attn_type != AttentionType.DECODER:
            raise NotImplementedError("Encoder self-attention and "
                                      "encoder/decoder cross-attention "
                                      "are not implemented for "
                                      "ROCmFlashAttentionImpl")

357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        num_tokens, hidden_size = query.shape
        # Reshape the query, key, and value tensors.
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

        if kv_cache is not None:
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            PagedAttention.write_to_paged_cache(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
376
                self.kv_cache_dtype,
377
378
                k_scale,
                v_scale,
379
380
            )

381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
        num_prefill_tokens = attn_metadata.num_prefill_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
        assert value.shape[0] == num_prefill_tokens + num_decode_tokens

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]
        key = key[:num_prefill_tokens]
        value = value[:num_prefill_tokens]

        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens

        if prefill_meta := attn_metadata.prefill_metadata:
398
            # Prompt run.
399
            assert prefill_meta.seq_lens is not None
400
            if kv_cache is None or prefill_meta.block_tables.numel() == 0:
401
402
403
                # triton attention
                # When block_tables are not filled, it means q and k are the
                # prompt, and they have the same length.
404
                attn_masks = None
405
                if self.use_triton_flash_attn:
406
407
408
409
410
411
                    if self.alibi_slopes is not None:
                        attn_masks = _make_alibi_bias(
                            self.alibi_slopes,
                            query.dtype,
                            attn_metadata.seq_lens,
                            make_attn_mask=False)  # type: ignore
412
                    if self.use_flash_attn_auto:
413
                        if prefill_meta.max_prefill_seq_len > 8000:
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
                            out = self.attn_func_triton(
                                q=query,
                                k=key,
                                v=value,
                                cu_seqlens_q=prefill_meta.seq_start_loc,
                                cu_seqlens_k=prefill_meta.seq_start_loc,
                                max_seqlens_q=prefill_meta.max_prefill_seq_len,
                                max_seqlens_k=prefill_meta.max_prefill_seq_len,
                                softmax_scale=self.scale,
                                causal=True,
                            )
                        else:
                            out = self.attn_func_ck(
                                q=query,
                                k=key,
                                v=value,
                                cu_seqlens_q=prefill_meta.seq_start_loc,
                                cu_seqlens_k=prefill_meta.seq_start_loc,
                                max_seqlen_q=prefill_meta.max_prefill_seq_len,
                                max_seqlen_k=prefill_meta.max_prefill_seq_len,
                                softmax_scale=self.scale,
                                causal=True,
                            )
                    else:
438
439
440
441
442
443
444
445
446
447
448
                    # out = self.attn_func(
                    #     query,
                    #     key,
                    #     value,
                    #     prefill_meta.seq_lens,
                    #     num_tokens,
                    #     self.num_heads,
                    #     self.head_size,
                    #     self.scale,
                    #     attn_masks,
                    # )
449
                        out = self.attn_func(
450
451
452
453
454
455
456
457
458
459
                            q=query,
                            k=key,
                            v=value,
                            cu_seqlens_q=prefill_meta.seq_start_loc,
                            cu_seqlens_k=prefill_meta.seq_start_loc,
                            max_seqlens_q=prefill_meta.max_prefill_seq_len,
                            max_seqlens_k=prefill_meta.max_prefill_seq_len,
                            softmax_scale=self.scale,
                            causal=True,
                        )
460
                
461
                elif self.use_naive_attn:
462
463
464
465
                    if self.num_kv_heads != self.num_heads:
                        # Interleave for MQA workaround.
                        key = self.repeat_kv(key, self.num_queries_per_kv)
                        value = self.repeat_kv(value, self.num_queries_per_kv)
466
467
468
469
470
471
                    if self.alibi_slopes is not None:
                        attn_masks = _make_alibi_bias(
                            self.alibi_slopes,
                            query.dtype,
                            attn_metadata.seq_lens,
                            make_attn_mask=True)  # type: ignore
472
473
474
475
                    query = query.movedim(0, query.dim() - 2)
                    key = key.movedim(0, key.dim() - 2)
                    value = value.movedim(0, value.dim() - 2)
                    # sdpa math backend attention
476
477
478
479
                    out = self.attn_func(
                        query,
                        key,
                        value,
480
                        prefill_meta.seq_lens,
481
482
483
                        num_tokens,
                        self.num_heads,
                        self.head_size,
484
                        self.scale,
485
                        attn_masks,
486
                    )
487
                else:
488
                    out = self.attn_func(
489
490
491
                        q=query,
                        k=key,
                        v=value,
492
493
                        cu_seqlens_q=prefill_meta.seq_start_loc,
                        cu_seqlens_k=prefill_meta.seq_start_loc,
494
495
                        max_seqlen_q=prefill_meta.max_prefill_seq_len,
                        max_seqlen_k=prefill_meta.max_prefill_seq_len,
496
497
                        softmax_scale=self.scale,
                        causal=True,
498
499
                        # window_size=self.sliding_window,
                        # alibi_slopes=self.alibi_slopes,
500
                    )
501
502
503
504

                # common code for prefill
                assert output[:num_prefill_tokens].shape == out.shape
                output[:num_prefill_tokens] = out
505
506
            else:
                # prefix-enabled attention
507
                output[:num_prefill_tokens] = PagedAttention.forward_prefix(
508
509
510
511
512
                    query,
                    key,
                    value,
                    key_cache,
                    value_cache,
513
                    prefill_meta.block_tables,
514
                    prefill_meta.query_start_loc,
515
516
517
                    prefill_meta.seq_lens_tensor,
                    prefill_meta.context_lens_tensor,
                    prefill_meta.max_query_len,
518
                    self.alibi_slopes,
519
                    self.sliding_window[0],
520
                )
521
522

        if decode_meta := attn_metadata.decode_metadata:
523
            # Decoding run.
524
525
            output[num_prefill_tokens:] = PagedAttention.forward_decode(
                decode_query,
526
527
                key_cache,
                value_cache,
528
                decode_meta.block_tables,
529
                decode_meta.seq_lens_tensor,
530
                decode_meta.max_decode_seq_len,
531
                self.kv_cache_dtype,
532
533
534
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
535
536
                k_scale,
                v_scale,
537
538
539
540
541
542
            )

        # Reshape the output tensor.
        return output.view(num_tokens, hidden_size)


543
def _sdpa_attention(
544
545
546
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
547
    seq_lens: List[int],
548
549
550
    num_tokens: int,
    num_heads: int,
    head_size: int,
551
    scale: float,
552
    attn_masks: Optional[List[torch.Tensor]] = None,
553
554
) -> torch.Tensor:
    start = 0
555
556
557
558
    output = torch.empty((num_tokens, num_heads, head_size),
                         dtype=query.dtype,
                         device=query.device)

559
    for i, seq_len in enumerate(seq_lens):
560
        end = start + seq_len
561
562
563
564
565
566
567
568
        with torch.backends.cuda.sdp_kernel(enable_math=True,
                                            enable_flash=False,
                                            enable_mem_efficient=False):
            sub_out = torch.nn.functional.scaled_dot_product_attention(
                query[:, start:end, :],
                key[:, start:end, :],
                value[:, start:end, :],
                dropout_p=0.0,
569
570
                is_causal=attn_masks is None,
                attn_mask=attn_masks[i] if attn_masks else None,
571
572
573
                scale=scale).movedim(query.dim() - 2, 0)
            output[start:end, :, :] = sub_out
            start = end
574

575
    return output