attention.py 14.7 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import List, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                         LowerTriangularMaskWithTensorBias)
Woosuk Kwon's avatar
Woosuk Kwon committed
9

Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12
13
from vllm import attention_ops
from vllm import cache_ops
from vllm import pos_encoding_ops
from vllm.model_executor.input_metadata import InputMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128]
16

17

Woosuk Kwon's avatar
Woosuk Kwon committed
18
class PagedAttention(nn.Module):
19
    # pylint: disable=line-too-long
Woosuk Kwon's avatar
Woosuk Kwon committed
20
    """GPT-style multi-head PagedAttention.
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

    This class takes flattened 1D query, key, and value tensors as input. The
    input 1D tensors can be split into three parts: the prompt tokens, the
    generation tokens, and the paddings.

    |<------------------------------------- num_valid_tokens ------------------------------------->|
    |<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|

    The prompts might have different lengths, while the generation tokens always
    have length 1. The paddings are appended to make the input length a multiple
    of 8, which is desirable for Tensor Cores.

    The class does the following:
    1. Perform multi_query_kv_attention for the prompts. This operation does
        not use the KV cache.
    2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
    3. Reshape and store the input key and value tensors in the KV cache.
    4. Perform single_query_cached_kv_attention for the generation tokens.
        This operation reads the previous key and value tensors from the KV
        cache.
    5. Output a flattened 1D tensor.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
46

47
    def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
48
        super().__init__()
49
50
        self.num_heads = num_heads
        self.head_size = head_size
51
        self.scale = float(scale)
52
        self.attn_op = xops.fmha.cutlass.FwOp()
Woosuk Kwon's avatar
Woosuk Kwon committed
53

54
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
57

Woosuk Kwon's avatar
Woosuk Kwon committed
58
59
60
61
62
63
64
65
    def set_attn_bias(self, input_metadata: InputMetadata) -> None:
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        prompt_lens = input_metadata.prompt_lens
        attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
        input_metadata.attn_bias.append(attn_bias)

Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
    def multi_query_kv_attention(
        self,
68
69
70
71
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        input_metadata: InputMetadata,
73
    ) -> torch.Tensor:
74
75
76
77
78
79
80
        """Normal attention for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
            key: shape = [num_prompt_tokens, num_heads, head_size]
            value: shape = [num_prompt_tokens, num_heads, head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
81
            input_metadata: metadata for paged attention.
82
        """
83
84
85
86
87
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
Woosuk Kwon's avatar
Woosuk Kwon committed
88
            attn_bias=input_metadata.attn_bias[0],
89
90
91
            p=0.0,
            scale=self.scale,
            op=self.attn_op,
Woosuk Kwon's avatar
Woosuk Kwon committed
92
        )
93
94
95
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.squeeze(0))
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98

    def single_query_cached_kv_attention(
        self,
99
100
101
102
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
        input_metadata: InputMetadata,
    ) -> None:
105
106
107
108
109
110
111
112
113
114
        """PagedAttention for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
            key_cache: shape = [num_blocks, num_heads, head_size/x,
                block_size, x]
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
            input_metadata: metadata for paged attention.
        """
115
116
117
118
119
120
121
122
123
124
125
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
126
            None,  # alibi_slopes
127
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
128
129
130

    def forward(
        self,
131
132
133
134
135
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        NOTE: The query, key, and value tensors must be sliced from a qkv
        tensor of shape [num_tokens, 3 * num_heads * head_size].

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_heads * head_size]
            value: shape = [num_tokens, num_heads * head_size]
            key_cache: shape = [num_blocks, num_heads, head_size/x,
                block_size, x]
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
157

Woosuk Kwon's avatar
Woosuk Kwon committed
158
        # Reshape the query, key, and value tensors.
159
160
161
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_heads, self.head_size)
        value = value.view(-1, self.num_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
162
163
164

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
165
166

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
167
168
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
Woosuk Kwon's avatar
Woosuk Kwon committed
169
            self.set_attn_bias(input_metadata)
170
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
171
172
173
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
174
                value[:num_prompt_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
175
                input_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
176
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
177
178
179
180
181
182

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
183
184
        # When key_cache and value_cache are not provided, the new key
        # and value vectors will not be cached.
Woosuk Kwon's avatar
Woosuk Kwon committed
185
        num_valid_tokens = input_metadata.num_valid_tokens
186
        if (num_valid_tokens > 0 and key_cache is not None
187
                and value_cache is not None):
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190
191
192
193
194
195
            # The stride is 3 because the key and value are sliced from qkv.
            cache_ops.reshape_and_cache(
                key[:num_valid_tokens],
                value[:num_valid_tokens],
                key_cache,
                value_cache,
                input_metadata.slot_mapping,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
196
197

        if input_metadata.num_generation_tokens > 0:
198
199
            assert key_cache is not None and value_cache is not None, (
                "key_cache and value_cache must be provided when "
200
                "generating tokens.")
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
203
                output[num_prompt_tokens:num_valid_tokens],
204
205
                query[num_prompt_tokens:num_valid_tokens], key_cache,
                value_cache, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
206
207

        # Reshape the output tensor.
208
        # NOTE(woosuk): The output tensor may include paddings.
209
        return output.view(-1, self.num_heads * self.head_size)
210
211


Woosuk Kwon's avatar
Woosuk Kwon committed
212
213
class PagedAttentionWithRoPE(PagedAttention):
    """PagedAttention with GPT-NeoX style rotary embedding."""
214
215
216

    def __init__(
        self,
217
218
        num_heads: int,
        head_size: int,
219
        scale: float,
220
        rotary_dim: int,
221
222
223
        max_position: int = 8192,
        base: int = 10000,
    ) -> None:
224
        super().__init__(num_heads, head_size, scale)
225
226

        # Create the cos and sin cache.
227
        inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
228
        t = torch.arange(max_position).float()
229
        freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
230
231
232
233
234
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)

        # FIXME(woosuk): This assumes that we configure the default dtype when
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
        # initializing the model.
        # TODO(woosuk): Make it more robust.
237
238
        torch_dtype = torch.get_default_dtype()
        cache = cache.to(torch_dtype)
239
        # Embedding size: [max_position, rotary_dim]
240
        self.register_buffer("cos_sin_cache", cache, persistent=False)
241
242
243

    def forward(
        self,
244
245
246
247
248
249
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
250
251
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    ) -> torch.Tensor:
        """ PagedAttention forward pass with rotary embedding.

        Args:
            positions: shape = [num_tokens]
                        query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_heads * head_size]
            value: shape = [num_tokens, num_heads * head_size]
            key_cache: shape = [num_blocks, num_heads, head_size/x,
                block_size, x]
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """

270
271
272
273
274
275
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
        pos_encoding_ops.rotary_embedding_neox(
            positions,
            query,
            key,
276
            self.head_size,
277
278
279
            self.cos_sin_cache,
        )
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
280
281
            query,
            key,
282
283
284
285
286
287
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396


class PagedAttentionWithALiBi(PagedAttention):
    """PagedAttention with ALiBi attention bias."""

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        slopes: List[float],
    ) -> None:
        super().__init__(num_heads, head_size, scale)
        assert len(slopes) == num_heads

        slopes = torch.tensor(slopes, dtype=torch.float32)
        self.register_buffer("alibi_slopes", slopes, persistent=False)

    def set_attn_bias(self, input_metadata: InputMetadata) -> None:
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        # Generates ALiBi mask for each prompt.
        for prompt_len in input_metadata.prompt_lens:
            bias = torch.arange(prompt_len)
            bias = bias[None, :] - bias[:, None]
            bias = bias.to(self.alibi_slopes.device)

            # When using custom attention bias, xformers requires the bias to
            # be sliced from a tensor whose length is a multiple of 8.
            padded_len = (prompt_len + 7) // 8 * 8
            bias = torch.empty(
                self.num_heads,
                padded_len,
                padded_len,
                device=self.alibi_slopes.device,
            )[:, :prompt_len, :prompt_len].copy_(bias)
            bias.mul_(self.alibi_slopes[:, None, None])
            attn_bias = LowerTriangularMaskWithTensorBias(bias)
            input_metadata.attn_bias.append(attn_bias)

    def multi_query_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        """Attention with ALiBi bias for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
            key: shape = [num_prompt_tokens, num_heads, head_size]
            value: shape = [num_prompt_tokens, num_heads, head_size]
            input_metadata: metadata for paged attention.
        """
        # FIXME(woosuk): Because xformers does not support dynamic sequence
        # lengths with custom attention bias, we process each prompt one by
        # one. This is inefficient, especially when we have many short prompts.
        start = 0
        for i, prompt_len in enumerate(input_metadata.prompt_lens):
            end = start + prompt_len
            out = xops.memory_efficient_attention_forward(
                query[None, start:end],
                key[None, start:end],
                value[None, start:end],
                attn_bias=input_metadata.attn_bias[i],
                p=0.0,
                scale=self.scale,
                op=self.attn_op,
            )
            # TODO(woosuk): Unnecessary copy. Optimize.
            output[start:end].copy_(out.squeeze(0))
            start += prompt_len
        return output

    def single_query_cached_kv_attention(
        self,
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
    ) -> None:
        """PagedAttention with ALiBi bias for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
            key_cache: shape = [num_blocks, num_heads, head_size/x,
                block_size, x]
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
            input_metadata: metadata for paged attention.
        """
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
            self.alibi_slopes,
        )