"R-package/man/predict.lgb.Booster.Rd" did not exist on "d2a6eb0e2f5ce873e647135d4b2e45c70d3fc994"
attention.py 16.8 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
13
from vllm import attention_ops
from vllm import cache_ops
from vllm import pos_encoding_ops
from vllm.model_executor.input_metadata import InputMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
16

17

Woosuk Kwon's avatar
Woosuk Kwon committed
18
class PagedAttention(nn.Module):
19
    # pylint: disable=line-too-long
Woosuk Kwon's avatar
Woosuk Kwon committed
20
    """GPT-style multi-head PagedAttention.
21
22

    This class takes flattened 1D query, key, and value tensors as input. The
Woosuk Kwon's avatar
Woosuk Kwon committed
23
24
    input 1D tensors can either contain prompt tokens or generation tokens, in
    addition to paddings.
25

Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
30
31
32
33
34
35
36
    If the input tensors contain prompt tokens, the layout is as follows:

    |<---------------------- num_valid_tokens ---------------------->|
    |<--------------- num_prompt_tokens -------------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|

    Otherwise, the layout is as follows:

    |<------------------ num_valid_tokens ------------------->|
    |<------- num_generation_tokens (M) ------->|
    |<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    The prompts might have different lengths, while the generation tokens always
    have length 1. The paddings are appended to make the input length a multiple
    of 8, which is desirable for Tensor Cores.

    The class does the following:
    1. Perform multi_query_kv_attention for the prompts. This operation does
        not use the KV cache.
    2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
    3. Reshape and store the input key and value tensors in the KV cache.
    4. Perform single_query_cached_kv_attention for the generation tokens.
        This operation reads the previous key and value tensors from the KV
        cache.
    5. Output a flattened 1D tensor.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
54

Zhuohan Li's avatar
Zhuohan Li committed
55
56
57
58
59
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
                 num_kv_heads: Optional[int] = None) -> None:
60
        super().__init__()
61
62
        self.num_heads = num_heads
        self.head_size = head_size
63
        self.scale = float(scale)
64
        self.attn_op = xops.fmha.cutlass.FwOp()
Zhuohan Li's avatar
Zhuohan Li committed
65
66
67
68
69
70
71
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.head_mapping = torch.repeat_interleave(
            torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
            self.num_queries_per_kv)
Woosuk Kwon's avatar
Woosuk Kwon committed
72

73
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
74
75
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
76

Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
80
81
82
83
84
    def set_attn_bias(self, input_metadata: InputMetadata) -> None:
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        prompt_lens = input_metadata.prompt_lens
        attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
        input_metadata.attn_bias.append(attn_bias)

Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
    def multi_query_kv_attention(
        self,
87
88
89
90
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
91
        input_metadata: InputMetadata,
92
    ) -> torch.Tensor:
93
94
95
96
97
        """Normal attention for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
98
99
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
100
            input_metadata: metadata for paged attention.
101
        """
Zhuohan Li's avatar
Zhuohan Li committed
102
103
104
105
106
107
108
109

        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)

110
111
112
113
114
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
Woosuk Kwon's avatar
Woosuk Kwon committed
115
            attn_bias=input_metadata.attn_bias[0],
116
117
118
            p=0.0,
            scale=self.scale,
            op=self.attn_op,
Woosuk Kwon's avatar
Woosuk Kwon committed
119
        )
120
121
122
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.squeeze(0))
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
123
124
125

    def single_query_cached_kv_attention(
        self,
126
127
128
129
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
        input_metadata: InputMetadata,
    ) -> None:
132
133
134
135
136
        """PagedAttention for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
137
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
138
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
139
140
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
141
142
            input_metadata: metadata for paged attention.
        """
143
144
145
146
147
148
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
Zhuohan Li's avatar
Zhuohan Li committed
149
            self.head_mapping,
150
151
152
153
154
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
155
            None,  # alibi_slopes
156
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
157
158
159

    def forward(
        self,
160
161
162
163
164
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
165
166
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
167
168
169
170
171
172
173
174
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        NOTE: The query, key, and value tensors must be sliced from a qkv
        tensor of shape [num_tokens, 3 * num_heads * head_size].

        Args:
            query: shape = [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
175
176
177
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
178
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
179
180
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
181
182
183
184
185
186
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
187

Woosuk Kwon's avatar
Woosuk Kwon committed
188
        # Reshape the query, key, and value tensors.
189
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
190
191
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
192
193
194

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
195
196

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
199
200
            # Prompt run.
            assert input_metadata.num_generation_tokens == 0
Woosuk Kwon's avatar
Woosuk Kwon committed
201
            self.set_attn_bias(input_metadata)
202
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
203
204
205
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
206
                value[:num_prompt_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
207
                input_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
208
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
211
212
213
214

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
215
216
        # When key_cache and value_cache are not provided, the new key
        # and value vectors will not be cached.
Woosuk Kwon's avatar
Woosuk Kwon committed
217
        num_valid_tokens = input_metadata.num_valid_tokens
218
        if (num_valid_tokens > 0 and key_cache is not None
219
                and value_cache is not None):
Woosuk Kwon's avatar
Woosuk Kwon committed
220
221
222
223
224
225
226
227
            # The stride is 3 because the key and value are sliced from qkv.
            cache_ops.reshape_and_cache(
                key[:num_valid_tokens],
                value[:num_valid_tokens],
                key_cache,
                value_cache,
                input_metadata.slot_mapping,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229

        if input_metadata.num_generation_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
230
231
            # Decoding run.
            assert input_metadata.num_prompt_tokens == 0
232
233
            assert key_cache is not None and value_cache is not None, (
                "key_cache and value_cache must be provided when "
234
                "generating tokens.")
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
237
                output[num_prompt_tokens:num_valid_tokens],
238
239
                query[num_prompt_tokens:num_valid_tokens], key_cache,
                value_cache, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241

        # Reshape the output tensor.
242
        # NOTE(woosuk): The output tensor may include paddings.
243
        return output.view(-1, self.num_heads * self.head_size)
244
245


Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
class PagedAttentionWithRoPE(PagedAttention):
    """PagedAttention with GPT-NeoX style rotary embedding."""
248
249
250

    def __init__(
        self,
251
252
        num_heads: int,
        head_size: int,
253
        scale: float,
254
        rotary_dim: int,
255
256
        max_position: int = 8192,
        base: int = 10000,
Zhuohan Li's avatar
Zhuohan Li committed
257
        num_kv_heads: Optional[int] = None,
258
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
259
        super().__init__(num_heads, head_size, scale, num_kv_heads)
260
261

        # Create the cos and sin cache.
262
        inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
263
        t = torch.arange(max_position).float()
264
        freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
265
266
267
268
269
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)

        # FIXME(woosuk): This assumes that we configure the default dtype when
Woosuk Kwon's avatar
Woosuk Kwon committed
270
271
        # initializing the model.
        # TODO(woosuk): Make it more robust.
272
273
        torch_dtype = torch.get_default_dtype()
        cache = cache.to(torch_dtype)
274
        # Embedding size: [max_position, rotary_dim]
275
        self.register_buffer("cos_sin_cache", cache, persistent=False)
276
277
278

    def forward(
        self,
279
280
281
282
283
284
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
285
286
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
287
288
289
290
291
292
    ) -> torch.Tensor:
        """ PagedAttention forward pass with rotary embedding.

        Args:
            positions: shape = [num_tokens]
                        query: shape = [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
293
294
295
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
296
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
297
298
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
299
300
301
302
303
304
305
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """

306
307
308
309
310
311
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
        pos_encoding_ops.rotary_embedding_neox(
            positions,
            query,
            key,
312
            self.head_size,
313
314
315
            self.cos_sin_cache,
        )
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
316
317
            query,
            key,
318
319
320
321
322
323
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
324
325
326
327
328


class PagedAttentionWithALiBi(PagedAttention):
    """PagedAttention with ALiBi attention bias."""

Zhuohan Li's avatar
Zhuohan Li committed
329
330
331
332
333
334
335
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
                 slopes: List[float],
                 num_kv_heads: Optional[int] = None) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads)
Woosuk Kwon's avatar
Woosuk Kwon committed
336
337
338
339
340
341
342
343
344
345
346
347
        assert len(slopes) == num_heads

        slopes = torch.tensor(slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", slopes, persistent=False)

    def set_attn_bias(self, input_metadata: InputMetadata) -> None:
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        # Generates ALiBi mask for each prompt.
        for prompt_len in input_metadata.prompt_lens:
            bias = torch.arange(prompt_len)
Zhuohan Li's avatar
Zhuohan Li committed
348
349
350
351
352
            # Note(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(prompt_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
Woosuk Kwon's avatar
Woosuk Kwon committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
            bias = bias[None, :] - bias[:, None]
            bias = bias.to(self.alibi_slopes.device)

            # When using custom attention bias, xformers requires the bias to
            # be sliced from a tensor whose length is a multiple of 8.
            padded_len = (prompt_len + 7) // 8 * 8
            bias = torch.empty(
                self.num_heads,
                padded_len,
                padded_len,
                device=self.alibi_slopes.device,
            )[:, :prompt_len, :prompt_len].copy_(bias)
            bias.mul_(self.alibi_slopes[:, None, None])
            attn_bias = LowerTriangularMaskWithTensorBias(bias)
            input_metadata.attn_bias.append(attn_bias)

    def multi_query_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        """Attention with ALiBi bias for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
382
383
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
384
385
            input_metadata: metadata for paged attention.
        """
Zhuohan Li's avatar
Zhuohan Li committed
386
387
388
389
390
391
392
        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)

Woosuk Kwon's avatar
Woosuk Kwon committed
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
        start = 0
        for i, prompt_len in enumerate(input_metadata.prompt_lens):
            end = start + prompt_len
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
                attn_bias=input_metadata.attn_bias[i],
                p=0.0,
                scale=self.scale,
                op=self.attn_op,
            )
            # TODO(woosuk): Unnecessary copy. Optimize.
            output[start:end].copy_(out.squeeze(0))
            start += prompt_len
        return output

    def single_query_cached_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> None:
        """PagedAttention with ALiBi bias for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
426
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
Woosuk Kwon's avatar
Woosuk Kwon committed
427
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
428
429
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
430
431
432
433
434
435
436
437
            input_metadata: metadata for paged attention.
        """
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
438
            self.head_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
439
440
441
442
443
444
445
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            self.alibi_slopes,
        )