attention.py 11.3 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
11
from vllm._C import ops
from vllm._C import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.model_executor.input_metadata import InputMetadata
13
14
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
    context_attention_fwd)
15
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
16

17
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
18
19
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
20

21

Woosuk Kwon's avatar
Woosuk Kwon committed
22
class PagedAttention(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
23
    """MHA/MQA/GQA layer with PagedAttention.
24

25
    This class takes query, key, and value tensors as input. The input tensors
Woosuk Kwon's avatar
Woosuk Kwon committed
26
    can either contain prompt tokens or generation tokens.
27
    The class does the following:
Woosuk Kwon's avatar
Woosuk Kwon committed
28

29
30
    1. Reshape and store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention using either
Woosuk Kwon's avatar
Woosuk Kwon committed
31
        xformers or the PagedAttention custom op.
32
    3. Return the output tensor.
33
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
34

Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37
38
39
40
41
42
43
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
    ) -> None:
44
        super().__init__()
45
46
        self.num_heads = num_heads
        self.head_size = head_size
47
        self.scale = float(scale)
Zhuohan Li's avatar
Zhuohan Li committed
48
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
49
        self.sliding_window = sliding_window
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
Zhuohan Li's avatar
Zhuohan Li committed
53
54
55

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Woosuk Kwon's avatar
Woosuk Kwon committed
56

57
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
60

Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
    def forward(
        self,
63
64
65
66
67
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
68
        input_metadata: InputMetadata,
69
70
71
72
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        Args:
73
74
            query: shape = [batch_size, seq_len, num_heads * head_size]
            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
75
            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
76
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
77
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
78
79
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
80
            input_metadata: metadata for the inputs.
81
        Returns:
82
            shape = [batch_size, seq_len, num_heads * head_size]
83
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
84
        batch_size, seq_len, hidden_size = query.shape
Woosuk Kwon's avatar
Woosuk Kwon committed
85
        # Reshape the query, key, and value tensors.
86
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
87
88
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90

        # Reshape the keys and values and store them in the cache.
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
93
        # If key_cache and value_cache are not provided, the new key and value
        # vectors will not be cached. This happens during the initial memory
        # profiling run.
94
        if key_cache is not None and value_cache is not None:
Woosuk Kwon's avatar
Woosuk Kwon committed
95
            cache_ops.reshape_and_cache(
96
97
                key,
                value,
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
                key_cache,
                value_cache,
100
                input_metadata.slot_mapping.flatten(),
Woosuk Kwon's avatar
Woosuk Kwon committed
101
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
102

103
        if input_metadata.is_prompt:
Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
            # Prompt run.
            if self.num_kv_heads != self.num_heads:
                # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
                # project the key and value tensors to the desired number of
                # heads.
                # TODO(woosuk): Use MQA/GQA kernels for higher performance.
                query = query.view(query.shape[0], self.num_kv_heads,
                                   self.num_queries_per_kv, query.shape[-1])
                key = key[:, :,
                          None, :].expand(key.shape[0], self.num_kv_heads,
                                          self.num_queries_per_kv,
                                          key.shape[-1])
                value = value[:, :, None, :].expand(value.shape[0],
                                                    self.num_kv_heads,
                                                    self.num_queries_per_kv,
                                                    value.shape[-1])
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
            # normal attention
            if (key_cache is None or value_cache is None
                    or input_metadata.block_tables.numel() == 0):
                # Set attention bias if not provided. This typically happens at
                # the very attention layer of every iteration.
                # FIXME(woosuk): This is a hack.
                if input_metadata.attn_bias is None:
                    if self.alibi_slopes is None:
                        attn_bias = BlockDiagonalCausalMask.from_seqlens(
                            [seq_len] * batch_size)
                        if self.sliding_window is not None:
                            attn_bias = attn_bias.make_local_attention(
                                self.sliding_window)
                        input_metadata.attn_bias = attn_bias
                    else:
                        input_metadata.attn_bias = _make_alibi_bias(
                            self.alibi_slopes, self.num_kv_heads, batch_size,
                            seq_len, query.dtype)

                # TODO(woosuk): Too many view operations. Let's try to reduce
                # them in the future for code readability.
Woosuk Kwon's avatar
Woosuk Kwon committed
141
                if self.alibi_slopes is None:
142
143
144
                    query = query.unsqueeze(0)
                    key = key.unsqueeze(0)
                    value = value.unsqueeze(0)
Woosuk Kwon's avatar
Woosuk Kwon committed
145
                else:
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
                    query = query.unflatten(0, (batch_size, seq_len))
                    key = key.unflatten(0, (batch_size, seq_len))
                    value = value.unflatten(0, (batch_size, seq_len))

                out = xops.memory_efficient_attention_forward(
                    query,
                    key,
                    value,
                    attn_bias=input_metadata.attn_bias,
                    p=0.0,
                    scale=self.scale,
                    op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
                    (is_hip()) else None,
                )
                output = out.view_as(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
161
            else:
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
                # prefix-enabled attention
                output = torch.empty_like(query)
                context_attention_fwd(
                    query,
                    key,
                    value,
                    output,
                    key_cache,
                    value_cache,
                    input_metadata.block_tables,  # [BS, max_block_per_request]
                    input_metadata.start_loc,
                    input_metadata.prompt_lens,
                    input_metadata.context_lens,
                    input_metadata.max_seq_len,
                    getattr(self, "alibi_slopes", None),
                )
Woosuk Kwon's avatar
Woosuk Kwon committed
178
179

        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
180
            # Decoding run.
181
182
183
184
185
186
187
188
189
            output = _paged_attention(
                query,
                key_cache,
                value_cache,
                input_metadata,
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
190
191

        # Reshape the output tensor.
Woosuk Kwon's avatar
Woosuk Kwon committed
192
193
194
195
196
        return output.view(batch_size, seq_len, hidden_size)


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
Megha Agarwal's avatar
Megha Agarwal committed
197
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
198
199
200
201
    batch_size: int,
    seq_len: int,
    dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias:
Megha Agarwal's avatar
Megha Agarwal committed
202
    bias = torch.arange(seq_len, dtype=dtype, device="cuda")
Woosuk Kwon's avatar
Woosuk Kwon committed
203
204
205
206
207
208
209
210
211
212
    # NOTE(zhuohan): HF uses
    #     `bias = bias[None, :].repeat(prompt_len, 1)`
    # here. We find that both biases give the same results, but
    # the bias below more accurately follows the original ALiBi
    # paper.
    bias = bias[None, :] - bias[:, None]

    # When using custom attention bias, xformers requires the bias to
    # be sliced from a tensor whose length is a multiple of 8.
    padded_len = (seq_len + 7) // 8 * 8
Megha Agarwal's avatar
Megha Agarwal committed
213
    num_heads = alibi_slopes.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
214
215
    bias = torch.empty(
        batch_size,
Megha Agarwal's avatar
Megha Agarwal committed
216
        num_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
219
220
221
222
        seq_len,
        padded_len,
        device=alibi_slopes.device,
        dtype=dtype,
    )[:, :, :, :seq_len].copy_(bias)
    bias.mul_(alibi_slopes[:, None, None])
Megha Agarwal's avatar
Megha Agarwal committed
223
224
    if num_heads != num_kv_heads:
        bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
227
228
229
230
231
232
233
    attn_bias = LowerTriangularMaskWithTensorBias(bias)
    return attn_bias


def _paged_attention(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    input_metadata: InputMetadata,
234
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
    output = torch.empty_like(query)

    block_size = value_cache.shape[3]
    num_seqs, num_heads, head_size = query.shape
    max_num_partitions = (
        (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
        _PARTITION_SIZE)
    # NOTE(woosuk): We use a simple heuristic to decide whether to use
    # PagedAttention V1 or V2. If the number of partitions is 1, we use
    # V1 to avoid the overhead of reduction. Also, if the number of
    # sequences or heads is large, we use V1 since there is enough work
    # to parallelize.
    # TODO(woosuk): Tune this heuristic.
    # For context len > 8192, use V2 kernel to avoid shared memory shortage.
    use_v1 = input_metadata.max_context_len <= 8192 and (
        max_num_partitions == 1 or num_seqs * num_heads > 512)
    if use_v1:
        # Run PagedAttention V1.
        ops.paged_attention_v1(
            output,
Woosuk Kwon's avatar
Woosuk Kwon committed
258
            query,
259
260
            key_cache,
            value_cache,
261
            num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
262
263
264
265
266
267
            scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            alibi_slopes,
268
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270
271
272
273
274
275
    else:
        # Run PagedAttention V2.
        assert _PARTITION_SIZE % block_size == 0
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
276
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)
        ops.paged_attention_v2(
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
291
            num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293
294
295
296
297
298
299
            scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            alibi_slopes,
        )
    return output