flash_attn.py 11 KB
Newer Older
1
2
3
4
5
6
7
"""Attention layer with Flash and PagedAttention.

NOTE(woosuk): At the moment, this file includes a lot of duplicated code from
XFormers backend. The duplicated code will be removed once we use flash-attn or
flashinfer for all the attention operations.
"""
from dataclasses import dataclass
8
from typing import List, Optional, Tuple, Type
9
10

import torch
11
from vllm_flash_attn import flash_attn_varlen_func
12
13

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
14
15
                                              AttentionMetadata,
                                              AttentionMetadataPerStage)
16
17
from vllm.attention.ops.paged_attn import (PagedAttention,
                                           PagedAttentionMetadata)
18
19
20
21


class FlashAttentionBackend(AttentionBackend):

22
23
24
25
    @staticmethod
    def get_name() -> str:
        return "flash-attn"

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    @staticmethod
    def get_impl_cls() -> Type["FlashAttentionImpl"]:
        return FlashAttentionImpl

    @staticmethod
    def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
        return FlashAttentionMetadata(*args, **kwargs)

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
                                                 num_kv_heads, head_size)

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
48
        src_to_dst: torch.Tensor,
49
50
51
52
53
54
    ) -> None:
        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
55
        src_to_dists: torch.Tensor,
56
57
58
59
60
    ) -> None:
        PagedAttention.copy_blocks(kv_caches, src_to_dists)


@dataclass
61
62
class FlashAttentionMetadata(AttentionMetadataPerStage,
                             PagedAttentionMetadata):
63
64
65
66
67
68
69
70
71
72
    """Metadata for FlashAttentionBackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
    # Currently, input sequences can only contain all prompts
    # or all decoding. True if all sequences are prompts.
    is_prompt: bool
73
74
75
76
77
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]
78

79
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
80
81
82
83
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
84
85
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
86

87
88
89
90
    # Maximum query length in the batch.
    max_query_len: Optional[int]
    # Maximum sequence length in the batch.
    max_seq_len: Optional[int]
91
92
93
94
95
96
97
98
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    subquery_start_loc: Optional[torch.Tensor]
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    seq_start_loc: Optional[torch.Tensor]
99
100
101
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
    context_lens_tensor: Optional[torch.Tensor]
102
103
104
105
106
107
108
109
110
111

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool


class FlashAttentionImpl(AttentionImpl):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
112
113
    |<--------------- num_prefill_tokens ----------------->|	
    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
114
115

    Otherwise, the layout is as follows:	
116
117
    |<----------------- num_decode_tokens ------------------>|	
    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
118
119
120
121
122
123

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
124
125
126
127
128
129
130
131
132

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
    |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
133
134
135
136
137
138
139
140
141
142
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
143
        kv_cache_dtype: str = "auto",
144
145
146
147
148
149
150
151
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
152
153
154
        self.sliding_window = ((sliding_window, sliding_window)
                               if sliding_window is not None else (-1, -1))
        self.kv_cache_dtype = kv_cache_dtype
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
        if head_size not in suppored_head_sizes:
            raise ValueError(
                f"Head size {head_size} is not supported by PagedAttention. "
                f"Supported head sizes are: {suppored_head_sizes}.")

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
171
        attn_metadata: AttentionMetadata[FlashAttentionMetadata],
172
        kv_scale: float = 1.0,
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    ) -> torch.Tensor:
        """Forward pass with FlashAttention and PagedAttention.

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
        num_tokens, hidden_size = query.shape
        # Reshape the query, key, and value tensors.
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

        if kv_cache is not None:
            key_cache, value_cache = PagedAttention.split_kv_cache(
                kv_cache, self.num_kv_heads, self.head_size)

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                value_cache,
                                                attn_metadata.slot_mapping,
201
                                                self.kv_cache_dtype, kv_scale)
202

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        num_prefill_tokens = attn_metadata.num_prefill_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
        assert value.shape[0] == num_prefill_tokens + num_decode_tokens

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]
        key = key[:num_prefill_tokens]
        value = value[:num_prefill_tokens]

        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens

        if prefill_meta := attn_metadata.prefill_metadata:
220
            # Prompt run.
221
            if kv_cache is None or prefill_meta.block_tables.numel() == 0:
222
223
224
                # normal attention
                # When block_tables are not filled, it means q and k are the
                # prompt, and they have the same length.
225
                out = flash_attn_varlen_func(
226
227
228
                    q=query,
                    k=key,
                    v=value,
229
230
                    cu_seqlens_q=prefill_meta.seq_start_loc,
                    cu_seqlens_k=prefill_meta.seq_start_loc,
231
232
                    max_seqlen_q=prefill_meta.max_seq_len,
                    max_seqlen_k=prefill_meta.max_seq_len,
233
234
235
236
237
                    softmax_scale=self.scale,
                    causal=True,
                    window_size=self.sliding_window,
                    alibi_slopes=self.alibi_slopes,
                )
238
239
                assert output[:num_prefill_tokens].shape == out.shape
                output[:num_prefill_tokens] = out
240
241
            else:
                # prefix-enabled attention
242
243
244
                # TODO(Hai) this triton kernel has regression issue (broke) to
                # deal with different data types between KV and FP8 KV cache,
                # to be addressed separately.
245
                output[:num_prefill_tokens] = PagedAttention.forward_prefix(
246
247
248
249
250
                    query,
                    key,
                    value,
                    key_cache,
                    value_cache,
251
252
                    prefill_meta.block_tables,
                    prefill_meta.subquery_start_loc,
253
254
255
                    prefill_meta.seq_lens_tensor,
                    prefill_meta.context_lens_tensor,
                    prefill_meta.max_query_len,
256
                    self.alibi_slopes,
257
                    self.sliding_window[0],
258
                )
259
        if decode_meta := attn_metadata.decode_metadata:
260
            # Decoding run.
261
262
            output[num_prefill_tokens:] = PagedAttention.forward_decode(
                decode_query,
263
264
                key_cache,
                value_cache,
265
                decode_meta.block_tables,
266
267
                decode_meta.seq_lens_tensor,
                decode_meta.max_seq_len,
268
                self.kv_cache_dtype,
269
270
271
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
272
                kv_scale,
273
274
275
276
            )

        # Reshape the output tensor.
        return output.view(num_tokens, hidden_size)