paged_attn.py 17.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from dataclasses import dataclass
4
from typing import List, Optional, Tuple
5
6
7

import torch

8
from vllm import _custom_ops as ops
9
from vllm.triton_utils import HAS_TRITON
10
import vllm.envs as envs
11
12
13

if HAS_TRITON:
    from vllm.attention.ops.prefix_prefill import context_attention_fwd
14
15
16

# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
17
18
19
gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name
support_tc = gpuname.startswith('K100_AI') or gpuname.startswith('BW')
use_tc = envs.VLLM_USE_OPT_OP and envs.VLLM_USE_TC_PAGED_ATTN and support_tc
20
21
22
23

@dataclass
class PagedAttentionMetadata:
    """Metadata for PagedAttention."""
24
25
26
    # (batch_size,). The length of sequences (entire tokens seen so far) per
    # sequence.
    seq_lens_tensor: Optional[torch.Tensor]
27
28
    # Maximum sequence length in the batch. 0 if it is prefill-only batch.
    max_decode_seq_len: int
29
30
31
32
33
34
35
36
37
38
39
40
41
    # (batch_size, max_blocks_per_seq).
    # Block addresses per sequence. (Seq id -> list of physical block)
    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
    # in the kv cache. Each block can contain up to block_size tokens.
    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
    # captured.
    block_tables: Optional[torch.Tensor]


class PagedAttention:

    @staticmethod
    def get_supported_head_sizes() -> List[int]:
zhuwenwen's avatar
zhuwenwen committed
42
        return [64, 80, 96, 112, 120, 128, 192, 256]
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return (2, num_blocks, block_size * num_kv_heads * head_size)

    @staticmethod
    def split_kv_cache(
        kv_cache: torch.Tensor,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x = 16 // kv_cache.element_size()
        num_blocks = kv_cache.shape[1]

        key_cache = kv_cache[0]
        key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
                                   -1, x)
        value_cache = kv_cache[1]
        value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
        return key_cache, value_cache

    @staticmethod
    def write_to_paged_cache(
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
        kv_cache_dtype: str,
77
78
        k_scale: torch.Tensor,
        v_scale: torch.Tensor,
79
    ) -> None:
80
        ops.reshape_and_cache(
81
82
83
84
85
86
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping.flatten(),
            kv_cache_dtype,
87
88
            k_scale,
            v_scale,
89
90
91
92
93
94
95
96
        )

    @staticmethod
    def forward_decode(
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        block_tables: torch.Tensor,
97
98
        seq_lens: torch.Tensor,
        max_seq_len: int,
99
100
101
102
        kv_cache_dtype: str,
        num_kv_heads: int,
        scale: float,
        alibi_slopes: Optional[torch.Tensor],
103
104
        k_scale: torch.Tensor,
        v_scale: torch.Tensor,
105
106
107
108
109
        tp_rank: int = 0,
        blocksparse_local_blocks: int = 0,
        blocksparse_vert_stride: int = 0,
        blocksparse_block_size: int = 64,
        blocksparse_head_sliding_step: int = 0,
110
111
        attn_masks: Optional[torch.Tensor] = None,
        attn_masks_stride: int = 0
112
    ) -> torch.Tensor:
113
114
115
116
117
118
119
        if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
            # use blocksparse paged attention
            block_size = value_cache.size(-1)
            assert (blocksparse_block_size > 0 and
                    blocksparse_block_size % block_size == 0), \
                (f"{blocksparse_block_size=} needs to be a multiple of"
                 f"{block_size=} used in block_tables.")
120

121
        output = torch.empty_like(query)
122
123
        block_size = value_cache.shape[3]
        num_seqs, num_heads, head_size = query.shape
124
        max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
125
126
127
128
129
130
131
132
                              _PARTITION_SIZE)
        # NOTE(woosuk): We use a simple heuristic to decide whether to use
        # PagedAttention V1 or V2. If the number of partitions is 1, we use
        # V1 to avoid the overhead of reduction. Also, if the number of
        # sequences or heads is large, we use V1 since there is enough work
        # to parallelize.
        # TODO(woosuk): Tune this heuristic.
        # For context len > 8192, use V2 kernel to avoid shared memory shortage.
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        
        
        if use_tc and head_size==128:
            if envs.VLLM_USE_PA_PRINT_PARAM:
                print("PA V1 SIZE:")
                print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
                print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
            if attn_masks is None:
                ops.paged_attention_v1_opt_tc(
                    output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                    tp_rank,
                    blocksparse_local_blocks,
                    blocksparse_vert_stride,
                    blocksparse_block_size,
                    blocksparse_head_sliding_step
                )
            else:
                ops.paged_attention_v1_opt_tc_with_mask(
                    output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                    tp_rank,
                    blocksparse_local_blocks,
                    blocksparse_vert_stride,
                    blocksparse_block_size,
                    blocksparse_head_sliding_step,
                    attn_masks,
                    attn_masks_stride
                )
            return output

        use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512))
189

190
191
        if use_v1:
            # Run PagedAttention V1.
192
193
194
195
196
            if envs.VLLM_USE_PA_PRINT_PARAM:
                print("PA V1 SIZE:")
                print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
                print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")

zhuwenwen's avatar
zhuwenwen committed
197
            if envs.VLLM_USE_OPT_OP:
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
                if attn_masks is None:
                    ops.paged_attention_v1_opt(
                        output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                        tp_rank,
                        blocksparse_local_blocks,
                        blocksparse_vert_stride,
                        blocksparse_block_size,
                        blocksparse_head_sliding_step
                    )
220
                else:
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
                    ops.paged_attention_v1_opt_with_mask(
                        output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                        tp_rank,
                        blocksparse_local_blocks,
                        blocksparse_vert_stride,
                        blocksparse_block_size,
                        blocksparse_head_sliding_step,
                        attn_masks,
                        attn_masks_stride
                    )
244
245
246
            else:
                if attn_masks is None:
                    ops.paged_attention_v1(
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
                        output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                        tp_rank,
                        blocksparse_local_blocks,
                        blocksparse_vert_stride,
                        blocksparse_block_size,
265
                        blocksparse_head_sliding_step
266
267
                    )
                else:
268
                    ops.paged_attention_v1_with_mask(
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
                        output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                        tp_rank,
                        blocksparse_local_blocks,
                        blocksparse_vert_stride,
                        blocksparse_block_size,
                        blocksparse_head_sliding_step,
288
289
                        attn_masks,
                        attn_masks_stride
290
                    )
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        else:
            # Run PagedAttention V2.
            assert _PARTITION_SIZE % block_size == 0
            tmp_output = torch.empty(
                size=(num_seqs, num_heads, max_num_partitions, head_size),
                dtype=output.dtype,
                device=output.device,
            )
            exp_sums = torch.empty(
                size=(num_seqs, num_heads, max_num_partitions),
                dtype=torch.float32,
                device=output.device,
            )
            max_logits = torch.empty_like(exp_sums)
305
306
307
308
309
310
311
            
            if envs.VLLM_USE_PA_PRINT_PARAM:
                print("PA V2 SIZE:")
                print(f"exp_sums.shape = {exp_sums.shape}, max_logits.shape = {max_logits.shape}, tmp_output.shape = {tmp_output.shape}")
                print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
                print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")

312
            if envs.VLLM_USE_OPT_OP:
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
                if attn_masks is None:
                    ops.paged_attention_v2_opt(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                        tp_rank,
                        blocksparse_local_blocks,
                        blocksparse_vert_stride,
                        blocksparse_block_size,
                        blocksparse_head_sliding_step
                    )
338
                else:
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
                    ops.paged_attention_v2_opt_with_mask(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                        tp_rank,
                        blocksparse_local_blocks,
                        blocksparse_vert_stride,
                        blocksparse_block_size,
                        blocksparse_head_sliding_step,
                        attn_masks,
                        attn_masks_stride
                    )
365
366
367
            else:
                if attn_masks is None:
                    ops.paged_attention_v2(
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                        tp_rank,
                        blocksparse_local_blocks,
                        blocksparse_vert_stride,
                        blocksparse_block_size,
389
                        blocksparse_head_sliding_step
390
391
                    )
                else:
392
                    ops.paged_attention_v2_with_mask(
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        k_scale,
                        v_scale,
                        tp_rank,
                        blocksparse_local_blocks,
                        blocksparse_vert_stride,
                        blocksparse_block_size,
                        blocksparse_head_sliding_step,
415
416
                        attn_masks,
                        attn_masks_stride
417
                    )
418
419
420
421
422
423
424
        return output

    @staticmethod
    def forward_prefix(
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
425
        kv_cache_dtype: str,
426
427
428
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        block_tables: torch.Tensor,
429
        query_start_loc: torch.Tensor,
430
431
        seq_lens_tensor: torch.Tensor,
        max_query_len: int,
432
        alibi_slopes: Optional[torch.Tensor],
433
        sliding_window: Optional[int],
434
435
        k_scale: torch.Tensor,
        v_scale: torch.Tensor,
436
437
    ) -> torch.Tensor:
        output = torch.empty_like(query)
438
        max_seq_len = None
439
440
441
442
443
        context_attention_fwd(
            query,
            key,
            value,
            output,
444
            kv_cache_dtype,
445
446
447
            key_cache,
            value_cache,
            block_tables,
448
            # query_start_loc is (batch_size + 1,)
449
            query_start_loc,
450
            seq_lens_tensor,
451
            max_seq_len,
452
            max_query_len,
453
454
            k_scale,
            v_scale,
455
            alibi_slopes,
456
            sliding_window,
457
458
459
460
461
462
463
        )
        return output

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
464
        src_to_dst: torch.Tensor,
465
466
467
    ) -> None:
        src_key_cache = src_kv_cache[0]
        dst_key_cache = dst_kv_cache[0]
468
        ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
469
470
471

        src_value_cache = src_kv_cache[1]
        dst_value_cache = dst_kv_cache[1]
472
        ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
473
474
475
476

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
477
        src_to_dists: torch.Tensor,
478
479
480
    ) -> None:
        key_caches = [kv_cache[0] for kv_cache in kv_caches]
        value_caches = [kv_cache[1] for kv_cache in kv_caches]
481
        ops.copy_blocks(key_caches, value_caches, src_to_dists)