attention.py 17.6 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
13
from vllm import attention_ops
from vllm import cache_ops
from vllm import pos_encoding_ops
from vllm.model_executor.input_metadata import InputMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
16

17

Woosuk Kwon's avatar
Woosuk Kwon committed
18
class PagedAttention(nn.Module):
19
    # pylint: disable=line-too-long
Woosuk Kwon's avatar
Woosuk Kwon committed
20
    """GPT-style multi-head PagedAttention.
21
22

    This class takes flattened 1D query, key, and value tensors as input. The
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
    input 1D tensors can either contain prompt tokens or generation tokens, in
    addition to paddings.
25

Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
30
31
32
33
34
35
36
    If the input tensors contain prompt tokens, the layout is as follows:

    |<---------------------- num_valid_tokens ---------------------->|
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|

    Otherwise, the layout is as follows:

    |<------------------ num_valid_tokens ------------------->|
    |<------- num_generation_tokens (M) ------->|
    |<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    The prompts might have different lengths, while the generation tokens always
    have length 1. The paddings are appended to make the input length a multiple
    of 8, which is desirable for Tensor Cores.

    The class does the following:
    1. Perform multi_query_kv_attention for the prompts. This operation does
        not use the KV cache.
    2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
    3. Reshape and store the input key and value tensors in the KV cache.
    4. Perform single_query_cached_kv_attention for the generation tokens.
        This operation reads the previous key and value tensors from the KV
        cache.
    5. Output a flattened 1D tensor.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
54

Zhuohan Li's avatar
Zhuohan Li committed
55
56
57
58
59
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
                 num_kv_heads: Optional[int] = None) -> None:
60
        super().__init__()
61
62
        self.num_heads = num_heads
        self.head_size = head_size
63
        self.scale = float(scale)
Zhuohan Li's avatar
Zhuohan Li committed
64
65
66
67
68
69
70
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.head_mapping = torch.repeat_interleave(
            torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
            self.num_queries_per_kv)
Woosuk Kwon's avatar
Woosuk Kwon committed
71

72
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
73
74
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
75

76
77
78
79
80
81
    def set_attn_bias(
        self,
        input_metadata: InputMetadata,
        dtype: torch.dtype,
    ) -> None:
        del dtype  # Unused.
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
85
86
87
88
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        prompt_lens = input_metadata.prompt_lens
        attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
        input_metadata.attn_bias.append(attn_bias)

Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
    def multi_query_kv_attention(
        self,
91
92
93
94
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
95
        input_metadata: InputMetadata,
96
    ) -> torch.Tensor:
97
98
99
100
101
        """Normal attention for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
102
103
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
104
            input_metadata: metadata for paged attention.
105
        """
Zhuohan Li's avatar
Zhuohan Li committed
106
107
108
109
110
111
112
113

        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)

114
115
116
117
118
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
Woosuk Kwon's avatar
Woosuk Kwon committed
119
            attn_bias=input_metadata.attn_bias[0],
120
121
            p=0.0,
            scale=self.scale,
Woosuk Kwon's avatar
Woosuk Kwon committed
122
        )
123
124
125
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.squeeze(0))
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128

    def single_query_cached_kv_attention(
        self,
129
130
131
132
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
133
134
        input_metadata: InputMetadata,
    ) -> None:
135
136
137
138
139
        """PagedAttention for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
140
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
141
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
142
143
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
144
145
            input_metadata: metadata for paged attention.
        """
146
147
148
149
150
151
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
Zhuohan Li's avatar
Zhuohan Li committed
152
            self.head_mapping,
153
154
155
156
157
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
158
            None,  # alibi_slopes
159
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
160
161
162

    def forward(
        self,
163
164
165
166
167
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
170
171
172
173
174
175
176
177
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        NOTE: The query, key, and value tensors must be sliced from a qkv
        tensor of shape [num_tokens, 3 * num_heads * head_size].

        Args:
            query: shape = [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
178
179
180
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
181
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
182
183
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
184
185
186
187
188
189
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
190

Woosuk Kwon's avatar
Woosuk Kwon committed
191
        # Reshape the query, key, and value tensors.
192
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
193
194
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
195
196
197

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
198
199

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
202
203
            # Prompt run.
            assert input_metadata.num_generation_tokens == 0
204
            self.set_attn_bias(input_metadata, dtype=query.dtype)
205
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
206
207
208
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
209
                value[:num_prompt_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
210
                input_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
211
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
212
213
214
215
216
217

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
218
219
        # When key_cache and value_cache are not provided, the new key
        # and value vectors will not be cached.
Woosuk Kwon's avatar
Woosuk Kwon committed
220
        num_valid_tokens = input_metadata.num_valid_tokens
221
        if (num_valid_tokens > 0 and key_cache is not None
222
                and value_cache is not None):
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
225
226
227
228
229
230
            # The stride is 3 because the key and value are sliced from qkv.
            cache_ops.reshape_and_cache(
                key[:num_valid_tokens],
                value[:num_valid_tokens],
                key_cache,
                value_cache,
                input_metadata.slot_mapping,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
231
232

        if input_metadata.num_generation_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234
            # Decoding run.
            assert input_metadata.num_prompt_tokens == 0
235
236
            assert key_cache is not None and value_cache is not None, (
                "key_cache and value_cache must be provided when "
237
                "generating tokens.")
Woosuk Kwon's avatar
Woosuk Kwon committed
238
239
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
240
                output[num_prompt_tokens:num_valid_tokens],
241
242
                query[num_prompt_tokens:num_valid_tokens], key_cache,
                value_cache, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
243
244

        # Reshape the output tensor.
245
        # NOTE(woosuk): The output tensor may include paddings.
246
        return output.view(-1, self.num_heads * self.head_size)
247
248


Woosuk Kwon's avatar
Woosuk Kwon committed
249
class PagedAttentionWithRoPE(PagedAttention):
250
    """PagedAttention with rotary embedding."""
251
252
253

    def __init__(
        self,
254
255
        num_heads: int,
        head_size: int,
256
        scale: float,
257
        rotary_dim: int,
258
259
        max_position: int = 8192,
        base: int = 10000,
Zhuohan Li's avatar
Zhuohan Li committed
260
        num_kv_heads: Optional[int] = None,
261
        is_neox_style: bool = True,
262
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
263
        super().__init__(num_heads, head_size, scale, num_kv_heads)
264
        self.is_neox_style = is_neox_style
265
266

        # Create the cos and sin cache.
267
268
269
270
271
272
273
274
275
        # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
        # However, we use `torch.arange(..., dtype=torch.float)` instead to
        # avoid numerical issues with large base values (e.g., 10000000).
        # This may cause a slight numerical difference between the HF
        # implementation and ours.
        # NOTE(woosuk): To exactly match the HF implementation, we need to
        # use CPU to compute the cache and then move it to GPU. However, we
        # create the cache on GPU for faster initialization. This may cause
        # a slight numerical difference between the HF implementation and ours.
276
277
278
279
        inv_freq = 1.0 / (base**(torch.arange(
            0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
        t = torch.arange(max_position, dtype=torch.float, device="cuda")
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
280
281
282
283
284
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)

        # FIXME(woosuk): This assumes that we configure the default dtype when
Woosuk Kwon's avatar
Woosuk Kwon committed
285
        # initializing the model.
286
287
        torch_dtype = torch.get_default_dtype()
        cache = cache.to(torch_dtype)
288
        # Embedding size: [max_position, rotary_dim]
289
        self.register_buffer("cos_sin_cache", cache, persistent=False)
290
291
292

    def forward(
        self,
293
294
295
296
297
298
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
299
300
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
301
302
303
304
305
306
    ) -> torch.Tensor:
        """ PagedAttention forward pass with rotary embedding.

        Args:
            positions: shape = [num_tokens]
                        query: shape = [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
307
308
309
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
310
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
311
312
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
313
314
315
316
317
318
319
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """

320
321
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
322
        pos_encoding_ops.rotary_embedding(
323
324
325
            positions,
            query,
            key,
326
            self.head_size,
327
            self.cos_sin_cache,
328
            self.is_neox_style,
329
330
        )
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
331
332
            query,
            key,
333
334
335
336
337
338
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
339
340
341
342
343


class PagedAttentionWithALiBi(PagedAttention):
    """PagedAttention with ALiBi attention bias."""

Zhuohan Li's avatar
Zhuohan Li committed
344
345
346
347
348
349
350
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
                 slopes: List[float],
                 num_kv_heads: Optional[int] = None) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads)
Woosuk Kwon's avatar
Woosuk Kwon committed
351
352
353
354
355
        assert len(slopes) == num_heads

        slopes = torch.tensor(slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", slopes, persistent=False)

356
357
    def set_attn_bias(self, input_metadata: InputMetadata,
                      dtype: torch.dtype) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
358
359
360
361
362
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        # Generates ALiBi mask for each prompt.
        for prompt_len in input_metadata.prompt_lens:
363
            bias = torch.arange(prompt_len, dtype=dtype)
Zhuohan Li's avatar
Zhuohan Li committed
364
365
366
367
368
            # Note(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(prompt_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
Woosuk Kwon's avatar
Woosuk Kwon committed
369
370
371
372
373
374
375
            bias = bias[None, :] - bias[:, None]
            bias = bias.to(self.alibi_slopes.device)

            # When using custom attention bias, xformers requires the bias to
            # be sliced from a tensor whose length is a multiple of 8.
            padded_len = (prompt_len + 7) // 8 * 8
            bias = torch.empty(
376
                1,  # batch_size
Woosuk Kwon's avatar
Woosuk Kwon committed
377
                self.num_heads,
378
                prompt_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
379
380
                padded_len,
                device=self.alibi_slopes.device,
381
                dtype=dtype,
382
            )[:, :, :, :prompt_len].copy_(bias)
Woosuk Kwon's avatar
Woosuk Kwon committed
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
            bias.mul_(self.alibi_slopes[:, None, None])
            attn_bias = LowerTriangularMaskWithTensorBias(bias)
            input_metadata.attn_bias.append(attn_bias)

    def multi_query_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        """Attention with ALiBi bias for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
400
401
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
402
403
            input_metadata: metadata for paged attention.
        """
Zhuohan Li's avatar
Zhuohan Li committed
404
405
406
407
408
409
410
        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)

Woosuk Kwon's avatar
Woosuk Kwon committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
        start = 0
        for i, prompt_len in enumerate(input_metadata.prompt_lens):
            end = start + prompt_len
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
                attn_bias=input_metadata.attn_bias[i],
                p=0.0,
                scale=self.scale,
            )
            # TODO(woosuk): Unnecessary copy. Optimize.
            output[start:end].copy_(out.squeeze(0))
            start += prompt_len
        return output

    def single_query_cached_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> None:
        """PagedAttention with ALiBi bias for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
443
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
Woosuk Kwon's avatar
Woosuk Kwon committed
444
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
445
446
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
447
448
449
450
451
452
453
454
            input_metadata: metadata for paged attention.
        """
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
455
            self.head_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
456
457
458
459
460
461
462
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            self.alibi_slopes,
        )