paged_attn.py 6.73 KB
Newer Older
1
from dataclasses import dataclass
2
from typing import List, Optional, Tuple
3
4
5

import torch

6
from vllm import _custom_ops as ops
7
8
9
10
11
12
13
14
15
from vllm.attention.ops.prefix_prefill import context_attention_fwd

# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512


@dataclass
class PagedAttentionMetadata:
    """Metadata for PagedAttention."""
16
17
18
19
20
    # (batch_size,). The length of sequences (entire tokens seen so far) per
    # sequence.
    seq_lens_tensor: Optional[torch.Tensor]
    # Maximum sequence length in the batch.
    max_seq_len: Optional[int]
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    # (batch_size, max_blocks_per_seq).
    # Block addresses per sequence. (Seq id -> list of physical block)
    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
    # in the kv cache. Each block can contain up to block_size tokens.
    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
    # captured.
    block_tables: Optional[torch.Tensor]


class PagedAttention:

    @staticmethod
    def get_supported_head_sizes() -> List[int]:
        return [64, 80, 96, 112, 128, 256]

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[int, ...]:
        return (2, num_blocks, block_size * num_kv_heads * head_size)

    @staticmethod
    def split_kv_cache(
        kv_cache: torch.Tensor,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x = 16 // kv_cache.element_size()
        num_blocks = kv_cache.shape[1]

        key_cache = kv_cache[0]
        key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
                                   -1, x)
        value_cache = kv_cache[1]
        value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
        return key_cache, value_cache

    @staticmethod
    def write_to_paged_cache(
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
        kv_cache_dtype: str,
69
        kv_scale: float,
70
    ) -> None:
71
        ops.reshape_and_cache(
72
73
74
75
76
77
            key,
            value,
            key_cache,
            value_cache,
            slot_mapping.flatten(),
            kv_cache_dtype,
78
            kv_scale,
79
80
81
82
83
84
85
86
        )

    @staticmethod
    def forward_decode(
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        block_tables: torch.Tensor,
87
88
        seq_lens: torch.Tensor,
        max_seq_len: int,
89
90
91
92
        kv_cache_dtype: str,
        num_kv_heads: int,
        scale: float,
        alibi_slopes: Optional[torch.Tensor],
93
        kv_scale: float,
94
95
96
97
98
    ) -> torch.Tensor:
        output = torch.empty_like(query)

        block_size = value_cache.shape[3]
        num_seqs, num_heads, head_size = query.shape
99
        max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
100
101
102
103
104
105
106
107
                              _PARTITION_SIZE)
        # NOTE(woosuk): We use a simple heuristic to decide whether to use
        # PagedAttention V1 or V2. If the number of partitions is 1, we use
        # V1 to avoid the overhead of reduction. Also, if the number of
        # sequences or heads is large, we use V1 since there is enough work
        # to parallelize.
        # TODO(woosuk): Tune this heuristic.
        # For context len > 8192, use V2 kernel to avoid shared memory shortage.
108
        use_v1 = (max_seq_len <= 8192
109
110
111
112
113
114
115
116
117
118
119
                  and (max_num_partitions == 1 or num_seqs * num_heads > 512))
        if use_v1:
            # Run PagedAttention V1.
            ops.paged_attention_v1(
                output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
120
                seq_lens,
121
                block_size,
122
                max_seq_len,
123
124
                alibi_slopes,
                kv_cache_dtype,
125
                kv_scale,
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
            )
        else:
            # Run PagedAttention V2.
            assert _PARTITION_SIZE % block_size == 0
            tmp_output = torch.empty(
                size=(num_seqs, num_heads, max_num_partitions, head_size),
                dtype=output.dtype,
                device=output.device,
            )
            exp_sums = torch.empty(
                size=(num_seqs, num_heads, max_num_partitions),
                dtype=torch.float32,
                device=output.device,
            )
            max_logits = torch.empty_like(exp_sums)
            ops.paged_attention_v2(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                num_kv_heads,
                scale,
                block_tables,
152
                seq_lens,
153
                block_size,
154
                max_seq_len,
155
156
                alibi_slopes,
                kv_cache_dtype,
157
                kv_scale,
158
159
160
161
162
163
164
165
166
167
168
169
            )
        return output

    @staticmethod
    def forward_prefix(
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        block_tables: torch.Tensor,
        subquery_start_loc: torch.Tensor,
170
        seq_lens_tensor: torch.Tensor,
171
        context_lens: torch.Tensor,
172
        max_query_len: int,
173
        alibi_slopes: Optional[torch.Tensor],
174
        sliding_window: Optional[int],
175
176
177
178
179
180
181
182
183
184
185
186
    ) -> torch.Tensor:
        output = torch.empty_like(query)
        context_attention_fwd(
            query,
            key,
            value,
            output,
            key_cache,
            value_cache,
            block_tables,
            # subquery_start_loc is (batch_size + 1,)
            subquery_start_loc[:-1],
187
            seq_lens_tensor,
188
            context_lens,
189
            max_query_len,
190
            alibi_slopes,
191
            sliding_window,
192
193
194
195
196
197
198
        )
        return output

    @staticmethod
    def swap_blocks(
        src_kv_cache: torch.Tensor,
        dst_kv_cache: torch.Tensor,
199
        src_to_dst: torch.Tensor,
200
201
202
    ) -> None:
        src_key_cache = src_kv_cache[0]
        dst_key_cache = dst_kv_cache[0]
203
        ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
204
205
206

        src_value_cache = src_kv_cache[1]
        dst_value_cache = dst_kv_cache[1]
207
        ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
208
209
210
211

    @staticmethod
    def copy_blocks(
        kv_caches: List[torch.Tensor],
212
        src_to_dists: torch.Tensor,
213
214
215
    ) -> None:
        key_caches = [kv_cache[0] for kv_cache in kv_caches]
        value_caches = [kv_cache[1] for kv_cache in kv_caches]
216
        ops.copy_blocks(key_caches, value_caches, src_to_dists)