attention.py 10.7 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
11
from vllm._C import ops
from vllm._C import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.model_executor.input_metadata import InputMetadata
13
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
16
17
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
18

19

Woosuk Kwon's avatar
Woosuk Kwon committed
20
class PagedAttention(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
21
    """MHA/MQA/GQA layer with PagedAttention.
22

23
    This class takes query, key, and value tensors as input. The input tensors
Woosuk Kwon's avatar
Woosuk Kwon committed
24
    can either contain prompt tokens or generation tokens.
25
    The class does the following:
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27

    1. Wait for the cache operations (e.g., swap, copy) to finish. The cache
28
29
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
32
33
    2. Reshape and store the input key and value tensors in the KV cache.
    3. Perform (multi-head/multi-query/grouped-query) attention using either
        xformers or the PagedAttention custom op.
    4. Return the output tensor.
34
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
35

Woosuk Kwon's avatar
Woosuk Kwon committed
36
37
38
39
40
41
42
43
44
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
    ) -> None:
45
        super().__init__()
46
47
        self.num_heads = num_heads
        self.head_size = head_size
48
        self.scale = float(scale)
Zhuohan Li's avatar
Zhuohan Li committed
49
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
50
        self.sliding_window = sliding_window
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
Zhuohan Li's avatar
Zhuohan Li committed
54
55
56

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Woosuk Kwon's avatar
Woosuk Kwon committed
57

58
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
61

Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
    def forward(
        self,
64
65
66
67
68
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
69
70
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
71
72
73
74
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        Args:
75
76
            query: shape = [batch_size, seq_len, num_heads * head_size]
            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
77
            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
78
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
79
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
80
81
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
82
            input_metadata: metadata for the inputs.
83
84
            cache_event: event to wait for the cache operations to finish.
        Returns:
85
            shape = [batch_size, seq_len, num_heads * head_size]
86
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
87
        batch_size, seq_len, hidden_size = query.shape
Woosuk Kwon's avatar
Woosuk Kwon committed
88
        # Reshape the query, key, and value tensors.
89
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
90
91
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
92
        slot_mapping = input_metadata.slot_mapping.flatten()
Woosuk Kwon's avatar
Woosuk Kwon committed
93

Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
97
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
Woosuk Kwon's avatar
Woosuk Kwon committed
98
99
100
        # If key_cache and value_cache are not provided, the new key and value
        # vectors will not be cached. This happens during the initial memory
        # profiling run.
101
        if key_cache is not None and value_cache is not None:
Woosuk Kwon's avatar
Woosuk Kwon committed
102
            cache_ops.reshape_and_cache(
103
104
                key,
                value,
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
                key_cache,
                value_cache,
107
                slot_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
108
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
109

110
        if input_metadata.is_prompt:
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
            # Prompt run.
            if self.num_kv_heads != self.num_heads:
                # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
                # project the key and value tensors to the desired number of
                # heads.
                # TODO(woosuk): Use MQA/GQA kernels for higher performance.
                query = query.view(query.shape[0], self.num_kv_heads,
                                   self.num_queries_per_kv, query.shape[-1])
                key = key[:, :,
                          None, :].expand(key.shape[0], self.num_kv_heads,
                                          self.num_queries_per_kv,
                                          key.shape[-1])
                value = value[:, :, None, :].expand(value.shape[0],
                                                    self.num_kv_heads,
                                                    self.num_queries_per_kv,
                                                    value.shape[-1])

            # Set attention bias if not provided. This typically happens at the
            # very attention layer of every iteration.
            # FIXME(woosuk): This is a hack.
            if input_metadata.attn_bias is None:
                if self.alibi_slopes is None:
                    attn_bias = BlockDiagonalCausalMask.from_seqlens(
                        [seq_len] * batch_size)
                    if self.sliding_window is not None:
                        attn_bias = attn_bias.make_local_attention(
                            self.sliding_window)
                    input_metadata.attn_bias = attn_bias
                else:
                    input_metadata.attn_bias = _make_alibi_bias(
Megha Agarwal's avatar
Megha Agarwal committed
141
142
                        self.alibi_slopes, self.num_kv_heads, batch_size,
                        seq_len, query.dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

            # TODO(woosuk): Too many view operations. Let's try to reduce them
            # in the future for code readability.
            if self.alibi_slopes is None:
                query = query.unsqueeze(0)
                key = key.unsqueeze(0)
                value = value.unsqueeze(0)
            else:
                query = query.unflatten(0, (batch_size, seq_len))
                key = key.unflatten(0, (batch_size, seq_len))
                value = value.unflatten(0, (batch_size, seq_len))

            out = xops.memory_efficient_attention_forward(
                query,
                key,
                value,
                attn_bias=input_metadata.attn_bias,
                p=0.0,
                scale=self.scale,
162
163
                op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
                (is_hip()) else None,
Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
166
            )
            output = out.view_as(query)
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
167
            # Decoding run.
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
170
171
172
            output = _paged_attention(
                query,
                key_cache,
                value_cache,
                input_metadata,
173
                self.num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176
                self.scale,
                self.alibi_slopes,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
177
178

        # Reshape the output tensor.
Woosuk Kwon's avatar
Woosuk Kwon committed
179
180
181
182
183
        return output.view(batch_size, seq_len, hidden_size)


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
Megha Agarwal's avatar
Megha Agarwal committed
184
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
185
186
187
188
    batch_size: int,
    seq_len: int,
    dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias:
Megha Agarwal's avatar
Megha Agarwal committed
189
    bias = torch.arange(seq_len, dtype=dtype, device="cuda")
Woosuk Kwon's avatar
Woosuk Kwon committed
190
191
192
193
194
195
196
197
198
199
    # NOTE(zhuohan): HF uses
    #     `bias = bias[None, :].repeat(prompt_len, 1)`
    # here. We find that both biases give the same results, but
    # the bias below more accurately follows the original ALiBi
    # paper.
    bias = bias[None, :] - bias[:, None]

    # When using custom attention bias, xformers requires the bias to
    # be sliced from a tensor whose length is a multiple of 8.
    padded_len = (seq_len + 7) // 8 * 8
Megha Agarwal's avatar
Megha Agarwal committed
200
    num_heads = alibi_slopes.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
    bias = torch.empty(
        batch_size,
Megha Agarwal's avatar
Megha Agarwal committed
203
        num_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
206
207
208
209
        seq_len,
        padded_len,
        device=alibi_slopes.device,
        dtype=dtype,
    )[:, :, :, :seq_len].copy_(bias)
    bias.mul_(alibi_slopes[:, None, None])
Megha Agarwal's avatar
Megha Agarwal committed
210
211
    if num_heads != num_kv_heads:
        bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
Woosuk Kwon's avatar
Woosuk Kwon committed
212
213
214
215
216
217
218
219
220
    attn_bias = LowerTriangularMaskWithTensorBias(bias)
    return attn_bias


def _paged_attention(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    input_metadata: InputMetadata,
221
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
    output = torch.empty_like(query)

    block_size = value_cache.shape[3]
    num_seqs, num_heads, head_size = query.shape
    max_num_partitions = (
        (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
        _PARTITION_SIZE)
    # NOTE(woosuk): We use a simple heuristic to decide whether to use
    # PagedAttention V1 or V2. If the number of partitions is 1, we use
    # V1 to avoid the overhead of reduction. Also, if the number of
    # sequences or heads is large, we use V1 since there is enough work
    # to parallelize.
    # TODO(woosuk): Tune this heuristic.
    # For context len > 8192, use V2 kernel to avoid shared memory shortage.
    use_v1 = input_metadata.max_context_len <= 8192 and (
        max_num_partitions == 1 or num_seqs * num_heads > 512)
    if use_v1:
        # Run PagedAttention V1.
        ops.paged_attention_v1(
            output,
Woosuk Kwon's avatar
Woosuk Kwon committed
245
            query,
246
247
            key_cache,
            value_cache,
248
            num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
249
250
251
252
253
254
            scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            alibi_slopes,
255
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
258
259
260
261
262
    else:
        # Run PagedAttention V2.
        assert _PARTITION_SIZE % block_size == 0
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
263
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)
        ops.paged_attention_v2(
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
278
            num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
279
280
281
282
283
284
285
286
            scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            alibi_slopes,
        )
    return output