attention.py 9.27 KB
Newer Older
1
"""Multi-head attention."""
Woosuk Kwon's avatar
Woosuk Kwon committed
2
from typing import Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5

import torch
import torch.nn as nn
6
from xformers import ops as xops
Woosuk Kwon's avatar
Woosuk Kwon committed
7

8
9
from cacheflow import attention_ops
from cacheflow import cache_ops
10
from cacheflow import pos_encoding_ops
11
from cacheflow.model_executor.input_metadata import InputMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
14
_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]

15

16
class GPTCacheFlowAttention(nn.Module):
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    """GPT-style multi-head attention.

    This class takes flattened 1D query, key, and value tensors as input. The
    input 1D tensors can be split into three parts: the prompt tokens, the
    generation tokens, and the paddings.

    |<------------------------------------- num_valid_tokens ------------------------------------->|
    |<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|

    The prompts might have different lengths, while the generation tokens always
    have length 1. The paddings are appended to make the input length a multiple
    of 8, which is desirable for Tensor Cores.

    The class does the following:
    1. Perform multi_query_kv_attention for the prompts. This operation does
        not use the KV cache.
    2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
        operations are issued by the cache engine before executing the forward
        pass of the model, and they are executed asynchronously.
    3. Reshape and store the input key and value tensors in the KV cache.
    4. Perform single_query_cached_kv_attention for the generation tokens.
        This operation reads the previous key and value tensors from the KV
        cache.
    5. Output a flattened 1D tensor.
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
43

44
    def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
45
        super().__init__()
46
47
        self.num_heads = num_heads
        self.head_size = head_size
48
        self.scale = float(scale)
49
        self.attn_op = xops.fmha.cutlass.FwOp()
Woosuk Kwon's avatar
Woosuk Kwon committed
50

51
52
53
54
55
56
        if self.head_size not in _SUPPORTED_HEAD_SIZES:
            raise ValueError(f'head_size ({self.head_size}) is not supported by '
                             'the single_query_cached_kv_attention kernel. '
                             'Use one of the following head sizes: '
                             f'{_SUPPORTED_HEAD_SIZES}.')

Woosuk Kwon's avatar
Woosuk Kwon committed
57
58
    def multi_query_kv_attention(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
62
        output: torch.Tensor,                   # [num_prompt_tokens, num_heads, head_size]
        query: torch.Tensor,                    # [num_prompt_tokens, num_heads, head_size]
        key: torch.Tensor,                      # [num_prompt_tokens, num_heads, head_size]
        value: torch.Tensor,                    # [num_prompt_tokens, num_heads, head_size]
63
        attn_bias: xops.AttentionBias,
64
    ) -> torch.Tensor:
65
66
67
68
69
70
71
72
73
        # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
        out = xops.memory_efficient_attention_forward(
            query.unsqueeze(0),
            key.unsqueeze(0),
            value.unsqueeze(0),
            attn_bias=attn_bias,
            p=0.0,
            scale=self.scale,
            op=self.attn_op,
Woosuk Kwon's avatar
Woosuk Kwon committed
74
        )
75
76
77
        # TODO(woosuk): Unnecessary copy. Optimize.
        output.copy_(out.squeeze(0))
        return output
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80

    def single_query_cached_kv_attention(
        self,
81
82
83
        output: torch.Tensor,           # [num_generation_tokens, num_heads, head_size]
        query: torch.Tensor,            # [num_generation_tokens, num_heads, head_size]
        key_cache: torch.Tensor,        # [num_blocks, num_heads, head_size/x, block_size, x]
84
        value_cache: torch.Tensor,      # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
        input_metadata: InputMetadata,
    ) -> None:
87
88
89
90
91
92
93
94
95
96
97
98
        block_size = value_cache.shape[3]
        attention_ops.single_query_cached_kv_attention(
            output,
            query,
            key_cache,
            value_cache,
            self.scale,
            input_metadata.block_tables,
            input_metadata.context_lens,
            block_size,
            input_metadata.max_context_len,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
99
100
101

    def forward(
        self,
102
103
104
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
105
106
        key_cache: Optional[torch.Tensor],      # [num_blocks, num_heads, head_size/x, block_size, x]
        value_cache: Optional[torch.Tensor],    # [num_blocks, num_heads, head_size, block_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
107
108
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
109
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
        # NOTE: The query, key, and value tensors must be sliced from a qkv
        # tensor of shape [num_tokens, 3 * num_heads * head_size].
Woosuk Kwon's avatar
Woosuk Kwon committed
112

Woosuk Kwon's avatar
Woosuk Kwon committed
113
        # Reshape the query, key, and value tensors.
114
115
116
        query = query.view(-1, self.num_heads, self.head_size)
        key = key.view(-1, self.num_heads, self.head_size)
        value = value.view(-1, self.num_heads, self.head_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
117
118
119

        # Pre-allocate the output tensor.
        output = torch.empty_like(query)
Woosuk Kwon's avatar
Woosuk Kwon committed
120
121

        # Compute the attention op for prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
122
123
        num_prompt_tokens = input_metadata.num_prompt_tokens
        if num_prompt_tokens > 0:
124
            self.multi_query_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
127
                output[:num_prompt_tokens],
                query[:num_prompt_tokens],
                key[:num_prompt_tokens],
Zhuohan Li's avatar
Zhuohan Li committed
128
                value[:num_prompt_tokens],
129
                input_metadata.attn_bias,
Woosuk Kwon's avatar
Woosuk Kwon committed
130
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133
134
135
136

        # Wait until the cache op is done.
        if cache_event is not None:
            cache_event.wait()

        # Reshape the keys and values and store them in the cache.
137
138
        # When key_cache and value_cache are not provided, the new key
        # and value vectors will not be cached.
Woosuk Kwon's avatar
Woosuk Kwon committed
139
        num_valid_tokens = input_metadata.num_valid_tokens
140
141
        if (num_valid_tokens > 0 and key_cache is not None
            and value_cache is not None):
Woosuk Kwon's avatar
Woosuk Kwon committed
142
143
144
145
146
147
148
149
            # The stride is 3 because the key and value are sliced from qkv.
            cache_ops.reshape_and_cache(
                key[:num_valid_tokens],
                value[:num_valid_tokens],
                key_cache,
                value_cache,
                input_metadata.slot_mapping,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
150
151

        if input_metadata.num_generation_tokens > 0:
152
153
154
155
            assert key_cache is not None and value_cache is not None, (
                "key_cache and value_cache must be provided when "
                "generating tokens."
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
156
157
            # Compute the attention op for generation tokens.
            self.single_query_cached_kv_attention(
Woosuk Kwon's avatar
Woosuk Kwon committed
158
159
                output[num_prompt_tokens:num_valid_tokens],
                query[num_prompt_tokens:num_valid_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
160
161
162
163
164
                key_cache,
                value_cache,
                input_metadata)

        # Reshape the output tensor.
165
        # NOTE(woosuk): The output tensor may include paddings.
166
        return output.view(-1, self.num_heads * self.head_size)
167
168


169
170
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
    """Attention with GPT-NeoX style rotary embedding."""
171
172
173

    def __init__(
        self,
174
175
        num_heads: int,
        head_size: int,
176
        scale: float,
177
        rotary_dim: int,
178
179
180
        max_position: int = 8192,
        base: int = 10000,
    ) -> None:
181
        super().__init__(num_heads, head_size, scale)
182
183

        # Create the cos and sin cache.
184
        inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
185
186
187
188
189
190
191
192
193
194
        t = torch.arange(max_position).float()
        freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)

        # FIXME(woosuk): This assumes that we configure the default dtype when
        # initializing the model. Make it more robust.
        torch_dtype = torch.get_default_dtype()
        cache = cache.to(torch_dtype)
195
        # Embedding size: [max_position, rotary_dim]
196
        self.register_buffer("cos_sin_cache", cache, persistent=False)
197
198
199

    def forward(
        self,
200
        positions: torch.Tensor,                # [num_tokens]
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        query: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key: torch.Tensor,                      # [num_tokens, num_heads * head_size]
        value: torch.Tensor,                    # [num_tokens, num_heads * head_size]
        key_cache: torch.Tensor,                # [num_blocks, num_heads, head_size/x, block_size, x]
        value_cache: torch.Tensor,              # [num_blocks, num_heads, head_size, block_size]
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:                          # [num_tokens, num_heads * head_size]
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
        pos_encoding_ops.rotary_embedding_neox(
            positions,
            query,
            key,
215
            self.head_size,
216
217
218
            self.cos_sin_cache,
        )
        return super().forward(
Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
            query,
            key,
221
222
223
224
225
226
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )