attention.py 10 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
10
11
from vllm import attention_ops
from vllm import cache_ops
from vllm import pos_encoding_ops
from vllm.model_executor.input_metadata import InputMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
12

Woosuk Kwon's avatar
Woosuk Kwon committed
13
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
14

15

Woosuk Kwon's avatar
Woosuk Kwon committed
16
class PagedAttention(nn.Module):
17
    # pylint: disable=line-too-long
Woosuk Kwon's avatar
Woosuk Kwon committed
18
    """GPT-style multi-head PagedAttention.
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

    This class takes flattened 1D query, key, and value tensors as input. The
    input 1D tensors can be split into three parts: the prompt tokens, the
    generation tokens, and the paddings.

    |<------------------------------------- num_valid_tokens ------------------------------------->|
    |<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|

    The prompts might have different lengths, while the generation tokens always
    have length 1. The paddings are appended to make the input length a multiple
    of 8, which is desirable for Tensor Cores.

    The class does the following:
    1. Perform multi_query_kv_attention for the prompts. This operation does
        not use the KV cache.
    2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
    3. Reshape and store the input key and value tensors in the KV cache.
    4. Perform single_query_cached_kv_attention for the generation tokens.
        This operation reads the previous key and value tensors from the KV
        cache.
    5. Output a flattened 1D tensor.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
44

45
    def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
46
        super().__init__()
47
48
        self.num_heads = num_heads
        self.head_size = head_size
49
        self.scale = float(scale)
50
        self.attn_op = xops.fmha.cutlass.FwOp()
Woosuk Kwon's avatar
Woosuk Kwon committed
51

52
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
            raise ValueError(f"head_size ({self.head_size}) is not supported. "
                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
55

Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
    def multi_query_kv_attention(
        self,
58
59
60
61
        output: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
62
        attn_bias: xops.AttentionBias,
63
    ) -> torch.Tensor:
64
65
66
67
68
69
70
71
        """Normal attention for the prompt tokens.

        Args:
            output: shape = [num_prompt_tokens, num_heads, head_size]
            query: shape = [num_prompt_tokens, num_heads, head_size]
            key: shape = [num_prompt_tokens, num_heads, head_size]
            value: shape = [num_prompt_tokens, num_heads, head_size]
        """
72
73
74
75
76
77
78
79
80
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
            attn_bias=attn_bias,
            p=0.0,
            scale=self.scale,
            op=self.attn_op,
Woosuk Kwon's avatar
Woosuk Kwon committed
81
        )
82
83
84
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.squeeze(0))
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87

    def single_query_cached_kv_attention(
        self,
88
89
90
91
        output: torch.Tensor,
        query: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
        input_metadata: InputMetadata,
    ) -> None:
94
95
96
97
98
99
100
101
102
103
        """PagedAttention for the generation tokens.

        Args:
            output: shape = [num_generation_tokens, num_heads, head_size]
            query: shape = [num_generation_tokens, num_heads, head_size]
            key_cache: shape = [num_blocks, num_heads, head_size/x,
                block_size, x]
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
            input_metadata: metadata for paged attention.
        """
104
105
106
107
108
109
110
111
112
113
114
115
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
116
117
118

    def forward(
        self,
119
120
121
122
123
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: Optional[torch.Tensor],
        value_cache: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
124
125
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    ) -> torch.Tensor:
        """PagedAttention forward pass.

        NOTE: The query, key, and value tensors must be sliced from a qkv
        tensor of shape [num_tokens, 3 * num_heads * head_size].

        Args:
            query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_heads * head_size]
            value: shape = [num_tokens, num_heads * head_size]
            key_cache: shape = [num_blocks, num_heads, head_size/x,
                block_size, x]
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """
Woosuk Kwon's avatar
Woosuk Kwon committed
145

Woosuk Kwon's avatar
Woosuk Kwon committed
146
        # Reshape the query, key, and value tensors.
147
148
149
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_heads, self.head_size)
        value = value.view(-1, self.num_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
150
151
152

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
153
154

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
157
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
158
159
160
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
161
                value[:num_prompt_tokens],
162
                input_metadata.attn_bias,
Woosuk Kwon's avatar
Woosuk Kwon committed
163
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
166
167
168
169

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
170
171
        # When key_cache and value_cache are not provided, the new key
        # and value vectors will not be cached.
Woosuk Kwon's avatar
Woosuk Kwon committed
172
        num_valid_tokens = input_metadata.num_valid_tokens
173
        if (num_valid_tokens > 0 and key_cache is not None
174
                and value_cache is not None):
Woosuk Kwon's avatar
Woosuk Kwon committed
175
176
177
178
179
180
181
182
            # The stride is 3 because the key and value are sliced from qkv.
            cache_ops.reshape_and_cache(
                key[:num_valid_tokens],
                value[:num_valid_tokens],
                key_cache,
                value_cache,
                input_metadata.slot_mapping,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
183
184

        if input_metadata.num_generation_tokens > 0:
185
186
            assert key_cache is not None and value_cache is not None, (
                "key_cache and value_cache must be provided when "
187
                "generating tokens.")
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
190
                output[num_prompt_tokens:num_valid_tokens],
191
192
                query[num_prompt_tokens:num_valid_tokens], key_cache,
                value_cache, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
193
194

        # Reshape the output tensor.
195
        # NOTE(woosuk): The output tensor may include paddings.
196
        return output.view(-1, self.num_heads * self.head_size)
197
198


Woosuk Kwon's avatar
Woosuk Kwon committed
199
200
class PagedAttentionWithRoPE(PagedAttention):
    """PagedAttention with GPT-NeoX style rotary embedding."""
201
202
203

    def __init__(
        self,
204
205
        num_heads: int,
        head_size: int,
206
        scale: float,
207
        rotary_dim: int,
208
209
210
        max_position: int = 8192,
        base: int = 10000,
    ) -> None:
211
        super().__init__(num_heads, head_size, scale)
212
213

        # Create the cos and sin cache.
214
        inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
215
        t = torch.arange(max_position).float()
216
        freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
217
218
219
220
221
222
223
224
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)

        # FIXME(woosuk): This assumes that we configure the default dtype when
        # initializing the model. Make it more robust.
        torch_dtype = torch.get_default_dtype()
        cache = cache.to(torch_dtype)
225
        # Embedding size: [max_position, rotary_dim]
226
        self.register_buffer("cos_sin_cache", cache, persistent=False)
227
228
229

    def forward(
        self,
230
231
232
233
234
235
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
236
237
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    ) -> torch.Tensor:
        """ PagedAttention forward pass with rotary embedding.

        Args:
            positions: shape = [num_tokens]
                        query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_heads * head_size]
            value: shape = [num_tokens, num_heads * head_size]
            key_cache: shape = [num_blocks, num_heads, head_size/x,
                block_size, x]
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """

256
257
258
259
260
261
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
        pos_encoding_ops.rotary_embedding_neox(
            positions,
            query,
            key,
262
            self.head_size,
263
264
265
            self.cos_sin_cache,
        )
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
266
267
            query,
            key,
268
269
270
271
272
273
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )