attention.py 16.9 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
13
from vllm import attention_ops
from vllm import cache_ops
from vllm import pos_encoding_ops
from vllm.model_executor.input_metadata import InputMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
16

17

Woosuk Kwon's avatar
Woosuk Kwon committed
18
class PagedAttention(nn.Module):
19
    # pylint: disable=line-too-long
Woosuk Kwon's avatar
Woosuk Kwon committed
20
    """GPT-style multi-head PagedAttention.
21
22

    This class takes flattened 1D query, key, and value tensors as input. The
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
    input 1D tensors can either contain prompt tokens or generation tokens, in
    addition to paddings.
25

Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
30
31
32
33
34
35
36
    If the input tensors contain prompt tokens, the layout is as follows:

    |<---------------------- num_valid_tokens ---------------------->|
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|

    Otherwise, the layout is as follows:

    |<------------------ num_valid_tokens ------------------->|
    |<------- num_generation_tokens (M) ------->|
    |<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    The prompts might have different lengths, while the generation tokens always
    have length 1. The paddings are appended to make the input length a multiple
    of 8, which is desirable for Tensor Cores.

    The class does the following:
    1. Perform multi_query_kv_attention for the prompts. This operation does
        not use the KV cache.
    2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
    3. Reshape and store the input key and value tensors in the KV cache.
    4. Perform single_query_cached_kv_attention for the generation tokens.
        This operation reads the previous key and value tensors from the KV
        cache.
    5. Output a flattened 1D tensor.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
54

Zhuohan Li's avatar
Zhuohan Li committed
55
56
57
58
59
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
                 num_kv_heads: Optional[int] = None) -> None:
60
        super().__init__()
61
62
        self.num_heads = num_heads
        self.head_size = head_size
63
        self.scale = float(scale)
Zhuohan Li's avatar
Zhuohan Li committed
64
65
66
67
68
69
70
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.head_mapping = torch.repeat_interleave(
            torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
            self.num_queries_per_kv)
Woosuk Kwon's avatar
Woosuk Kwon committed
71

72
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
73
74
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
75

Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78
79
80
81
82
83
    def set_attn_bias(self, input_metadata: InputMetadata) -> None:
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        prompt_lens = input_metadata.prompt_lens
        attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
        input_metadata.attn_bias.append(attn_bias)

Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
    def multi_query_kv_attention(
        self,
86
87
88
89
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
90
        input_metadata: InputMetadata,
91
    ) -> torch.Tensor:
92
93
94
95
96
        """Normal attention for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
97
98
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
99
            input_metadata: metadata for paged attention.
100
        """
Zhuohan Li's avatar
Zhuohan Li committed
101
102
103
104
105
106
107
108

        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)

109
110
111
112
113
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
Woosuk Kwon's avatar
Woosuk Kwon committed
114
            attn_bias=input_metadata.attn_bias[0],
115
116
            p=0.0,
            scale=self.scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
117
        )
118
119
120
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.squeeze(0))
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
123

    def single_query_cached_kv_attention(
        self,
124
125
126
127
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
128
129
        input_metadata: InputMetadata,
    ) -> None:
130
131
132
133
134
        """PagedAttention for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
135
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
136
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
137
138
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
139
140
            input_metadata: metadata for paged attention.
        """
141
142
143
144
145
146
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
Zhuohan Li's avatar
Zhuohan Li committed
147
            self.head_mapping,
148
149
150
151
152
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
153
            None,  # alibi_slopes
154
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
157

    def forward(
        self,
158
159
160
161
162
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
163
164
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
165
166
167
168
169
170
171
172
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        NOTE: The query, key, and value tensors must be sliced from a qkv
        tensor of shape [num_tokens, 3 * num_heads * head_size].

        Args:
            query: shape = [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
173
174
175
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
176
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
177
178
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
179
180
181
182
183
184
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
185

Woosuk Kwon's avatar
Woosuk Kwon committed
186
        # Reshape the query, key, and value tensors.
187
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
188
189
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
190
191
192

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
193
194

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
195
196
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
            # Prompt run.
            assert input_metadata.num_generation_tokens == 0
Woosuk Kwon's avatar
Woosuk Kwon committed
199
            self.set_attn_bias(input_metadata)
200
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
203
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
204
                value[:num_prompt_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
205
                input_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
206
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
207
208
209
210
211
212

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
213
214
        # When key_cache and value_cache are not provided, the new key
        # and value vectors will not be cached.
Woosuk Kwon's avatar
Woosuk Kwon committed
215
        num_valid_tokens = input_metadata.num_valid_tokens
216
        if (num_valid_tokens > 0 and key_cache is not None
217
                and value_cache is not None):
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219
220
221
222
223
224
225
            # The stride is 3 because the key and value are sliced from qkv.
            cache_ops.reshape_and_cache(
                key[:num_valid_tokens],
                value[:num_valid_tokens],
                key_cache,
                value_cache,
                input_metadata.slot_mapping,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
226
227

        if input_metadata.num_generation_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229
            # Decoding run.
            assert input_metadata.num_prompt_tokens == 0
230
231
            assert key_cache is not None and value_cache is not None, (
                "key_cache and value_cache must be provided when "
232
                "generating tokens.")
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
235
                output[num_prompt_tokens:num_valid_tokens],
236
237
                query[num_prompt_tokens:num_valid_tokens], key_cache,
                value_cache, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
238
239

        # Reshape the output tensor.
240
        # NOTE(woosuk): The output tensor may include paddings.
241
        return output.view(-1, self.num_heads * self.head_size)
242
243


Woosuk Kwon's avatar
Woosuk Kwon committed
244
class PagedAttentionWithRoPE(PagedAttention):
245
    """PagedAttention with rotary embedding."""
246
247
248

    def __init__(
        self,
249
250
        num_heads: int,
        head_size: int,
251
        scale: float,
252
        rotary_dim: int,
253
254
        max_position: int = 8192,
        base: int = 10000,
Zhuohan Li's avatar
Zhuohan Li committed
255
        num_kv_heads: Optional[int] = None,
256
        is_neox_style: bool = True,
257
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
258
        super().__init__(num_heads, head_size, scale, num_kv_heads)
259
        self.is_neox_style = is_neox_style
260
261

        # Create the cos and sin cache.
Robert Irvine's avatar
Robert Irvine committed
262
263
264
        inv_freq = 1.0 / (base**(
            torch.arange(0, rotary_dim, 2, device="cuda") / rotary_dim))
        t = torch.arange(max_position, device="cuda").float()
265
        freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
266
267
268
269
270
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)

        # FIXME(woosuk): This assumes that we configure the default dtype when
Woosuk Kwon's avatar
Woosuk Kwon committed
271
272
        # initializing the model.
        # TODO(woosuk): Make it more robust.
273
274
        torch_dtype = torch.get_default_dtype()
        cache = cache.to(torch_dtype)
275
        # Embedding size: [max_position, rotary_dim]
276
        self.register_buffer("cos_sin_cache", cache, persistent=False)
277
278
279

    def forward(
        self,
280
281
282
283
284
285
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
286
287
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
288
289
290
291
292
293
    ) -> torch.Tensor:
        """ PagedAttention forward pass with rotary embedding.

        Args:
            positions: shape = [num_tokens]
                        query: shape = [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
294
295
296
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
297
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
298
299
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
300
301
302
303
304
305
306
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """

307
308
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
309
        pos_encoding_ops.rotary_embedding(
310
311
312
            positions,
            query,
            key,
313
            self.head_size,
314
            self.cos_sin_cache,
315
            self.is_neox_style,
316
317
        )
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
318
319
            query,
            key,
320
321
322
323
324
325
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
326
327
328
329
330


class PagedAttentionWithALiBi(PagedAttention):
    """PagedAttention with ALiBi attention bias."""

Zhuohan Li's avatar
Zhuohan Li committed
331
332
333
334
335
336
337
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
                 slopes: List[float],
                 num_kv_heads: Optional[int] = None) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads)
Woosuk Kwon's avatar
Woosuk Kwon committed
338
339
340
341
342
343
344
345
346
347
348
349
        assert len(slopes) == num_heads

        slopes = torch.tensor(slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", slopes, persistent=False)

    def set_attn_bias(self, input_metadata: InputMetadata) -> None:
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        # Generates ALiBi mask for each prompt.
        for prompt_len in input_metadata.prompt_lens:
            bias = torch.arange(prompt_len)
Zhuohan Li's avatar
Zhuohan Li committed
350
351
352
353
354
            # Note(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(prompt_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
Woosuk Kwon's avatar
Woosuk Kwon committed
355
356
357
358
359
360
361
            bias = bias[None, :] - bias[:, None]
            bias = bias.to(self.alibi_slopes.device)

            # When using custom attention bias, xformers requires the bias to
            # be sliced from a tensor whose length is a multiple of 8.
            padded_len = (prompt_len + 7) // 8 * 8
            bias = torch.empty(
362
                1,  # batch_size
Woosuk Kwon's avatar
Woosuk Kwon committed
363
                self.num_heads,
364
                prompt_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
365
366
                padded_len,
                device=self.alibi_slopes.device,
367
            )[:, :, :, :prompt_len].copy_(bias)
Woosuk Kwon's avatar
Woosuk Kwon committed
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
            bias.mul_(self.alibi_slopes[:, None, None])
            attn_bias = LowerTriangularMaskWithTensorBias(bias)
            input_metadata.attn_bias.append(attn_bias)

    def multi_query_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        """Attention with ALiBi bias for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
385
386
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
387
388
            input_metadata: metadata for paged attention.
        """
Zhuohan Li's avatar
Zhuohan Li committed
389
390
391
392
393
394
395
        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)

Woosuk Kwon's avatar
Woosuk Kwon committed
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
        start = 0
        for i, prompt_len in enumerate(input_metadata.prompt_lens):
            end = start + prompt_len
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
                attn_bias=input_metadata.attn_bias[i],
                p=0.0,
                scale=self.scale,
            )
            # TODO(woosuk): Unnecessary copy. Optimize.
            output[start:end].copy_(out.squeeze(0))
            start += prompt_len
        return output

    def single_query_cached_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> None:
        """PagedAttention with ALiBi bias for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
428
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
Woosuk Kwon's avatar
Woosuk Kwon committed
429
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
430
431
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
432
433
434
435
436
437
438
439
            input_metadata: metadata for paged attention.
        """
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
440
            self.head_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
441
442
443
444
445
446
447
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            self.alibi_slopes,
        )