paged_attn.py 10.8 KB
Newer Older
1
from dataclasses import dataclass
2
from typing import List, Optional, Tuple
3
4
5

import torch

6
from vllm import _custom_ops as ops
7
from vllm.triton_utils import HAS_TRITON
8
import vllm.envs as envs
9
10
11

if HAS_TRITON:
    from vllm.attention.ops.prefix_prefill import context_attention_fwd
12
13
14
15
16
17
18
19

# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512


@dataclass
class PagedAttentionMetadata:
    """Metadata for PagedAttention."""
20
21
22
    # (batch_size,). The length of sequences (entire tokens seen so far) per
    # sequence.
    seq_lens_tensor: Optional[torch.Tensor]
23
24
    # Maximum sequence length in the batch. 0 if it is prefill-only batch.
    max_decode_seq_len: int
25
26
27
28
29
30
31
32
33
34
35
36
37
    # (batch_size, max_blocks_per_seq).
    # Block addresses per sequence. (Seq id -> list of physical block)
    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
    # in the kv cache. Each block can contain up to block_size tokens.
    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
    # captured.
    block_tables: Optional[torch.Tensor]


class PagedAttention:

    @staticmethod
    def get_supported_head_sizes() -> List[int]:
Joe's avatar
Joe committed
38
        return [64, 80, 96, 112, 120, 128, 192, 256]
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return (2, num_blocks, block_size * num_kv_heads * head_size)

    @staticmethod
    def split_kv_cache(
        kv_cache: torch.Tensor,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x = 16 // kv_cache.element_size()
        num_blocks = kv_cache.shape[1]

        key_cache = kv_cache[0]
        key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
                                   -1, x)
        value_cache = kv_cache[1]
        value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
        return key_cache, value_cache

    @staticmethod
    def write_to_paged_cache(
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
        kv_cache_dtype: str,
73
74
        k_scale: float,
        v_scale: float,
75
    ) -> None:
76
        ops.reshape_and_cache(
77
78
79
80
81
82
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping.flatten(),
            kv_cache_dtype,
83
84
            k_scale,
            v_scale,
85
86
87
88
89
90
91
92
        )

    @staticmethod
    def forward_decode(
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        block_tables: torch.Tensor,
93
94
        seq_lens: torch.Tensor,
        max_seq_len: int,
95
96
97
98
        kv_cache_dtype: str,
        num_kv_heads: int,
        scale: float,
        alibi_slopes: Optional[torch.Tensor],
99
100
        k_scale: float,
        v_scale: float,
101
102
103
104
105
        tp_rank: int = 0,
        blocksparse_local_blocks: int = 0,
        blocksparse_vert_stride: int = 0,
        blocksparse_block_size: int = 64,
        blocksparse_head_sliding_step: int = 0,
106
    ) -> torch.Tensor:
107
108
109
110
111
112
113
        if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
            # use blocksparse paged attention
            block_size = value_cache.size(-1)
            assert (blocksparse_block_size > 0 and
                    blocksparse_block_size % block_size == 0), \
                (f"{blocksparse_block_size=} needs to be a multiple of"
                 f"{block_size=} used in block_tables.")
114

115
        output = torch.empty_like(query)
116
117
        block_size = value_cache.shape[3]
        num_seqs, num_heads, head_size = query.shape
118
        max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
119
120
121
122
123
124
125
126
                              _PARTITION_SIZE)
        # NOTE(woosuk): We use a simple heuristic to decide whether to use
        # PagedAttention V1 or V2. If the number of partitions is 1, we use
        # V1 to avoid the overhead of reduction. Also, if the number of
        # sequences or heads is large, we use V1 since there is enough work
        # to parallelize.
        # TODO(woosuk): Tune this heuristic.
        # For context len > 8192, use V2 kernel to avoid shared memory shortage.
127
        use_v1 = (max_seq_len <= 8192
128
                  and (max_num_partitions == 1 or num_seqs * num_heads > 512))
129

130
131
        if use_v1:
            # Run PagedAttention V1.
132
133
134
135
136
            if envs.VLLM_USE_PA_PRINT_PARAM:
                print("PA V1 SIZE:")
                print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
                print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")

zhuwenwen's avatar
zhuwenwen committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
            if envs.VLLM_USE_OPT_OP:
                ops.paged_attention_v1_opt(
                    output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                    tp_rank,
                    blocksparse_local_blocks,
                    blocksparse_vert_stride,
                    blocksparse_block_size,
                    blocksparse_head_sliding_step,
                )
            else:
                ops.paged_attention_v1(
                    output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                    tp_rank,
                    blocksparse_local_blocks,
                    blocksparse_vert_stride,
                    blocksparse_block_size,
                    blocksparse_head_sliding_step,
                )
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        else:
            # Run PagedAttention V2.
            assert _PARTITION_SIZE % block_size == 0
            tmp_output = torch.empty(
                size=(num_seqs, num_heads, max_num_partitions, head_size),
                dtype=output.dtype,
                device=output.device,
            )
            exp_sums = torch.empty(
                size=(num_seqs, num_heads, max_num_partitions),
                dtype=torch.float32,
                device=output.device,
            )
            max_logits = torch.empty_like(exp_sums)
195
196
197
198
199
200
201
            
            if envs.VLLM_USE_PA_PRINT_PARAM:
                print("PA V2 SIZE:")
                print(f"exp_sums.shape = {exp_sums.shape}, max_logits.shape = {max_logits.shape}, tmp_output.shape = {tmp_output.shape}")
                print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
                print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")

zhuwenwen's avatar
zhuwenwen committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            if envs.VLLM_USE_OPT_OP:
                ops.paged_attention_v2_opt(
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                    tp_rank,
                    blocksparse_local_blocks,
                    blocksparse_vert_stride,
                    blocksparse_block_size,
                    blocksparse_head_sliding_step,
                )
            else:
                ops.paged_attention_v2(
                    output,
                    exp_sums,
                    max_logits,
                    tmp_output,
                    query,
                    key_cache,
                    value_cache,
                    num_kv_heads,
                    scale,
                    block_tables,
                    seq_lens,
                    block_size,
                    max_seq_len,
                    alibi_slopes,
                    kv_cache_dtype,
                    k_scale,
                    v_scale,
                    tp_rank,
                    blocksparse_local_blocks,
                    blocksparse_vert_stride,
                    blocksparse_block_size,
                    blocksparse_head_sliding_step,
                )
252
253
254
255
256
257
258
        return output

    @staticmethod
    def forward_prefix(
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
259
        kv_cache_dtype: str,
260
261
262
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        block_tables: torch.Tensor,
263
        query_start_loc: torch.Tensor,
264
        seq_lens_tensor: torch.Tensor,
265
        context_lens: torch.Tensor,
266
        max_query_len: int,
267
        alibi_slopes: Optional[torch.Tensor],
268
        sliding_window: Optional[int],
269
270
        k_scale: float,
        v_scale: float,
271
272
273
274
275
276
277
    ) -> torch.Tensor:
        output = torch.empty_like(query)
        context_attention_fwd(
            query,
            key,
            value,
            output,
278
            kv_cache_dtype,
279
280
281
            key_cache,
            value_cache,
            block_tables,
282
283
            # query_start_loc is (batch_size + 1,)
            query_start_loc[:-1],
284
            seq_lens_tensor,
285
            context_lens,
286
            max_query_len,
287
288
            k_scale,
            v_scale,
289
            alibi_slopes,
290
            sliding_window,
291
292
293
294
295
296
297
        )
        return output

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
298
        src_to_dst: torch.Tensor,
299
300
301
    ) -> None:
        src_key_cache = src_kv_cache[0]
        dst_key_cache = dst_kv_cache[0]
302
        ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
303
304
305

        src_value_cache = src_kv_cache[1]
        dst_value_cache = dst_kv_cache[1]
306
        ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
307
308
309
310

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
311
        src_to_dists: torch.Tensor,
312
313
314
    ) -> None:
        key_caches = [kv_cache[0] for kv_cache in kv_caches]
        value_caches = [kv_cache[1] for kv_cache in kv_caches]
315
        ops.copy_blocks(key_caches, value_caches, src_to_dists)