attention.py 13.4 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3

4
import importlib
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
import torch
import torch.nn as nn
7
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
10

11
12
from vllm._C import ops
from vllm._C import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
13
from vllm.model_executor.input_metadata import InputMetadata
14
15
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
    context_attention_fwd)
16
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
17

18
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
19
20
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
21

22

Woosuk Kwon's avatar
Woosuk Kwon committed
23
class PagedAttention(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
24
    """MHA/MQA/GQA layer with PagedAttention.
25

26
    This class takes query, key, and value tensors as input. The input tensors
Woosuk Kwon's avatar
Woosuk Kwon committed
27
    can either contain prompt tokens or generation tokens.
28
    The class does the following:
Woosuk Kwon's avatar
Woosuk Kwon committed
29

30
31
    1. Reshape and store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention using either
Woosuk Kwon's avatar
Woosuk Kwon committed
32
        xformers or the PagedAttention custom op.
33
    3. Return the output tensor.
34
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
35

Woosuk Kwon's avatar
Woosuk Kwon committed
36
37
38
39
40
41
42
43
44
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
    ) -> None:
45
        super().__init__()
46
47
        self.num_heads = num_heads
        self.head_size = head_size
48
        self.scale = float(scale)
Zhuohan Li's avatar
Zhuohan Li committed
49
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
50
        self.sliding_window = sliding_window
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
Zhuohan Li's avatar
Zhuohan Li committed
54
55
56

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Woosuk Kwon's avatar
Woosuk Kwon committed
57

58
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
61

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        self.use_ref_attention = self.check_use_ref_attention()

    def check_use_ref_attention(self) -> bool:
        if not is_hip():
            return False
        # For ROCm, check whether flash attention is installed or not.
        # if not, use_ref_attention needs to be True
        return importlib.util.find_spec("flash_attn") is None

    def ref_masked_attention(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
    ) -> torch.Tensor:
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)

        seq_len, _, _ = query.shape
        attn_mask = torch.triu(torch.ones(seq_len,
                                          seq_len,
                                          dtype=query.dtype,
                                          device=query.device),
                               diagonal=1)
        attn_mask = attn_mask * torch.finfo(query.dtype).min

        attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
                                                 key).float()
        attn_weights = attn_weights + attn_mask.float()
        attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
        out = torch.einsum("hqk,khd->qhd", attn_weights, value)
        return out

Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
    def forward(
        self,
98
99
100
101
102
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
103
        input_metadata: InputMetadata,
104
105
106
107
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        Args:
108
109
            query: shape = [batch_size, seq_len, num_heads * head_size]
            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
110
            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
111
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
112
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
113
114
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
115
            input_metadata: metadata for the inputs.
116
        Returns:
117
            shape = [batch_size, seq_len, num_heads * head_size]
118
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
119
        batch_size, seq_len, hidden_size = query.shape
Woosuk Kwon's avatar
Woosuk Kwon committed
120
        # Reshape the query, key, and value tensors.
121
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
122
123
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
124
125

        # Reshape the keys and values and store them in the cache.
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128
        # If key_cache and value_cache are not provided, the new key and value
        # vectors will not be cached. This happens during the initial memory
        # profiling run.
129
        if key_cache is not None and value_cache is not None:
Woosuk Kwon's avatar
Woosuk Kwon committed
130
            cache_ops.reshape_and_cache(
131
132
                key,
                value,
Woosuk Kwon's avatar
Woosuk Kwon committed
133
134
                key_cache,
                value_cache,
135
                input_metadata.slot_mapping.flatten(),
136
                input_metadata.kv_cache_dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
137
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
138

139
        if input_metadata.is_prompt:
140
141
142
            # normal attention
            if (key_cache is None or value_cache is None
                    or input_metadata.block_tables.numel() == 0):
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
                if self.num_kv_heads != self.num_heads:
                    # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
                    # project the key and value tensors to the desired number of
                    # heads.
                    # TODO(woosuk): Use MQA/GQA kernels for higher performance.
                    query = query.view(query.shape[0], self.num_kv_heads,
                                       self.num_queries_per_kv,
                                       query.shape[-1])
                    key = key[:, :,
                              None, :].expand(key.shape[0], self.num_kv_heads,
                                              self.num_queries_per_kv,
                                              key.shape[-1])
                    value = value[:, :,
                                  None, :].expand(value.shape[0],
                                                  self.num_kv_heads,
                                                  self.num_queries_per_kv,
                                                  value.shape[-1])

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
                # Set attention bias if not provided. This typically happens at
                # the very attention layer of every iteration.
                # FIXME(woosuk): This is a hack.
                if input_metadata.attn_bias is None:
                    if self.alibi_slopes is None:
                        attn_bias = BlockDiagonalCausalMask.from_seqlens(
                            [seq_len] * batch_size)
                        if self.sliding_window is not None:
                            attn_bias = attn_bias.make_local_attention(
                                self.sliding_window)
                        input_metadata.attn_bias = attn_bias
                    else:
                        input_metadata.attn_bias = _make_alibi_bias(
                            self.alibi_slopes, self.num_kv_heads, batch_size,
                            seq_len, query.dtype)

177
178
179
180
181
182
183
184
185
186
                if self.use_ref_attention:
                    output = self.ref_masked_attention(
                        query,
                        key,
                        value,
                    )
                    # Using view got RuntimeError: view size is not compatible with input tensor's size and stride
                    # (at least one dimension spans across two contiguous subspaces). Use reshape instead
                    return output.reshape(batch_size, seq_len, hidden_size)

187
188
                # TODO(woosuk): Too many view operations. Let's try to reduce
                # them in the future for code readability.
Woosuk Kwon's avatar
Woosuk Kwon committed
189
                if self.alibi_slopes is None:
190
191
192
                    query = query.unsqueeze(0)
                    key = key.unsqueeze(0)
                    value = value.unsqueeze(0)
Woosuk Kwon's avatar
Woosuk Kwon committed
193
                else:
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
                    query = query.unflatten(0, (batch_size, seq_len))
                    key = key.unflatten(0, (batch_size, seq_len))
                    value = value.unflatten(0, (batch_size, seq_len))

                out = xops.memory_efficient_attention_forward(
                    query,
                    key,
                    value,
                    attn_bias=input_metadata.attn_bias,
                    p=0.0,
                    scale=self.scale,
                    op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
                    (is_hip()) else None,
                )
                output = out.view_as(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
209
            else:
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
                # prefix-enabled attention
                output = torch.empty_like(query)
                context_attention_fwd(
                    query,
                    key,
                    value,
                    output,
                    key_cache,
                    value_cache,
                    input_metadata.block_tables,  # [BS, max_block_per_request]
                    input_metadata.start_loc,
                    input_metadata.prompt_lens,
                    input_metadata.context_lens,
                    input_metadata.max_seq_len,
                    getattr(self, "alibi_slopes", None),
                )
Woosuk Kwon's avatar
Woosuk Kwon committed
226
227

        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
228
            # Decoding run.
229
230
231
232
233
234
235
236
237
            output = _paged_attention(
                query,
                key_cache,
                value_cache,
                input_metadata,
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
238
239

        # Reshape the output tensor.
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
242
243
244
        return output.view(batch_size, seq_len, hidden_size)


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
Megha Agarwal's avatar
Megha Agarwal committed
245
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
248
249
    batch_size: int,
    seq_len: int,
    dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias:
250
    bias = torch.arange(seq_len, dtype=dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
251
252
253
254
255
256
257
258
259
260
    # NOTE(zhuohan): HF uses
    #     `bias = bias[None, :].repeat(prompt_len, 1)`
    # here. We find that both biases give the same results, but
    # the bias below more accurately follows the original ALiBi
    # paper.
    bias = bias[None, :] - bias[:, None]

    # When using custom attention bias, xformers requires the bias to
    # be sliced from a tensor whose length is a multiple of 8.
    padded_len = (seq_len + 7) // 8 * 8
Megha Agarwal's avatar
Megha Agarwal committed
261
    num_heads = alibi_slopes.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
262
263
    bias = torch.empty(
        batch_size,
Megha Agarwal's avatar
Megha Agarwal committed
264
        num_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
265
266
267
268
269
270
        seq_len,
        padded_len,
        device=alibi_slopes.device,
        dtype=dtype,
    )[:, :, :, :seq_len].copy_(bias)
    bias.mul_(alibi_slopes[:, None, None])
Megha Agarwal's avatar
Megha Agarwal committed
271
272
    if num_heads != num_kv_heads:
        bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
Woosuk Kwon's avatar
Woosuk Kwon committed
273
274
275
276
277
278
279
280
281
    attn_bias = LowerTriangularMaskWithTensorBias(bias)
    return attn_bias


def _paged_attention(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    input_metadata: InputMetadata,
282
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
    output = torch.empty_like(query)

    block_size = value_cache.shape[3]
    num_seqs, num_heads, head_size = query.shape
    max_num_partitions = (
        (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
        _PARTITION_SIZE)
    # NOTE(woosuk): We use a simple heuristic to decide whether to use
    # PagedAttention V1 or V2. If the number of partitions is 1, we use
    # V1 to avoid the overhead of reduction. Also, if the number of
    # sequences or heads is large, we use V1 since there is enough work
    # to parallelize.
    # TODO(woosuk): Tune this heuristic.
    # For context len > 8192, use V2 kernel to avoid shared memory shortage.
    use_v1 = input_metadata.max_context_len <= 8192 and (
        max_num_partitions == 1 or num_seqs * num_heads > 512)
    if use_v1:
        # Run PagedAttention V1.
        ops.paged_attention_v1(
            output,
Woosuk Kwon's avatar
Woosuk Kwon committed
306
            query,
307
308
            key_cache,
            value_cache,
309
            num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
310
311
312
313
314
315
            scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            alibi_slopes,
316
            input_metadata.kv_cache_dtype,
317
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
318
319
320
321
322
323
324
    else:
        # Run PagedAttention V2.
        assert _PARTITION_SIZE % block_size == 0
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
325
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)
        ops.paged_attention_v2(
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
340
            num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
341
342
343
344
345
346
            scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            alibi_slopes,
347
            input_metadata.kv_cache_dtype,
Woosuk Kwon's avatar
Woosuk Kwon committed
348
349
        )
    return output