attention.py 10.7 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
11
from vllm._C import ops
from vllm._C import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.model_executor.input_metadata import InputMetadata
13
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
16
17
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
18

19

Woosuk Kwon's avatar
Woosuk Kwon committed
20
class PagedAttention(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
21
    """MHA/MQA/GQA layer with PagedAttention.
22

23
    This class takes query, key, and value tensors as input. The input tensors
Woosuk Kwon's avatar
Woosuk Kwon committed
24
    can either contain prompt tokens or generation tokens.
25
    The class does the following:
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27

    1. Wait for the cache operations (e.g., swap, copy) to finish. The cache
28
29
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
32
33
    2. Reshape and store the input key and value tensors in the KV cache.
    3. Perform (multi-head/multi-query/grouped-query) attention using either
        xformers or the PagedAttention custom op.
    4. Return the output tensor.
34
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
35

Woosuk Kwon's avatar
Woosuk Kwon committed
36
37
38
39
40
41
42
43
44
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
    ) -> None:
45
        super().__init__()
46
47
        self.num_heads = num_heads
        self.head_size = head_size
48
        self.scale = float(scale)
Zhuohan Li's avatar
Zhuohan Li committed
49
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
50
        self.sliding_window = sliding_window
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
Zhuohan Li's avatar
Zhuohan Li committed
54
55
56
57
58
59

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.head_mapping = torch.repeat_interleave(
            torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
            self.num_queries_per_kv)
Woosuk Kwon's avatar
Woosuk Kwon committed
60

61
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
64

Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
    def forward(
        self,
67
68
69
70
71
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
72
73
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
74
75
76
77
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        Args:
78
79
80
            query: shape = [batch_size, seq_len, num_heads * head_size]
            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
            value: shape = [batch_size, num_kv_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
81
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
82
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
83
84
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
85
            input_metadata: metadata for the inputs.
86
87
            cache_event: event to wait for the cache operations to finish.
        Returns:
88
            shape = [batch_size, seq_len, num_heads * head_size]
89
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
90
        batch_size, seq_len, hidden_size = query.shape
Woosuk Kwon's avatar
Woosuk Kwon committed
91
        # Reshape the query, key, and value tensors.
92
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
93
94
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
95
        slot_mapping = input_metadata.slot_mapping.flatten()
Woosuk Kwon's avatar
Woosuk Kwon committed
96

Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
100
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
103
        # If key_cache and value_cache are not provided, the new key and value
        # vectors will not be cached. This happens during the initial memory
        # profiling run.
104
        if key_cache is not None and value_cache is not None:
Woosuk Kwon's avatar
Woosuk Kwon committed
105
            cache_ops.reshape_and_cache(
106
107
                key,
                value,
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
                key_cache,
                value_cache,
110
                slot_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
111
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
112

113
        if input_metadata.is_prompt:
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
            # Prompt run.
            if self.num_kv_heads != self.num_heads:
                # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
                # project the key and value tensors to the desired number of
                # heads.
                # TODO(woosuk): Use MQA/GQA kernels for higher performance.
                query = query.view(query.shape[0], self.num_kv_heads,
                                   self.num_queries_per_kv, query.shape[-1])
                key = key[:, :,
                          None, :].expand(key.shape[0], self.num_kv_heads,
                                          self.num_queries_per_kv,
                                          key.shape[-1])
                value = value[:, :, None, :].expand(value.shape[0],
                                                    self.num_kv_heads,
                                                    self.num_queries_per_kv,
                                                    value.shape[-1])

            # Set attention bias if not provided. This typically happens at the
            # very attention layer of every iteration.
            # FIXME(woosuk): This is a hack.
            if input_metadata.attn_bias is None:
                if self.alibi_slopes is None:
                    attn_bias = BlockDiagonalCausalMask.from_seqlens(
                        [seq_len] * batch_size)
                    if self.sliding_window is not None:
                        attn_bias = attn_bias.make_local_attention(
                            self.sliding_window)
                    input_metadata.attn_bias = attn_bias
                else:
                    input_metadata.attn_bias = _make_alibi_bias(
                        self.alibi_slopes, batch_size, seq_len, query.dtype)

            # TODO(woosuk): Too many view operations. Let's try to reduce them
            # in the future for code readability.
            if self.alibi_slopes is None:
                query = query.unsqueeze(0)
                key = key.unsqueeze(0)
                value = value.unsqueeze(0)
            else:
                query = query.unflatten(0, (batch_size, seq_len))
                key = key.unflatten(0, (batch_size, seq_len))
                value = value.unflatten(0, (batch_size, seq_len))

            out = xops.memory_efficient_attention_forward(
                query,
                key,
                value,
                attn_bias=input_metadata.attn_bias,
                p=0.0,
                scale=self.scale,
164
165
                op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
                (is_hip()) else None,
Woosuk Kwon's avatar
Woosuk Kwon committed
166
167
168
            )
            output = out.view_as(query)
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
169
            # Decoding run.
Woosuk Kwon's avatar
Woosuk Kwon committed
170
171
172
173
174
175
176
177
178
            output = _paged_attention(
                query,
                key_cache,
                value_cache,
                input_metadata,
                self.head_mapping,
                self.scale,
                self.alibi_slopes,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
179
180

        # Reshape the output tensor.
Woosuk Kwon's avatar
Woosuk Kwon committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        return output.view(batch_size, seq_len, hidden_size)


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
    batch_size: int,
    seq_len: int,
    dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias:
    bias = torch.arange(seq_len, dtype=dtype)
    # NOTE(zhuohan): HF uses
    #     `bias = bias[None, :].repeat(prompt_len, 1)`
    # here. We find that both biases give the same results, but
    # the bias below more accurately follows the original ALiBi
    # paper.
    bias = bias[None, :] - bias[:, None]
    bias = bias.to(alibi_slopes.device)

    # When using custom attention bias, xformers requires the bias to
    # be sliced from a tensor whose length is a multiple of 8.
    padded_len = (seq_len + 7) // 8 * 8
    bias = torch.empty(
        batch_size,
        alibi_slopes.shape[0],
        seq_len,
        padded_len,
        device=alibi_slopes.device,
        dtype=dtype,
    )[:, :, :, :seq_len].copy_(bias)
    bias.mul_(alibi_slopes[:, None, None])
    attn_bias = LowerTriangularMaskWithTensorBias(bias)
    return attn_bias


def _paged_attention(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    input_metadata: InputMetadata,
    head_mapping: torch.Tensor,
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
    output = torch.empty_like(query)

    block_size = value_cache.shape[3]
    num_seqs, num_heads, head_size = query.shape
    max_num_partitions = (
        (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
        _PARTITION_SIZE)
    # NOTE(woosuk): We use a simple heuristic to decide whether to use
    # PagedAttention V1 or V2. If the number of partitions is 1, we use
    # V1 to avoid the overhead of reduction. Also, if the number of
    # sequences or heads is large, we use V1 since there is enough work
    # to parallelize.
    # TODO(woosuk): Tune this heuristic.
    # For context len > 8192, use V2 kernel to avoid shared memory shortage.
    use_v1 = input_metadata.max_context_len <= 8192 and (
        max_num_partitions == 1 or num_seqs * num_heads > 512)
    if use_v1:
        # Run PagedAttention V1.
        ops.paged_attention_v1(
            output,
Woosuk Kwon's avatar
Woosuk Kwon committed
244
            query,
245
246
            key_cache,
            value_cache,
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249
250
251
252
253
            head_mapping,
            scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            alibi_slopes,
254
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257
258
259
260
261
    else:
        # Run PagedAttention V2.
        assert _PARTITION_SIZE % block_size == 0
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
262
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)
        ops.paged_attention_v2(
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
            head_mapping,
            scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            alibi_slopes,
        )
    return output