rocm_flash_attn.py 17.3 KB
Newer Older
1
2
"""Attention layer ROCm GPUs."""
from dataclasses import dataclass
3
from typing import Any, Dict, List, Optional, Tuple, Type
4
5
6

import torch

7
import vllm.envs as envs
8
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
9
                                              AttentionMetadata)
10
11
12
13
14
15
16
17
18
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
from vllm.logger import init_logger

logger = init_logger(__name__)


class ROCmFlashAttentionBackend(AttentionBackend):

19
20
21
22
    @staticmethod
    def get_name() -> str:
        return "rocm-flash-attn"

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    @staticmethod
    def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
        return ROCmFlashAttentionImpl

    @staticmethod
    def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
        return ROCmFlashAttentionMetadata(*args, **kwargs)

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
45
        src_to_dst: torch.Tensor,
46
47
48
49
50
51
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
52
        src_to_dists: torch.Tensor,
53
54
55
56
57
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
58
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
59
60
61
62
63
64
65
    """Metadata for FlashAttentionBackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
66
67
68
69
70
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]
71

72
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
73
74
75
76
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
77
78
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
79

80
    # Maximum query length in the batch. None for decoding.
81
    max_query_len: Optional[int]
82
83
84
85
86
87
    # Maximum sequence length among prefill batch. 0 if there are decoding
    # requests only.
    max_prefill_seq_len: int
    # Maximum sequence length among decode batch. 0 if there are prefill
    # requests only.
    max_decode_seq_len: int
88
89
90
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
91
    query_start_loc: Optional[torch.Tensor]
92
93
94
95
96
97
98
99
100
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    seq_start_loc: Optional[torch.Tensor]

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool
101
102
103
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
    context_lens_tensor: Optional[torch.Tensor]
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
    _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None

    @property
    def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
        if self.num_prefills == 0:
            return None

        if self._cached_prefill_metadata is not None:
            return self._cached_prefill_metadata

        assert self.seq_lens is not None
        assert self.seq_lens_tensor is not None
        assert self.query_start_loc is not None
        assert self.context_lens_tensor is not None
        assert self.block_tables is not None
        assert self.seq_start_loc is not None

        self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
            num_prefills=self.num_prefills,
            num_prefill_tokens=self.num_prefill_tokens,
            num_decode_tokens=0,
            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
            seq_lens=self.seq_lens[:self.num_prefills],
            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
            max_query_len=self.max_query_len,
            max_prefill_seq_len=self.max_prefill_seq_len,
            max_decode_seq_len=0,
            query_start_loc=self.query_start_loc[:self.num_prefills + 1],
            seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
            context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
            block_tables=self.block_tables[:self.num_prefills],
            use_cuda_graph=False,
        )
        return self._cached_prefill_metadata

    @property
    def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
        if self.num_decode_tokens == 0:
            return None

        if self._cached_decode_metadata is not None:
            return self._cached_decode_metadata
        assert self.block_tables is not None
        assert self.seq_lens_tensor is not None

        self._cached_decode_metadata = ROCmFlashAttentionMetadata(
            num_prefills=0,
            num_prefill_tokens=0,
            num_decode_tokens=self.num_decode_tokens,
            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
            seq_lens=None,
            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
            max_query_len=None,
            max_prefill_seq_len=0,
            max_decode_seq_len=self.max_decode_seq_len,
            query_start_loc=None,
            seq_start_loc=None,
            context_lens_tensor=None,
            block_tables=self.block_tables[self.num_prefills:],
            use_cuda_graph=self.use_cuda_graph,
        )
        return self._cached_decode_metadata
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183


class ROCmFlashAttentionImpl(AttentionImpl):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|

    Otherwise, the layout is as follows:
    |<------------------ num_generation_tokens (M) ----------------->|
    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
184
185
186
187
188
189
190
191
192

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|	
    |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
193
194
195
196
197
198
199
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
200
201
202
203
        num_kv_heads: int,
        alibi_slopes: Optional[List[float]],
        sliding_window: Optional[int],
        kv_cache_dtype: str,
204
        blocksparse_params: Optional[Dict[str, Any]] = None,
205
    ) -> None:
206
207
        assert blocksparse_params is None, ValueError(
            "ROCFlashAttention does not support blocksparse attention.")
208
209
210
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
211
        self.num_kv_heads = num_kv_heads
212
213
214
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
215
216
217
        self.sliding_window = ((sliding_window, sliding_window)
                               if sliding_window is not None else (-1, -1))
        self.kv_cache_dtype = kv_cache_dtype
218
219
220
221

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

222
223
        supported_head_sizes = PagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
224
225
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
226
                f"Supported head sizes are: {supported_head_sizes}.")
227

228
        self.use_naive_attn = False
229
        # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
230
        self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
231
        if self.use_triton_flash_attn:
232
233
234
235
236
            from vllm.attention.ops.triton_flash_attention import (  # noqa: F401
                triton_attention)
            self.attn_func = triton_attention
            logger.debug("Using Triton FA in ROCmBackend")
        else:
237
238
239
            # if not using triton, navi3x/navi21/navi10 do not use flash-attn
            # either
            if torch.cuda.get_device_capability()[0] != 9:
240
241
242
243
244
245
246
247
248
249
                self.use_naive_attn = True
            else:
                try:
                    from flash_attn import flash_attn_varlen_func  # noqa: F401
                    self.attn_func = flash_attn_varlen_func
                    logger.debug("Using CK FA in ROCmBackend")
                except ModuleNotFoundError:
                    self.use_naive_attn = True

            if self.use_naive_attn:
250
                self.attn_func = _sdpa_attention
251
                logger.debug("Using naive attention in ROCmBackend")
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

    def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
        """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
        tokens, n_kv_heads, head_dim = x.shape
        return (x[:, :,
                  None, :].expand(tokens, n_kv_heads, n_rep,
                                  head_dim).reshape(tokens, n_kv_heads * n_rep,
                                                    head_dim))

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
267
        attn_metadata: ROCmFlashAttentionMetadata,
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        kv_scale: float = 1.0,
    ) -> torch.Tensor:
        """Forward pass with FlashAttention and PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        num_tokens, hidden_size = query.shape
        # Reshape the query, key, and value tensors.
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

        if kv_cache is not None:
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            PagedAttention.write_to_paged_cache(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
300
                self.kv_cache_dtype,
301
302
303
                kv_scale,
            )

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        num_prefill_tokens = attn_metadata.num_prefill_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
        assert value.shape[0] == num_prefill_tokens + num_decode_tokens

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]
        key = key[:num_prefill_tokens]
        value = value[:num_prefill_tokens]

        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens

        if prefill_meta := attn_metadata.prefill_metadata:
321
            # Prompt run.
322
            assert prefill_meta.seq_lens is not None
323
            if kv_cache is None or prefill_meta.block_tables.numel() == 0:
324
325
326
                # triton attention
                # When block_tables are not filled, it means q and k are the
                # prompt, and they have the same length.
327
328
329
330
331
332
333
334
                if self.use_triton_flash_attn:
                    out, _ = self.attn_func(
                        query,
                        key,
                        value,
                        None,
                        prefill_meta.seq_start_loc,
                        prefill_meta.seq_start_loc,
335
336
                        prefill_meta.max_prefill_seq_len,
                        prefill_meta.max_prefill_seq_len,
337
338
339
340
                        True,
                        self.scale,
                    )
                elif self.use_naive_attn:
341
342
343
344
                    if self.num_kv_heads != self.num_heads:
                        # Interleave for MQA workaround.
                        key = self.repeat_kv(key, self.num_queries_per_kv)
                        value = self.repeat_kv(value, self.num_queries_per_kv)
345
346
347
348
                    query = query.movedim(0, query.dim() - 2)
                    key = key.movedim(0, key.dim() - 2)
                    value = value.movedim(0, value.dim() - 2)
                    # sdpa math backend attention
349
350
351
352
                    out = self.attn_func(
                        query,
                        key,
                        value,
353
                        prefill_meta.seq_lens,
354
355
356
                        num_tokens,
                        self.num_heads,
                        self.head_size,
357
358
                        self.scale,
                    )
359
                else:
360
                    out = self.attn_func(
361
362
363
                        q=query,
                        k=key,
                        v=value,
364
365
                        cu_seqlens_q=prefill_meta.seq_start_loc,
                        cu_seqlens_k=prefill_meta.seq_start_loc,
366
367
                        max_seqlen_q=prefill_meta.max_prefill_seq_len,
                        max_seqlen_k=prefill_meta.max_prefill_seq_len,
368
369
370
                        softmax_scale=self.scale,
                        causal=True,
                    )
371
372
373
374

                # common code for prefill
                assert output[:num_prefill_tokens].shape == out.shape
                output[:num_prefill_tokens] = out
375
376
            else:
                # prefix-enabled attention
377
                output[:num_prefill_tokens] = PagedAttention.forward_prefix(
378
379
380
381
382
                    query,
                    key,
                    value,
                    key_cache,
                    value_cache,
383
                    prefill_meta.block_tables,
384
                    prefill_meta.query_start_loc,
385
386
387
                    prefill_meta.seq_lens_tensor,
                    prefill_meta.context_lens_tensor,
                    prefill_meta.max_query_len,
388
                    self.alibi_slopes,
389
                    self.sliding_window[0],
390
                )
391
392

        if decode_meta := attn_metadata.decode_metadata:
393
            # Decoding run.
394
395
            output[num_prefill_tokens:] = PagedAttention.forward_decode(
                decode_query,
396
397
                key_cache,
                value_cache,
398
                decode_meta.block_tables,
399
                decode_meta.seq_lens_tensor,
400
                decode_meta.max_decode_seq_len,
401
                self.kv_cache_dtype,
402
403
404
405
406
407
408
409
410
411
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
                kv_scale,
            )

        # Reshape the output tensor.
        return output.view(num_tokens, hidden_size)


412
def _sdpa_attention(
413
414
415
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
416
    seq_lens: List[int],
417
418
419
    num_tokens: int,
    num_heads: int,
    head_size: int,
420
421
422
    scale: float,
) -> torch.Tensor:
    start = 0
423
424
425
426
427
    output = torch.empty((num_tokens, num_heads, head_size),
                         dtype=query.dtype,
                         device=query.device)

    for seq_len in seq_lens:
428
        end = start + seq_len
429
430
431
432
433
434
435
436
437
438
439
440
        with torch.backends.cuda.sdp_kernel(enable_math=True,
                                            enable_flash=False,
                                            enable_mem_efficient=False):
            sub_out = torch.nn.functional.scaled_dot_product_attention(
                query[:, start:end, :],
                key[:, start:end, :],
                value[:, start:end, :],
                dropout_p=0.0,
                is_causal=True,
                scale=scale).movedim(query.dim() - 2, 0)
            output[start:end, :, :] = sub_out
            start = end
441

442
    return output