rocm_flash_attn.py 14.5 KB
Newer Older
1
2
"""Attention layer ROCm GPUs."""
from dataclasses import dataclass
3
from typing import List, Optional, Tuple, Type
4
5
6

import torch

7
import vllm.envs as envs
8
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
9
10
                                              AttentionMetadata,
                                              AttentionMetadataPerStage)
11
12
13
14
15
16
17
18
19
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
from vllm.logger import init_logger

logger = init_logger(__name__)


class ROCmFlashAttentionBackend(AttentionBackend):

20
21
22
23
    @staticmethod
    def get_name() -> str:
        return "rocm-flash-attn"

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    @staticmethod
    def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
        return ROCmFlashAttentionImpl

    @staticmethod
    def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
        return ROCmFlashAttentionMetadata(*args, **kwargs)

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
46
        src_to_dst: torch.Tensor,
47
48
49
50
51
52
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
53
        src_to_dists: torch.Tensor,
54
55
56
57
58
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
59
60
class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
                                 PagedAttentionMetadata):
61
62
63
64
65
66
67
68
69
70
    """Metadata for FlashAttentionBackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
    # Currently, input sequences can only contain all prompts
    # or all decoding. True if all sequences are prompts.
    is_prompt: bool
71
72
73
74
75
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]
76

77
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
78
79
80
81
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
82
83
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
84

85
86
87
88
    # Maximum query length in the batch.
    max_query_len: Optional[int]
    # Maximum sequence length in the batch.
    max_seq_len: Optional[int]
89
90
91
92
93
94
95
96
97
98
99
100
101
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    subquery_start_loc: Optional[torch.Tensor]
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    seq_start_loc: Optional[torch.Tensor]

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool
102
103
104
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
    context_lens_tensor: Optional[torch.Tensor]
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121


class ROCmFlashAttentionImpl(AttentionImpl):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|

    Otherwise, the layout is as follows:
    |<------------------ num_generation_tokens (M) ----------------->|
    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
122
123
124
125
126
127
128
129
130

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|	
    |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
131
132
133
134
135
136
137
138
139
140
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
141
        kv_cache_dtype: str = "auto",
142
143
144
145
146
147
148
149
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
150
151
152
        self.sliding_window = ((sliding_window, sliding_window)
                               if sliding_window is not None else (-1, -1))
        self.kv_cache_dtype = kv_cache_dtype
153
154
155
156

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

157
158
        supported_head_sizes = PagedAttention.get_supported_head_sizes()
        if head_size not in supported_head_sizes:
159
160
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
161
                f"Supported head sizes are: {supported_head_sizes}.")
162

163
        self.use_naive_attn = False
164
        # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
165
        self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
166
        if self.use_triton_flash_attn:
167
168
169
170
171
            from vllm.attention.ops.triton_flash_attention import (  # noqa: F401
                triton_attention)
            self.attn_func = triton_attention
            logger.debug("Using Triton FA in ROCmBackend")
        else:
172
173
174
175
176
177
178
179
180
181
182
183
184
185
            # if not using triton, navi3x not use flash-attn either
            if torch.cuda.get_device_capability()[0] == 11:
                self.use_naive_attn = True
            else:
                try:
                    from flash_attn import flash_attn_varlen_func  # noqa: F401
                    self.attn_func = flash_attn_varlen_func
                    logger.debug("Using CK FA in ROCmBackend")
                except ModuleNotFoundError:
                    self.use_naive_attn = True

            if self.use_naive_attn:
                self.attn_func = _naive_attention
                logger.debug("Using naive attention in ROCmBackend")
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

    def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
        """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
        tokens, n_kv_heads, head_dim = x.shape
        return (x[:, :,
                  None, :].expand(tokens, n_kv_heads, n_rep,
                                  head_dim).reshape(tokens, n_kv_heads * n_rep,
                                                    head_dim))

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
201
        attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        kv_scale: float = 1.0,
    ) -> torch.Tensor:
        """Forward pass with FlashAttention and PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        num_tokens, hidden_size = query.shape
        # Reshape the query, key, and value tensors.
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

        if kv_cache is not None:
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            PagedAttention.write_to_paged_cache(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping,
234
                self.kv_cache_dtype,
235
236
237
                kv_scale,
            )

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        num_prefill_tokens = attn_metadata.num_prefill_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
        assert value.shape[0] == num_prefill_tokens + num_decode_tokens

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]
        key = key[:num_prefill_tokens]
        value = value[:num_prefill_tokens]

        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens

        if prefill_meta := attn_metadata.prefill_metadata:
255
            # Prompt run.
256
            assert prefill_meta.seq_lens is not None
257
            if kv_cache is None or prefill_meta.block_tables.numel() == 0:
258
259
260
                # triton attention
                # When block_tables are not filled, it means q and k are the
                # prompt, and they have the same length.
261
262
263
264
265
266
267
268
                if self.use_triton_flash_attn:
                    out, _ = self.attn_func(
                        query,
                        key,
                        value,
                        None,
                        prefill_meta.seq_start_loc,
                        prefill_meta.seq_start_loc,
269
270
                        prefill_meta.max_seq_len,
                        prefill_meta.max_seq_len,
271
272
273
274
                        True,
                        self.scale,
                    )
                elif self.use_naive_attn:
275
276
277
278
                    if self.num_kv_heads != self.num_heads:
                        # Interleave for MQA workaround.
                        key = self.repeat_kv(key, self.num_queries_per_kv)
                        value = self.repeat_kv(value, self.num_queries_per_kv)
279
280
281
282
                    out = self.attn_func(
                        query,
                        key,
                        value,
283
                        prefill_meta.seq_lens,
284
285
                        self.scale,
                    )
286
                else:
287
                    out = self.attn_func(
288
289
290
                        q=query,
                        k=key,
                        v=value,
291
292
                        cu_seqlens_q=prefill_meta.seq_start_loc,
                        cu_seqlens_k=prefill_meta.seq_start_loc,
293
294
                        max_seqlen_q=prefill_meta.max_seq_len,
                        max_seqlen_k=prefill_meta.max_seq_len,
295
296
297
                        softmax_scale=self.scale,
                        causal=True,
                    )
298
299
300
301

                # common code for prefill
                assert output[:num_prefill_tokens].shape == out.shape
                output[:num_prefill_tokens] = out
302
303
            else:
                # prefix-enabled attention
304
                output[:num_prefill_tokens] = PagedAttention.forward_prefix(
305
306
307
308
309
                    query,
                    key,
                    value,
                    key_cache,
                    value_cache,
310
311
                    prefill_meta.block_tables,
                    prefill_meta.subquery_start_loc,
312
313
314
                    prefill_meta.seq_lens_tensor,
                    prefill_meta.context_lens_tensor,
                    prefill_meta.max_query_len,
315
                    self.alibi_slopes,
316
                    self.sliding_window[0],
317
                )
318
319

        if decode_meta := attn_metadata.decode_metadata:
320
            # Decoding run.
321
322
            output[num_prefill_tokens:] = PagedAttention.forward_decode(
                decode_query,
323
324
                key_cache,
                value_cache,
325
                decode_meta.block_tables,
326
327
                decode_meta.seq_lens_tensor,
                decode_meta.max_seq_len,
328
                self.kv_cache_dtype,
329
330
331
332
333
334
335
336
337
338
339
340
341
342
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
                kv_scale,
            )

        # Reshape the output tensor.
        return output.view(num_tokens, hidden_size)


def _naive_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
343
    seq_lens: List[int],
344
345
346
347
    scale: float,
) -> torch.Tensor:
    output = torch.empty_like(query)
    start = 0
348
349
    for _, seq_len in enumerate(seq_lens):
        end = start + seq_len
350
        out = _naive_masked_attention(
351
352
353
            query[start:end],
            key[start:end],
            value[start:end],
354
355
356
357
            scale,
        )
        # TODO(woosuk): Unnecessary copy. Optimize.
        output[start:end].copy_(out)
358
        start += seq_len
359

360
    return output
361
362
363
364
365
366
367
368


def _naive_masked_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    scale: float,
) -> torch.Tensor:
369
    seq_len, head_size, head_dim = query.shape
370
371
372
373
374
375
376
377
378
379
380
    attn_mask = torch.triu(torch.ones(seq_len,
                                      seq_len,
                                      dtype=query.dtype,
                                      device=query.device),
                           diagonal=1)
    attn_mask = attn_mask * torch.finfo(query.dtype).min
    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
    attn_weights = attn_weights + attn_mask.float()
    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
    return out