attention.py 11.4 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
11
from vllm._C import ops
from vllm._C import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.model_executor.input_metadata import InputMetadata
13
14
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
    context_attention_fwd)
15
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
16

17
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
18
19
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
20

21

Woosuk Kwon's avatar
Woosuk Kwon committed
22
class PagedAttention(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
23
    """MHA/MQA/GQA layer with PagedAttention.
24

25
    This class takes query, key, and value tensors as input. The input tensors
Woosuk Kwon's avatar
Woosuk Kwon committed
26
    can either contain prompt tokens or generation tokens.
27
    The class does the following:
Woosuk Kwon's avatar
Woosuk Kwon committed
28

29
30
    1. Reshape and store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention using either
Woosuk Kwon's avatar
Woosuk Kwon committed
31
        xformers or the PagedAttention custom op.
32
    3. Return the output tensor.
33
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
34

Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37
38
39
40
41
42
43
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
    ) -> None:
44
        super().__init__()
45
46
        self.num_heads = num_heads
        self.head_size = head_size
47
        self.scale = float(scale)
Zhuohan Li's avatar
Zhuohan Li committed
48
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
49
        self.sliding_window = sliding_window
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
Zhuohan Li's avatar
Zhuohan Li committed
53
54
55

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Woosuk Kwon's avatar
Woosuk Kwon committed
56

57
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
60

Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
    def forward(
        self,
63
64
65
66
67
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
68
        input_metadata: InputMetadata,
69
70
71
72
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        Args:
73
74
            query: shape = [batch_size, seq_len, num_heads * head_size]
            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
75
            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
76
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
77
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
78
79
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
80
            input_metadata: metadata for the inputs.
81
        Returns:
82
            shape = [batch_size, seq_len, num_heads * head_size]
83
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
84
        batch_size, seq_len, hidden_size = query.shape
Woosuk Kwon's avatar
Woosuk Kwon committed
85
        # Reshape the query, key, and value tensors.
86
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
87
88
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90

        # Reshape the keys and values and store them in the cache.
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
93
        # If key_cache and value_cache are not provided, the new key and value
        # vectors will not be cached. This happens during the initial memory
        # profiling run.
94
        if key_cache is not None and value_cache is not None:
Woosuk Kwon's avatar
Woosuk Kwon committed
95
            cache_ops.reshape_and_cache(
96
97
                key,
                value,
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
                key_cache,
                value_cache,
100
                input_metadata.slot_mapping.flatten(),
101
                input_metadata.kv_cache_dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
102
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
103

104
        if input_metadata.is_prompt:
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
            # Prompt run.
            if self.num_kv_heads != self.num_heads:
                # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
                # project the key and value tensors to the desired number of
                # heads.
                # TODO(woosuk): Use MQA/GQA kernels for higher performance.
                query = query.view(query.shape[0], self.num_kv_heads,
                                   self.num_queries_per_kv, query.shape[-1])
                key = key[:, :,
                          None, :].expand(key.shape[0], self.num_kv_heads,
                                          self.num_queries_per_kv,
                                          key.shape[-1])
                value = value[:, :, None, :].expand(value.shape[0],
                                                    self.num_kv_heads,
                                                    self.num_queries_per_kv,
                                                    value.shape[-1])
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
            # normal attention
            if (key_cache is None or value_cache is None
                    or input_metadata.block_tables.numel() == 0):
                # Set attention bias if not provided. This typically happens at
                # the very attention layer of every iteration.
                # FIXME(woosuk): This is a hack.
                if input_metadata.attn_bias is None:
                    if self.alibi_slopes is None:
                        attn_bias = BlockDiagonalCausalMask.from_seqlens(
                            [seq_len] * batch_size)
                        if self.sliding_window is not None:
                            attn_bias = attn_bias.make_local_attention(
                                self.sliding_window)
                        input_metadata.attn_bias = attn_bias
                    else:
                        input_metadata.attn_bias = _make_alibi_bias(
                            self.alibi_slopes, self.num_kv_heads, batch_size,
                            seq_len, query.dtype)

                # TODO(woosuk): Too many view operations. Let's try to reduce
                # them in the future for code readability.
Woosuk Kwon's avatar
Woosuk Kwon committed
142
                if self.alibi_slopes is None:
143
144
145
                    query = query.unsqueeze(0)
                    key = key.unsqueeze(0)
                    value = value.unsqueeze(0)
Woosuk Kwon's avatar
Woosuk Kwon committed
146
                else:
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
                    query = query.unflatten(0, (batch_size, seq_len))
                    key = key.unflatten(0, (batch_size, seq_len))
                    value = value.unflatten(0, (batch_size, seq_len))

                out = xops.memory_efficient_attention_forward(
                    query,
                    key,
                    value,
                    attn_bias=input_metadata.attn_bias,
                    p=0.0,
                    scale=self.scale,
                    op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
                    (is_hip()) else None,
                )
                output = out.view_as(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
162
            else:
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
                # prefix-enabled attention
                output = torch.empty_like(query)
                context_attention_fwd(
                    query,
                    key,
                    value,
                    output,
                    key_cache,
                    value_cache,
                    input_metadata.block_tables,  # [BS, max_block_per_request]
                    input_metadata.start_loc,
                    input_metadata.prompt_lens,
                    input_metadata.context_lens,
                    input_metadata.max_seq_len,
                    getattr(self, "alibi_slopes", None),
                )
Woosuk Kwon's avatar
Woosuk Kwon committed
179
180

        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
181
            # Decoding run.
182
183
184
185
186
187
188
189
190
            output = _paged_attention(
                query,
                key_cache,
                value_cache,
                input_metadata,
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
191
192

        # Reshape the output tensor.
Woosuk Kwon's avatar
Woosuk Kwon committed
193
194
195
196
197
        return output.view(batch_size, seq_len, hidden_size)


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
Megha Agarwal's avatar
Megha Agarwal committed
198
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
199
200
201
202
    batch_size: int,
    seq_len: int,
    dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias:
Megha Agarwal's avatar
Megha Agarwal committed
203
    bias = torch.arange(seq_len, dtype=dtype, device="cuda")
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
206
207
208
209
210
211
212
213
    # NOTE(zhuohan): HF uses
    #     `bias = bias[None, :].repeat(prompt_len, 1)`
    # here. We find that both biases give the same results, but
    # the bias below more accurately follows the original ALiBi
    # paper.
    bias = bias[None, :] - bias[:, None]

    # When using custom attention bias, xformers requires the bias to
    # be sliced from a tensor whose length is a multiple of 8.
    padded_len = (seq_len + 7) // 8 * 8
Megha Agarwal's avatar
Megha Agarwal committed
214
    num_heads = alibi_slopes.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
215
216
    bias = torch.empty(
        batch_size,
Megha Agarwal's avatar
Megha Agarwal committed
217
        num_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219
220
221
222
223
        seq_len,
        padded_len,
        device=alibi_slopes.device,
        dtype=dtype,
    )[:, :, :, :seq_len].copy_(bias)
    bias.mul_(alibi_slopes[:, None, None])
Megha Agarwal's avatar
Megha Agarwal committed
224
225
    if num_heads != num_kv_heads:
        bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
Woosuk Kwon's avatar
Woosuk Kwon committed
226
227
228
229
230
231
232
233
234
    attn_bias = LowerTriangularMaskWithTensorBias(bias)
    return attn_bias


def _paged_attention(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    input_metadata: InputMetadata,
235
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
    output = torch.empty_like(query)

    block_size = value_cache.shape[3]
    num_seqs, num_heads, head_size = query.shape
    max_num_partitions = (
        (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
        _PARTITION_SIZE)
    # NOTE(woosuk): We use a simple heuristic to decide whether to use
    # PagedAttention V1 or V2. If the number of partitions is 1, we use
    # V1 to avoid the overhead of reduction. Also, if the number of
    # sequences or heads is large, we use V1 since there is enough work
    # to parallelize.
    # TODO(woosuk): Tune this heuristic.
    # For context len > 8192, use V2 kernel to avoid shared memory shortage.
    use_v1 = input_metadata.max_context_len <= 8192 and (
        max_num_partitions == 1 or num_seqs * num_heads > 512)
    if use_v1:
        # Run PagedAttention V1.
        ops.paged_attention_v1(
            output,
Woosuk Kwon's avatar
Woosuk Kwon committed
259
            query,
260
261
            key_cache,
            value_cache,
262
            num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
265
266
267
268
            scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            alibi_slopes,
269
            input_metadata.kv_cache_dtype,
270
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
271
272
273
274
275
276
277
    else:
        # Run PagedAttention V2.
        assert _PARTITION_SIZE % block_size == 0
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
278
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
279
280
281
282
283
284
285
286
287
288
289
290
291
292
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)
        ops.paged_attention_v2(
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
293
            num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
294
295
296
297
298
299
            scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            alibi_slopes,
300
            input_metadata.kv_cache_dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
        )
    return output