attention.py 17.9 KB
Newer Older
1
"""Multi-head attention."""
2
from typing import Any, Dict, List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
from vllm import attention_ops
from vllm import cache_ops
from vllm.model_executor.input_metadata import InputMetadata
13
14
15
from vllm.model_executor.layers.rotary_embedding import (
    DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
    RotaryEmbedding)
Woosuk Kwon's avatar
Woosuk Kwon committed
16

17
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
18

19

Woosuk Kwon's avatar
Woosuk Kwon committed
20
class PagedAttention(nn.Module):
21
    # pylint: disable=line-too-long
Woosuk Kwon's avatar
Woosuk Kwon committed
22
    """GPT-style multi-head PagedAttention.
23
24

    This class takes flattened 1D query, key, and value tensors as input. The
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
    input 1D tensors can either contain prompt tokens or generation tokens, in
    addition to paddings.
27

Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
32
33
34
35
36
37
38
    If the input tensors contain prompt tokens, the layout is as follows:

    |<---------------------- num_valid_tokens ---------------------->|
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|

    Otherwise, the layout is as follows:

    |<------------------ num_valid_tokens ------------------->|
    |<------- num_generation_tokens (M) ------->|
    |<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

    The prompts might have different lengths, while the generation tokens always
    have length 1. The paddings are appended to make the input length a multiple
    of 8, which is desirable for Tensor Cores.

    The class does the following:
    1. Perform multi_query_kv_attention for the prompts. This operation does
        not use the KV cache.
    2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
    3. Reshape and store the input key and value tensors in the KV cache.
    4. Perform single_query_cached_kv_attention for the generation tokens.
        This operation reads the previous key and value tensors from the KV
        cache.
    5. Output a flattened 1D tensor.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
56

Zhuohan Li's avatar
Zhuohan Li committed
57
58
59
60
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
61
62
                 num_kv_heads: Optional[int] = None,
                 sliding_window: Optional[int] = None) -> None:
63
        super().__init__()
64
65
        self.num_heads = num_heads
        self.head_size = head_size
66
        self.scale = float(scale)
Zhuohan Li's avatar
Zhuohan Li committed
67
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
68
        self.sliding_window = sliding_window
Zhuohan Li's avatar
Zhuohan Li committed
69
70
71
72
73
74

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.head_mapping = torch.repeat_interleave(
            torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
            self.num_queries_per_kv)
Woosuk Kwon's avatar
Woosuk Kwon committed
75

76
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
79

80
81
82
83
84
85
    def set_attn_bias(
        self,
        input_metadata: InputMetadata,
        dtype: torch.dtype,
    ) -> None:
        del dtype  # Unused.
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
88
89
90
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        prompt_lens = input_metadata.prompt_lens
        attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
91
92
        if self.sliding_window is not None:
            attn_bias = attn_bias.make_local_attention(self.sliding_window)
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
        input_metadata.attn_bias.append(attn_bias)

Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
    def multi_query_kv_attention(
        self,
97
98
99
100
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
101
        input_metadata: InputMetadata,
102
    ) -> torch.Tensor:
103
104
105
106
107
        """Normal attention for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
108
109
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
110
            input_metadata: metadata for paged attention.
111
        """
Zhuohan Li's avatar
Zhuohan Li committed
112
113
114
115
116
117
118
119

        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)

120
121
122
123
124
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
Woosuk Kwon's avatar
Woosuk Kwon committed
125
            attn_bias=input_metadata.attn_bias[0],
126
127
            p=0.0,
            scale=self.scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
128
        )
129
130
131
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.squeeze(0))
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
132
133
134

    def single_query_cached_kv_attention(
        self,
135
136
137
138
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
139
140
        input_metadata: InputMetadata,
    ) -> None:
141
142
143
144
145
        """PagedAttention for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
146
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
147
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
148
149
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
150
151
            input_metadata: metadata for paged attention.
        """
152
153
154
155
156
157
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
Zhuohan Li's avatar
Zhuohan Li committed
158
            self.head_mapping,
159
160
161
162
163
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
164
            None,  # alibi_slopes
165
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
166
167
168

    def forward(
        self,
169
170
171
172
173
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
176
177
178
179
180
181
182
183
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        NOTE: The query, key, and value tensors must be sliced from a qkv
        tensor of shape [num_tokens, 3 * num_heads * head_size].

        Args:
            query: shape = [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
184
185
186
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
187
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
188
189
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
190
191
192
193
194
195
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
196

Woosuk Kwon's avatar
Woosuk Kwon committed
197
        # Reshape the query, key, and value tensors.
198
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
199
200
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
203

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
206
207
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
208
209
            # Prompt run.
            assert input_metadata.num_generation_tokens == 0
210
            self.set_attn_bias(input_metadata, dtype=query.dtype)
211
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
212
213
214
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
215
                value[:num_prompt_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
216
                input_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
217
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219
220
221
222
223

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
224
225
        # When key_cache and value_cache are not provided, the new key
        # and value vectors will not be cached.
Woosuk Kwon's avatar
Woosuk Kwon committed
226
        num_valid_tokens = input_metadata.num_valid_tokens
227
        if (num_valid_tokens > 0 and key_cache is not None
228
                and value_cache is not None):
Woosuk Kwon's avatar
Woosuk Kwon committed
229
            # The stride is 3 because the key and value are sliced from qkv.
230
231
232
233
234
235
236
237
            key_to_cache = key[:num_valid_tokens]
            value_to_cache = value[:num_valid_tokens]
            slot_mapping = input_metadata.slot_mapping
            if input_metadata.to_cache is not None:
                key_to_cache = key_to_cache[input_metadata.to_cache]
                value_to_cache = value_to_cache[input_metadata.to_cache]
                slot_mapping = slot_mapping[input_metadata.to_cache]

Woosuk Kwon's avatar
Woosuk Kwon committed
238
            cache_ops.reshape_and_cache(
239
240
                key_to_cache,
                value_to_cache,
Woosuk Kwon's avatar
Woosuk Kwon committed
241
242
                key_cache,
                value_cache,
243
                slot_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
244
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
245
246

        if input_metadata.num_generation_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
            # Decoding run.
            assert input_metadata.num_prompt_tokens == 0
249
250
            assert key_cache is not None and value_cache is not None, (
                "key_cache and value_cache must be provided when "
251
                "generating tokens.")
Woosuk Kwon's avatar
Woosuk Kwon committed
252
253
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
254
                output[num_prompt_tokens:num_valid_tokens],
255
256
                query[num_prompt_tokens:num_valid_tokens], key_cache,
                value_cache, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258

        # Reshape the output tensor.
259
        # NOTE(woosuk): The output tensor may include paddings.
260
        return output.view(-1, self.num_heads * self.head_size)
261
262


Woosuk Kwon's avatar
Woosuk Kwon committed
263
class PagedAttentionWithRoPE(PagedAttention):
264
    """PagedAttention with rotary positional embedding."""
265
266
267

    def __init__(
        self,
268
269
        num_heads: int,
        head_size: int,
270
        scale: float,
271
        rotary_dim: int,
272
273
        max_position: int = 8192,
        base: int = 10000,
Zhuohan Li's avatar
Zhuohan Li committed
274
        num_kv_heads: Optional[int] = None,
275
        is_neox_style: bool = True,
276
        rope_scaling: Optional[Dict[str, Any]] = None,
277
        sliding_window: Optional[int] = None,
278
    ) -> None:
279
280
281
282
283
        super().__init__(num_heads,
                         head_size,
                         scale,
                         num_kv_heads,
                         sliding_window=sliding_window)
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        if rope_scaling is None:
            self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
                                              max_position, base,
                                              is_neox_style)
        else:
            scaling_type = rope_scaling["type"]
            scaling_factor = rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = LinearScalingRotaryEmbedding(
                    head_size, rotary_dim, max_position, base, is_neox_style,
                    scaling_factor)
            elif scaling_type == "dynamic":
                self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
                    head_size, rotary_dim, max_position, base, is_neox_style,
                    scaling_factor)
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
301
302
303

    def forward(
        self,
304
305
306
307
308
309
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
310
311
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
312
313
314
315
316
    ) -> torch.Tensor:
        """ PagedAttention forward pass with rotary embedding.

        Args:
            positions: shape = [num_tokens]
317
            query: shape = [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
318
319
320
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
321
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
322
323
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
324
325
326
327
328
329
330
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """

331
332
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
333
        query, key = self.rotary_emb(positions, query, key)
334
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
335
336
            query,
            key,
337
338
339
340
341
342
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
343
344
345
346
347


class PagedAttentionWithALiBi(PagedAttention):
    """PagedAttention with ALiBi attention bias."""

Zhuohan Li's avatar
Zhuohan Li committed
348
349
350
351
352
353
354
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
                 slopes: List[float],
                 num_kv_heads: Optional[int] = None) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads)
Woosuk Kwon's avatar
Woosuk Kwon committed
355
356
357
358
359
        assert len(slopes) == num_heads

        slopes = torch.tensor(slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", slopes, persistent=False)

360
361
    def set_attn_bias(self, input_metadata: InputMetadata,
                      dtype: torch.dtype) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
362
363
364
365
366
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        # Generates ALiBi mask for each prompt.
        for prompt_len in input_metadata.prompt_lens:
367
            bias = torch.arange(prompt_len, dtype=dtype)
Zhuohan Li's avatar
Zhuohan Li committed
368
369
370
371
372
            # Note(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(prompt_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
Woosuk Kwon's avatar
Woosuk Kwon committed
373
374
375
376
377
378
379
            bias = bias[None, :] - bias[:, None]
            bias = bias.to(self.alibi_slopes.device)

            # When using custom attention bias, xformers requires the bias to
            # be sliced from a tensor whose length is a multiple of 8.
            padded_len = (prompt_len + 7) // 8 * 8
            bias = torch.empty(
380
                1,  # batch_size
Woosuk Kwon's avatar
Woosuk Kwon committed
381
                self.num_heads,
382
                prompt_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
383
384
                padded_len,
                device=self.alibi_slopes.device,
385
                dtype=dtype,
386
            )[:, :, :, :prompt_len].copy_(bias)
Woosuk Kwon's avatar
Woosuk Kwon committed
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
            bias.mul_(self.alibi_slopes[:, None, None])
            attn_bias = LowerTriangularMaskWithTensorBias(bias)
            input_metadata.attn_bias.append(attn_bias)

    def multi_query_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        """Attention with ALiBi bias for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
404
405
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
406
407
            input_metadata: metadata for paged attention.
        """
Zhuohan Li's avatar
Zhuohan Li committed
408
409
410
411
412
413
414
        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)

Woosuk Kwon's avatar
Woosuk Kwon committed
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
        start = 0
        for i, prompt_len in enumerate(input_metadata.prompt_lens):
            end = start + prompt_len
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
                attn_bias=input_metadata.attn_bias[i],
                p=0.0,
                scale=self.scale,
            )
            # TODO(woosuk): Unnecessary copy. Optimize.
            output[start:end].copy_(out.squeeze(0))
            start += prompt_len
        return output

    def single_query_cached_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> None:
        """PagedAttention with ALiBi bias for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
447
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
Woosuk Kwon's avatar
Woosuk Kwon committed
448
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
449
450
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
451
452
453
454
455
456
457
458
            input_metadata: metadata for paged attention.
        """
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
459
            self.head_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
460
461
462
463
464
465
466
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            self.alibi_slopes,
        )