attention.py 17.3 KB
Newer Older
1
"""Multi-head attention."""
2
from typing import Any, Dict, List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
from vllm import attention_ops
from vllm import cache_ops
from vllm.model_executor.input_metadata import InputMetadata
13
from vllm.model_executor.layers.rotary_embedding import get_rope
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
16
17
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
18

19

Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
class PagedAttention(nn.Module):
    """GPT-style multi-head PagedAttention.
22

23
24
25
    This class takes query, key, and value tensors as input. The input tensors
    can either contain prompt tokens or generation tokens, in addition to
    paddings.
26
27
28
29
30
31
32
33
34
35
36

    The class does the following:
    1. Perform multi_query_kv_attention for the prompts. This operation does
        not use the KV cache.
    2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
    3. Reshape and store the input key and value tensors in the KV cache.
    4. Perform single_query_cached_kv_attention for the generation tokens.
        This operation reads the previous key and value tensors from the KV
        cache.
37
    5. Return the output tensor.
38
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
39

Zhuohan Li's avatar
Zhuohan Li committed
40
41
42
43
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
44
45
                 num_kv_heads: Optional[int] = None,
                 sliding_window: Optional[int] = None) -> None:
46
        super().__init__()
47
48
        self.num_heads = num_heads
        self.head_size = head_size
49
        self.scale = float(scale)
Zhuohan Li's avatar
Zhuohan Li committed
50
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
51
        self.sliding_window = sliding_window
Zhuohan Li's avatar
Zhuohan Li committed
52
53
54
55
56
57

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.head_mapping = torch.repeat_interleave(
            torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
            self.num_queries_per_kv)
Woosuk Kwon's avatar
Woosuk Kwon committed
58

59
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
62

63
64
65
66
67
68
    def set_attn_bias(
        self,
        input_metadata: InputMetadata,
        dtype: torch.dtype,
    ) -> None:
        del dtype  # Unused.
69
        if input_metadata.attn_bias is not None:
Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
            # Already set by a previous layer.
            return
72
73
        prompt_lens = [input_metadata.max_prompt_len
                       ] * input_metadata.num_prompts
Woosuk Kwon's avatar
Woosuk Kwon committed
74
        attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
75
76
        if self.sliding_window is not None:
            attn_bias = attn_bias.make_local_attention(self.sliding_window)
77
        input_metadata.attn_bias = attn_bias
Woosuk Kwon's avatar
Woosuk Kwon committed
78

Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
    def multi_query_kv_attention(
        self,
81
82
83
84
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
85
        input_metadata: InputMetadata,
86
    ) -> torch.Tensor:
87
88
89
90
91
        """Normal attention for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
92
93
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
94
            input_metadata: metadata for paged attention.
95
        """
Zhuohan Li's avatar
Zhuohan Li committed
96
97
98
99
100
101
102
        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)

103
104
105
106
107
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
108
            attn_bias=input_metadata.attn_bias,
109
110
            p=0.0,
            scale=self.scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
111
        )
112
113
114
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.squeeze(0))
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
115

116
117
118
119
120
121
122
123
    def get_alibi_slopes(self) -> Optional[torch.Tensor]:
        """Returns the slopes for the alibi attention bias.

        Returns:
            slopes: shape = [num_heads]
        """
        return None

Woosuk Kwon's avatar
Woosuk Kwon committed
124
125
    def single_query_cached_kv_attention(
        self,
126
127
128
129
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
130
        input_metadata: InputMetadata,
131
        alibi_slopes: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
132
    ) -> None:
133
134
135
136
137
        """PagedAttention for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
138
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
139
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
140
141
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
142
            input_metadata: metadata for paged attention.
143
            alibi_slopes: shape = [num_heads]
144
        """
145
        block_size = value_cache.shape[3]
146
147
148
149
150
151
152
153
154
155
        num_seqs, num_heads, head_size = query.shape
        max_num_partitions = (
            (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
            _PARTITION_SIZE)
        # NOTE(woosuk): We use a simple heuristic to decide whether to use
        # PagedAttention V1 or V2. If the number of partitions is 1, we use
        # V1 to avoid the overhead of reduction. Also, if the number of
        # sequences or heads is large, we use V1 since there is enough work
        # to parallelize.
        # TODO(woosuk): Tune this heuristic.
156
157
158
        # For context len > 8192, use V2 kernel to avoid shared memory shortage.
        use_v1 = input_metadata.max_context_len <= 8192 and (
            max_num_partitions == 1 or num_seqs * num_heads > 512)
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        if use_v1:
            # Run PagedAttention V1.
            attention_ops.paged_attention_v1(
                output,
                query,
                key_cache,
                value_cache,
                self.head_mapping,
                self.scale,
                input_metadata.block_tables,
                input_metadata.context_lens,
                block_size,
                input_metadata.max_context_len,
                alibi_slopes,
            )
        else:
            # Run PagedAttention V2.
            assert _PARTITION_SIZE % block_size == 0
            tmp_output = torch.empty(
                size=(num_seqs, num_heads, max_num_partitions, head_size),
                dtype=output.dtype,
                device=output.device,
            )
            exp_sums = torch.empty(
                size=(num_seqs, num_heads, max_num_partitions),
                dtype=torch.float32,
                device=output.device,
            )
            max_logits = torch.empty_like(exp_sums)
            attention_ops.paged_attention_v2(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                self.head_mapping,
                self.scale,
                input_metadata.block_tables,
                input_metadata.context_lens,
                block_size,
                input_metadata.max_context_len,
                alibi_slopes,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
206

    def forward(
        self,
207
208
209
210
211
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
212
213
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
214
215
216
217
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        NOTE: The query, key, and value tensors must be sliced from a qkv
218
        tensor of shape [batch_size, seq_len, 3 * num_heads * head_size].
219
220

        Args:
221
222
223
            query: shape = [batch_size, seq_len, num_heads * head_size]
            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
            value: shape = [batch_size, num_kv_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
224
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
225
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
226
227
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
228
229
230
231
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
232
            shape = [batch_size, seq_len, num_heads * head_size]
233
        """
234
        batch_size, seq_len, _ = query.shape
Woosuk Kwon's avatar
Woosuk Kwon committed
235
        # Reshape the query, key, and value tensors.
236
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
237
238
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
239
240
241

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
242
243

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
            # Prompt run.
            assert input_metadata.num_generation_tokens == 0
248
            self.set_attn_bias(input_metadata, dtype=query.dtype)
249
            self.multi_query_kv_attention(
250
251
252
253
                output,
                query,
                key,
                value,
Woosuk Kwon's avatar
Woosuk Kwon committed
254
                input_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
255
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
258
259
260
261

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
262
263
        # When key_cache and value_cache are not provided, the new key
        # and value vectors will not be cached.
264
265
266
267
        if key_cache is not None and value_cache is not None:
            key_to_cache = key
            value_to_cache = value
            slot_mapping = input_metadata.slot_mapping.view(-1)
268
269
270
271
272
            if input_metadata.to_cache is not None:
                key_to_cache = key_to_cache[input_metadata.to_cache]
                value_to_cache = value_to_cache[input_metadata.to_cache]
                slot_mapping = slot_mapping[input_metadata.to_cache]

Woosuk Kwon's avatar
Woosuk Kwon committed
273
            cache_ops.reshape_and_cache(
274
275
                key_to_cache,
                value_to_cache,
Woosuk Kwon's avatar
Woosuk Kwon committed
276
277
                key_cache,
                value_cache,
278
                slot_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
279
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
280
281

        if input_metadata.num_generation_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
282
283
            # Decoding run.
            assert input_metadata.num_prompt_tokens == 0
284
285
            assert key_cache is not None and value_cache is not None, (
                "key_cache and value_cache must be provided when "
286
                "generating tokens.")
Woosuk Kwon's avatar
Woosuk Kwon committed
287
            # Compute the attention op for generation tokens.
288
289
290
            self.single_query_cached_kv_attention(output, query, key_cache,
                                                  value_cache, input_metadata,
                                                  self.get_alibi_slopes())
Woosuk Kwon's avatar
Woosuk Kwon committed
291
292

        # Reshape the output tensor.
293
        # NOTE(woosuk): The output tensor may include paddings.
294
295
        return output.view(batch_size, seq_len,
                           self.num_heads * self.head_size)
296
297


Woosuk Kwon's avatar
Woosuk Kwon committed
298
class PagedAttentionWithRoPE(PagedAttention):
299
    """PagedAttention with rotary positional embedding."""
300
301
302

    def __init__(
        self,
303
304
        num_heads: int,
        head_size: int,
305
        scale: float,
306
        rotary_dim: int,
307
308
        max_position: int = 8192,
        base: int = 10000,
Zhuohan Li's avatar
Zhuohan Li committed
309
        num_kv_heads: Optional[int] = None,
310
        is_neox_style: bool = True,
311
        rope_scaling: Optional[Dict[str, Any]] = None,
312
        sliding_window: Optional[int] = None,
313
    ) -> None:
314
315
316
317
318
        super().__init__(num_heads,
                         head_size,
                         scale,
                         num_kv_heads,
                         sliding_window=sliding_window)
319
320
        self.rotary_emb = get_rope(head_size, rotary_dim, max_position, base,
                                   is_neox_style, rope_scaling)
321
322
323

    def forward(
        self,
324
325
326
327
328
329
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
330
331
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
332
333
334
335
    ) -> torch.Tensor:
        """ PagedAttention forward pass with rotary embedding.

        Args:
336
337
338
339
            positions: shape = [batch_size, seq_len]
            query: shape = [batch_size, seq_len, num_heads * head_size]
            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
340
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
341
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
342
343
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
344
345
346
347
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
348
            shape = [batch_size, seq_len, num_heads * head_size]
349
350
        """

351
352
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
353
        query, key = self.rotary_emb(positions, query, key)
354
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
355
356
            query,
            key,
357
358
359
360
361
362
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
363
364
365
366
367


class PagedAttentionWithALiBi(PagedAttention):
    """PagedAttention with ALiBi attention bias."""

Zhuohan Li's avatar
Zhuohan Li committed
368
369
370
371
372
373
374
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
                 slopes: List[float],
                 num_kv_heads: Optional[int] = None) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads)
Woosuk Kwon's avatar
Woosuk Kwon committed
375
376
377
378
379
        assert len(slopes) == num_heads

        slopes = torch.tensor(slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", slopes, persistent=False)

380
381
    def set_attn_bias(self, input_metadata: InputMetadata,
                      dtype: torch.dtype) -> None:
382
        if input_metadata.attn_bias is not None:
Woosuk Kwon's avatar
Woosuk Kwon committed
383
384
            # Already set by a previous layer.
            return
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
        # Generates ALiBi mask based on the max prompt length.
        max_prompt_len = input_metadata.max_prompt_len
        bias = torch.arange(max_prompt_len, dtype=dtype)
        # NOTE(zhuohan): HF uses
        #     `bias = bias[None, :].repeat(prompt_len, 1)`
        # here. We find that both biases give the same results, but
        # the bias below more accurately follows the original ALiBi
        # paper.
        bias = bias[None, :] - bias[:, None]
        bias = bias.to(self.alibi_slopes.device)

        # When using custom attention bias, xformers requires the bias to
        # be sliced from a tensor whose length is a multiple of 8.
        padded_len = (max_prompt_len + 7) // 8 * 8
        bias = torch.empty(
            input_metadata.num_prompts,
            self.num_heads,
            max_prompt_len,
            padded_len,
            device=self.alibi_slopes.device,
            dtype=dtype,
        )[:, :, :, :max_prompt_len].copy_(bias)
        bias.mul_(self.alibi_slopes[:, None, None])
        attn_bias = LowerTriangularMaskWithTensorBias(bias)
        input_metadata.attn_bias = attn_bias
Woosuk Kwon's avatar
Woosuk Kwon committed
410
411
412
413
414
415
416
417
418
419
420
421
422
423

    def multi_query_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        """Attention with ALiBi bias for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
424
425
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
426
427
            input_metadata: metadata for paged attention.
        """
Zhuohan Li's avatar
Zhuohan Li committed
428
429
430
431
432
433
        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)
434
435
        batch_size = input_metadata.num_prompts
        seq_len = input_metadata.max_prompt_len
Zhuohan Li's avatar
Zhuohan Li committed
436

437
438
439
440
441
442
443
444
445
446
        out = xops.memory_efficient_attention_forward(
            query.view(batch_size, seq_len, self.num_heads, self.head_size),
            key.view(batch_size, seq_len, self.num_heads, self.head_size),
            value.view(batch_size, seq_len, self.num_heads, self.head_size),
            attn_bias=input_metadata.attn_bias,
            p=0.0,
            scale=self.scale,
        )
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.view(-1, self.num_heads, self.head_size))
Woosuk Kwon's avatar
Woosuk Kwon committed
447
448
        return output

449
450
    def get_alibi_slopes(self) -> Optional[torch.Tensor]:
        return self.alibi_slopes