attention.py 19 KB
Newer Older
1
"""Multi-head attention."""
2
from typing import Any, Dict, List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
from vllm import attention_ops
from vllm import cache_ops
from vllm.model_executor.input_metadata import InputMetadata
13
14
15
from vllm.model_executor.layers.rotary_embedding import (
    DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
    RotaryEmbedding)
Woosuk Kwon's avatar
Woosuk Kwon committed
16

17
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
18
19
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512
20

21

Woosuk Kwon's avatar
Woosuk Kwon committed
22
class PagedAttention(nn.Module):
23
    # pylint: disable=line-too-long
Woosuk Kwon's avatar
Woosuk Kwon committed
24
    """GPT-style multi-head PagedAttention.
25
26

    This class takes flattened 1D query, key, and value tensors as input. The
Woosuk Kwon's avatar
Woosuk Kwon committed
27
28
    input 1D tensors can either contain prompt tokens or generation tokens, in
    addition to paddings.
29

Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
32
33
34
35
36
37
38
39
40
    If the input tensors contain prompt tokens, the layout is as follows:

    |<---------------------- num_valid_tokens ---------------------->|
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|

    Otherwise, the layout is as follows:

    |<------------------ num_valid_tokens ------------------->|
    |<------- num_generation_tokens (M) ------->|
    |<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    The prompts might have different lengths, while the generation tokens always
    have length 1. The paddings are appended to make the input length a multiple
    of 8, which is desirable for Tensor Cores.

    The class does the following:
    1. Perform multi_query_kv_attention for the prompts. This operation does
        not use the KV cache.
    2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
    3. Reshape and store the input key and value tensors in the KV cache.
    4. Perform single_query_cached_kv_attention for the generation tokens.
        This operation reads the previous key and value tensors from the KV
        cache.
    5. Output a flattened 1D tensor.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
58

Zhuohan Li's avatar
Zhuohan Li committed
59
60
61
62
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
63
64
                 num_kv_heads: Optional[int] = None,
                 sliding_window: Optional[int] = None) -> None:
65
        super().__init__()
66
67
        self.num_heads = num_heads
        self.head_size = head_size
68
        self.scale = float(scale)
Zhuohan Li's avatar
Zhuohan Li committed
69
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
70
        self.sliding_window = sliding_window
Zhuohan Li's avatar
Zhuohan Li committed
71
72
73
74
75
76

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.head_mapping = torch.repeat_interleave(
            torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
            self.num_queries_per_kv)
Woosuk Kwon's avatar
Woosuk Kwon committed
77

78
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
81

82
83
84
85
86
87
    def set_attn_bias(
        self,
        input_metadata: InputMetadata,
        dtype: torch.dtype,
    ) -> None:
        del dtype  # Unused.
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
90
91
92
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        prompt_lens = input_metadata.prompt_lens
        attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
93
94
        if self.sliding_window is not None:
            attn_bias = attn_bias.make_local_attention(self.sliding_window)
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
        input_metadata.attn_bias.append(attn_bias)

Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
    def multi_query_kv_attention(
        self,
99
100
101
102
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
103
        input_metadata: InputMetadata,
104
    ) -> torch.Tensor:
105
106
107
108
109
        """Normal attention for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
110
111
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
112
            input_metadata: metadata for paged attention.
113
        """
Zhuohan Li's avatar
Zhuohan Li committed
114
115
116
117
118
119
120
121

        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)

122
123
124
125
126
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
Woosuk Kwon's avatar
Woosuk Kwon committed
127
            attn_bias=input_metadata.attn_bias[0],
128
129
            p=0.0,
            scale=self.scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
130
        )
131
132
133
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.squeeze(0))
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
134

135
136
137
138
139
140
141
142
    def get_alibi_slopes(self) -> Optional[torch.Tensor]:
        """Returns the slopes for the alibi attention bias.

        Returns:
            slopes: shape = [num_heads]
        """
        return None

Woosuk Kwon's avatar
Woosuk Kwon committed
143
144
    def single_query_cached_kv_attention(
        self,
145
146
147
148
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
149
        input_metadata: InputMetadata,
150
        alibi_slopes: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
151
    ) -> None:
152
153
154
155
156
        """PagedAttention for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
157
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
158
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
159
160
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
161
            input_metadata: metadata for paged attention.
162
            alibi_slopes: shape = [num_heads]
163
        """
164
        block_size = value_cache.shape[3]
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        num_seqs, num_heads, head_size = query.shape
        max_num_partitions = (
            (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
            _PARTITION_SIZE)
        # NOTE(woosuk): We use a simple heuristic to decide whether to use
        # PagedAttention V1 or V2. If the number of partitions is 1, we use
        # V1 to avoid the overhead of reduction. Also, if the number of
        # sequences or heads is large, we use V1 since there is enough work
        # to parallelize.
        # TODO(woosuk): Tune this heuristic.
        use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
        if use_v1:
            # Run PagedAttention V1.
            attention_ops.paged_attention_v1(
                output,
                query,
                key_cache,
                value_cache,
                self.head_mapping,
                self.scale,
                input_metadata.block_tables,
                input_metadata.context_lens,
                block_size,
                input_metadata.max_context_len,
                alibi_slopes,
            )
        else:
            # Run PagedAttention V2.
            assert _PARTITION_SIZE % block_size == 0
            tmp_output = torch.empty(
                size=(num_seqs, num_heads, max_num_partitions, head_size),
                dtype=output.dtype,
                device=output.device,
            )
            exp_sums = torch.empty(
                size=(num_seqs, num_heads, max_num_partitions),
                dtype=torch.float32,
                device=output.device,
            )
            max_logits = torch.empty_like(exp_sums)
            attention_ops.paged_attention_v2(
                output,
                exp_sums,
                max_logits,
                tmp_output,
                query,
                key_cache,
                value_cache,
                self.head_mapping,
                self.scale,
                input_metadata.block_tables,
                input_metadata.context_lens,
                block_size,
                input_metadata.max_context_len,
                alibi_slopes,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
221
222
223

    def forward(
        self,
224
225
226
227
228
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
229
230
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
231
232
233
234
235
236
237
238
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        NOTE: The query, key, and value tensors must be sliced from a qkv
        tensor of shape [num_tokens, 3 * num_heads * head_size].

        Args:
            query: shape = [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
239
240
241
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
242
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
243
244
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
245
246
247
248
249
250
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
251

Woosuk Kwon's avatar
Woosuk Kwon committed
252
        # Reshape the query, key, and value tensors.
253
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
254
255
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
258

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
259
260

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
261
262
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
            # Prompt run.
            assert input_metadata.num_generation_tokens == 0
265
            self.set_attn_bias(input_metadata, dtype=query.dtype)
266
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
267
268
269
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
270
                value[:num_prompt_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
271
                input_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
272
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
273
274
275
276
277
278

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
279
280
        # When key_cache and value_cache are not provided, the new key
        # and value vectors will not be cached.
Woosuk Kwon's avatar
Woosuk Kwon committed
281
        num_valid_tokens = input_metadata.num_valid_tokens
282
        if (num_valid_tokens > 0 and key_cache is not None
283
                and value_cache is not None):
Woosuk Kwon's avatar
Woosuk Kwon committed
284
            # The stride is 3 because the key and value are sliced from qkv.
285
286
287
288
289
290
291
292
            key_to_cache = key[:num_valid_tokens]
            value_to_cache = value[:num_valid_tokens]
            slot_mapping = input_metadata.slot_mapping
            if input_metadata.to_cache is not None:
                key_to_cache = key_to_cache[input_metadata.to_cache]
                value_to_cache = value_to_cache[input_metadata.to_cache]
                slot_mapping = slot_mapping[input_metadata.to_cache]

Woosuk Kwon's avatar
Woosuk Kwon committed
293
            cache_ops.reshape_and_cache(
294
295
                key_to_cache,
                value_to_cache,
Woosuk Kwon's avatar
Woosuk Kwon committed
296
297
                key_cache,
                value_cache,
298
                slot_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
299
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
300
301

        if input_metadata.num_generation_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
302
303
            # Decoding run.
            assert input_metadata.num_prompt_tokens == 0
304
305
            assert key_cache is not None and value_cache is not None, (
                "key_cache and value_cache must be provided when "
306
                "generating tokens.")
Woosuk Kwon's avatar
Woosuk Kwon committed
307
308
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
309
                output[num_prompt_tokens:num_valid_tokens],
310
                query[num_prompt_tokens:num_valid_tokens], key_cache,
311
                value_cache, input_metadata, self.get_alibi_slopes())
Woosuk Kwon's avatar
Woosuk Kwon committed
312
313

        # Reshape the output tensor.
314
        # NOTE(woosuk): The output tensor may include paddings.
315
        return output.view(-1, self.num_heads * self.head_size)
316
317


Woosuk Kwon's avatar
Woosuk Kwon committed
318
class PagedAttentionWithRoPE(PagedAttention):
319
    """PagedAttention with rotary positional embedding."""
320
321
322

    def __init__(
        self,
323
324
        num_heads: int,
        head_size: int,
325
        scale: float,
326
        rotary_dim: int,
327
328
        max_position: int = 8192,
        base: int = 10000,
Zhuohan Li's avatar
Zhuohan Li committed
329
        num_kv_heads: Optional[int] = None,
330
        is_neox_style: bool = True,
331
        rope_scaling: Optional[Dict[str, Any]] = None,
332
        sliding_window: Optional[int] = None,
333
    ) -> None:
334
335
336
337
338
        super().__init__(num_heads,
                         head_size,
                         scale,
                         num_kv_heads,
                         sliding_window=sliding_window)
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
        if rope_scaling is None:
            self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
                                              max_position, base,
                                              is_neox_style)
        else:
            scaling_type = rope_scaling["type"]
            scaling_factor = rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = LinearScalingRotaryEmbedding(
                    head_size, rotary_dim, max_position, base, is_neox_style,
                    scaling_factor)
            elif scaling_type == "dynamic":
                self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
                    head_size, rotary_dim, max_position, base, is_neox_style,
                    scaling_factor)
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
356
357
358

    def forward(
        self,
359
360
361
362
363
364
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
365
366
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
367
368
369
370
371
    ) -> torch.Tensor:
        """ PagedAttention forward pass with rotary embedding.

        Args:
            positions: shape = [num_tokens]
372
            query: shape = [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
373
374
375
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
376
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
377
378
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
379
380
381
382
383
384
385
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """

386
387
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
388
        query, key = self.rotary_emb(positions, query, key)
389
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
390
391
            query,
            key,
392
393
394
395
396
397
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
398
399
400
401
402


class PagedAttentionWithALiBi(PagedAttention):
    """PagedAttention with ALiBi attention bias."""

Zhuohan Li's avatar
Zhuohan Li committed
403
404
405
406
407
408
409
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
                 slopes: List[float],
                 num_kv_heads: Optional[int] = None) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads)
Woosuk Kwon's avatar
Woosuk Kwon committed
410
411
412
413
414
        assert len(slopes) == num_heads

        slopes = torch.tensor(slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", slopes, persistent=False)

415
416
    def set_attn_bias(self, input_metadata: InputMetadata,
                      dtype: torch.dtype) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
417
418
419
420
421
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        # Generates ALiBi mask for each prompt.
        for prompt_len in input_metadata.prompt_lens:
422
            bias = torch.arange(prompt_len, dtype=dtype)
423
            # NOTE(zhuohan): HF uses
Zhuohan Li's avatar
Zhuohan Li committed
424
425
426
427
            #     `bias = bias[None, :].repeat(prompt_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
Woosuk Kwon's avatar
Woosuk Kwon committed
428
429
430
431
432
433
434
            bias = bias[None, :] - bias[:, None]
            bias = bias.to(self.alibi_slopes.device)

            # When using custom attention bias, xformers requires the bias to
            # be sliced from a tensor whose length is a multiple of 8.
            padded_len = (prompt_len + 7) // 8 * 8
            bias = torch.empty(
435
                1,  # batch_size
Woosuk Kwon's avatar
Woosuk Kwon committed
436
                self.num_heads,
437
                prompt_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
438
439
                padded_len,
                device=self.alibi_slopes.device,
440
                dtype=dtype,
441
            )[:, :, :, :prompt_len].copy_(bias)
Woosuk Kwon's avatar
Woosuk Kwon committed
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
            bias.mul_(self.alibi_slopes[:, None, None])
            attn_bias = LowerTriangularMaskWithTensorBias(bias)
            input_metadata.attn_bias.append(attn_bias)

    def multi_query_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        """Attention with ALiBi bias for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
459
460
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
461
462
            input_metadata: metadata for paged attention.
        """
Zhuohan Li's avatar
Zhuohan Li committed
463
464
465
466
467
468
469
        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)

Woosuk Kwon's avatar
Woosuk Kwon committed
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
        start = 0
        for i, prompt_len in enumerate(input_metadata.prompt_lens):
            end = start + prompt_len
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
                attn_bias=input_metadata.attn_bias[i],
                p=0.0,
                scale=self.scale,
            )
            # TODO(woosuk): Unnecessary copy. Optimize.
            output[start:end].copy_(out.squeeze(0))
            start += prompt_len
        return output

489
490
    def get_alibi_slopes(self) -> Optional[torch.Tensor]:
        return self.alibi_slopes