attention.py 17.2 KB
Newer Older
1
"""Multi-head attention."""
2
from typing import Any, Dict, List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
from vllm import attention_ops
from vllm import cache_ops
from vllm.model_executor.input_metadata import InputMetadata
13
14
15
from vllm.model_executor.layers.rotary_embedding import (
    DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
    RotaryEmbedding)
Woosuk Kwon's avatar
Woosuk Kwon committed
16

17
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
18

19

Woosuk Kwon's avatar
Woosuk Kwon committed
20
class PagedAttention(nn.Module):
21
    # pylint: disable=line-too-long
Woosuk Kwon's avatar
Woosuk Kwon committed
22
    """GPT-style multi-head PagedAttention.
23
24

    This class takes flattened 1D query, key, and value tensors as input. The
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
    input 1D tensors can either contain prompt tokens or generation tokens, in
    addition to paddings.
27

Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
32
33
34
35
36
37
38
    If the input tensors contain prompt tokens, the layout is as follows:

    |<---------------------- num_valid_tokens ---------------------->|
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|

    Otherwise, the layout is as follows:

    |<------------------ num_valid_tokens ------------------->|
    |<------- num_generation_tokens (M) ------->|
    |<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

    The prompts might have different lengths, while the generation tokens always
    have length 1. The paddings are appended to make the input length a multiple
    of 8, which is desirable for Tensor Cores.

    The class does the following:
    1. Perform multi_query_kv_attention for the prompts. This operation does
        not use the KV cache.
    2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
    3. Reshape and store the input key and value tensors in the KV cache.
    4. Perform single_query_cached_kv_attention for the generation tokens.
        This operation reads the previous key and value tensors from the KV
        cache.
    5. Output a flattened 1D tensor.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
56

Zhuohan Li's avatar
Zhuohan Li committed
57
58
59
60
61
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
                 num_kv_heads: Optional[int] = None) -> None:
62
        super().__init__()
63
64
        self.num_heads = num_heads
        self.head_size = head_size
65
        self.scale = float(scale)
Zhuohan Li's avatar
Zhuohan Li committed
66
67
68
69
70
71
72
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.head_mapping = torch.repeat_interleave(
            torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
            self.num_queries_per_kv)
Woosuk Kwon's avatar
Woosuk Kwon committed
73

74
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
77

78
79
80
81
82
83
    def set_attn_bias(
        self,
        input_metadata: InputMetadata,
        dtype: torch.dtype,
    ) -> None:
        del dtype  # Unused.
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
86
87
88
89
90
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        prompt_lens = input_metadata.prompt_lens
        attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
        input_metadata.attn_bias.append(attn_bias)

Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
    def multi_query_kv_attention(
        self,
93
94
95
96
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
97
        input_metadata: InputMetadata,
98
    ) -> torch.Tensor:
99
100
101
102
103
        """Normal attention for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
104
105
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
106
            input_metadata: metadata for paged attention.
107
        """
Zhuohan Li's avatar
Zhuohan Li committed
108
109
110
111
112
113
114
115

        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)

116
117
118
119
120
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
Woosuk Kwon's avatar
Woosuk Kwon committed
121
            attn_bias=input_metadata.attn_bias[0],
122
123
            p=0.0,
            scale=self.scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
124
        )
125
126
127
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.squeeze(0))
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
128
129
130

    def single_query_cached_kv_attention(
        self,
131
132
133
134
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
        input_metadata: InputMetadata,
    ) -> None:
137
138
139
140
141
        """PagedAttention for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
142
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
143
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
144
145
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
146
147
            input_metadata: metadata for paged attention.
        """
148
149
150
151
152
153
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
Zhuohan Li's avatar
Zhuohan Li committed
154
            self.head_mapping,
155
156
157
158
159
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
160
            None,  # alibi_slopes
161
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
162
163
164

    def forward(
        self,
165
166
167
168
169
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
170
171
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
172
173
174
175
176
177
178
179
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        NOTE: The query, key, and value tensors must be sliced from a qkv
        tensor of shape [num_tokens, 3 * num_heads * head_size].

        Args:
            query: shape = [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
180
181
182
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
183
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
184
185
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
186
187
188
189
190
191
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
192

Woosuk Kwon's avatar
Woosuk Kwon committed
193
        # Reshape the query, key, and value tensors.
194
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
195
196
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
199

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
202
203
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
            # Prompt run.
            assert input_metadata.num_generation_tokens == 0
206
            self.set_attn_bias(input_metadata, dtype=query.dtype)
207
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
208
209
210
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
211
                value[:num_prompt_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
212
                input_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
213
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
214
215
216
217
218
219

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
220
221
        # When key_cache and value_cache are not provided, the new key
        # and value vectors will not be cached.
Woosuk Kwon's avatar
Woosuk Kwon committed
222
        num_valid_tokens = input_metadata.num_valid_tokens
223
        if (num_valid_tokens > 0 and key_cache is not None
224
                and value_cache is not None):
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
227
228
229
230
231
232
            # The stride is 3 because the key and value are sliced from qkv.
            cache_ops.reshape_and_cache(
                key[:num_valid_tokens],
                value[:num_valid_tokens],
                key_cache,
                value_cache,
                input_metadata.slot_mapping,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234

        if input_metadata.num_generation_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
            # Decoding run.
            assert input_metadata.num_prompt_tokens == 0
237
238
            assert key_cache is not None and value_cache is not None, (
                "key_cache and value_cache must be provided when "
239
                "generating tokens.")
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
242
                output[num_prompt_tokens:num_valid_tokens],
243
244
                query[num_prompt_tokens:num_valid_tokens], key_cache,
                value_cache, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
245
246

        # Reshape the output tensor.
247
        # NOTE(woosuk): The output tensor may include paddings.
248
        return output.view(-1, self.num_heads * self.head_size)
249
250


Woosuk Kwon's avatar
Woosuk Kwon committed
251
class PagedAttentionWithRoPE(PagedAttention):
252
    """PagedAttention with rotary positional embedding."""
253
254
255

    def __init__(
        self,
256
257
        num_heads: int,
        head_size: int,
258
        scale: float,
259
        rotary_dim: int,
260
261
        max_position: int = 8192,
        base: int = 10000,
Zhuohan Li's avatar
Zhuohan Li committed
262
        num_kv_heads: Optional[int] = None,
263
        is_neox_style: bool = True,
264
        rope_scaling: Optional[Dict[str, Any]] = None,
265
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
266
        super().__init__(num_heads, head_size, scale, num_kv_heads)
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        if rope_scaling is None:
            self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
                                              max_position, base,
                                              is_neox_style)
        else:
            scaling_type = rope_scaling["type"]
            scaling_factor = rope_scaling["factor"]
            if scaling_type == "linear":
                self.rotary_emb = LinearScalingRotaryEmbedding(
                    head_size, rotary_dim, max_position, base, is_neox_style,
                    scaling_factor)
            elif scaling_type == "dynamic":
                self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
                    head_size, rotary_dim, max_position, base, is_neox_style,
                    scaling_factor)
            else:
                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
284
285
286

    def forward(
        self,
287
288
289
290
291
292
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
293
294
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
295
296
297
298
299
    ) -> torch.Tensor:
        """ PagedAttention forward pass with rotary embedding.

        Args:
            positions: shape = [num_tokens]
300
            query: shape = [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
301
302
303
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
304
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
305
306
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
307
308
309
310
311
312
313
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """

314
315
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
316
        query, key = self.rotary_emb(positions, query, key)
317
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
318
319
            query,
            key,
320
321
322
323
324
325
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
326
327
328
329
330


class PagedAttentionWithALiBi(PagedAttention):
    """PagedAttention with ALiBi attention bias."""

Zhuohan Li's avatar
Zhuohan Li committed
331
332
333
334
335
336
337
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
                 slopes: List[float],
                 num_kv_heads: Optional[int] = None) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads)
Woosuk Kwon's avatar
Woosuk Kwon committed
338
339
340
341
342
        assert len(slopes) == num_heads

        slopes = torch.tensor(slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", slopes, persistent=False)

343
344
    def set_attn_bias(self, input_metadata: InputMetadata,
                      dtype: torch.dtype) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
345
346
347
348
349
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        # Generates ALiBi mask for each prompt.
        for prompt_len in input_metadata.prompt_lens:
350
            bias = torch.arange(prompt_len, dtype=dtype)
Zhuohan Li's avatar
Zhuohan Li committed
351
352
353
354
355
            # Note(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(prompt_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
Woosuk Kwon's avatar
Woosuk Kwon committed
356
357
358
359
360
361
362
            bias = bias[None, :] - bias[:, None]
            bias = bias.to(self.alibi_slopes.device)

            # When using custom attention bias, xformers requires the bias to
            # be sliced from a tensor whose length is a multiple of 8.
            padded_len = (prompt_len + 7) // 8 * 8
            bias = torch.empty(
363
                1,  # batch_size
Woosuk Kwon's avatar
Woosuk Kwon committed
364
                self.num_heads,
365
                prompt_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
366
367
                padded_len,
                device=self.alibi_slopes.device,
368
                dtype=dtype,
369
            )[:, :, :, :prompt_len].copy_(bias)
Woosuk Kwon's avatar
Woosuk Kwon committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
            bias.mul_(self.alibi_slopes[:, None, None])
            attn_bias = LowerTriangularMaskWithTensorBias(bias)
            input_metadata.attn_bias.append(attn_bias)

    def multi_query_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        """Attention with ALiBi bias for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
387
388
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
389
390
            input_metadata: metadata for paged attention.
        """
Zhuohan Li's avatar
Zhuohan Li committed
391
392
393
394
395
396
397
        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)

Woosuk Kwon's avatar
Woosuk Kwon committed
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
        start = 0
        for i, prompt_len in enumerate(input_metadata.prompt_lens):
            end = start + prompt_len
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
                attn_bias=input_metadata.attn_bias[i],
                p=0.0,
                scale=self.scale,
            )
            # TODO(woosuk): Unnecessary copy. Optimize.
            output[start:end].copy_(out.squeeze(0))
            start += prompt_len
        return output

    def single_query_cached_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> None:
        """PagedAttention with ALiBi bias for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
430
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
Woosuk Kwon's avatar
Woosuk Kwon committed
431
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
432
433
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
434
435
436
437
438
439
440
441
            input_metadata: metadata for paged attention.
        """
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
442
            self.head_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
443
444
445
446
447
448
449
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            self.alibi_slopes,
        )