flash_attn.py 11.6 KB
Newer Older
1
"""Attention layer with FlashAttention."""
2
from dataclasses import dataclass
3
from typing import List, Optional, Tuple, Type
4
5

import torch
6
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
7

8
from vllm._C import cache_ops
9
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
10
11
                                              AttentionMetadata,
                                              AttentionMetadataPerStage)
12
13

_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256]
14
15
16
17


class FlashAttentionBackend(AttentionBackend):

18
19
20
21
    @staticmethod
    def get_name() -> str:
        return "flash-attn"

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    @staticmethod
    def get_impl_cls() -> Type["FlashAttentionImpl"]:
        return FlashAttentionImpl

    @staticmethod
    def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
        return FlashAttentionMetadata(*args, **kwargs)

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
37
38
39
        if block_size % 16 != 0:
            raise ValueError("Block size must be a multiple of 16.")
        return (2, num_blocks, block_size, num_kv_heads, head_size)
40
41
42
43
44

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
45
        src_to_dst: torch.Tensor,
46
    ) -> None:
47
48
49
50
51
52
53
        src_key_cache = src_kv_cache[0]
        dst_key_cache = dst_kv_cache[0]
        cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)

        src_value_cache = src_kv_cache[1]
        dst_value_cache = dst_kv_cache[1]
        cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
54
55
56
57

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
58
        src_to_dists: torch.Tensor,
59
    ) -> None:
60
61
62
        key_caches = [kv_cache[0] for kv_cache in kv_caches]
        value_caches = [kv_cache[1] for kv_cache in kv_caches]
        cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
63
64
65


@dataclass
66
class FlashAttentionMetadata(AttentionMetadataPerStage):
67
68
69
70
71
72
73
74
75
76
    """Metadata for FlashAttentionBackend.

    NOTE: Any python object stored here is not updated when it is
    cuda-graph replayed. If you have values that need to be changed
    dynamically, it should be stored in tensor. The tensor has to be
    updated from `CUDAGraphRunner.forward` API.
    """
    # Currently, input sequences can only contain all prompts
    # or all decoding. True if all sequences are prompts.
    is_prompt: bool
77
78
79
80
81
    # (batch_size,). The sequence length per sequence. Sequence length means
    # the computed tokens + new tokens None if it is a decoding.
    seq_lens: Optional[List[int]]
    # seq_lens stored as a tensor.
    seq_lens_tensor: Optional[torch.Tensor]
82

83
    # NOTE(sang): Definition of context_len, query_len, and seq_len.
84
85
86
87
    # |---------- N-1 iteration --------|
    # |---------------- N iteration ---------------------|
    # |- tokenA -|......................|-- newTokens ---|
    # |---------- context_len ----------|
88
89
    # |-------------------- seq_len ----------------------|
    #                                   |-- query_len ---|
90

91
92
93
94
    # Maximum query length in the batch.
    max_query_len: Optional[int]
    # Maximum sequence length in the batch.
    max_seq_len: Optional[int]
95
96
97
98
99
100
101
102
    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
    # the batch, used to index into subquery. E.g., if the subquery length
    # is [4, 6], it is [0, 4, 10].
    subquery_start_loc: Optional[torch.Tensor]
    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
    # the batch, used to index into sequence. E.g., if the sequence length is
    # [4, 6], it is [0, 4, 10].
    seq_start_loc: Optional[torch.Tensor]
103
104
105
    # (batch_size,) A tensor of context lengths (tokens that are computed
    # so far).
    context_lens_tensor: Optional[torch.Tensor]
106
107
108
109
110
111

    # Whether or not if cuda graph is enabled.
    # Cuda-graph is currently enabled for decoding only.
    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
    use_cuda_graph: bool

112
113
114
115
116
117
118
119
    # (batch_size, max_blocks_per_seq).
    # Block addresses per sequence. (Seq id -> list of physical block)
    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
    # in the kv cache. Each block can contain up to block_size tokens.
    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
    # captured.
    block_tables: Optional[torch.Tensor]

120
121
122
123

class FlashAttentionImpl(AttentionImpl):
    """
    If the input tensors contain prompt tokens, the layout is as follows:
124
125
    |<--------------- num_prefill_tokens ----------------->|	
    |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
126
127

    Otherwise, the layout is as follows:	
128
129
    |<----------------- num_decode_tokens ------------------>|	
    |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
130
131
132
133
134
135

    Generation tokens can contain padding when cuda-graph is used.
    Currently, prompt tokens don't contain any padding.

    The prompts might have different lengths, while the generation tokens
    always have length 1.
136
137
138
139
140
141
142
143
144

    If chunked prefill is enabled, prefill tokens and decode tokens can be
    batched together in a flattened 1D query.

    |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
    |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|

    Currently, cuda graph is disabled for chunked prefill, meaning there's no
    padding between prefill and decode tokens.
145
146
147
148
149
150
151
152
153
154
    """

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
155
        kv_cache_dtype: str = "auto",
156
157
158
159
160
161
162
163
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.alibi_slopes = alibi_slopes
164
165
166
        self.sliding_window = ((sliding_window, sliding_window)
                               if sliding_window is not None else (-1, -1))
        self.kv_cache_dtype = kv_cache_dtype
167
168
169
170

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads

171
172
173
174
175
176
        if sliding_window is not None:
            # NOTE(woosuk): flash-attn's sliding window does not work with
            # paged KV cache.
            raise ValueError(
                "Sliding window is not supported in FlashAttention.")
        if head_size not in _SUPPORTED_HEAD_SIZES:
177
            raise ValueError(
178
179
                f"Head size {head_size} is not supported by FlashAttention. "
                f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.")
180
181
182
183
184
185
186

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
187
        attn_metadata: AttentionMetadata[FlashAttentionMetadata],
188
        kv_scale: float = 1.0,
189
    ) -> torch.Tensor:
190
        """Forward pass with FlashAttention.
191
192
193
194
195

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
196
            kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
197
198
199
200
            attn_metadata: Metadata for attention.
        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
201
202
203
        # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
        assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention."

204
205
206
207
208
209
210
        num_tokens, hidden_size = query.shape
        # Reshape the query, key, and value tensors.
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

        if kv_cache is not None:
211
212
            key_cache = kv_cache[0]
            value_cache = kv_cache[1]
213
214
215
216

            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
217
218
219
220
221
222
223
224
            cache_ops.reshape_and_cache_flash(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping.flatten(),
                self.kv_cache_dtype,
            )
225

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        num_prefill_tokens = attn_metadata.num_prefill_tokens
        num_decode_tokens = attn_metadata.num_decode_tokens
        assert key.shape[0] == num_prefill_tokens + num_decode_tokens
        assert value.shape[0] == num_prefill_tokens + num_decode_tokens

        output = torch.empty_like(query)
        # Query for decode. KV is not needed because it is already cached.
        decode_query = query[num_prefill_tokens:]
        # QKV for prefill.
        query = query[:num_prefill_tokens]
        key = key[:num_prefill_tokens]
        value = value[:num_prefill_tokens]

        assert query.shape[0] == num_prefill_tokens
        assert decode_query.shape[0] == num_decode_tokens

        if prefill_meta := attn_metadata.prefill_metadata:
243
            # Prompt run.
244
245
            if (kv_cache is None or prefill_meta.block_tables is None
                    or prefill_meta.block_tables.numel() == 0):
246
247
248
                # normal attention
                # When block_tables are not filled, it means q and k are the
                # prompt, and they have the same length.
249
                out = flash_attn_varlen_func(
250
251
252
                    q=query,
                    k=key,
                    v=value,
253
254
                    cu_seqlens_q=prefill_meta.seq_start_loc,
                    cu_seqlens_k=prefill_meta.seq_start_loc,
255
256
                    max_seqlen_q=prefill_meta.max_seq_len,
                    max_seqlen_k=prefill_meta.max_seq_len,
257
258
259
260
261
                    softmax_scale=self.scale,
                    causal=True,
                    window_size=self.sliding_window,
                    alibi_slopes=self.alibi_slopes,
                )
262
263
                assert output[:num_prefill_tokens].shape == out.shape
                output[:num_prefill_tokens] = out
264
265
            else:
                # prefix-enabled attention
266
267
268
269
270
271
272
273
274
275
276
277
                output[:num_prefill_tokens] = flash_attn_varlen_func(
                    q=query,
                    k=key_cache,
                    v=value_cache,
                    cu_seqlens_q=prefill_meta.subquery_start_loc,
                    max_seqlen_q=prefill_meta.max_query_len,
                    cu_seqlens_k=prefill_meta.seq_start_loc,
                    max_seqlen_k=prefill_meta.max_seq_len,
                    softmax_scale=self.scale,
                    causal=True,
                    alibi_slopes=self.alibi_slopes,
                    block_table=prefill_meta.block_tables,
278
                )
279

280
        if decode_meta := attn_metadata.decode_metadata:
281
            # Decoding run.
282
283
            output[num_prefill_tokens:] = flash_attn_with_kvcache(
                decode_query.unsqueeze(1),
284
285
                key_cache,
                value_cache,
286
287
288
289
290
291
                block_table=decode_meta.block_tables,
                cache_seqlens=decode_meta.seq_lens_tensor,
                softmax_scale=self.scale,
                causal=True,
                alibi_slopes=self.alibi_slopes,
            ).squeeze(1)
292
293
294

        # Reshape the output tensor.
        return output.view(num_tokens, hidden_size)