attention.py 10.3 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

10
11
from vllm._C import ops
from vllm._C import cache_ops
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.model_executor.input_metadata import InputMetadata
13
from vllm.utils import is_hip
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
16
17
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
18

19

Woosuk Kwon's avatar
Woosuk Kwon committed
20
class PagedAttention(nn.Module):
Woosuk Kwon's avatar
Woosuk Kwon committed
21
    """MHA/MQA/GQA layer with PagedAttention.
22

23
    This class takes query, key, and value tensors as input. The input tensors
Woosuk Kwon's avatar
Woosuk Kwon committed
24
    can either contain prompt tokens or generation tokens.
25
    The class does the following:
Woosuk Kwon's avatar
Woosuk Kwon committed
26

27
28
    1. Reshape and store the input key and value tensors in the KV cache.
    2. Perform (multi-head/multi-query/grouped-query) attention using either
Woosuk Kwon's avatar
Woosuk Kwon committed
29
        xformers or the PagedAttention custom op.
30
    3. Return the output tensor.
31
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
32

Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35
36
37
38
39
40
41
    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: Optional[int] = None,
        alibi_slopes: Optional[List[float]] = None,
        sliding_window: Optional[int] = None,
    ) -> None:
42
        super().__init__()
43
44
        self.num_heads = num_heads
        self.head_size = head_size
45
        self.scale = float(scale)
Zhuohan Li's avatar
Zhuohan Li committed
46
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
47
        self.sliding_window = sliding_window
Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
50
        if alibi_slopes is not None:
            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
Zhuohan Li's avatar
Zhuohan Li committed
51
52
53

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Woosuk Kwon's avatar
Woosuk Kwon committed
54

55
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
58

Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
    def forward(
        self,
61
62
63
64
65
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
66
        input_metadata: InputMetadata,
67
68
69
70
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        Args:
71
72
            query: shape = [batch_size, seq_len, num_heads * head_size]
            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
73
            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
74
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
75
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
76
77
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
78
            input_metadata: metadata for the inputs.
79
        Returns:
80
            shape = [batch_size, seq_len, num_heads * head_size]
81
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
82
        batch_size, seq_len, hidden_size = query.shape
Woosuk Kwon's avatar
Woosuk Kwon committed
83
        # Reshape the query, key, and value tensors.
84
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
85
86
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88

        # Reshape the keys and values and store them in the cache.
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
        # If key_cache and value_cache are not provided, the new key and value
        # vectors will not be cached. This happens during the initial memory
        # profiling run.
92
        if key_cache is not None and value_cache is not None:
Woosuk Kwon's avatar
Woosuk Kwon committed
93
            cache_ops.reshape_and_cache(
94
95
                key,
                value,
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
                key_cache,
                value_cache,
98
                input_metadata.slot_mapping.flatten(),
Woosuk Kwon's avatar
Woosuk Kwon committed
99
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
100

101
        if input_metadata.is_prompt:
Woosuk Kwon's avatar
Woosuk Kwon committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
            # Prompt run.
            if self.num_kv_heads != self.num_heads:
                # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
                # project the key and value tensors to the desired number of
                # heads.
                # TODO(woosuk): Use MQA/GQA kernels for higher performance.
                query = query.view(query.shape[0], self.num_kv_heads,
                                   self.num_queries_per_kv, query.shape[-1])
                key = key[:, :,
                          None, :].expand(key.shape[0], self.num_kv_heads,
                                          self.num_queries_per_kv,
                                          key.shape[-1])
                value = value[:, :, None, :].expand(value.shape[0],
                                                    self.num_kv_heads,
                                                    self.num_queries_per_kv,
                                                    value.shape[-1])

            # Set attention bias if not provided. This typically happens at the
            # very attention layer of every iteration.
            # FIXME(woosuk): This is a hack.
            if input_metadata.attn_bias is None:
                if self.alibi_slopes is None:
                    attn_bias = BlockDiagonalCausalMask.from_seqlens(
                        [seq_len] * batch_size)
                    if self.sliding_window is not None:
                        attn_bias = attn_bias.make_local_attention(
                            self.sliding_window)
                    input_metadata.attn_bias = attn_bias
                else:
                    input_metadata.attn_bias = _make_alibi_bias(
Megha Agarwal's avatar
Megha Agarwal committed
132
133
                        self.alibi_slopes, self.num_kv_heads, batch_size,
                        seq_len, query.dtype)
Woosuk Kwon's avatar
Woosuk Kwon committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

            # TODO(woosuk): Too many view operations. Let's try to reduce them
            # in the future for code readability.
            if self.alibi_slopes is None:
                query = query.unsqueeze(0)
                key = key.unsqueeze(0)
                value = value.unsqueeze(0)
            else:
                query = query.unflatten(0, (batch_size, seq_len))
                key = key.unflatten(0, (batch_size, seq_len))
                value = value.unflatten(0, (batch_size, seq_len))

            out = xops.memory_efficient_attention_forward(
                query,
                key,
                value,
                attn_bias=input_metadata.attn_bias,
                p=0.0,
                scale=self.scale,
153
154
                op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
                (is_hip()) else None,
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
157
            )
            output = out.view_as(query)
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
158
            # Decoding run.
159
160
161
162
163
164
165
166
167
            output = _paged_attention(
                query,
                key_cache,
                value_cache,
                input_metadata,
                self.num_kv_heads,
                self.scale,
                self.alibi_slopes,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169

        # Reshape the output tensor.
Woosuk Kwon's avatar
Woosuk Kwon committed
170
171
172
173
174
        return output.view(batch_size, seq_len, hidden_size)


def _make_alibi_bias(
    alibi_slopes: torch.Tensor,
Megha Agarwal's avatar
Megha Agarwal committed
175
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
176
177
178
179
    batch_size: int,
    seq_len: int,
    dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias:
Megha Agarwal's avatar
Megha Agarwal committed
180
    bias = torch.arange(seq_len, dtype=dtype, device="cuda")
Woosuk Kwon's avatar
Woosuk Kwon committed
181
182
183
184
185
186
187
188
189
190
    # NOTE(zhuohan): HF uses
    #     `bias = bias[None, :].repeat(prompt_len, 1)`
    # here. We find that both biases give the same results, but
    # the bias below more accurately follows the original ALiBi
    # paper.
    bias = bias[None, :] - bias[:, None]

    # When using custom attention bias, xformers requires the bias to
    # be sliced from a tensor whose length is a multiple of 8.
    padded_len = (seq_len + 7) // 8 * 8
Megha Agarwal's avatar
Megha Agarwal committed
191
    num_heads = alibi_slopes.shape[0]
Woosuk Kwon's avatar
Woosuk Kwon committed
192
193
    bias = torch.empty(
        batch_size,
Megha Agarwal's avatar
Megha Agarwal committed
194
        num_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
195
196
197
198
199
200
        seq_len,
        padded_len,
        device=alibi_slopes.device,
        dtype=dtype,
    )[:, :, :, :seq_len].copy_(bias)
    bias.mul_(alibi_slopes[:, None, None])
Megha Agarwal's avatar
Megha Agarwal committed
201
202
    if num_heads != num_kv_heads:
        bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
Woosuk Kwon's avatar
Woosuk Kwon committed
203
204
205
206
207
208
209
210
211
    attn_bias = LowerTriangularMaskWithTensorBias(bias)
    return attn_bias


def _paged_attention(
    query: torch.Tensor,
    key_cache: torch.Tensor,
    value_cache: torch.Tensor,
    input_metadata: InputMetadata,
212
    num_kv_heads: int,
Woosuk Kwon's avatar
Woosuk Kwon committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    scale: float,
    alibi_slopes: Optional[torch.Tensor],
) -> torch.Tensor:
    output = torch.empty_like(query)

    block_size = value_cache.shape[3]
    num_seqs, num_heads, head_size = query.shape
    max_num_partitions = (
        (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
        _PARTITION_SIZE)
    # NOTE(woosuk): We use a simple heuristic to decide whether to use
    # PagedAttention V1 or V2. If the number of partitions is 1, we use
    # V1 to avoid the overhead of reduction. Also, if the number of
    # sequences or heads is large, we use V1 since there is enough work
    # to parallelize.
    # TODO(woosuk): Tune this heuristic.
    # For context len > 8192, use V2 kernel to avoid shared memory shortage.
    use_v1 = input_metadata.max_context_len <= 8192 and (
        max_num_partitions == 1 or num_seqs * num_heads > 512)
    if use_v1:
        # Run PagedAttention V1.
        ops.paged_attention_v1(
            output,
Woosuk Kwon's avatar
Woosuk Kwon committed
236
            query,
237
238
            key_cache,
            value_cache,
239
            num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
242
243
244
245
            scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            alibi_slopes,
246
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249
250
251
252
253
    else:
        # Run PagedAttention V2.
        assert _PARTITION_SIZE % block_size == 0
        tmp_output = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions, head_size),
            dtype=output.dtype,
            device=output.device,
254
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        exp_sums = torch.empty(
            size=(num_seqs, num_heads, max_num_partitions),
            dtype=torch.float32,
            device=output.device,
        )
        max_logits = torch.empty_like(exp_sums)
        ops.paged_attention_v2(
            output,
            exp_sums,
            max_logits,
            tmp_output,
            query,
            key_cache,
            value_cache,
269
            num_kv_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
270
271
272
273
274
275
276
277
            scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            alibi_slopes,
        )
    return output