paged_attn.py 13.1 KB
Newer Older
1
from dataclasses import dataclass
2
from typing import List, Optional, Tuple
3
4
5

import torch

6
from vllm import _custom_ops as ops
7
from vllm.triton_utils import HAS_TRITON
8
import vllm.envs as envs
9
10
11

if HAS_TRITON:
    from vllm.attention.ops.prefix_prefill import context_attention_fwd
12
13
14
15
16
17
18
19

# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512


@dataclass
class PagedAttentionMetadata:
    """Metadata for PagedAttention."""
20
21
22
    # (batch_size,). The length of sequences (entire tokens seen so far) per
    # sequence.
    seq_lens_tensor: Optional[torch.Tensor]
23
24
    # Maximum sequence length in the batch. 0 if it is prefill-only batch.
    max_decode_seq_len: int
25
26
27
28
29
30
31
32
33
34
35
36
37
    # (batch_size, max_blocks_per_seq).
    # Block addresses per sequence. (Seq id -> list of physical block)
    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
    # in the kv cache. Each block can contain up to block_size tokens.
    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
    # captured.
    block_tables: Optional[torch.Tensor]


class PagedAttention:

    @staticmethod
    def get_supported_head_sizes() -> List[int]:
zhuwenwen's avatar
zhuwenwen committed
38
        return [64, 80, 96, 112, 120, 128, 160, 192, 256]
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return (2, num_blocks, block_size * num_kv_heads * head_size)

    @staticmethod
    def split_kv_cache(
        kv_cache: torch.Tensor,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x = 16 // kv_cache.element_size()
        num_blocks = kv_cache.shape[1]

        key_cache = kv_cache[0]
        key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
                                   -1, x)
        value_cache = kv_cache[1]
        value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
        return key_cache, value_cache

    @staticmethod
    def write_to_paged_cache(
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
        kv_cache_dtype: str,
73
74
        k_scale: float,
        v_scale: float,
75
    ) -> None:
76
        ops.reshape_and_cache(
77
78
79
80
81
82
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping.flatten(),
            kv_cache_dtype,
83
84
            k_scale,
            v_scale,
85
86
87
88
89
90
91
92
        )

    @staticmethod
    def forward_decode(
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        block_tables: torch.Tensor,
93
94
        seq_lens: torch.Tensor,
        max_seq_len: int,
95
96
97
98
        kv_cache_dtype: str,
        num_kv_heads: int,
        scale: float,
        alibi_slopes: Optional[torch.Tensor],
99
100
        k_scale: float,
        v_scale: float,
101
102
103
104
105
        tp_rank: int = 0,
        blocksparse_local_blocks: int = 0,
        blocksparse_vert_stride: int = 0,
        blocksparse_block_size: int = 64,
        blocksparse_head_sliding_step: int = 0,
106
    ) -> torch.Tensor:
107
108
109
110
111
112
113
        if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
            # use blocksparse paged attention
            block_size = value_cache.size(-1)
            assert (blocksparse_block_size > 0 and
                    blocksparse_block_size % block_size == 0), \
                (f"{blocksparse_block_size=} needs to be a multiple of"
                 f"{block_size=} used in block_tables.")
114

115
        output = torch.empty_like(query)
116
117
        block_size = value_cache.shape[3]
        num_seqs, num_heads, head_size = query.shape
118
        max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
119
120
121
122
123
124
125
126
                              _PARTITION_SIZE)
        # NOTE(woosuk): We use a simple heuristic to decide whether to use
        # PagedAttention V1 or V2. If the number of partitions is 1, we use
        # V1 to avoid the overhead of reduction. Also, if the number of
        # sequences or heads is large, we use V1 since there is enough work
        # to parallelize.
        # TODO(woosuk): Tune this heuristic.
        # For context len > 8192, use V2 kernel to avoid shared memory shortage.
127
128
129
130
131
132
        if envs.VLLM_USE_TC_PAGED_ATTN:
            use_v1 = (max_seq_len < 8192
                    and (max_seq_len<(1024 if num_kv_heads == num_heads else 600) or num_seqs * num_heads > (1024 if num_kv_heads < num_heads else 512)))
        else:
            use_v1 = (max_seq_len <= 8192
                  and (max_num_partitions == 1 or num_seqs * num_heads > 512))
133

134
135
        if use_v1:
            # Run PagedAttention V1.
136
137
138
139
140
            if envs.VLLM_USE_PA_PRINT_PARAM:
                print("PA V1 SIZE:")
                print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
                print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")

zhuwenwen's avatar
zhuwenwen committed
141
            if envs.VLLM_USE_OPT_OP:
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
                if envs.VLLM_USE_TC_PAGED_ATTN:
                    ops.paged_attention_v1_opt_tc(
                        output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                        tp_rank,
                        blocksparse_local_blocks,
                        blocksparse_vert_stride,
                        blocksparse_block_size,
                        blocksparse_head_sliding_step,
                    )
                else:
                    ops.paged_attention_v1_opt(
                        output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                        tp_rank,
                        blocksparse_local_blocks,
                        blocksparse_vert_stride,
                        blocksparse_block_size,
                        blocksparse_head_sliding_step,
                    )
zhuwenwen's avatar
zhuwenwen committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            else:
                ops.paged_attention_v1(
                    output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                    tp_rank,
                    blocksparse_local_blocks,
                    blocksparse_vert_stride,
                    blocksparse_block_size,
                    blocksparse_head_sliding_step,
                )
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        else:
            # Run PagedAttention V2.
            assert _PARTITION_SIZE % block_size == 0
            tmp_output = torch.empty(
                size=(num_seqs, num_heads, max_num_partitions, head_size),
                dtype=output.dtype,
                device=output.device,
            )
            exp_sums = torch.empty(
                size=(num_seqs, num_heads, max_num_partitions),
                dtype=torch.float32,
                device=output.device,
            )
            max_logits = torch.empty_like(exp_sums)
222
223
224
225
226
227
228
            
            if envs.VLLM_USE_PA_PRINT_PARAM:
                print("PA V2 SIZE:")
                print(f"exp_sums.shape = {exp_sums.shape}, max_logits.shape = {max_logits.shape}, tmp_output.shape = {tmp_output.shape}")
                print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
                print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")

229
            if envs.VLLM_USE_OPT_OP:
zhangshao's avatar
zhangshao committed
230
                if envs.VLLM_USE_TC_PAGED_ATTN:
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
                    ops.paged_attention_v2_opt_tc(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                        tp_rank,
                        blocksparse_local_blocks,
                        blocksparse_vert_stride,
                        blocksparse_block_size,
                        blocksparse_head_sliding_step,
                    )
                else:
                    ops.paged_attention_v2_opt(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                        tp_rank,
                        blocksparse_local_blocks,
                        blocksparse_vert_stride,
                        blocksparse_block_size,
                        blocksparse_head_sliding_step,
                    )
zhuwenwen's avatar
zhuwenwen committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
            else:
                ops.paged_attention_v2(
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                    tp_rank,
                    blocksparse_local_blocks,
                    blocksparse_vert_stride,
                    blocksparse_block_size,
                    blocksparse_head_sliding_step,
                )
305
306
307
308
309
310
311
        return output

    @staticmethod
    def forward_prefix(
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
312
        kv_cache_dtype: str,
313
314
315
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        block_tables: torch.Tensor,
316
        query_start_loc: torch.Tensor,
317
        seq_lens_tensor: torch.Tensor,
318
        context_lens: torch.Tensor,
319
        max_query_len: int,
320
        alibi_slopes: Optional[torch.Tensor],
321
        sliding_window: Optional[int],
322
323
        k_scale: float,
        v_scale: float,
324
325
326
327
328
329
330
    ) -> torch.Tensor:
        output = torch.empty_like(query)
        context_attention_fwd(
            query,
            key,
            value,
            output,
331
            kv_cache_dtype,
332
333
334
            key_cache,
            value_cache,
            block_tables,
335
336
            # query_start_loc is (batch_size + 1,)
            query_start_loc[:-1],
337
            seq_lens_tensor,
338
            context_lens,
339
            max_query_len,
340
341
            k_scale,
            v_scale,
342
            alibi_slopes,
343
            sliding_window,
344
345
346
347
348
349
350
        )
        return output

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
351
        src_to_dst: torch.Tensor,
352
353
354
    ) -> None:
        src_key_cache = src_kv_cache[0]
        dst_key_cache = dst_kv_cache[0]
355
        ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
356
357
358

        src_value_cache = src_kv_cache[1]
        dst_value_cache = dst_kv_cache[1]
359
        ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
360
361
362
363

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
364
        src_to_dists: torch.Tensor,
365
366
367
    ) -> None:
        key_caches = [kv_cache[0] for kv_cache in kv_caches]
        value_caches = [kv_cache[1] for kv_cache in kv_caches]
368
        ops.copy_blocks(key_caches, value_caches, src_to_dists)