attention.py 15.7 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
13
from vllm import attention_ops
from vllm import cache_ops
from vllm import pos_encoding_ops
from vllm.model_executor.input_metadata import InputMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
16

17

Woosuk Kwon's avatar
Woosuk Kwon committed
18
class PagedAttention(nn.Module):
19
    # pylint: disable=line-too-long
Woosuk Kwon's avatar
Woosuk Kwon committed
20
    """GPT-style multi-head PagedAttention.
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

    This class takes flattened 1D query, key, and value tensors as input. The
    input 1D tensors can be split into three parts: the prompt tokens, the
    generation tokens, and the paddings.

    |<------------------------------------- num_valid_tokens ------------------------------------->|
    |<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|

    The prompts might have different lengths, while the generation tokens always
    have length 1. The paddings are appended to make the input length a multiple
    of 8, which is desirable for Tensor Cores.

    The class does the following:
    1. Perform multi_query_kv_attention for the prompts. This operation does
        not use the KV cache.
    2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
    3. Reshape and store the input key and value tensors in the KV cache.
    4. Perform single_query_cached_kv_attention for the generation tokens.
        This operation reads the previous key and value tensors from the KV
        cache.
    5. Output a flattened 1D tensor.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
46

Zhuohan Li's avatar
Zhuohan Li committed
47
48
49
50
51
    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
                 num_kv_heads: Optional[int] = None) -> None:
52
        super().__init__()
53
54
        self.num_heads = num_heads
        self.head_size = head_size
55
        self.scale = float(scale)
56
        self.attn_op = xops.fmha.cutlass.FwOp()
Zhuohan Li's avatar
Zhuohan Li committed
57
58
59
60
61
62
63
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

        assert self.num_heads % self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
        self.head_mapping = torch.repeat_interleave(
            torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
            self.num_queries_per_kv)
Woosuk Kwon's avatar
Woosuk Kwon committed
64

65
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
68

Woosuk Kwon's avatar
Woosuk Kwon committed
69
70
71
72
73
74
75
76
    def set_attn_bias(self, input_metadata: InputMetadata) -> None:
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        prompt_lens = input_metadata.prompt_lens
        attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
        input_metadata.attn_bias.append(attn_bias)

Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
    def multi_query_kv_attention(
        self,
79
80
81
82
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
83
        input_metadata: InputMetadata,
84
    ) -> torch.Tensor:
85
86
87
88
89
        """Normal attention for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
90
91
            key: shape = [num_prompt_tokens, num_kv_heads, head_size]
            value: shape = [num_prompt_tokens, num_kv_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
92
            input_metadata: metadata for paged attention.
93
        """
Zhuohan Li's avatar
Zhuohan Li committed
94
95
96
97
98
99
100
101

        if self.num_kv_heads != self.num_heads:
            # Project the key and value tensors to the desired number of heads.
            key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
            value = torch.repeat_interleave(value,
                                            self.num_queries_per_kv,
                                            dim=1)

102
103
104
105
106
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
Woosuk Kwon's avatar
Woosuk Kwon committed
107
            attn_bias=input_metadata.attn_bias[0],
108
109
110
            p=0.0,
            scale=self.scale,
            op=self.attn_op,
Woosuk Kwon's avatar
Woosuk Kwon committed
111
        )
112
113
114
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.squeeze(0))
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
115
116
117

    def single_query_cached_kv_attention(
        self,
118
119
120
121
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
122
123
        input_metadata: InputMetadata,
    ) -> None:
124
125
126
127
128
        """PagedAttention for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
Zhuohan Li's avatar
Zhuohan Li committed
129
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
130
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
131
132
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
133
134
            input_metadata: metadata for paged attention.
        """
135
136
137
138
139
140
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
Zhuohan Li's avatar
Zhuohan Li committed
141
            self.head_mapping,
142
143
144
145
146
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
147
            None,  # alibi_slopes
148
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
151

    def forward(
        self,
152
153
154
155
156
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
157
158
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
159
160
161
162
163
164
165
166
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        NOTE: The query, key, and value tensors must be sliced from a qkv
        tensor of shape [num_tokens, 3 * num_heads * head_size].

        Args:
            query: shape = [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
167
168
169
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
170
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
171
172
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
173
174
175
176
177
178
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
179

Woosuk Kwon's avatar
Woosuk Kwon committed
180
        # Reshape the query, key, and value tensors.
181
        query = query.view(-1, self.num_heads, self.head_size)
Zhuohan Li's avatar
Zhuohan Li committed
182
183
        key = key.view(-1, self.num_kv_heads, self.head_size)
        value = value.view(-1, self.num_kv_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
184
185
186

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
187
188

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
189
190
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
191
            self.set_attn_bias(input_metadata)
192
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
193
194
195
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
196
                value[:num_prompt_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
197
                input_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
198
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
199
200
201
202
203
204

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
205
206
        # When key_cache and value_cache are not provided, the new key
        # and value vectors will not be cached.
Woosuk Kwon's avatar
Woosuk Kwon committed
207
        num_valid_tokens = input_metadata.num_valid_tokens
208
        if (num_valid_tokens > 0 and key_cache is not None
209
                and value_cache is not None):
Woosuk Kwon's avatar
Woosuk Kwon committed
210
211
212
213
214
215
216
217
            # The stride is 3 because the key and value are sliced from qkv.
            cache_ops.reshape_and_cache(
                key[:num_valid_tokens],
                value[:num_valid_tokens],
                key_cache,
                value_cache,
                input_metadata.slot_mapping,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219

        if input_metadata.num_generation_tokens > 0:
220
221
            assert key_cache is not None and value_cache is not None, (
                "key_cache and value_cache must be provided when "
222
                "generating tokens.")
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
225
                output[num_prompt_tokens:num_valid_tokens],
226
227
                query[num_prompt_tokens:num_valid_tokens], key_cache,
                value_cache, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
228
229

        # Reshape the output tensor.
230
        # NOTE(woosuk): The output tensor may include paddings.
231
        return output.view(-1, self.num_heads * self.head_size)
232
233


Woosuk Kwon's avatar
Woosuk Kwon committed
234
235
class PagedAttentionWithRoPE(PagedAttention):
    """PagedAttention with GPT-NeoX style rotary embedding."""
236
237
238

    def __init__(
        self,
239
240
        num_heads: int,
        head_size: int,
241
        scale: float,
242
        rotary_dim: int,
243
244
        max_position: int = 8192,
        base: int = 10000,
Zhuohan Li's avatar
Zhuohan Li committed
245
        num_kv_heads: Optional[int] = None,
246
    ) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
247
        super().__init__(num_heads, head_size, scale, num_kv_heads)
248
249

        # Create the cos and sin cache.
250
        inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
251
        t = torch.arange(max_position).float()
252
        freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
253
254
255
256
257
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)

        # FIXME(woosuk): This assumes that we configure the default dtype when
Woosuk Kwon's avatar
Woosuk Kwon committed
258
259
        # initializing the model.
        # TODO(woosuk): Make it more robust.
260
261
        torch_dtype = torch.get_default_dtype()
        cache = cache.to(torch_dtype)
262
        # Embedding size: [max_position, rotary_dim]
263
        self.register_buffer("cos_sin_cache", cache, persistent=False)
264
265
266

    def forward(
        self,
267
268
269
270
271
272
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
273
274
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
275
276
277
278
279
280
    ) -> torch.Tensor:
        """ PagedAttention forward pass with rotary embedding.

        Args:
            positions: shape = [num_tokens]
                        query: shape = [num_tokens, num_heads * head_size]
Zhuohan Li's avatar
Zhuohan Li committed
281
282
283
            key: shape = [num_tokens, num_kv_heads * head_size]
            value: shape = [num_tokens, num_kv_heads * head_size]
            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
284
                block_size, x]
Zhuohan Li's avatar
Zhuohan Li committed
285
286
            value_cache: shape = [num_blocks, num_kv_heads, head_size,
                block_size]
287
288
289
290
291
292
293
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """

294
295
296
297
298
299
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
        pos_encoding_ops.rotary_embedding_neox(
            positions,
            query,
            key,
300
            self.head_size,
301
302
303
            self.cos_sin_cache,
        )
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
304
305
            query,
            key,
306
307
308
309
310
311
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413


class PagedAttentionWithALiBi(PagedAttention):
    """PagedAttention with ALiBi attention bias."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        slopes: List[float],
    ) -> None:
        super().__init__(num_heads, head_size, scale)
        assert len(slopes) == num_heads

        slopes = torch.tensor(slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", slopes, persistent=False)

    def set_attn_bias(self, input_metadata: InputMetadata) -> None:
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        # Generates ALiBi mask for each prompt.
        for prompt_len in input_metadata.prompt_lens:
            bias = torch.arange(prompt_len)
            bias = bias[None, :] - bias[:, None]
            bias = bias.to(self.alibi_slopes.device)

            # When using custom attention bias, xformers requires the bias to
            # be sliced from a tensor whose length is a multiple of 8.
            padded_len = (prompt_len + 7) // 8 * 8
            bias = torch.empty(
                self.num_heads,
                padded_len,
                padded_len,
                device=self.alibi_slopes.device,
            )[:, :prompt_len, :prompt_len].copy_(bias)
            bias.mul_(self.alibi_slopes[:, None, None])
            attn_bias = LowerTriangularMaskWithTensorBias(bias)
            input_metadata.attn_bias.append(attn_bias)

    def multi_query_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        """Attention with ALiBi bias for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
            key: shape = [num_prompt_tokens, num_heads, head_size]
            value: shape = [num_prompt_tokens, num_heads, head_size]
            input_metadata: metadata for paged attention.
        """
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
        start = 0
        for i, prompt_len in enumerate(input_metadata.prompt_lens):
            end = start + prompt_len
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
                attn_bias=input_metadata.attn_bias[i],
                p=0.0,
                scale=self.scale,
                op=self.attn_op,
            )
            # TODO(woosuk): Unnecessary copy. Optimize.
            output[start:end].copy_(out.squeeze(0))
            start += prompt_len
        return output

    def single_query_cached_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> None:
        """PagedAttention with ALiBi bias for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
            key_cache: shape = [num_blocks, num_heads, head_size/x,
                block_size, x]
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
            input_metadata: metadata for paged attention.
        """
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
414
            self.head_mapping,
Woosuk Kwon's avatar
Woosuk Kwon committed
415
416
417
418
419
420
421
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            self.alibi_slopes,
        )